Skip to content

Commit

Permalink
feat: renew session for expiration, re-execute when error
Browse files Browse the repository at this point in the history
  • Loading branch information
wey-gu committed Mar 4, 2024
1 parent 251c253 commit ba27e4c
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 1 deletion.
17 changes: 17 additions & 0 deletions nebula3/Exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def __init__(self, message):
Exception.__init__(self, message)
self.message = 'Invalid hostname: {}'.format(message)

class SessionException(Exception):
E_SESSION_INVALID = -1002
E_SESSION_TIMEOUT = -1003

def __init__(self, code=E_SESSION_INVALID, message=None):
Exception.__init__(self, message)
self.type = code
self.message = message

class ExecutionErrorException(Exception):
E_EXECUTION_ERROR = -1005

def __init__(self, message=None):
Exception.__init__(self, message)
self.type = self.E_EXECUTION_ERROR
self.message = message


class IOErrorException(Exception):
E_UNKNOWN = 0
Expand Down
14 changes: 14 additions & 0 deletions nebula3/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
AuthFailedException,
IOErrorException,
ClientServerIncompatibleException,
SessionException,
ExecutionErrorException,
)

from nebula3.gclient.net.AuthResult import AuthResult
Expand Down Expand Up @@ -146,6 +148,12 @@ def execute_parameter(self, session_id, stmt, params):
"""
try:
resp = self._connection.executeWithParameter(session_id, stmt, params)
if resp.error_code == ErrorCode.E_SESSION_INVALID:
raise SessionException(resp.error_code, resp.error_msg)
if resp.error_code == ErrorCode.E_SESSION_TIMEOUT:
raise SessionException(resp.error_code, resp.error_msg)
if resp.error_code == ErrorCode.E_EXECUTION_ERROR:
raise ExecutionErrorException(resp.error_msg)
return resp
except Exception as te:
if isinstance(te, TTransportException):
Expand Down Expand Up @@ -179,6 +187,12 @@ def execute_json_with_parameter(self, session_id, stmt, params):
"""
try:
resp = self._connection.executeJsonWithParameter(session_id, stmt, params)
if resp.error_code == ErrorCode.E_SESSION_INVALID:
raise SessionException(resp.error_code, resp.error_msg)
if resp.error_code == ErrorCode.E_SESSION_TIMEOUT:
raise SessionException(resp.error_code, resp.error_msg)
if resp.error_code == ErrorCode.E_EXECUTION_ERROR:
raise ExecutionErrorException(resp.error_msg)
return resp
except Exception as te:
if isinstance(te, TTransportException):
Expand Down
33 changes: 32 additions & 1 deletion nebula3/gclient/net/Session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from nebula3.Exception import (
IOErrorException,
NotValidConnectionException,
ExecutionErrorException,
)

from nebula3.data.ResultSet import ResultSet
Expand All @@ -18,14 +19,15 @@


class Session(object):
def __init__(self, connection, auth_result: AuthResult, pool, retry_connect=True):
def __init__(self, connection, auth_result: AuthResult, pool, retry_connect=True, retry_times=3):
self._session_id = auth_result.get_session_id()
self._timezone_offset = auth_result.get_timezone_offset()
self._connection = connection
self._timezone = 0
# connection the where the session was created, if session pool was used
self._pool = pool
self._retry_connect = retry_connect
self._retry_times = retry_times
# the time stamp when the session was added to the idle list of the session pool
self._idle_time_start = 0

Expand Down Expand Up @@ -65,6 +67,23 @@ def execute_parameter(self, stmt, params):
timezone_offset=self._timezone_offset,
)
raise
except ExecutionErrorException as eee:
retry_count = 0
while retry_count < self._retry_times:
try:
resp = self._connection.execute_parameter(self._session_id, stmt, params)
end_time = time.time()
return ResultSet(
resp,
all_latency=int((end_time - start_time) * 1000000),
timezone_offset=self._timezone_offset,
)
except ExecutionErrorException:
if retry_count >= self._retry_times - 1:
raise eee
else:
retry_count += 1
continue
except Exception:
raise

Expand Down Expand Up @@ -222,6 +241,18 @@ def execute_json_with_parameter(self, stmt, params):
)
return resp_json
raise
except ExecutionErrorException as eee:
retry_count = 0
while retry_count < self._retry_times:
try:
resp = self._connection.execute_json_with_parameter(self._session_id, stmt, params)
return resp
except ExecutionErrorException:
if retry_count >= self._retry_times - 1:
raise eee
else:
retry_count += 1
continue
except Exception:
raise

