From b32a3d3aa11bfedb001285fa308ad5d7d0bf04fc Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Mon, 11 Sep 2023 20:44:23 +0300 Subject: [PATCH] Expire flyte_idt cookie at logout (#610) * Expire flyte_idt cookie at logout --- auth/cookie_manager.go | 20 ++--- auth/cookie_manager_test.go | 59 ++++++++++---- boilerplate/flyte/end2end/Makefile | 2 + boilerplate/flyte/end2end/end2end.sh | 9 +-- .../flyte/end2end/functional-test-config.yaml | 2 +- boilerplate/flyte/end2end/run-tests.py | 78 +++++++++---------- 6 files changed, 89 insertions(+), 81 deletions(-) diff --git a/auth/cookie_manager.go b/auth/cookie_manager.go index 8b6eca36b..77914195e 100644 --- a/auth/cookie_manager.go +++ b/auth/cookie_manager.go @@ -175,20 +175,9 @@ func (c CookieManager) SetTokenCookies(ctx context.Context, writer http.Response return nil } -func (c *CookieManager) getLogoutAccessCookie() *http.Cookie { +func (c *CookieManager) getLogoutCookie(name string) *http.Cookie { return &http.Cookie{ - Name: accessTokenCookieName, - Value: "", - Domain: c.domain, - MaxAge: 0, - HttpOnly: true, - Expires: time.Now().Add(-1 * time.Hour), - } -} - -func (c *CookieManager) getLogoutRefreshCookie() *http.Cookie { - return &http.Cookie{ - Name: refreshTokenCookieName, + Name: name, Value: "", Domain: c.domain, MaxAge: 0, @@ -198,8 +187,9 @@ func (c *CookieManager) getLogoutRefreshCookie() *http.Cookie { } func (c CookieManager) DeleteCookies(_ context.Context, writer http.ResponseWriter) { - http.SetCookie(writer, c.getLogoutAccessCookie()) - http.SetCookie(writer, c.getLogoutRefreshCookie()) + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieName)) + http.SetCookie(writer, c.getLogoutCookie(refreshTokenCookieName)) + http.SetCookie(writer, c.getLogoutCookie(idTokenCookieName)) } func (c CookieManager) getHTTPSameSitePolicy() http.SameSite { diff --git a/auth/cookie_manager_test.go b/auth/cookie_manager_test.go index 5bb11f5c7..a6f8c5631 100644 --- a/auth/cookie_manager_test.go +++ b/auth/cookie_manager_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -17,7 +19,7 @@ import ( "github.com/flyteorg/flyteadmin/auth/config" ) -func TestCookieManager_SetTokenCookies(t *testing.T) { +func TestCookieManager(t *testing.T) { ctx := context.Background() // These were generated for unit testing only. hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst @@ -61,6 +63,14 @@ func TestCookieManager_SetTokenCookies(t *testing.T) { assert.Equal(t, "flyte_rt", c[2].Name) }) + t.Run("set_token_nil", func(t *testing.T) { + w := httptest.NewRecorder() + + err = manager.SetTokenCookies(ctx, w, nil) + + assert.EqualError(t, err, "[EMPTY_OAUTH_TOKEN] Attempting to set cookies with nil token") + }) + t.Run("set_token_cookies_wrong_key", func(t *testing.T) { wrongKey := base64.RawStdEncoding.EncodeToString(bytes.Repeat([]byte("X"), 75)) wrongManager, err := NewCookieManager(ctx, wrongKey, wrongKey, cookieSetting) @@ -115,31 +125,22 @@ func TestCookieManager_SetTokenCookies(t *testing.T) { assert.EqualError(t, err, "[EMPTY_OAUTH_TOKEN] Error reading existing secure cookie [flyte_idt]. Error: [SECURE_COOKIE_ERROR] Error reading secure cookie flyte_idt, caused by: securecookie: error - caused by: crypto/aes: invalid key size 75") }) - t.Run("logout_access_cookie", func(t *testing.T) { - cookie := manager.getLogoutAccessCookie() - - assert.True(t, time.Now().After(cookie.Expires)) - assert.Equal(t, cookieSetting.Domain, cookie.Domain) - }) - - t.Run("logout_refresh_cookie", func(t *testing.T) { - cookie := manager.getLogoutRefreshCookie() - - assert.True(t, time.Now().After(cookie.Expires)) - assert.Equal(t, cookieSetting.Domain, cookie.Domain) - }) - t.Run("delete_cookies", func(t *testing.T) { w := httptest.NewRecorder() manager.DeleteCookies(ctx, w) cookies := w.Result().Cookies() - require.Equal(t, 2, len(cookies)) + require.Equal(t, 3, len(cookies)) assert.True(t, time.Now().After(cookies[0].Expires)) assert.Equal(t, cookieSetting.Domain, cookies[0].Domain) + assert.Equal(t, accessTokenCookieName, cookies[0].Name) assert.True(t, time.Now().After(cookies[1].Expires)) assert.Equal(t, cookieSetting.Domain, cookies[1].Domain) + assert.Equal(t, refreshTokenCookieName, cookies[1].Name) + assert.True(t, time.Now().After(cookies[1].Expires)) + assert.Equal(t, cookieSetting.Domain, cookies[1].Domain) + assert.Equal(t, idTokenCookieName, cookies[2].Name) }) t.Run("get_http_same_site_policy", func(t *testing.T) { @@ -152,4 +153,30 @@ func TestCookieManager_SetTokenCookies(t *testing.T) { manager.sameSitePolicy = config.SameSiteNoneMode assert.Equal(t, http.SameSiteNoneMode, manager.getHTTPSameSitePolicy()) }) + + t.Run("set_user_info", func(t *testing.T) { + w := httptest.NewRecorder() + info := &service.UserInfoResponse{ + Subject: "sub", + Name: "foo", + } + + err := manager.SetUserInfoCookie(ctx, w, info) + + assert.NoError(t, err) + cookies := w.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "flyte_user_info", cookies[0].Name) + }) + + t.Run("set_auth_code", func(t *testing.T) { + w := httptest.NewRecorder() + + err := manager.SetAuthCodeCookie(ctx, w, "foo.com") + + assert.NoError(t, err) + cookies := w.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "flyte_auth_code", cookies[0].Name) + }) } diff --git a/boilerplate/flyte/end2end/Makefile b/boilerplate/flyte/end2end/Makefile index 61ee99bc7..98ee63ae7 100644 --- a/boilerplate/flyte/end2end/Makefile +++ b/boilerplate/flyte/end2end/Makefile @@ -4,6 +4,8 @@ # TO OPT OUT OF UPDATES, SEE https://github.com/flyteorg/boilerplate/blob/master/Readme.rst .PHONY: end2end_execute +end2end_execute: export FLYTESNACKS_PRIORITIES ?= P0 +end2end_execute: export FLYTESNACKS_VERSION ?= $(shell curl --silent "https://api.github.com/repos/flyteorg/flytesnacks/releases/latest" | jq -r .tag_name) end2end_execute: ./boilerplate/flyte/end2end/end2end.sh ./boilerplate/flyte/end2end/functional-test-config.yaml --return_non_zero_on_failure diff --git a/boilerplate/flyte/end2end/end2end.sh b/boilerplate/flyte/end2end/end2end.sh index acc9d012e..5dd825c1a 100755 --- a/boilerplate/flyte/end2end/end2end.sh +++ b/boilerplate/flyte/end2end/end2end.sh @@ -4,14 +4,9 @@ # ONLY EDIT THIS FILE FROM WITHIN THE 'FLYTEORG/BOILERPLATE' REPOSITORY: # # TO OPT OUT OF UPDATES, SEE https://github.com/flyteorg/boilerplate/blob/master/Readme.rst -set -e +set -eu CONFIG_FILE=$1; shift EXTRA_FLAGS=( "$@" ) -# By default only execute `core` tests -PRIORITIES="${PRIORITIES:-P0}" - -LATEST_VERSION=$(curl --silent "https://api.github.com/repos/flyteorg/flytesnacks/releases/latest" | jq -r .tag_name) - -python ./boilerplate/flyte/end2end/run-tests.py $LATEST_VERSION $PRIORITIES $CONFIG_FILE ${EXTRA_FLAGS[@]} +python ./boilerplate/flyte/end2end/run-tests.py $FLYTESNACKS_VERSION $FLYTESNACKS_PRIORITIES $CONFIG_FILE ${EXTRA_FLAGS[@]} diff --git a/boilerplate/flyte/end2end/functional-test-config.yaml b/boilerplate/flyte/end2end/functional-test-config.yaml index 6d06b7075..13fc44567 100644 --- a/boilerplate/flyte/end2end/functional-test-config.yaml +++ b/boilerplate/flyte/end2end/functional-test-config.yaml @@ -1,5 +1,5 @@ admin: # For GRPC endpoints you might want to use dns:///flyte.myexample.com - endpoint: localhost:30081 + endpoint: dns:///localhost:30080 authType: Pkce insecure: true diff --git a/boilerplate/flyte/end2end/run-tests.py b/boilerplate/flyte/end2end/run-tests.py index 66c678fd4..eb2b28d8d 100644 --- a/boilerplate/flyte/end2end/run-tests.py +++ b/boilerplate/flyte/end2end/run-tests.py @@ -1,19 +1,19 @@ #!/usr/bin/env python3 -import click import datetime import json import sys import time import traceback +from typing import Dict, List, Mapping, Tuple + +import click import requests -from typing import List, Mapping, Tuple, Dict -from flytekit.remote import FlyteRemote +from flytekit.configuration import Config from flytekit.models.core.execution import WorkflowExecutionPhase -from flytekit.configuration import Config, ImageConfig, SerializationSettings +from flytekit.remote import FlyteRemote from flytekit.remote.executions import FlyteWorkflowExecution - WAIT_TIME = 10 MAX_ATTEMPTS = 200 @@ -22,15 +22,14 @@ # starting with "core". FLYTESNACKS_WORKFLOW_GROUPS: Mapping[str, List[Tuple[str, dict]]] = { "lite": [ - ("basics.hello_world.my_wf", {}), - ("basics.lp.go_greet", {"day_of_week": "5", "number": 3, "am": True}), + ("basics.hello_world.hello_world_wf", {}), ], "core": [ - ("basics.deck.wf", {}), + # ("development_lifecycle.decks.image_renderer_wf", {}), # The chain_workflows example in flytesnacks expects to be running in a sandbox. - # ("control_flow.chain_entities.chain_workflows_wf", {}), - ("control_flow.dynamics.wf", {"s1": "Pear", "s2": "Earth"}), - ("control_flow.map_task.my_map_workflow", {"a": [1, 2, 3, 4, 5]}), + ("advanced_composition.chain_entities.chain_workflows_wf", {}), + ("advanced_composition.dynamics.wf", {"s1": "Pear", "s2": "Earth"}), + ("advanced_composition.map_task.my_map_workflow", {"a": [1, 2, 3, 4, 5]}), # Workflows that use nested executions cannot be launched via flyteremote. # This issue is being tracked in https://github.com/flyteorg/flyte/issues/1482. # ("control_flow.run_conditions.multiplier", {"my_input": 0.5}), @@ -41,24 +40,22 @@ # ("control_flow.run_conditions.nested_conditions", {"my_input": 0.4}), # ("control_flow.run_conditions.consume_outputs", {"my_input": 0.4, "seed": 7}), # ("control_flow.run_merge_sort.merge_sort", {"numbers": [5, 4, 3, 2, 1], "count": 5}), - ("control_flow.subworkflows.parent_wf", {"a": 3}), - ("control_flow.subworkflows.nested_parent_wf", {"a": 3}), - ("basics.basic_workflow.my_wf", {"a": 50, "b": "hello"}), + ("advanced_composition.subworkflows.parent_workflow", {"my_input1": "hello"}), + ("advanced_composition.subworkflows.nested_parent_wf", {"a": 3}), + ("basics.workflow.simple_wf", {"x": [1, 2, 3], "y": [1, 2, 3]}), # TODO: enable new files and folders workflows # ("basics.files.rotate_one_workflow", {"in_image": "https://upload.wikimedia.org/wikipedia/commons/d/d2/Julia_set_%28C_%3D_0.285%2C_0.01%29.jpg"}), # ("basics.folders.download_and_rotate", {}), - ("basics.hello_world.my_wf", {}), - ("basics.lp.my_wf", {"val": 4}), - ("basics.lp.go_greet", {"day_of_week": "5", "number": 3, "am": True}), - ("basics.named_outputs.my_wf", {}), + ("basics.hello_world.hello_world_wf", {}), + ("basics.named_outputs.simple_wf_with_named_outputs", {}), # # Getting a 403 for the wikipedia image # # ("basics.reference_task.wf", {}), - ("type_system.custom_objects.wf", {"x": 10, "y": 20}), + ("data_types_and_io.custom_objects.wf", {"x": 10, "y": 20}), # Enums are not supported in flyteremote # ("type_system.enums.enum_wf", {"c": "red"}), - ("type_system.schema.df_wf", {"a": 42}), - ("type_system.typed_schema.wf", {}), - #("my.imperative.workflow.example", {"in1": "hello", "in2": "foo"}), + ("data_types_and_io.schema.df_wf", {"a": 42}), + ("data_types_and_io.typed_schema.wf", {}), + # ("my.imperative.workflow.example", {"in1": "hello", "in2": "foo"}), ], "integrations-k8s-spark": [ ("k8s_spark_plugin.pyspark_pi.my_spark", {"triggered_date": datetime.datetime.now()}), @@ -97,19 +94,22 @@ def execute_workflow(remote, version, workflow_name, inputs): wf = remote.fetch_workflow(name=workflow_name, version=version) return remote.execute(wf, inputs=inputs, wait=False) + def executions_finished(executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]) -> bool: for executions in executions_by_wfgroup.values(): if not all([execution.is_done for execution in executions]): return False return True + def sync_executions(remote: FlyteRemote, executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]): try: for executions in executions_by_wfgroup.values(): for execution in executions: print(f"About to sync execution_id={execution.id.name}") remote.sync(execution) - except: + except Exception: + print(traceback.format_exc()) print("GOT TO THE EXCEPT") print("COUNT THIS!") @@ -119,6 +119,7 @@ def report_executions(executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecuti for execution in executions: print(execution) + def schedule_workflow_groups( tag: str, workflow_groups: List[str], @@ -139,17 +140,12 @@ def schedule_workflow_groups( # Wait for all executions to finish attempt = 0 - while attempt == 0 or ( - not executions_finished(executions_by_wfgroup) and attempt < MAX_ATTEMPTS - ): + while attempt == 0 or (not executions_finished(executions_by_wfgroup) and attempt < MAX_ATTEMPTS): attempt += 1 - print( - f"Not all executions finished yet. Sleeping for some time, will check again in {WAIT_TIME}s" - ) + print(f"Not all executions finished yet. Sleeping for some time, will check again in {WAIT_TIME}s") time.sleep(WAIT_TIME) sync_executions(remote, executions_by_wfgroup) - report_executions(executions_by_wfgroup) results = {} @@ -192,14 +188,17 @@ def run( # For a given release tag and priority, this function filters the workflow groups from the flytesnacks # manifest file. For example, for the release tag "v0.2.224" and the priority "P0" it returns [ "core" ]. - manifest_url = "https://raw.githubusercontent.com/flyteorg/flytesnacks/" \ - f"{flytesnacks_release_tag}/flyte_tests_manifest.json" + manifest_url = ( + "https://raw.githubusercontent.com/flyteorg/flytesnacks/" f"{flytesnacks_release_tag}/flyte_tests_manifest.json" + ) r = requests.get(manifest_url) parsed_manifest = r.json() workflow_groups = [] - workflow_groups = ["lite"] if "lite" in priorities else [ - group["name"] for group in parsed_manifest if group["priority"] in priorities - ] + workflow_groups = ( + ["lite"] + if "lite" in priorities + else [group["name"] for group in parsed_manifest if group["priority"] in priorities] + ) results = [] valid_workgroups = [] @@ -216,10 +215,7 @@ def run( valid_workgroups.append(workflow_group) results_by_wfgroup = schedule_workflow_groups( - flytesnacks_release_tag, - valid_workgroups, - remote, - terminate_workflow_on_failure + flytesnacks_release_tag, valid_workgroups, remote, terminate_workflow_on_failure ) for workflow_group, succeeded in results_by_wfgroup.items(): @@ -273,9 +269,7 @@ def cli( terminate_workflow_on_failure, ): print(f"return_non_zero_on_failure={return_non_zero_on_failure}") - results = run( - flytesnacks_release_tag, priorities, config_file, terminate_workflow_on_failure - ) + results = run(flytesnacks_release_tag, priorities, config_file, terminate_workflow_on_failure) # Write a json object in its own line describing the result of this run to stdout print(f"Result of run:\n{json.dumps(results)}")