Skip to content

Commit

Permalink
[FSTORE-1382] fix for async concurrency issues for sql client (logica…
Browse files Browse the repository at this point in the history
…lclocks#1343) (logicalclocks#1372)

* init-working

* add lock changes

* move lock to execute method

* update lock

* pass pool as arg

* loop changes and refactoring

* make new connection pool for each call

* make exec_prep async and add lock

* use context manager and refactor init connection

* use create task

* add semaphore

* remove nest_async and use manual loop

* make nest asyncio only if jupyter

* refactoring and move hostname retrieving to init serving

* minor clean

* minor cleanup

* add unit test

* remove return of pool

* revert locust changes and refactoring

* revert locust changes and refactoring

* fix review comments

* fix review comments

* remove loop method argument

---------

Co-authored-by: Victor Jouffrey <37411285+vatj@users.noreply.github.com>
  • Loading branch information
dhananjay-mk and vatj authored Aug 7, 2024
1 parent 41ce3dc commit fa9acf4
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 51 deletions.
1 change: 1 addition & 0 deletions locust_benchmark/common/hopsworks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, environment=None):
port=self.hopsworks_config.get("port", 443),
api_key_file=".api_key",
secrets_store="local",
engine="python",
)
self.fs = self.connection.get_feature_store()

Expand Down
1 change: 1 addition & 0 deletions locust_benchmark/common/stop_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def wrapper(*args, **kwargs):
name=task_name,
response_time=total,
exception=e,
response_length=0,
)
else:
total = int((time.time() - start) * 1000)
Expand Down
5 changes: 3 additions & 2 deletions locust_benchmark/locustfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from locust import HttpUser, User, task, constant, events
from locust.runners import MasterRunner, LocalRunner
from urllib3 import PoolManager
import nest_asyncio


@events.init.add_listener
Expand Down Expand Up @@ -73,10 +74,10 @@ def __init__(self, environment):
def on_start(self):
print("Init user")
self.fv.init_serving(external=self.client.external)
nest_asyncio.apply()

def on_stop(self):
print("Closing user")
self.client.close()

@task
def get_feature_vector(self):
Expand All @@ -101,10 +102,10 @@ def __init__(self, environment):
def on_start(self):
print("Init user")
self.fv.init_serving(external=self.client.external)
nest_asyncio.apply()

def on_stop(self):
print("Closing user")
self.client.close()

@task
def get_feature_vector_batch(self):
Expand Down
2 changes: 1 addition & 1 deletion locust_benchmark/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
markupsafe==2.0.1
locust==2.17.0
git+https://github.com/logicalclocks/feature-store-api@master#egg=hsfs[python]&subdirectory=python
git+https://github.com/logicalclocks/feature-store-api@master#egg=hsfs[python]&subdirectory=python
5 changes: 0 additions & 5 deletions python/hsfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import os
import warnings

import nest_asyncio


# Setting polars skip cpu flag to suppress CPU false positive warning messages printed while importing hsfs
os.environ["POLARS_SKIP_CPU_CHECK"] = "1"
Expand Down Expand Up @@ -59,6 +57,3 @@ def get_sdk_info():


