Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

authorization for submission access #2276

Merged
merged 5 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions taipy/gui_core/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
_GuiCoreScenarioProperties,
_invoke_action,
)
from ._utils import _ClientStatus
from .filters import CustomScenarioFilter


Expand All @@ -92,7 +93,7 @@ def __init__(self, gui: Gui) -> None:
self.data_nodes_by_owner: t.Optional[t.Dict[t.Optional[str], t.List[DataNode]]] = None
self.scenario_configs: t.Optional[t.List[t.Tuple[str, str]]] = None
self.jobs_list: t.Optional[t.List[Job]] = None
self.client_submission: t.Dict[str, SubmissionStatus] = {}
self.client_submission: t.Dict[str, _ClientStatus] = {}
# register to taipy core notification
reg_id, reg_queue = Notifier.register()
# locks
Expand Down Expand Up @@ -162,28 +163,32 @@ def scenario_refresh(self, scenario_id: t.Optional[str]):
self.broadcast_core_changed({"scenario": scenario_id or True})

def submission_status_callback(self, submission_id: t.Optional[str] = None, event: t.Optional[Event] = None):
if not submission_id or not is_readable(t.cast(SubmissionId, submission_id)):
if not submission_id:
return
submission = None
new_status = None
payload: t.Optional[t.Dict[str, t.Any]] = None
client_id: t.Optional[str] = None
try:
last_status = self.client_submission.get(submission_id)
if not last_status:
last_client_status = self.client_submission.get(submission_id)
if not last_client_status:
return

submission = t.cast(Submission, core_get(submission_id))
if not submission or not submission.entity_id:
return
client_id = last_client_status.client_id

with self.gui._get_authorization(client_id):
if not is_readable(t.cast(SubmissionId, submission_id)):
return
submission = t.cast(Submission, core_get(submission_id))
if not submission or not submission.entity_id:
return

payload = {}
new_status = t.cast(SubmissionStatus, submission.submission_status)
payload = {}
new_status = t.cast(SubmissionStatus, submission.submission_status)

client_id = submission.properties.get("client_id")
if client_id:
running_tasks = {}
with self.gui._get_authorization(client_id):
if client_id:
running_tasks = {}
# with self.gui._get_authorization(client_id):
for job in submission.jobs:
job = job if isinstance(job, Job) else t.cast(Job, core_get(job))
running_tasks[job.task.id] = (
Expand All @@ -195,7 +200,7 @@ def submission_status_callback(self, submission_id: t.Optional[str] = None, even
)
payload.update(tasks=running_tasks)

if last_status is not new_status:
if last_client_status.submission_status is not new_status:
# callback
submission_name = submission.properties.get("on_submission")
if submission_name:
Expand All @@ -213,15 +218,15 @@ def submission_status_callback(self, submission_id: t.Optional[str] = None, even
submission.properties.get("module_context"),
)

with self.submissions_lock:
if new_status in (
SubmissionStatus.COMPLETED,
SubmissionStatus.FAILED,
SubmissionStatus.CANCELED,
):
if new_status in (
SubmissionStatus.COMPLETED,
SubmissionStatus.FAILED,
SubmissionStatus.CANCELED,
):
with self.submissions_lock:
self.client_submission.pop(submission_id, None)
else:
self.client_submission[submission_id] = new_status
else:
last_client_status.submission_status = new_status

except Exception as e:
_warn(f"Submission ({submission_id}) is not available", e)
Expand Down Expand Up @@ -634,11 +639,11 @@ def submit_entity(self, state: State, id: str, payload: t.Dict[str, str]):
client_id=self.gui._get_client_id(),
module_context=self.gui._get_locals_context(),
)
client_status = _ClientStatus(self.gui._get_client_id(), submission_entity.submission_status)
with self.submissions_lock:
self.client_submission[submission_entity.id] = submission_entity.submission_status
self.client_submission[submission_entity.id] = client_status
if Config.core.mode == "development":
with self.submissions_lock:
self.client_submission[submission_entity.id] = SubmissionStatus.SUBMITTED
client_status.submission_status = SubmissionStatus.SUBMITTED
self.submission_status_callback(submission_entity.id)
_GuiCoreContext.__assign_var(state, error_var, "")
except Exception as e:
Expand Down
20 changes: 20 additions & 0 deletions taipy/gui_core/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021-2024 Avaiga Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import typing as t
from dataclasses import dataclass

from taipy.core.submission.submission_status import SubmissionStatus


