Skip to content

Commit

Permalink
Updated new GufeTokenizable models in statestore
Browse files Browse the repository at this point in the history
* Removed TaskRestartPolicy and TaskHistory
* Added Traceback
  • Loading branch information
ianmkenney committed Jul 17, 2024
1 parent dd8f0e9 commit da17e45
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 46 deletions.
77 changes: 32 additions & 45 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,72 +151,59 @@ class TaskRestartPattern(GufeTokenizable):
----------
pattern: str
A regular expression pattern that can match to returned tracebacks of errored Tasks.
retry_count: int
max_retries: int
The number of times the pattern can trigger a restart for a Task.
"""

pattern: str
retry_count: int
max_retries: int

def __init__(self, pattern: str):
def __init__(self, pattern: str, max_retries: int):
self.pattern = pattern

if not isinstance(max_retries, int) or max_retries <= 0:
raise ValueError("`max_retries` must have a positive integer value.")
self.max_retries = max_retries

# TODO: these hashes can overlap across TaskHubs
def _gufe_tokenize(self):
return hashlib.md5(self.pattern).hexdigest()
return hashlib.md5(self.pattern.encode()).hexdigest()

@classmethod
def _defaults(cls):
raise NotImplementedError

@classmethod
def _from_dict(cls, dct):
return cls(**dct)

def _to_dict(self):
return {"pattern": self.pattern, "max_retries": self.max_retries}

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.pattern == other.pattern


# TODO: fill in docstrings
class TaskRestartPolicy(GufeTokenizable):
"""Restart policy that enforces a TaskHub.
Attributes
----------
taskhub: str
ScopedKey of the TaskHub this TaskRestartPolicy enforces.
"""

taskhub: str
class Traceback(GufeTokenizable):

def __init__(self, taskhub: ScopedKey):
self.taskhub = taskhub
def __init__(self, tracebacks: List[str]):
self.tracebacks = tracebacks

def _gufe_tokenize(self):
return hashlib.md5(
self.__class__.__qualname__ + str(self.taskhub), usedforsecurity=False
).hexdigest()

return hashlib.md5(str(self.tracebacks).encode()).hexdigest()

# TODO: fill in docstrings
class TaskHistory(GufeTokenizable):
"""History attached to a `Task`.
Attributes
----------
task: str
ScopedKey of the Task this TaskHistory corresponds to.
tracebacks: List[str]
The history of tracebacks returned with the newest entries appearing at the end of the list.
times_restarted: int
The number of times the task has bee
"""

task: str
tracebacks: list
times_restarted: int
@classmethod
def _defaults(cls):
raise NotImplementedError

def __init__(self, task: ScopedKey, tracebacks: List[str]):
self.task = task
self.tracebacks = tracebacks
@classmethod
def _from_dict(cls, dct):
return Traceback(**dct)

def _gufe_tokenize(self):
return hashlib.md5(
self.__class__.__qualname__ + str(self.task), usedforsecurity=False
).hexdigest()
def _to_dict(self):
return {"tracebacks": self.tracebacks}


class TaskHub(GufeTokenizable):
Expand Down
16 changes: 15 additions & 1 deletion alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
ComputeServiceRegistration,
NetworkMark,
NetworkStateEnum,
ProtocolDAGResultRef,
Task,
TaskHub,
TaskStatusEnum,
ProtocolDAGResultRef,
)
from ..strategies import Strategy
from ..models import Scope, ScopedKey
Expand Down Expand Up @@ -2703,6 +2703,20 @@ def err_msg(t, status):

return self._set_task_status(tasks, q, err_msg, raise_error=raise_error)

## task restart policy

# TODO: fill in docstring
def add_task_restart_policy_patterns(
self, taskhub: ScopedKey, patterns: List[str], number_of_retries: int
):
"""Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed.
Parameters
----------
"""
raise NotImplementedError

## authentication

def create_credentialed_entity(self, entity: CredentialedEntity):
Expand Down

0 comments on commit da17e45

Please sign in to comment.