__all__ = ["connection", "disable_usage_logging", "get_sdk_info"]
# running async code in jupyter throws "RuntimeError: This event loop is already running"
# with tornado 6. This fixes the issue without downgrade to tornado==4.5.3
nest_asyncio.apply()
119 changes: 82 additions & 37 deletions python/hsfs/core/online_store_sql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from hsfs import feature_view, training_dataset, util
import aiomysql
import aiomysql.utils
from hsfs import feature_view, storage_connector, training_dataset, util
from hsfs.constructor.serving_prepared_statement import ServingPreparedStatement
from hsfs.core import feature_view_api, storage_connector_api, training_dataset_api
from hsfs.serving_key import ServingKey
Expand All @@ -43,6 +45,7 @@ def __init__(
skip_fg_ids: Optional[Set[int]],
external: bool,
serving_keys: Optional[Set[ServingKey]] = None,
connection_options: Optional[Dict[str, Any]] = None,
):
_logger.debug("Initialising Online Store Sql Client")
self._feature_store_id = feature_store_id
Expand All @@ -52,7 +55,7 @@ def __init__(
self._prefix_by_serving_index = None
self._pkname_by_serving_index = None
self._serving_key_by_serving_index: Dict[str, ServingKey] = {}
self._async_pool = None
self._connection_pool = None
self._serving_keys: Set[ServingKey] = set(serving_keys or [])

self._prepared_statements: Dict[str, List[ServingPreparedStatement]] = {}
Expand All @@ -64,6 +67,9 @@ def __init__(
feature_store_id
)
self._storage_connector_api = storage_connector_api.StorageConnectorApi()
self._online_connector = None
self._hostname = None
self._connection_options = None

def fetch_prepared_statements(
self,
Expand Down Expand Up @@ -208,14 +214,22 @@ def init_async_mysql_connection(self, options=None):
"Prepared statements are not initialized. "
"Please call `init_prepared_statement` method first."
)
_logger.debug("Acquiring or starting event loop for async engine.")
loop = asyncio.get_event_loop()
_logger.debug(f"Setting up aiomysql connection, with options : {options}")
loop.run_until_complete(
self._set_aiomysql_connection(
len(self._prepared_statements[self.SINGLE_VECTOR_KEY]), options=options
)
_logger.debug(
"Fetching storage connector for sql connection to Online Feature Store."
)
self._online_connector = self._storage_connector_api.get_online_connector(
self._feature_store_id
)
self._connection_options = options
self._hostname = util.get_host_name() if self._external else None

if util.is_runtime_notebook():
_logger.debug("Running in Jupyter notebook, applying nest_asyncio")
import nest_asyncio

nest_asyncio.apply()
else:
_logger.debug("Running in python script. Not applying nest_asyncio")

def get_single_feature_vector(self, entry: Dict[str, Any]) -> Dict[str, Any]:
"""Retrieve single vector with parallel queries using aiomysql engine."""
Expand Down Expand Up @@ -289,12 +303,11 @@ def _single_vector_result(
_logger.debug(
f"Executing prepared statements for serving vector with entries: {bind_entries}"
)
loop = asyncio.get_event_loop()
loop = self._get_or_create_event_loop()
results_dict = loop.run_until_complete(
self._execute_prep_statements(prepared_statement_execution, bind_entries)
)
_logger.debug(f"Retrieved feature vectors: {results_dict}")

_logger.debug("Constructing serving vector from results")
for key in results_dict:
for row in results_dict[key]:
Expand Down Expand Up @@ -358,7 +371,7 @@ def _batch_vector_results(
f"Executing prepared statements for batch vector with entries: {entry_values}"
)
# run all the prepared statements in parallel using aiomysql engine
loop = asyncio.get_event_loop()
loop = self._get_or_create_event_loop()
parallel_results = loop.run_until_complete(
self._execute_prep_statements(prepared_stmts_to_execute, entry_values)
)
Expand Down Expand Up @@ -406,6 +419,20 @@ def _batch_vector_results(
)
return batch_results, serving_keys_all_fg

def _get_or_create_event_loop(self):
try:
_logger.debug("Acquiring or starting event loop for async engine.")
loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError as ex:
if "There is no current event loop in thread" in str(ex):
_logger.debug(
"No existing running event loop. Creating new event loop."
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

def refresh_mysql_connection(self):
_logger.debug("Refreshing MySQL connection.")
try:
Expand Down Expand Up @@ -497,26 +524,22 @@ def get_prepared_statement_labels(
OnlineStoreSqlClient.BATCH_VECTOR_KEY,
]

async def _set_aiomysql_connection(
self, default_min_size: int, options: Optional[Dict[str, Any]] = None
) -> None:
_logger.debug(
"Fetching storage connector for sql connection to Online Feature Store."
)
online_connector = self._storage_connector_api.get_online_connector(
self._feature_store_id
)
_logger.debug(
f"Creating async engine with options: {options} and default min size: {default_min_size}"
)
self._async_pool = await util.create_async_engine(
online_connector, self._external, default_min_size, options=options
async def _get_connection_pool(self, default_min_size: int) -> None:
self._connection_pool = await util.create_async_engine(
self._online_connector,
self._external,
default_min_size,
options=self._connection_options,
hostname=self._hostname,
)

async def _query_async_sql(self, stmt, bind_params):
"""Query prepared statement together with bind params using aiomysql connection pool"""
# Get a connection from the pool
async with self._async_pool.acquire() as conn:
if self._connection_pool is None:
await self._get_connection_pool(
len(self._prepared_statements[self.SINGLE_VECTOR_KEY])
)
async with self._connection_pool.acquire() as conn:
# Execute the prepared statement
_logger.debug(
f"Executing prepared statement: {stmt} with bind params: {bind_params}"
Expand Down Expand Up @@ -546,16 +569,22 @@ async def _execute_prep_statements(
if key not in entries:
prepared_statements.pop(key)

tasks = [
asyncio.ensure_future(
self._query_async_sql(prepared_statements[key], entries[key])
)
for key in prepared_statements
]
# Run the queries in parallel using asyncio.gather
results = await asyncio.gather(*tasks)
results_dict = {}
try:
tasks = [
asyncio.create_task(
self._query_async_sql(prepared_statements[key], entries[key]),
name="query_prep_statement_key" + str(key),
)
for key in prepared_statements
]
# Run the queries in parallel using asyncio.gather
results = await asyncio.gather(*tasks)
except asyncio.CancelledError as e:
_logger.error(f"Failed executing prepared statements: {e}")
raise e

# Create a dict of results with the prepared statement index as key
results_dict = {}
for i, key in enumerate(prepared_statements):
results_dict[key] = results[i]

Expand Down Expand Up @@ -677,3 +706,19 @@ def feature_view_api(self) -> feature_view_api.FeatureViewApi:
@property
def storage_connector_api(self) -> storage_connector_api.StorageConnectorApi:
return self._storage_connector_api

@property
def hostname(self) -> str:
return self._hostname

@property
def connection_options(self) -> Dict[str, Any]:
return self._connection_options

@property
def online_connector(self) -> storage_connector.StorageConnector:
return self._online_connector

@property
def connection_pool(self) -> aiomysql.utils._ConnectionContextManager:
return self._connection_pool
29 changes: 23 additions & 6 deletions python/hsfs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itertools
import json
import re
import sys
import threading
import time
from datetime import date, datetime, timezone
Expand Down Expand Up @@ -189,23 +190,32 @@ async def create_async_engine(
external: bool,
default_min_size: int,
options: Optional[Dict[str, Any]] = None,
hostname: Optional[str] = None,
) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError as er:
raise RuntimeError(
"Event loop is not running. Please invoke this co-routine from a running loop or provide an event loop."
) from er

online_options = online_conn.spark_options()
# create a aiomysql connection pool
# read the keys user, password from online_conn as use them while creating the connection pool
url = make_url(online_options["url"].replace("jdbc:", ""))
if external:
hostname = get_host_name()
else:
hostname = url.host
if hostname is None:
if external:
hostname = get_host_name()
else:
hostname = url.host

# create a aiomysql connection pool
pool = await async_create_engine(
host=hostname,
port=3306,
user=online_options["user"],
password=online_options["password"],
db=url.database,
loop=asyncio.get_running_loop(),
loop=loop,
minsize=(
options.get("minsize", default_min_size) if options else default_min_size
),
Expand Down Expand Up @@ -530,6 +540,13 @@ def build_serving_keys_from_prepared_statements(
return serving_keys


def is_runtime_notebook():
if "ipykernel" in sys.modules:
return True
else:
return False


class NpDatetimeEncoder(json.JSONEncoder):
def default(self, obj):
dtypes = (np.datetime64, np.complexfloating)
Expand Down
13 changes: 13 additions & 0 deletions python/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#

import asyncio
from datetime import date, datetime

import pytest
Expand All @@ -22,6 +23,7 @@
from hsfs.client.exceptions import FeatureStoreException
from hsfs.embedding import EmbeddingFeature, EmbeddingIndex
from hsfs.feature import Feature
from mock import patch


class TestUtil:
Expand Down Expand Up @@ -206,3 +208,14 @@ def test_empty_schema(self):
# Call the method with an empty schema
util.validate_embedding_feature_type(embedding_index, schema)
# No exception should be raised

def test_create_async_engine(self, mocker):
# Test when get_running_loop() raises a RuntimeError
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
# mock storage connector
online_connector = patch.object(util, "get_online_connector")
with pytest.raises(
RuntimeError,
match="Event loop is not running. Please invoke this co-routine from a running loop or provide an event loop.",
):
asyncio.run(util.create_async_engine(online_connector, True, 1))

0 comments on commit fa9acf4

Please sign in to comment.