Skip to content

Commit

Permalink
Merge pull request #1642 from Avaiga/feature/#1572-Protect-scenario-c…
Browse files Browse the repository at this point in the history
…ustom-properties

feature/#1572 attempt to Protect scenario custom properties
  • Loading branch information
toan-quach authored Aug 27, 2024
2 parents d0eaff6 + 1a0db5c commit 892e082
Show file tree
Hide file tree
Showing 24 changed files with 172 additions and 94 deletions.
2 changes: 1 addition & 1 deletion taipy/core/common/warn_if_inputs_not_ready.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _warn_if_inputs_not_ready(inputs: Iterable[DataNode]):
]:
logger.warning(
f"{dn.id} cannot be read because it has never been written. "
f"Hint: The data node may refer to a wrong path : {dn.path} "
f"Hint: The data node may refer to a wrong path : {dn.properties['path']} "
)
else:
logger.warning(f"{dn.id} cannot be read because it has never been written.")
6 changes: 0 additions & 6 deletions taipy/core/cycle/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,6 @@ def _get_valid_filename(name: str) -> str:

return CycleId(_get_valid_filename(Cycle.__SEPARATOR.join([Cycle._ID_PREFIX, name, str(uuid.uuid4())])))

def __getattr__(self, attribute_name):
protected_attribute_name = attribute_name
if protected_attribute_name in self._properties:
return self._properties[protected_attribute_name]
raise AttributeError(f"{attribute_name} is not an attribute of cycle {self.id}")

def __eq__(self, other):
return isinstance(other, Cycle) and self.id == other.id

Expand Down
10 changes: 2 additions & 8 deletions taipy/core/data/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class DataNode(_Entity, _Labeled):

_ID_PREFIX = "DATANODE"
__ID_SEPARATOR = "_"
__logger = _TaipyLogger._get_logger()
_logger = _TaipyLogger._get_logger()
_REQUIRED_PROPERTIES: List[str] = []
_MANAGER_NAME: str = "data"
_PATH_KEY = "path"
Expand Down Expand Up @@ -347,12 +347,6 @@ def __getstate__(self):
def __setstate__(self, state):
vars(self).update(state)

def __getattr__(self, attribute_name):
protected_attribute_name = _validate_id(attribute_name)
if protected_attribute_name in self._properties:
return self._properties[protected_attribute_name]
raise AttributeError(f"{attribute_name} is not an attribute of data node {self.id}")

