Skip to content

Commit

Permalink
authorization for submission access (#2276)
Browse files Browse the repository at this point in the history
* authorization for submission access
resolves taipy-enterprise#562

* lint

* fix test

* fix test

---------

Co-authored-by: Fred Lefévère-Laoide <Fred.Lefevere-Laoide@Taipy.io>
  • Loading branch information
FredLL-Avaiga and Fred Lefévère-Laoide authored Nov 25, 2024
1 parent e97ed37 commit b1894c8
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 62 deletions.
55 changes: 30 additions & 25 deletions taipy/gui_core/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
_GuiCoreScenarioProperties,
_invoke_action,
)
from ._utils import _ClientStatus
from .filters import CustomScenarioFilter


Expand All @@ -92,7 +93,7 @@ def __init__(self, gui: Gui) -> None:
self.data_nodes_by_owner: t.Optional[t.Dict[t.Optional[str], t.List[DataNode]]] = None
self.scenario_configs: t.Optional[t.List[t.Tuple[str, str]]] = None
self.jobs_list: t.Optional[t.List[Job]] = None
self.client_submission: t.Dict[str, SubmissionStatus] = {}
self.client_submission: t.Dict[str, _ClientStatus] = {}
# register to taipy core notification
reg_id, reg_queue = Notifier.register()
# locks
Expand Down Expand Up @@ -162,28 +163,32 @@ def scenario_refresh(self, scenario_id: t.Optional[str]):
self.broadcast_core_changed({"scenario": scenario_id or True})

def submission_status_callback(self, submission_id: t.Optional[str] = None, event: t.Optional[Event] = None):
if not submission_id or not is_readable(t.cast(SubmissionId, submission_id)):
if not submission_id:
return
submission = None
new_status = None
payload: t.Optional[t.Dict[str, t.Any]] = None
client_id: t.Optional[str] = None
try:
last_status = self.client_submission.get(submission_id)
if not last_status:
last_client_status = self.client_submission.get(submission_id)
if not last_client_status:
return

submission = t.cast(Submission, core_get(submission_id))
if not submission or not submission.entity_id:
return
client_id = last_client_status.client_id

with self.gui._get_authorization(client_id):
if not is_readable(t.cast(SubmissionId, submission_id)):
return
submission = t.cast(Submission, core_get(submission_id))
if not submission or not submission.entity_id:
return

payload = {}
new_status = t.cast(SubmissionStatus, submission.submission_status)
payload = {}
new_status = t.cast(SubmissionStatus, submission.submission_status)

client_id = submission.properties.get("client_id")
if client_id:
running_tasks = {}
with self.gui._get_authorization(client_id):
if client_id:
running_tasks = {}
# with self.gui._get_authorization(client_id):
for job in submission.jobs:
job = job if isinstance(job, Job) else t.cast(Job, core_get(job))
running_tasks[job.task.id] = (
Expand All @@ -195,7 +200,7 @@ def submission_status_callback(self, submission_id: t.Optional[str] = None, even
)
payload.update(tasks=running_tasks)

if last_status is not new_status:
if last_client_status.submission_status is not new_status:
# callback
submission_name = submission.properties.get("on_submission")
if submission_name:
Expand All @@ -213,15 +218,15 @@ def submission_status_callback(self, submission_id: t.Optional[str] = None, even
submission.properties.get("module_context"),
)

with self.submissions_lock:
if new_status in (
SubmissionStatus.COMPLETED,
SubmissionStatus.FAILED,
SubmissionStatus.CANCELED,
):
if new_status in (
SubmissionStatus.COMPLETED,
SubmissionStatus.FAILED,
SubmissionStatus.CANCELED,
):
with self.submissions_lock:
self.client_submission.pop(submission_id, None)
else:
self.client_submission[submission_id] = new_status
else:
last_client_status.submission_status = new_status

except Exception as e:
_warn(f"Submission ({submission_id}) is not available", e)
Expand Down Expand Up @@ -634,11 +639,11 @@ def submit_entity(self, state: State, id: str, payload: t.Dict[str, str]):
client_id=self.gui._get_client_id(),
module_context=self.gui._get_locals_context(),
)
client_status = _ClientStatus(self.gui._get_client_id(), submission_entity.submission_status)
with self.submissions_lock:
self.client_submission[submission_entity.id] = submission_entity.submission_status
self.client_submission[submission_entity.id] = client_status
if Config.core.mode == "development":
with self.submissions_lock:
self.client_submission[submission_entity.id] = SubmissionStatus.SUBMITTED
client_status.submission_status = SubmissionStatus.SUBMITTED
self.submission_status_callback(submission_entity.id)
_GuiCoreContext.__assign_var(state, error_var, "")
except Exception as e:
Expand Down
20 changes: 20 additions & 0 deletions taipy/gui_core/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021-2024 Avaiga Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import typing as t
from dataclasses import dataclass

