diff --git a/src/nhl_model/poisson.py b/src/nhl_model/poisson.py index 27332b1..8548585 100644 --- a/src/nhl_model/poisson.py +++ b/src/nhl_model/poisson.py @@ -225,6 +225,19 @@ def parseSeasonEvents(year): # pylint: disable=too-many-branches logger.error(f"Failed to find a schedule for previous year {year-1}") return None, None + return getSeasonEventsFromSchedules(schedule, previousSchedule) + + +def getSeasonEventsFromSchedules(schedule, previousSchedule): # pylint: disable=too-many-branches + """Parse the events for the season given the json formatted season and + previous season data. This will include predicting which team will win + each game. + + :return: home team events, away team events + """ + if None in (schedule, previousSchedule): + return None, None + homeTeamEventsPrev, awayTeamEventsPrev = parseSchedule(previousSchedule) totalTeamIdsPrevSeason = set(list(homeTeamEventsPrev.keys())) totalTeamIdsPrevSeason.update(list(awayTeamEventsPrev.keys())) diff --git a/tests/test_poisson.py b/tests/test_poisson.py index c2f6823..bb83bc8 100644 --- a/tests/test_poisson.py +++ b/tests/test_poisson.py @@ -3,18 +3,34 @@ # pylint: disable=redefined-builtin # pylint: disable=deprecated-method from unittest import TestCase -from os.path import dirname, abspath, join +from os.path import dirname, abspath, join, exists from json import loads from collections import defaultdict from shutil import copytree, rmtree +from nhl_core.endpoints import MAX_GAME_NUMBER from nhl_model.poisson import ( parseSchedule, calculateAvgGoals, - calculateScores + calculateScores, + getSeasonEventsFromSchedules, ) BADSEASON = 2004 GOODSEASON = 2005 +TOPSEASON = 2006 + +def mocked_schedules_get(year): + if year is None or year == BADSEASON: + return None + + currDir = dirname(abspath(__file__)) + splitDir = currDir.split("/") + splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", f"{year}", "schedule.json"] + filename = "/" + join(*splitDir) + + with open(filename, "rb") as jsonData: + return loads(jsonData.read()) + class PoissonTests(TestCase): '''Test cases for the Poisson functionality to the module. @@ -25,15 +41,34 @@ def setUpClass(cls): '''Set up the class for testing the dataset''' currDir = dirname(abspath(__file__)) + # Grab and save the data in the GOOD SEASON splitDir = currDir.split("/") - splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", "2005"] + splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", f"{GOODSEASON}"] copyDir = "/" + join(*splitDir) cls.destDir = "/".join([currDir, f"{GOODSEASON}"]) + + if exists(cls.destDir): + rmtree(cls.destDir, ignore_errors=True) + copytree(copyDir, cls.destDir) with open("/".join([cls.destDir, "schedule.json"])) as jsonData: cls.jsonSchedule = loads(jsonData.read()) + # Grab and save the data in the TOP SEASON + splitDir = currDir.split("/") + splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", f"{TOPSEASON}"] + copyDir = "/" + join(*splitDir) + cls.destDirTop = "/".join([currDir, f"{TOPSEASON}"]) + + if exists(cls.destDirTop): + rmtree(cls.destDirTop, ignore_errors=True) + + copytree(copyDir, cls.destDirTop) + + with open("/".join([cls.destDirTop, "schedule.json"])) as jsonData: + cls.jsonScheduleTop = loads(jsonData.read()) + return super().setUpClass() @classmethod @@ -164,3 +199,38 @@ def test_calculate_scores_multi_with_invalid(self): value=value ): self.assertIsNone(value) + + def test_parse_season_events_no_schedule(self): + '''Test when the schedule is None''' + parsedHomeTeamEvents, parsedAwayTeamEvents = getSeasonEventsFromSchedules( + None, self.jsonSchedule + ) + + self.assertIsNone(parsedHomeTeamEvents) + self.assertIsNone(parsedAwayTeamEvents) + + def test_parse_season_events_no_previous(self): + '''Test when the previous schedule is None''' + parsedHomeTeamEvents, parsedAwayTeamEvents = getSeasonEventsFromSchedules( + self.jsonScheduleTop, None + ) + + self.assertIsNone(parsedHomeTeamEvents) + self.assertIsNone(parsedAwayTeamEvents) + + def test_parse_season_events(self): + parsedHomeTeamEvents, parsedAwayTeamEvents = getSeasonEventsFromSchedules( + self.jsonScheduleTop, self.jsonSchedule + ) + + # each game is both an away and home event so there should be entries for + # all teams. + self.assertEqual(len(parsedHomeTeamEvents), len(parsedAwayTeamEvents)) + + total = 0 + for key in parsedHomeTeamEvents: + total += len(parsedHomeTeamEvents[key]) + + # These schedules do NOT include the full expansion teams so the number + # of games should be less than max + self.assertLess(total, MAX_GAME_NUMBER)