@classmethod
def _get_last_modified_datetime(cls, path: Optional[str] = None) -> Optional[datetime]:
if path and os.path.isfile(path):
Expand Down Expand Up @@ -397,7 +391,7 @@ def read(self) -> Any:
try:
return self.read_or_raise()
except NoData:
self.__logger.warning(
self._logger.warning(
f"Data node {self.id} from config {self.config_id} is being read but has never been written."
)
return None
Expand Down
2 changes: 1 addition & 1 deletion taipy/core/data/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _read_from_path(self, path: Optional[str] = None, **read_kwargs) -> Any:

# return None if data was never written
if not self.last_edit_date:
self._DataNode__logger.warning(
self._logger.warning(
f"Data node {self.id} from config {self.config_id} is being read but has never been written."
)
return None
Expand Down
7 changes: 7 additions & 0 deletions taipy/core/exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,10 @@ class SQLQueryCannotBeExecuted(Exception):

class _SuspiciousFileOperation(Exception):
pass


class AttributeKeyAlreadyExisted(Exception):
"""Raised when an attribute key already existed."""

def __init__(self, key: str):
self.message = f"Attribute key '{key}' already existed."
33 changes: 23 additions & 10 deletions taipy/core/scenario/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import networkx as nx

from taipy.config.common._template_handler import _TemplateHandler as _tpl
from taipy.config.common._validate_id import _validate_id

from .._entity._entity import _Entity
Expand All @@ -31,6 +30,7 @@
from ..data.data_node import DataNode
from ..data.data_node_id import DataNodeId
from ..exceptions.exceptions import (
AttributeKeyAlreadyExisted,
InvalidSequence,
NonExistingDataNode,
NonExistingSequence,
Expand Down Expand Up @@ -117,6 +117,7 @@ def by_two(x: int):
_SEQUENCE_TASKS_KEY = "tasks"
_SEQUENCE_PROPERTIES_KEY = "properties"
_SEQUENCE_SUBSCRIBERS_KEY = "subscribers"
__CHECK_INIT_DONE_ATTR_NAME = "_init_done"

def __init__(
self,
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
)

self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
self._init_done = True

@staticmethod
def _new_id(config_id: str) -> ScenarioId:
Expand All @@ -176,20 +178,28 @@ def __hash__(self):
def __eq__(self, other):
return isinstance(other, Scenario) and self.id == other.id

def __getattr__(self, attribute_name):
def __setattr__(self, name: str, value: Any) -> None:
if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
return super().__setattr__(name, value)
else:
try:
self.__getattr__(name)
raise AttributeKeyAlreadyExisted(name)
except AttributeError:
return super().__setattr__(name, value)

def __getattr__(self, attribute_name) -> Union[Sequence, Task, DataNode]:
protected_attribute_name = _validate_id(attribute_name)
if protected_attribute_name in self._properties:
return _tpl._replace_templates(self._properties[protected_attribute_name])

sequences = self._get_sequences()
if protected_attribute_name in sequences:
return sequences[protected_attribute_name]
tasks = self.tasks
tasks = self.__get_tasks()
if protected_attribute_name in tasks:
return tasks[protected_attribute_name]
data_nodes = self.data_nodes
data_nodes = self.__get_data_nodes()
if protected_attribute_name in data_nodes:
return data_nodes[protected_attribute_name]

raise AttributeError(f"{attribute_name} is not an attribute of scenario {self.id}")

@property
Expand Down Expand Up @@ -458,14 +468,17 @@ def additional_data_nodes(self, val: Union[Set[TaskId], Set[DataNode]]):
def _get_set_of_tasks(self) -> Set[Task]:
return set(self.tasks.values())

@property # type: ignore
@_self_reload(_MANAGER_NAME)
def data_nodes(self) -> Dict[str, DataNode]:
def __get_data_nodes(self) -> Dict[str, DataNode]:
data_nodes_dict = self.__get_additional_data_nodes()
for _, task in self.__get_tasks().items():
data_nodes_dict.update(task.data_nodes)
return data_nodes_dict

@property # type: ignore
@_self_reload(_MANAGER_NAME)
def data_nodes(self) -> Dict[str, DataNode]:
return self.__get_data_nodes()

@property # type: ignore
@_self_reload(_MANAGER_NAME)
def creation_date(self):
Expand Down
17 changes: 13 additions & 4 deletions taipy/core/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import networkx as nx

from taipy.config.common._template_handler import _TemplateHandler as _tpl
from taipy.config.common._validate_id import _validate_id

from .._entity._entity import _Entity
Expand All @@ -27,7 +26,7 @@
from ..common._listattributes import _ListAttributes
from ..common._utils import _Subscriber
from ..data.data_node import DataNode
from ..exceptions.exceptions import NonExistingTask
from ..exceptions.exceptions import AttributeKeyAlreadyExisted, NonExistingTask
from ..job.job import Job
from ..notification.event import Event, EventEntityType, EventOperation, _make_event
from ..submission.submission import Submission
Expand Down Expand Up @@ -126,6 +125,7 @@ def planning(forecast, capacity):
_ID_PREFIX = "SEQUENCE"
_SEPARATOR = "_"
_MANAGER_NAME = "sequence"
__CHECK_INIT_DONE_ATTR_NAME = "_init_done"

def __init__(
self,
Expand All @@ -144,6 +144,7 @@ def __init__(
self._parent_ids = parent_ids or set()
self._properties = _Properties(self, **properties)
self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
self._init_done = True

@staticmethod
def _new_id(sequence_name: str, scenario_id) -> SequenceId:
Expand All @@ -156,10 +157,18 @@ def __hash__(self):
def __eq__(self, other):
return isinstance(other, Sequence) and self.id == other.id

def __setattr__(self, name: str, value: Any) -> None:
if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
return super().__setattr__(name, value)
else:
try:
self.__getattr__(name)
raise AttributeKeyAlreadyExisted(name)
except AttributeError:
return super().__setattr__(name, value)

def __getattr__(self, attribute_name):
protected_attribute_name = _validate_id(attribute_name)
if protected_attribute_name in self._properties:
return _tpl._replace_templates(self._properties[protected_attribute_name])
tasks = self._get_tasks()
if protected_attribute_name in tasks:
return tasks[protected_attribute_name]
Expand Down
16 changes: 13 additions & 3 deletions taipy/core/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union

from taipy.config.common._template_handler import _TemplateHandler as _tpl
from taipy.config.common._validate_id import _validate_id
from taipy.config.common.scope import Scope

Expand All @@ -22,6 +21,7 @@
from .._entity._reload import _Reloader, _self_reload, _self_setter
from .._version._version_manager_factory import _VersionManagerFactory
from ..data.data_node import DataNode
from ..exceptions import AttributeKeyAlreadyExisted
from ..notification.event import Event, EventEntityType, EventOperation, _make_event
from ..submission.submission import Submission
from .task_id import TaskId
Expand Down Expand Up @@ -97,6 +97,7 @@ def by_two(x: int):
_ID_PREFIX = "TASK"
__ID_SEPARATOR = "_"
_MANAGER_NAME = "task"
__CHECK_INIT_DONE_ATTR_NAME = "_init_done"

def __init__(
self,
Expand All @@ -121,6 +122,7 @@ def __init__(
self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
self._skippable = skippable
self._properties = _Properties(self, **properties)
self._init_done = True

def __hash__(self):
return hash(self.id)
Expand All @@ -134,10 +136,18 @@ def __getstate__(self):
def __setstate__(self, state):
vars(self).update(state)

def __setattr__(self, name: str, value: Any) -> None:
if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
return super().__setattr__(name, value)
else:
try:
self.__getattr__(name)
raise AttributeKeyAlreadyExisted(name)
except AttributeError:
return super().__setattr__(name, value)

def __getattr__(self, attribute_name):
protected_attribute_name = _validate_id(attribute_name)
if protected_attribute_name in self._properties:
return _tpl._replace_templates(self._properties[protected_attribute_name])
if protected_attribute_name in self.input:
return self.input[protected_attribute_name]
if protected_attribute_name in self.output:
Expand Down
4 changes: 2 additions & 2 deletions taipy/gui_core/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def process_event(self, event: Event):
and is_readable(t.cast(SequenceId, event.entity_id))
else None
)
if sequence and hasattr(sequence, "parent_ids") and sequence.parent_ids:
self.broadcast_core_changed({"scenario": list(sequence.parent_ids)})
if sequence and hasattr(sequence, "parent_ids") and sequence.parent_ids: # type: ignore
self.broadcast_core_changed({"scenario": list(sequence.parent_ids)}) # type: ignore
except Exception as e:
_warn(f"Access to sequence {event.entity_id} failed", e)
elif event.entity_type == EventEntityType.JOB:
Expand Down
9 changes: 5 additions & 4 deletions tests/core/cycle/test_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import datetime
from datetime import timedelta

Expand Down Expand Up @@ -49,7 +50,7 @@ def test_create_cycle_entity(current_datetime):
assert cycle_1.creation_date == current_datetime
assert cycle_1.start_date == current_datetime
assert cycle_1.end_date == current_datetime
assert cycle_1.key == "value"
assert cycle_1.properties["key"] == "value"
assert cycle_1.frequency == Frequency.DAILY

cycle_2 = Cycle(Frequency.YEARLY, {}, current_datetime, current_datetime, current_datetime)
Expand Down Expand Up @@ -111,13 +112,13 @@ def test_add_property_to_scenario(current_datetime):
name="foo",
)
assert cycle.properties == {"key": "value"}
assert cycle.key == "value"
assert cycle.properties["key"] == "value"

cycle.properties["new_key"] = "new_value"

assert cycle.properties == {"key": "value", "new_key": "new_value"}
assert cycle.key == "value"
assert cycle.new_key == "new_value"
assert cycle.properties["key"] == "value"
assert cycle.properties["new_key"] == "new_value"


def test_auto_set_and_reload(current_datetime):
Expand Down
2 changes: 1 addition & 1 deletion tests/core/cycle/test_cycle_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_create_and_delete_cycle_entity(tmpdir):
assert cycle_1.start_date is not None
assert cycle_1.end_date is not None
assert cycle_1.start_date < cycle_1.creation_date < cycle_1.end_date
assert cycle_1.key == "value"
assert cycle_1.properties["key"] == "value"
assert cycle_1.frequency == Frequency.DAILY

cycle_1_id = cycle_1.id
Expand Down
8 changes: 4 additions & 4 deletions tests/core/data/test_csv_data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ def test_create(self):
assert dn.job_ids == []
assert not dn.is_ready_for_reading
assert dn.path == default_path
assert dn.has_header is False
assert dn.exposed_type == "pandas"
assert dn.properties["has_header"] is False
assert dn.properties["exposed_type"] == "pandas"

csv_dn_config = Config.configure_csv_data_node(
id="foo", default_path=default_path, has_header=True, exposed_type=MyCustomObject
)
dn = _DataManagerFactory._build_manager()._create_and_set(csv_dn_config, None, None)
assert dn.storage_type() == "csv"
assert dn.config_id == "foo"
assert dn.has_header is True
assert dn.exposed_type == MyCustomObject
assert dn.properties["has_header"] is True
assert dn.properties["exposed_type"] == MyCustomObject

with pytest.raises(InvalidConfigurationId):
CSVDataNode(
Expand Down
4 changes: 2 additions & 2 deletions tests/core/data/test_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,14 @@ def test_create_uses_overridden_attributes_in_config_file(self):
assert csv_dn.config_id == "foo"
assert isinstance(csv_dn, CSVDataNode)
assert csv_dn._path == "path_from_config_file"
assert csv_dn.has_header
assert csv_dn.properties["has_header"]

csv_dn_cfg = Config.configure_data_node(id="baz", storage_type="csv", path="bar", has_header=True)
csv_dn = _DataManager._create_and_set(csv_dn_cfg, None, None)
assert csv_dn.config_id == "baz"
assert isinstance(csv_dn, CSVDataNode)
assert csv_dn._path == "bar"
assert csv_dn.has_header
assert csv_dn.properties["has_header"]

def test_get_if_not_exists(self):
with pytest.raises(ModelNotFound):
Expand Down
1 change: 0 additions & 1 deletion tests/core/data/test_data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ def test_data_node_with_env_variable_value_not_stored(self):
dn = _DataManager._bulk_get_or_create([dn_config])[dn_config]
assert dn._properties.data["prop"] == "ENV[FOO]"
assert dn.properties["prop"] == "bar"
assert dn.prop == "bar"

def test_path_populated_with_config_default_path(self):
dn_config = Config.configure_data_node("data_node", "pickle", default_path="foo.p")
Expand Down
Loading

0 comments on commit 892e082

Please sign in to comment.