diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 062784b52..183f29552 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,10 +10,11 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - v3.2.1(September 26,2023) - - Fixed a bug where url port and path were ignore in private link oscp retry. + - Fixed a bug where url port and path were ignored in private link oscp retry. - Added thread safety in telemetry when instantiating multiple connections concurrently. - Bumped platformdirs dependency from >=2.6.0,<3.9.0 to >=2.6.0,<4.0.0.0 and made necessary changes to allow this. - Removed the deprecation warning from the vendored urllib3 about urllib3.contrib.pyopenssl deprecation. + - Improved robustness in handling authentication response. - v3.2.0(September 06,2023) diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index 700cc30bc..70e78ed1a 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -282,11 +282,11 @@ def authenticate( ) # waiting for MFA authentication - if ret["data"].get("nextAction") in ( + if ret["data"] and ret["data"].get("nextAction") in ( "EXT_AUTHN_DUO_ALL", "EXT_AUTHN_DUO_PUSH_N_PASSCODE", ): - body["inFlightCtx"] = ret["data"]["inFlightCtx"] + body["inFlightCtx"] = ret["data"].get("inFlightCtx") body["data"]["EXT_AUTHN_DUO_METHOD"] = "push" self.ret = {"message": "Timeout", "data": {}} @@ -310,9 +310,13 @@ def post_request_wrapper(self, url, headers, body) -> None: t.join(timeout=timeout) ret = self.ret - if ret and ret["data"].get("nextAction") == "EXT_AUTHN_SUCCESS": + if ( + ret + and ret["data"] + and ret["data"].get("nextAction") == "EXT_AUTHN_SUCCESS" + ): body = copy.deepcopy(body_template) - body["inFlightCtx"] = ret["data"]["inFlightCtx"] + body["inFlightCtx"] = ret["data"].get("inFlightCtx") # final request to get tokens ret = self._rest._post_request( url, @@ -321,7 +325,7 @@ def post_request_wrapper(self, url, headers, body) -> None: timeout=self._rest._connection.login_timeout, socket_timeout=self._rest._connection.login_timeout, ) - elif not ret or not ret["data"].get("token"): + elif not ret or not ret["data"] or not ret["data"].get("token"): # not token is returned. Error.errorhandler_wrapper( self._rest._connection, @@ -343,10 +347,10 @@ def post_request_wrapper(self, url, headers, body) -> None: ) return session_parameters # required for unit test - elif ret["data"].get("nextAction") == "PWD_CHANGE": + elif ret["data"] and ret["data"].get("nextAction") == "PWD_CHANGE": if callable(password_callback): body = copy.deepcopy(body_template) - body["inFlightCtx"] = ret["data"]["inFlightCtx"] + body["inFlightCtx"] = ret["data"].get("inFlightCtx") body["data"]["LOGIN_NAME"] = user body["data"]["PASSWORD"] = ( auth_instance.password @@ -411,23 +415,41 @@ def post_request_wrapper(self, url, headers, body) -> None: ) else: logger.debug( - "token = %s", "******" if ret["data"]["token"] is not None else "NULL" + "token = %s", + "******" + if ret["data"] and ret["data"].get("token") is not None + else "NULL", ) logger.debug( "master_token = %s", - "******" if ret["data"]["masterToken"] is not None else "NULL", + "******" + if ret["data"] and ret["data"].get("masterToken") is not None + else "NULL", ) logger.debug( "id_token = %s", - "******" if ret["data"].get("idToken") is not None else "NULL", + "******" + if ret["data"] and ret["data"].get("idToken") is not None + else "NULL", ) logger.debug( "mfa_token = %s", - "******" if ret["data"].get("mfaToken") is not None else "NULL", + "******" + if ret["data"] and ret["data"].get("mfaToken") is not None + else "NULL", ) + if not ret["data"]: + Error.errorhandler_wrapper( + None, + None, + Error, + { + "msg": "There is no data in the returning response, please retry the operation." + }, + ) self._rest.update_tokens( - ret["data"]["token"], - ret["data"]["masterToken"], + ret["data"].get("token"), + ret["data"].get("masterToken"), master_validity_in_seconds=ret["data"].get("masterValidityInSeconds"), id_token=ret["data"].get("idToken"), mfa_token=ret["data"].get("mfaToken"), @@ -435,17 +457,17 @@ def post_request_wrapper(self, url, headers, body) -> None: self.write_temporary_credentials( self._rest._host, user, session_parameters, ret ) - if "sessionId" in ret["data"]: - self._rest._connection._session_id = ret["data"]["sessionId"] - if "sessionInfo" in ret["data"]: - session_info = ret["data"]["sessionInfo"] + if ret["data"] and "sessionId" in ret["data"]: + self._rest._connection._session_id = ret["data"].get("sessionId") + if ret["data"] and "sessionInfo" in ret["data"]: + session_info = ret["data"].get("sessionInfo") self._rest._connection._database = session_info.get("databaseName") self._rest._connection._schema = session_info.get("schemaName") self._rest._connection._warehouse = session_info.get("warehouseName") self._rest._connection._role = session_info.get("roleName") - if "parameters" in ret["data"]: + if ret["data"] and "parameters" in ret["data"]: session_parameters.update( - {p["name"]: p["value"] for p in ret["data"]["parameters"]} + {p["name"]: p["value"] for p in ret["data"].get("parameters")} ) self._rest._connection._update_parameters(session_parameters) return session_parameters diff --git a/test/unit/test_auth.py b/test/unit/test_auth.py index 49694e768..2d1832d1e 100644 --- a/test/unit/test_auth.py +++ b/test/unit/test_auth.py @@ -11,6 +11,7 @@ import pytest +import snowflake.connector.errors from snowflake.connector.constants import OCSPMode from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION from snowflake.connector.network import SnowflakeRestful @@ -102,7 +103,12 @@ def _mock_auth_mfa_rest_response_failure(url, headers, body, **kwargs): "inFlightCtx": "inFlightCtx", }, } - + elif mock_cnt == 2: + ret = { + "success": True, + "message": None, + "data": None, + } mock_cnt += 1 return ret @@ -126,6 +132,12 @@ def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs): elif mock_cnt == 1: time.sleep(10) # should timeout while here ret = {} + elif mock_cnt == 2: + ret = { + "success": True, + "message": None, + "data": None, + } mock_cnt += 1 return ret @@ -168,6 +180,14 @@ def test_auth_mfa(next_action: str): auth.authenticate(auth_instance, account, user, timeout=1) assert rest._connection.errorhandler.called # error + # ret["data"] is none + with pytest.raises(snowflake.connector.errors.Error): + mock_cnt = 2 + rest = _init_rest(application, _mock_auth_mfa_rest_response_timeout) + auth = Auth(rest) + auth_instance = AuthByDefault(password) + auth.authenticate(auth_instance, account, user) + def _mock_auth_password_change_rest_response(url, headers, body, **kwargs): """Test successful case."""