diff --git a/test_database.py b/test_database.py index e1cad83..0cb2b4e 100644 --- a/test_database.py +++ b/test_database.py @@ -2,46 +2,8 @@ # To run the tests, use: python3 -m pytest --capture=sys -from collections import Counter, namedtuple from database import Database -import os - -SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) -MAX_ID = 493 - -region_info = namedtuple('region_info', 'start end') -region_dict = { - 'kanto': region_info(1, 151), - 'johto': region_info(152, 251), - 'hoenn': region_info(252, 386), - 'sinnoh': region_info(387, 493), -} - -extra_counts = None # Set below after make_extra_counts() is defined - - -def _pokemon_id_to_region(pokemon_id): - """Rewrite of Database.__determine_region() avoids sharing implementations - between production code and test code.""" - for region_name, region_info in region_dict.items(): - if region_info.start <= pokemon_id <= region_info.end: - return region_name - assert False, '{} is an invalid region'.format(pokemon_id) - - -def make_extra_counts(filename='pokemon.txt'): - """Test that correct regions are used in load_all_pokemon.load_extras(). - Currently generates the dict: {'sinnoh': 14, 'hoenn': 9, 'johto': 1}""" - with open(os.path.join(SCRIPT_DIR, 'Data', filename)) as in_file: - pokemon_names = tuple([line.split()[0] for line in in_file]) - filenames = os.listdir(os.path.join(SCRIPT_DIR, 'Images', 'Extra')) - father_names = (filename.split('-')[0] for filename in filenames) - father_ids = (pokemon_names.index(name) for name in father_names) - father_regions = (_pokemon_id_to_region(id) for id in father_ids) - return dict(Counter(father_regions)) - - -extra_counts = make_extra_counts() +from test_utils import region_dict, get_region, make_extra_counts, MAX_ID def test_first_database(): @@ -58,54 +20,24 @@ def test_len(): def test_extra_counts(): - assert len(Database()) == MAX_ID + sum(extra_counts.values()) + assert len(Database()) == MAX_ID + sum(make_extra_counts().values()) def test_get_extras(): db = Database() assert db.get_extra(), 'db.get_extra() returns no pokemon' - assert db.get_extra() == sum(extra_counts.values()) - - -def test_region_dict(): - # test if region_dict counts match wikipedia - print('From https://en.wikipedia.org/wiki/Pok%C3%A9mon#Generation_1 ...') - counts = { - 'kanto': 151, - 'johto': 100, - 'hoenn': 135, - 'sinnoh': 107, - 'all': 493 - } - region_counts = (counts[r] for r in 'kanto johto hoenn sinnoh'.split()) - assert counts['all'] == sum(region_counts) == MAX_ID - for name, info in region_dict.items(): - assert counts[name] == info.end - info.start + 1 - print('{}: {}'.format(name, counts[name])) - - -def get_region(db, region_name): - """Database unfortunately makes db.__get_region() private :-(""" - func = { - 'kanto': db.get_kanto, - 'johto': db.get_johto, - 'hoenn': db.get_hoenn, - 'sinnoh': db.get_sinnoh, - 'extra': db.get_extra - }[region_name] - return func() + assert len(db.get_extra()) == sum(make_extra_counts().values()) def region_length_test(region_name): db = Database() # test db.get_region() - pokemon = db.get_region(region_name) if tuple_store else get_region( - db, region_name) + pokemon = get_region(db, region_name) assert pokemon, 'No pokemon found in region: ' + region_name # test that region_name is in region_dict region_info = region_dict[region_name] - extra_count = extra_counts.get(region_name, 0) - expected_len = region_info.end - region_info.start + 1 + extra_count + # extra_count = extra_counts.get(region_name, 0) + expected_len = region_info.end - region_info.start + 1 # + extra_count fmt = 'Testing {}({} vs. {}): {}' print(fmt.format(region_name, len(pokemon), expected_len, region_info)) # test the number of pokemon returned by db.get_region() @@ -131,8 +63,7 @@ def test_sinnoh_length(): def region_test(region_name): db = Database() # test db.get_region() - pokemon = db.get_region(region_name) if tuple_store else get_region( - db, region_name) + pokemon = get_region(db, region_name) assert pokemon, 'No pokemon found in region: ' + region_name # test that region_name is in region_dict region_info = region_dict[region_name] @@ -143,7 +74,7 @@ def region_test(region_name): middle_pokemon = db.get_pokemon(region_info.start + (delta // 2)) assert middle_pokemon in pokemon # test db.get_pokemon(name) - name = middle_pokemon.name if tuple_store else middle_pokemon.get_name() + name = middle_pokemon.get_name() assert db.get_pokemon(name) in pokemon # test the case insensivity of db.get_pokemon(name) # assert db.get_pokemon(name.upper()) in pokemon # !!! FixMe !!! @@ -184,8 +115,8 @@ def _test_region(region_name): # make sure there are no missing pokemon start = region_record.start end = region_record.end - extra_count = extra_counts.get(region_name, 0) - assert len(pokemon_list) == end - start + 1 + extra_count + # extra_count = extra_counts.get(region_name, 0) + assert len(pokemon_list) == end - start + 1 # + extra_count # make sure that all pokemon.id == '---' or are in the ID range assert all([start <= int(p.get_id()) <= end for p in pokemon_list if p.get_id() != '---'])