Expand Down
18 changes: 18 additions & 0 deletions nebula3/gclient/net/SessionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AuthFailedException,
NoValidSessionException,
InValidHostname,
SessionException,
)

from nebula3.gclient.net.Session import Session
Expand Down Expand Up @@ -170,6 +171,15 @@ def execute_parameter(self, stmt, params):
self._return_session(session)

return resp
except SessionException as se:
if se.type in [SessionException.E_SESSION_INVALID, SessionException.E_SESSION_TIMEOUT]:
self._active_sessions.remove(session)
session = self._get_idle_session()
if session is None:
raise RuntimeError('Get session failed')
self._add_session_to_active(session)
raise se

except Exception as e:
logger.error('Execute failed: {}'.format(e))
# remove the session from the pool if it is invalid
Expand Down Expand Up @@ -257,6 +267,14 @@ def execute_json_with_parameter(self, stmt, params):
self._return_session(session)

return resp
except SessionException as se:
if se.type in [SessionException.E_SESSION_INVALID, SessionException.E_SESSION_TIMEOUT]:
self._active_sessions.remove(session)
session = self._get_idle_session()
if session is None:
raise RuntimeError('Get session failed')
self._add_session_to_active(session)
raise se
except Exception as e:
logger.error('Execute failed: {}'.format(e))
# remove the session from the pool if it is invalid
Expand Down
40 changes: 40 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

from nebula3.Config import Config
from nebula3.gclient.net import ConnectionPool
from nebula3.Exception import (
SessionException,
ExecutionErrorException,
)


class TestSession(TestCase):
Expand Down Expand Up @@ -89,3 +93,39 @@ def test_4_timeout(self):
except Exception as ex:
assert str(ex).find("timed out") > 0
assert True, ex

def test_5_session_exception(self):
# test SessionException will be raised when session is invalid
try:
session = self.pool.get_session(self.user_name, self.password)
another_session = self.pool.get_session(self.user_name, self.password)
session_id = session.session_id
another_session.execute(f"KILL SESSION {session_id}")
session.execute("SHOW HOSTS")
except Exception as ex:
assert isinstance(ex, SessionException), "expect to get SessionException"

def test_6_execute_exception(self):
# test ExecutionErrorException will be raised when execute error
# we need to mock a query's response code to -1005
import unittest
from unittest.mock import call

try:
session = self.pool.get_session(self.user_name, self.password)
# Mocking a remote call that will trigger an ExecutionErrorException
with unittest.mock.patch(
"nebula3.gclient.net.Connection._connection.executeWithParameter"
) as mock_execute:
mock_execute.return_value.error_code = (
ExecutionErrorException.E_EXECUTION_ERROR
)
session.execute("SHOW HOSTS")
# Assert that executeWithParameter was called 3 times (retry mechanism)
assert (
mock_execute.call_count == 3
), "executeWithParameter was not retried 3 times"
except ExecutionErrorException as ex:
assert True, "ExecutionErrorException triggered as expected"
except Exception as ex:
assert False, "Unexpected exception: " + str(ex)
28 changes: 28 additions & 0 deletions tests/test_session_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nebula3.Config import SessionPoolConfig
from nebula3.Exception import (
InValidHostname,
SessionException,
)
from nebula3.gclient.net import Connection
from nebula3.gclient.net.SessionPool import SessionPool
Expand Down Expand Up @@ -143,6 +144,33 @@ def test_switch_space(self):
assert resp.is_succeeded()
assert resp.space_name() == "session_pool_test"

def test_session_renew_when_invalid(self):
# This test is used to test if the session will be renewed when the session is invalid.
session_pool = SessionPool(
"root", "nebula", "session_pool_test", self.addresses
)
configs = SessionPoolConfig()
configs.min_size = 1
configs.max_size = 1
assert session_pool.init(configs)

# kill all sessions of the pool, size 1 here though
for session in session_pool._idle_sessions:
session_id = session.session_id
session.execute(f"KILL SESSION {session_id}")
try:
session_pool.execute("SHOW HOSTS;")
except Exception as ex:
assert isinstance(ex, SessionException), "expect to get SessionException"
# The only session(size=1) should be renewed and usable
# - session_id is not in the pool
# - session_pool is usable
assert (
session_id not in session_pool._idle_sessions
), "session should be renewed"
resp = session_pool.execute("SHOW HOSTS;")
assert resp.is_succeeded(), "session_pool should be usable after renewing"


def test_session_pool_multi_thread():
# prepare space
Expand Down

0 comments on commit ba27e4c

Please sign in to comment.