@dataclass
class _ClientStatus:
client_id: t.Optional[str]
submission_status: SubmissionStatus
77 changes: 40 additions & 37 deletions tests/gui_core/test_context_is_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from taipy.core.submission._submission_manager_factory import _SubmissionManagerFactory
from taipy.core.submission.submission import Submission, SubmissionStatus
from taipy.core.task._task_manager_factory import _TaskManagerFactory
from taipy.gui import Gui
from taipy.gui import Gui, State
from taipy.gui_core._context import _GuiCoreContext
from taipy.gui_core._utils import _ClientStatus

a_cycle = Cycle(Frequency.DAILY, {}, datetime.now(), datetime.now(), datetime.now(), id=CycleId("CYCLE_id"))
a_scenario = Scenario("scenario_config_id", None, {}, sequences={"sequence": {}})
Expand Down Expand Up @@ -66,9 +67,14 @@ def mock_core_get(entity_id):
return a_task


class MockState:
class MockState(State):
def __init__(self, **kwargs) -> None:
self.assign = kwargs.get("assign")
self.assign = t.cast(t.Callable, kwargs.get("assign")) # type: ignore[method-assign]
self.gui = t.cast(Gui, kwargs.get("gui"))
def get_gui(self):
return self.gui
def broadcast(self, name: str, value: t.Any):
pass


class TestGuiCoreContext_is_readable:
Expand Down Expand Up @@ -96,7 +102,7 @@ def test_scenario_adapter(self):
def test_cycle_adapter(self):
with patch("taipy.gui_core._context.core_get", side_effect=mock_core_get):
gui_core_context = _GuiCoreContext(Mock())
gui_core_context.scenario_by_cycle = {"a": 1}
gui_core_context.scenario_by_cycle = t.cast(dict, {"a": 1})
outcome = gui_core_context.cycle_adapter(a_cycle)
assert isinstance(outcome, list)
assert outcome[0] == a_cycle.id
Expand All @@ -120,9 +126,9 @@ def test_crud_scenario(self):
gui_core_context = _GuiCoreContext(Mock())
assign = Mock()
gui_core_context.crud_scenario(
MockState(assign=assign),
MockState(assign=assign, gui=gui_core_context.gui),
"",
{
t.cast(dict, {
"args": [
"",
"",
Expand All @@ -132,7 +138,7 @@ def test_crud_scenario(self):
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_not_called()

Expand All @@ -141,7 +147,7 @@ def test_crud_scenario(self):
gui_core_context.crud_scenario(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
"",
"",
Expand All @@ -151,7 +157,7 @@ def test_crud_scenario(self):
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -164,12 +170,12 @@ def test_edit_entity(self):
gui_core_context.edit_entity(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -180,12 +186,12 @@ def test_edit_entity(self):
gui_core_context.edit_entity(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"name": "name", "id": a_scenario.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -198,10 +204,7 @@ def test_submission_status_callback(self):
mockGui._get_authorization = lambda s: contextlib.nullcontext()
gui_core_context = _GuiCoreContext(mockGui)

def sub_cb():
return True

gui_core_context.client_submission[a_submission.id] = SubmissionStatus.UNDEFINED
gui_core_context.client_submission[a_submission.id] = _ClientStatus("client_id", SubmissionStatus.UNDEFINED)
gui_core_context.submission_status_callback(a_submission.id)
mockget.assert_called()
found = False
Expand Down Expand Up @@ -248,12 +251,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "delete"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -263,12 +266,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "cancel"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -279,12 +282,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "delete"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -294,12 +297,12 @@ def test_act_on_jobs(self):
gui_core_context.act_on_jobs(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": [a_job.id], "action": "cancel"},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -312,12 +315,12 @@ def test_edit_data_node(self):
gui_core_context.edit_data_node(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -328,12 +331,12 @@ def test_edit_data_node(self):
gui_core_context.edit_data_node(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -348,12 +351,12 @@ def test_lock_datanode_for_edit(self):
gui_core_context.lock_datanode_for_edit(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -364,12 +367,12 @@ def test_lock_datanode_for_edit(self):
gui_core_context.lock_datanode_for_edit(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand All @@ -395,12 +398,12 @@ def test_update_data(self):
gui_core_context.update_data(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called()
assert assign.call_args_list[0].args[0] == "error_var"
Expand All @@ -411,12 +414,12 @@ def test_update_data(self):
gui_core_context.update_data(
MockState(assign=assign),
"",
{
t.cast(dict, {
"args": [
{"id": a_datanode.id},
],
"error_id": "error_var",
},
}),
)
assign.assert_called_once()
assert assign.call_args.args[0] == "error_var"
Expand Down
Loading