diff --git a/nebula3/Exception.py b/nebula3/Exception.py index cec770dd..e28dc3d1 100644 --- a/nebula3/Exception.py +++ b/nebula3/Exception.py @@ -71,6 +71,23 @@ def __init__(self, message): Exception.__init__(self, message) self.message = 'Invalid hostname: {}'.format(message) +class SessionException(Exception): + E_SESSION_INVALID = -1002 + E_SESSION_TIMEOUT = -1003 + + def __init__(self, code=E_SESSION_INVALID, message=None): + Exception.__init__(self, message) + self.type = code + self.message = message + +class ExecutionErrorException(Exception): + E_EXECUTION_ERROR = -1005 + + def __init__(self, message=None): + Exception.__init__(self, message) + self.type = self.E_EXECUTION_ERROR + self.message = message + class IOErrorException(Exception): E_UNKNOWN = 0 diff --git a/nebula3/gclient/net/Connection.py b/nebula3/gclient/net/Connection.py index fa68cc4a..b160f7fb 100644 --- a/nebula3/gclient/net/Connection.py +++ b/nebula3/gclient/net/Connection.py @@ -25,6 +25,8 @@ AuthFailedException, IOErrorException, ClientServerIncompatibleException, + SessionException, + ExecutionErrorException, ) from nebula3.gclient.net.AuthResult import AuthResult @@ -146,6 +148,12 @@ def execute_parameter(self, session_id, stmt, params): """ try: resp = self._connection.executeWithParameter(session_id, stmt, params) + if resp.error_code == ErrorCode.E_SESSION_INVALID: + raise SessionException(resp.error_code, resp.error_msg) + if resp.error_code == ErrorCode.E_SESSION_TIMEOUT: + raise SessionException(resp.error_code, resp.error_msg) + if resp.error_code == ErrorCode.E_EXECUTION_ERROR: + raise ExecutionErrorException(resp.error_msg) return resp except Exception as te: if isinstance(te, TTransportException): @@ -179,6 +187,12 @@ def execute_json_with_parameter(self, session_id, stmt, params): """ try: resp = self._connection.executeJsonWithParameter(session_id, stmt, params) + if resp.error_code == ErrorCode.E_SESSION_INVALID: + raise SessionException(resp.error_code, resp.error_msg) + if resp.error_code == ErrorCode.E_SESSION_TIMEOUT: + raise SessionException(resp.error_code, resp.error_msg) + if resp.error_code == ErrorCode.E_EXECUTION_ERROR: + raise ExecutionErrorException(resp.error_msg) return resp except Exception as te: if isinstance(te, TTransportException): diff --git a/nebula3/gclient/net/Session.py b/nebula3/gclient/net/Session.py index 30659b05..15bba9f8 100644 --- a/nebula3/gclient/net/Session.py +++ b/nebula3/gclient/net/Session.py @@ -10,6 +10,7 @@ from nebula3.Exception import ( IOErrorException, NotValidConnectionException, + ExecutionErrorException, ) from nebula3.data.ResultSet import ResultSet @@ -18,7 +19,7 @@ class Session(object): - def __init__(self, connection, auth_result: AuthResult, pool, retry_connect=True): + def __init__(self, connection, auth_result: AuthResult, pool, retry_connect=True, retry_times=3): self._session_id = auth_result.get_session_id() self._timezone_offset = auth_result.get_timezone_offset() self._connection = connection @@ -26,6 +27,7 @@ def __init__(self, connection, auth_result: AuthResult, pool, retry_connect=True # connection the where the session was created, if session pool was used self._pool = pool self._retry_connect = retry_connect + self._retry_times = retry_times # the time stamp when the session was added to the idle list of the session pool self._idle_time_start = 0 @@ -65,6 +67,23 @@ def execute_parameter(self, stmt, params): timezone_offset=self._timezone_offset, ) raise + except ExecutionErrorException as eee: + retry_count = 0 + while retry_count < self._retry_times: + try: + resp = self._connection.execute_parameter(self._session_id, stmt, params) + end_time = time.time() + return ResultSet( + resp, + all_latency=int((end_time - start_time) * 1000000), + timezone_offset=self._timezone_offset, + ) + except ExecutionErrorException: + if retry_count >= self._retry_times - 1: + raise eee + else: + retry_count += 1 + continue except Exception: raise @@ -222,6 +241,18 @@ def execute_json_with_parameter(self, stmt, params): ) return resp_json raise + except ExecutionErrorException as eee: + retry_count = 0 + while retry_count < self._retry_times: + try: + resp = self._connection.execute_json_with_parameter(self._session_id, stmt, params) + return resp + except ExecutionErrorException: + if retry_count >= self._retry_times - 1: + raise eee + else: + retry_count += 1 + continue except Exception: raise diff --git a/nebula3/gclient/net/SessionPool.py b/nebula3/gclient/net/SessionPool.py index 658b2599..0fd6ebd8 100644 --- a/nebula3/gclient/net/SessionPool.py +++ b/nebula3/gclient/net/SessionPool.py @@ -15,6 +15,7 @@ AuthFailedException, NoValidSessionException, InValidHostname, + SessionException, ) from nebula3.gclient.net.Session import Session @@ -170,6 +171,15 @@ def execute_parameter(self, stmt, params): self._return_session(session) return resp + except SessionException as se: + if se.type in [SessionException.E_SESSION_INVALID, SessionException.E_SESSION_TIMEOUT]: + self._active_sessions.remove(session) + session = self._get_idle_session() + if session is None: + raise RuntimeError('Get session failed') + self._add_session_to_active(session) + raise se + except Exception as e: logger.error('Execute failed: {}'.format(e)) # remove the session from the pool if it is invalid @@ -257,6 +267,14 @@ def execute_json_with_parameter(self, stmt, params): self._return_session(session) return resp + except SessionException as se: + if se.type in [SessionException.E_SESSION_INVALID, SessionException.E_SESSION_TIMEOUT]: + self._active_sessions.remove(session) + session = self._get_idle_session() + if session is None: + raise RuntimeError('Get session failed') + self._add_session_to_active(session) + raise se except Exception as e: logger.error('Execute failed: {}'.format(e)) # remove the session from the pool if it is invalid diff --git a/tests/test_session.py b/tests/test_session.py index 4bf0942a..fea3a628 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -12,6 +12,10 @@ from nebula3.Config import Config from nebula3.gclient.net import ConnectionPool +from nebula3.Exception import ( + SessionException, + ExecutionErrorException, +) class TestSession(TestCase): @@ -89,3 +93,39 @@ def test_4_timeout(self): except Exception as ex: assert str(ex).find("timed out") > 0 assert True, ex + + def test_5_session_exception(self): + # test SessionException will be raised when session is invalid + try: + session = self.pool.get_session(self.user_name, self.password) + another_session = self.pool.get_session(self.user_name, self.password) + session_id = session.session_id + another_session.execute(f"KILL SESSION {session_id}") + session.execute("SHOW HOSTS") + except Exception as ex: + assert isinstance(ex, SessionException), "expect to get SessionException" + + def test_6_execute_exception(self): + # test ExecutionErrorException will be raised when execute error + # we need to mock a query's response code to -1005 + import unittest + from unittest.mock import call + + try: + session = self.pool.get_session(self.user_name, self.password) + # Mocking a remote call that will trigger an ExecutionErrorException + with unittest.mock.patch( + "nebula3.gclient.net.Connection._connection.executeWithParameter" + ) as mock_execute: + mock_execute.return_value.error_code = ( + ExecutionErrorException.E_EXECUTION_ERROR + ) + session.execute("SHOW HOSTS") + # Assert that executeWithParameter was called 3 times (retry mechanism) + assert ( + mock_execute.call_count == 3 + ), "executeWithParameter was not retried 3 times" + except ExecutionErrorException as ex: + assert True, "ExecutionErrorException triggered as expected" + except Exception as ex: + assert False, "Unexpected exception: " + str(ex) diff --git a/tests/test_session_pool.py b/tests/test_session_pool.py index f508dc3a..9c32aae4 100644 --- a/tests/test_session_pool.py +++ b/tests/test_session_pool.py @@ -15,6 +15,7 @@ from nebula3.Config import SessionPoolConfig from nebula3.Exception import ( InValidHostname, + SessionException, ) from nebula3.gclient.net import Connection from nebula3.gclient.net.SessionPool import SessionPool @@ -143,6 +144,33 @@ def test_switch_space(self): assert resp.is_succeeded() assert resp.space_name() == "session_pool_test" + def test_session_renew_when_invalid(self): + # This test is used to test if the session will be renewed when the session is invalid. + session_pool = SessionPool( + "root", "nebula", "session_pool_test", self.addresses + ) + configs = SessionPoolConfig() + configs.min_size = 1 + configs.max_size = 1 + assert session_pool.init(configs) + + # kill all sessions of the pool, size 1 here though + for session in session_pool._idle_sessions: + session_id = session.session_id + session.execute(f"KILL SESSION {session_id}") + try: + session_pool.execute("SHOW HOSTS;") + except Exception as ex: + assert isinstance(ex, SessionException), "expect to get SessionException" + # The only session(size=1) should be renewed and usable + # - session_id is not in the pool + # - session_pool is usable + assert ( + session_id not in session_pool._idle_sessions + ), "session should be renewed" + resp = session_pool.execute("SHOW HOSTS;") + assert resp.is_succeeded(), "session_pool should be usable after renewing" + def test_session_pool_multi_thread(): # prepare space