Skip to content

Commit

Permalink
mock sqlclient
Browse files Browse the repository at this point in the history
  • Loading branch information
zqWu committed Mar 15, 2023
1 parent c038f84 commit 859a779
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tests/sqlgateway/mock/gw_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import List

import json


class DummyGwHandle:
session_prefix = "session_"
operation_prefix = "operation_"
session_cnt = 1
operation_cnt = 1
session_handles = set()
operation_handles = {}
statements: List[str] = []

@staticmethod
def next_session_handle():
return f"{DummyGwHandle.session_prefix}{DummyGwHandle.session_cnt}"

@staticmethod
def next_operation_handle():
return f"{DummyGwHandle.operation_prefix}{DummyGwHandle.operation_cnt}"

def __init__(self, config):
pass

@staticmethod
def all_statements() -> List[str]:
return DummyGwHandle.statements

@staticmethod
def clear_statements():
DummyGwHandle.statements = []

@staticmethod
def start():
pass

@staticmethod
def stop():
pass

@staticmethod
def session_create(request, uri, response_headers):
session_handle = DummyGwHandle.next_session_handle()
DummyGwHandle.session_handles.add(session_handle)
body = {"sessionHandle": session_handle}
return [200, response_headers, json.dumps(body)]

@staticmethod
def session_delete(request, uri, response_headers):
raise RuntimeError("not_support")

@staticmethod
def statement_create(request, uri, response_headers):
request_body = json.loads(request.body)
statement = request_body["statement"].strip()
operation_handle = DummyGwHandle.next_operation_handle()
DummyGwHandle.statements.append(statement)
body = {"operationHandle": operation_handle}
return [200, response_headers, json.dumps(body)]

@staticmethod
def operation_status(request, uri, response_headers):
"""always return FINISHED"""
# raise RuntimeError("not_support")
body = {"status": "FINISHED"}
return [200, response_headers, json.dumps(body)]

@staticmethod
def operation_cancel(request, uri, response_headers):
raise RuntimeError("not_support")

@staticmethod
def operation_close(request, uri, response_headers):
raise RuntimeError("not_support")

@staticmethod
def operation_result(request, uri, response_headers):
raise RuntimeError("not_support")
88 changes: 88 additions & 0 deletions tests/sqlgateway/mock/gw_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import re

import httpretty
import json

from tests.sqlgateway.mock.gw_handler import DummyGwHandle

mock_session_info = {
"properties": {
"state.checkpoints.num-retained": "5",
"sql-gateway.worker.threads.max": "10",
"jobmanager.execution.failover-strategy": "region",
"other_infos": "you_can_get /session/:session_handle"
}
}


class GwRouter:
""" simple route use httpretty """
config = None
handle: DummyGwHandle

url_api_version: str
url_info: str
url_op_status: str
url_op_cancel: str
url_op_close: str
url_op_result: str
url_op_create: str
url_s_create: str
url_s_get: str
url_s_delete: str
url_s_hb: str

def __init__(self, config):
GwRouter.config = config
GwRouter.handle = DummyGwHandle(config)
# url
v1_endpoint = f"{config.get('host_port')}/v1"
GwRouter.url_api_version = f"{v1_endpoint}/api_version"
GwRouter.url_info = f"{v1_endpoint}/info"
# operation
GwRouter.url_op_status = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/status")
GwRouter.url_op_cancel = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/cancel")
GwRouter.url_op_close = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/close")
GwRouter.url_op_result = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/result/\\w*")
GwRouter.url_op_create = re.compile(f"{v1_endpoint}/sessions/\\w*/statements")
# session
GwRouter.url_s_create = f"{v1_endpoint}/sessions"
GwRouter.url_s_get = re.compile(f"{v1_endpoint}/sessions/\\w*")
GwRouter.url_s_delete = re.compile(f"{v1_endpoint}/sessions/\\w*")
GwRouter.url_s_hb = re.compile(f"{v1_endpoint}/sessions/\\w*/heartbeat")

@staticmethod
def start():
httpretty.enable()
# DO NOT change register order, httpretty regex match = first match
# DO NOT change register order, httpretty regex match = first match
# DO NOT change register order, httpretty regex match = first match
httpretty.register_uri('GET', GwRouter.url_api_version, body=json.dumps({"versions": ["V1"]}))
httpretty.register_uri('GET', GwRouter.url_info,
body=json.dumps({"productName": "Apache Flink", "version": "1.16.1"}))

# operation
httpretty.register_uri('GET', GwRouter.url_op_status, body=GwRouter.handle.operation_status)
httpretty.register_uri('POST', GwRouter.url_op_cancel, body=GwRouter.handle.operation_cancel)
httpretty.register_uri('DELETE', GwRouter.url_op_close, body=GwRouter.handle.operation_close)
httpretty.register_uri('GET', GwRouter.url_op_result, body=GwRouter.handle.operation_result)
httpretty.register_uri('POST', GwRouter.url_op_create, body=GwRouter.handle.statement_create)

# session
httpretty.register_uri('POST', GwRouter.url_s_create, body=GwRouter.handle.session_create)
httpretty.register_uri('GET', GwRouter.url_s_get, body=json.dumps(mock_session_info))
httpretty.register_uri('DELETE', GwRouter.url_s_delete, body=GwRouter.handle.session_delete)
httpretty.register_uri('POST', GwRouter.url_s_hb, body=json.dumps({}))

@staticmethod
def all_statements():
return GwRouter.handle.all_statements()