from taipy.core.submission.submission_status import SubmissionStatus


@dataclass
class _ClientStatus:
client_id: t.Optional[str]
submission_status: SubmissionStatus
77 changes: 40 additions & 37 deletions tests/gui_core/test_context_is_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from taipy.core.submission._submission_manager_factory import _SubmissionManagerFactory
from taipy.core.submission.submission import Submission, SubmissionStatus
from taipy.core.task._task_manager_factory import _TaskManagerFactory
from taipy.gui import Gui
from taipy.gui import Gui, State
from taipy.gui_core._context import _GuiCoreContext
from taipy.gui_core._utils import _ClientStatus

a_cycle = Cycle(Frequency.DAILY, {}, datetime.now(), datetime.now(), datetime.now(), id=CycleId("CYCLE_id"))
a_scenario = Scenario("scenario_config_id", None, {}, sequences={"sequence": {}})
Expand Down Expand Up @@ -66,9 +67,14 @@ def mock_core_get(entity_id):
return a_task


class MockState:
class MockState(State):
def __init__(self, **kwargs) -> None:
self.assign = kwargs.get("assign")
self.assign = t.cast(t.Callable, kwargs.get("assign")) # type: ignore[method-assign]
self.gui = t.cast(Gui, kwargs.get("gui"))
def get_gui(self):
return self.gui
def broadcast(self, name: str, value: t.Any):
pass


class TestGuiCoreContext_is_readable:
Expand Down Expand Up @@ -96,7 +102,7 @@ def test_scenario_adapter(self):
def test_cycle_adapter(self):
with patch("taipy.gui_core._context.core_get", side_effect=mock_core_get):
gui_core_context = _GuiCoreContext(Mock())
gui_core_context.scenario_by_cycle = {"a": 1}
gui_core_context.scenario_by_cycle = t.cast(dict, {"a": 1})
outcome = gui_core_context.cycle_adapter(a_cycle)
assert isinstance(outcome, list)
assert outcome[0] == a_cycle.id
Expand All @@ -120,9 +126,9 @@ def test_crud_scenario(self):
gui_core_context = _GuiCoreContext(Mock())
assign = Mock()
gui_core_context.crud_scenario(
MockState(assign=assign),
MockState(assign=assign, gui=gui_core_context.gui),
"",
{
t.cast(dict, {
"args": [
"",
"",
Expand All @@ -132,7 +138,7 @@ def test_crud_scenario(self):
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_not_called()

Expand All @@ -141,7 +147,7 @@ def test_crud_scenario(self):
gui_core_context.crud_scenario(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
"",
"",
Expand All @@ -151,7 +157,7 @@ def test_crud_scenario(self):
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -164,12 +170,12 @@ def test_edit_entity(self):
gui_core_context.edit_entity(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -180,12 +186,12 @@ def test_edit_entity(self):
gui_core_context.edit_entity(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -198,10 +204,7 @@ def test_submission_status_callback(self):
mockGui._get_authorization = lambda s: contextlib.nullcontext()
gui_core_context = _GuiCoreContext(mockGui)

def sub_cb():
return True

gui_core_context.client_submission[a_submission.id] = SubmissionStatus.UNDEFINED
gui_core_context.client_submission[a_submission.id] = _ClientStatus("client_id", SubmissionStatus.UNDEFINED)
gui_core_context.submission_status_callback(a_submission.id)
mockget.assert_called()
found = False
Expand Down Expand Up @@ -248,12 +251,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "delete"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -263,12 +266,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "cancel"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -279,12 +282,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "delete"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -294,12 +297,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "cancel"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -312,12 +315,12 @@ def test_edit_data_node(self):
gui_core_context.edit_data_node(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -328,12 +331,12 @@ def test_edit_data_node(self):
gui_core_context.edit_data_node(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -348,12 +351,12 @@ def test_lock_datanode_for_edit(self):
gui_core_context.lock_datanode_for_edit(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -364,12 +367,12 @@ def test_lock_datanode_for_edit(self):
gui_core_context.lock_datanode_for_edit(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -395,12 +398,12 @@ def test_update_data(self):
gui_core_context.update_data(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called()
assert assign.call_args_list[0].args[0] == "error_var"
Expand All @@ -411,12 +414,12 @@ def test_update_data(self):
gui_core_context.update_data(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand Down

0 comments on commit b1894c8

Please sign in to comment.