Skip to content

Commit

Permalink
Updated function for testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
barbacbd committed Mar 21, 2024
1 parent 322b627 commit c46a5ab
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
13 changes: 13 additions & 0 deletions src/nhl_model/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def parseSeasonEvents(year): # pylint: disable=too-many-branches
logger.error(f"Failed to find a schedule for previous year {year-1}")
return None, None

return getSeasonEventsFromSchedules(schedule, previousSchedule)


def getSeasonEventsFromSchedules(schedule, previousSchedule): # pylint: disable=too-many-branches
"""Parse the events for the season given the json formatted season and
previous season data. This will include predicting which team will win
each game.
:return: home team events, away team events
"""
if None in (schedule, previousSchedule):
return None, None

homeTeamEventsPrev, awayTeamEventsPrev = parseSchedule(previousSchedule)
totalTeamIdsPrevSeason = set(list(homeTeamEventsPrev.keys()))
totalTeamIdsPrevSeason.update(list(awayTeamEventsPrev.keys()))
Expand Down
76 changes: 73 additions & 3 deletions tests/test_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,34 @@
# pylint: disable=redefined-builtin
# pylint: disable=deprecated-method
from unittest import TestCase
from os.path import dirname, abspath, join
from os.path import dirname, abspath, join, exists
from json import loads
from collections import defaultdict
from shutil import copytree, rmtree
from nhl_core.endpoints import MAX_GAME_NUMBER
from nhl_model.poisson import (
parseSchedule,
calculateAvgGoals,
calculateScores
calculateScores,
getSeasonEventsFromSchedules,
)

BADSEASON = 2004
GOODSEASON = 2005
TOPSEASON = 2006

def mocked_schedules_get(year):
if year is None or year == BADSEASON:
return None

currDir = dirname(abspath(__file__))
splitDir = currDir.split("/")
splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", f"{year}", "schedule.json"]
filename = "/" + join(*splitDir)

with open(filename, "rb") as jsonData:
return loads(jsonData.read())


class PoissonTests(TestCase):
'''Test cases for the Poisson functionality to the module.
Expand All @@ -25,15 +41,34 @@ def setUpClass(cls):
'''Set up the class for testing the dataset'''
currDir = dirname(abspath(__file__))

# Grab and save the data in the GOOD SEASON
splitDir = currDir.split("/")
splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", "2005"]
splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", f"{GOODSEASON}"]
copyDir = "/" + join(*splitDir)
cls.destDir = "/".join([currDir, f"{GOODSEASON}"])

if exists(cls.destDir):
rmtree(cls.destDir, ignore_errors=True)

copytree(copyDir, cls.destDir)

with open("/".join([cls.destDir, "schedule.json"])) as jsonData:
cls.jsonSchedule = loads(jsonData.read())

# Grab and save the data in the TOP SEASON
splitDir = currDir.split("/")
splitDir = splitDir[:-1] + ["src", "nhl_model", "support", "schedules", f"{TOPSEASON}"]
copyDir = "/" + join(*splitDir)
cls.destDirTop = "/".join([currDir, f"{TOPSEASON}"])

if exists(cls.destDirTop):
rmtree(cls.destDirTop, ignore_errors=True)

copytree(copyDir, cls.destDirTop)

with open("/".join([cls.destDirTop, "schedule.json"])) as jsonData:
cls.jsonScheduleTop = loads(jsonData.read())

return super().setUpClass()

@classmethod
Expand Down Expand Up @@ -164,3 +199,38 @@ def test_calculate_scores_multi_with_invalid(self):
value=value
):
self.assertIsNone(value)

def test_parse_season_events_no_schedule(self):
'''Test when the schedule is None'''
parsedHomeTeamEvents, parsedAwayTeamEvents = getSeasonEventsFromSchedules(
None, self.jsonSchedule
)

self.assertIsNone(parsedHomeTeamEvents)
self.assertIsNone(parsedAwayTeamEvents)

def test_parse_season_events_no_previous(self):
'''Test when the previous schedule is None'''
parsedHomeTeamEvents, parsedAwayTeamEvents = getSeasonEventsFromSchedules(
self.jsonScheduleTop, None
)

self.assertIsNone(parsedHomeTeamEvents)
self.assertIsNone(parsedAwayTeamEvents)

def test_parse_season_events(self):
parsedHomeTeamEvents, parsedAwayTeamEvents = getSeasonEventsFromSchedules(
self.jsonScheduleTop, self.jsonSchedule
)

# each game is both an away and home event so there should be entries for
# all teams.
self.assertEqual(len(parsedHomeTeamEvents), len(parsedAwayTeamEvents))

total = 0
for key in parsedHomeTeamEvents:
total += len(parsedHomeTeamEvents[key])

# These schedules do NOT include the full expansion teams so the number
# of games should be less than max
self.assertLess(total, MAX_GAME_NUMBER)

0 comments on commit c46a5ab

Please sign in to comment.