@staticmethod
def clear_statements():
return GwRouter.handle.clear_statements()

@staticmethod
def stop():
httpretty.disable()
httpretty.reset()
54 changes: 54 additions & 0 deletions tests/sqlgateway/mock/mock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List

from flink.sqlgateway.client import FlinkSqlGatewayClient
from flink.sqlgateway.operation import SqlGatewayOperation
from flink.sqlgateway.session import SqlGatewaySession
from flink.sqlgateway.config import SqlGatewayConfig
from tests.sqlgateway.mock.gw_router import GwRouter
import requests
import json


class MockFlinkSqlGatewayClient(FlinkSqlGatewayClient):
router: GwRouter

@staticmethod
def create_session(host: str, port: int, session_name: str) -> SqlGatewaySession:
host_port = f"http://{host}:{port}"
test_config = {
"host_port": host_port,
"schemas": [
{"catalog": "default_catalog", "database": "default_database", "tables": [], "views": []}
],
"current_catalog": "default_catalog",
"current_database": "default_database",
}
MockFlinkSqlGatewayClient.router = GwRouter(test_config)
MockFlinkSqlGatewayClient.router.start()
# create session
r = requests.post(f"{host_port}/v1/sessions", '{"sessionName" : "hehe"}')
session_handle = r.json()['sessionHandle']
import threading
print(f"thread_id_AA={threading.get_native_id()}")
return SqlGatewaySession(SqlGatewayConfig(host, port, session_name), session_handle)

@staticmethod
def execute_statement(session: SqlGatewaySession, sql: str) -> SqlGatewayOperation:
if session.session_handle is None:
raise Exception(
f"Session '{session.config.session_name}' is not created. Call create() method first"
)
host_port = f"http://{session.config.host}:{session.config.port}"
session_handle = session.session_handle
data = {"statement": sql}
r = requests.post(f"{host_port}/v1/sessions/{session_handle}/statements", json.dumps(data))
operation_handle = r.json()['operationHandle']
return SqlGatewayOperation(session=session, operation_handle=operation_handle)

@staticmethod
def clear_statements(session: SqlGatewaySession):
return MockFlinkSqlGatewayClient.router.clear_statements()

@staticmethod
def all_statements(session: SqlGatewaySession) -> List[str]:
return MockFlinkSqlGatewayClient.router.all_statements()
65 changes: 65 additions & 0 deletions tests/tmp/test_mock_gw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import unittest

import requests

from tests.sqlgateway.mock.gw_router import GwRouter

test_config = {
"host_port": "http://127.0.0.1:8083",
"schemas": [
{"catalog": "default_catalog", "database": "default_database", "tables": ["a"], "views": []},
{"catalog": "default_catalog", "database": "default_database", "tables": [], "views": ["b"]},
{"catalog": "my_hive", "database": "flink01", "tables": ["t01", "t02", "v01", "v02"], "views": ["v01", "v02"]},
{"catalog": "hive_catalog", "database": "flink04", "tables": ["t11", "t12", ], "views": []},
# {"catalog": "", "database": "", "tables": [], "views": []},
],
"current_catalog": "my_hive",
"current_database": "flink01",
}


class TestTmp(unittest.TestCase):
def test1(self):
router = GwRouter(test_config)
v1_ep = f"{test_config.get('host_port')}/v1"
router.start()

print(f"======= get {v1_ep}/api_version")
r = requests.get(f"{v1_ep}/api_version")
# print(r.status_code)
self.assertEqual(200, r.status_code)
print(r.text)

# session 测试 =========
# session 测试 =========
# session 测试 =========
print(f"======= post {v1_ep}/sessions")
r = requests.post(f"{v1_ep}/sessions", '{"sessionName" : "hehe"}')
# print(r.status_code)
print(r.text)
self.assertEqual(200, r.status_code)
session_handle = r.json()['sessionHandle']

print(f"======= get {v1_ep}/sessions/{session_handle}")
r = requests.get(f"{v1_ep}/sessions/{session_handle}")
# print(r.status_code)
self.assertEqual(200, r.status_code)
print(r.text)

print(f"======= post {v1_ep}/sessions/{session_handle}/statements")
data = {"statement": "select 1 as id", "executionTimeout": 100}
# r = requests.post(url=f"{v1_ep}/sessions/{session_handle}/statements", data=json.dumps(data))
r = requests.post(
url=f"{v1_ep}/sessions/{session_handle}/statements",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
},
)
# print(r.status_code)
print(r.text)
operation_handle = r.json()['operationHandle']
print(router.all_statements())

router.stop()
21 changes: 21 additions & 0 deletions tests/tmp/tmp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

from dbt.adapters.flink.handler import FlinkCursor
from tests.sqlgateway.mock.mock_client import MockFlinkSqlGatewayClient


class TestTmp(unittest.TestCase):

def test_tmp(self):
session = MockFlinkSqlGatewayClient.create_session(
host="127.0.0.1",
port=8083,
session_name="some_session",
)
cursor = FlinkCursor(session)
sql = "select * /** fetch_max(10) fetch_mode('streaming') fetch_timeout_ms(5000) */ from input2"
cursor.execute(sql)
# check sql received
stats = MockFlinkSqlGatewayClient.all_statements(session)
self.assertTrue("SET 'execution.runtime-mode' = 'batch'", stats[0])
self.assertTrue(sql, stats[1])

0 comments on commit 859a779

Please sign in to comment.