-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import List | ||
|
||
import json | ||
|
||
|
||
class DummyGwHandle: | ||
session_prefix = "session_" | ||
operation_prefix = "operation_" | ||
session_cnt = 1 | ||
operation_cnt = 1 | ||
session_handles = set() | ||
operation_handles = {} | ||
statements: List[str] = [] | ||
|
||
@staticmethod | ||
def next_session_handle(): | ||
return f"{DummyGwHandle.session_prefix}{DummyGwHandle.session_cnt}" | ||
|
||
@staticmethod | ||
def next_operation_handle(): | ||
return f"{DummyGwHandle.operation_prefix}{DummyGwHandle.operation_cnt}" | ||
|
||
def __init__(self, config): | ||
pass | ||
|
||
@staticmethod | ||
def all_statements() -> List[str]: | ||
return DummyGwHandle.statements | ||
|
||
@staticmethod | ||
def clear_statements(): | ||
DummyGwHandle.statements = [] | ||
|
||
@staticmethod | ||
def start(): | ||
pass | ||
|
||
@staticmethod | ||
def stop(): | ||
pass | ||
|
||
@staticmethod | ||
def session_create(request, uri, response_headers): | ||
session_handle = DummyGwHandle.next_session_handle() | ||
DummyGwHandle.session_handles.add(session_handle) | ||
body = {"sessionHandle": session_handle} | ||
return [200, response_headers, json.dumps(body)] | ||
|
||
@staticmethod | ||
def session_delete(request, uri, response_headers): | ||
raise RuntimeError("not_support") | ||
|
||
@staticmethod | ||
def statement_create(request, uri, response_headers): | ||
request_body = json.loads(request.body) | ||
statement = request_body["statement"].strip() | ||
operation_handle = DummyGwHandle.next_operation_handle() | ||
DummyGwHandle.statements.append(statement) | ||
body = {"operationHandle": operation_handle} | ||
return [200, response_headers, json.dumps(body)] | ||
|
||
@staticmethod | ||
def operation_status(request, uri, response_headers): | ||
"""always return FINISHED""" | ||
# raise RuntimeError("not_support") | ||
body = {"status": "FINISHED"} | ||
return [200, response_headers, json.dumps(body)] | ||
|
||
@staticmethod | ||
def operation_cancel(request, uri, response_headers): | ||
raise RuntimeError("not_support") | ||
|
||
@staticmethod | ||
def operation_close(request, uri, response_headers): | ||
raise RuntimeError("not_support") | ||
|
||
@staticmethod | ||
def operation_result(request, uri, response_headers): | ||
raise RuntimeError("not_support") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import re | ||
|
||
import httpretty | ||
import json | ||
|
||
from tests.sqlgateway.mock.gw_handler import DummyGwHandle | ||
|
||
mock_session_info = { | ||
"properties": { | ||
"state.checkpoints.num-retained": "5", | ||
"sql-gateway.worker.threads.max": "10", | ||
"jobmanager.execution.failover-strategy": "region", | ||
"other_infos": "you_can_get /session/:session_handle" | ||
} | ||
} | ||
|
||
|
||
class GwRouter: | ||
""" simple route use httpretty """ | ||
config = None | ||
handle: DummyGwHandle | ||
|
||
url_api_version: str | ||
url_info: str | ||
url_op_status: str | ||
url_op_cancel: str | ||
url_op_close: str | ||
url_op_result: str | ||
url_op_create: str | ||
url_s_create: str | ||
url_s_get: str | ||
url_s_delete: str | ||
url_s_hb: str | ||
|
||
def __init__(self, config): | ||
GwRouter.config = config | ||
GwRouter.handle = DummyGwHandle(config) | ||
# url | ||
v1_endpoint = f"{config.get('host_port')}/v1" | ||
GwRouter.url_api_version = f"{v1_endpoint}/api_version" | ||
GwRouter.url_info = f"{v1_endpoint}/info" | ||
# operation | ||
GwRouter.url_op_status = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/status") | ||
GwRouter.url_op_cancel = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/cancel") | ||
GwRouter.url_op_close = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/close") | ||
GwRouter.url_op_result = re.compile(f"{v1_endpoint}/sessions/\\w*/operations/\\w*/result/\\w*") | ||
GwRouter.url_op_create = re.compile(f"{v1_endpoint}/sessions/\\w*/statements") | ||
# session | ||
GwRouter.url_s_create = f"{v1_endpoint}/sessions" | ||
GwRouter.url_s_get = re.compile(f"{v1_endpoint}/sessions/\\w*") | ||
GwRouter.url_s_delete = re.compile(f"{v1_endpoint}/sessions/\\w*") | ||
GwRouter.url_s_hb = re.compile(f"{v1_endpoint}/sessions/\\w*/heartbeat") | ||
|
||
@staticmethod | ||
def start(): | ||
httpretty.enable() | ||
# DO NOT change register order, httpretty regex match = first match | ||
# DO NOT change register order, httpretty regex match = first match | ||
# DO NOT change register order, httpretty regex match = first match | ||
httpretty.register_uri('GET', GwRouter.url_api_version, body=json.dumps({"versions": ["V1"]})) | ||
httpretty.register_uri('GET', GwRouter.url_info, | ||
body=json.dumps({"productName": "Apache Flink", "version": "1.16.1"})) | ||
|
||
# operation | ||
httpretty.register_uri('GET', GwRouter.url_op_status, body=GwRouter.handle.operation_status) | ||
httpretty.register_uri('POST', GwRouter.url_op_cancel, body=GwRouter.handle.operation_cancel) | ||
httpretty.register_uri('DELETE', GwRouter.url_op_close, body=GwRouter.handle.operation_close) | ||
httpretty.register_uri('GET', GwRouter.url_op_result, body=GwRouter.handle.operation_result) | ||
httpretty.register_uri('POST', GwRouter.url_op_create, body=GwRouter.handle.statement_create) | ||
|
||
# session | ||
httpretty.register_uri('POST', GwRouter.url_s_create, body=GwRouter.handle.session_create) | ||
httpretty.register_uri('GET', GwRouter.url_s_get, body=json.dumps(mock_session_info)) | ||
httpretty.register_uri('DELETE', GwRouter.url_s_delete, body=GwRouter.handle.session_delete) | ||
httpretty.register_uri('POST', GwRouter.url_s_hb, body=json.dumps({})) | ||
|
||
@staticmethod | ||
def all_statements(): | ||
return GwRouter.handle.all_statements() | ||
|
||
@staticmethod | ||
def clear_statements(): | ||
return GwRouter.handle.clear_statements() | ||
|
||
@staticmethod | ||
def stop(): | ||
httpretty.disable() | ||
httpretty.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import List | ||
|
||
from flink.sqlgateway.client import FlinkSqlGatewayClient | ||
from flink.sqlgateway.operation import SqlGatewayOperation | ||
from flink.sqlgateway.session import SqlGatewaySession | ||
from flink.sqlgateway.config import SqlGatewayConfig | ||
from tests.sqlgateway.mock.gw_router import GwRouter | ||
import requests | ||
import json | ||
|
||
|
||
class MockFlinkSqlGatewayClient(FlinkSqlGatewayClient): | ||
router: GwRouter | ||
|
||
@staticmethod | ||
def create_session(host: str, port: int, session_name: str) -> SqlGatewaySession: | ||
host_port = f"http://{host}:{port}" | ||
test_config = { | ||
"host_port": host_port, | ||
"schemas": [ | ||
{"catalog": "default_catalog", "database": "default_database", "tables": [], "views": []} | ||
], | ||
"current_catalog": "default_catalog", | ||
"current_database": "default_database", | ||
} | ||
MockFlinkSqlGatewayClient.router = GwRouter(test_config) | ||
MockFlinkSqlGatewayClient.router.start() | ||
# create session | ||
r = requests.post(f"{host_port}/v1/sessions", '{"sessionName" : "hehe"}') | ||
session_handle = r.json()['sessionHandle'] | ||
import threading | ||
print(f"thread_id_AA={threading.get_native_id()}") | ||
return SqlGatewaySession(SqlGatewayConfig(host, port, session_name), session_handle) | ||
|
||
@staticmethod | ||
def execute_statement(session: SqlGatewaySession, sql: str) -> SqlGatewayOperation: | ||
if session.session_handle is None: | ||
raise Exception( | ||
f"Session '{session.config.session_name}' is not created. Call create() method first" | ||
) | ||
host_port = f"http://{session.config.host}:{session.config.port}" | ||
session_handle = session.session_handle | ||
data = {"statement": sql} | ||
r = requests.post(f"{host_port}/v1/sessions/{session_handle}/statements", json.dumps(data)) | ||
operation_handle = r.json()['operationHandle'] | ||
return SqlGatewayOperation(session=session, operation_handle=operation_handle) | ||
|
||
@staticmethod | ||
def clear_statements(session: SqlGatewaySession): | ||
return MockFlinkSqlGatewayClient.router.clear_statements() | ||
|
||
@staticmethod | ||
def all_statements(session: SqlGatewaySession) -> List[str]: | ||
return MockFlinkSqlGatewayClient.router.all_statements() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import json | ||
import unittest | ||
|
||
import requests | ||
|
||
from tests.sqlgateway.mock.gw_router import GwRouter | ||
|
||
test_config = { | ||
"host_port": "http://127.0.0.1:8083", | ||
"schemas": [ | ||
{"catalog": "default_catalog", "database": "default_database", "tables": ["a"], "views": []}, | ||
{"catalog": "default_catalog", "database": "default_database", "tables": [], "views": ["b"]}, | ||
{"catalog": "my_hive", "database": "flink01", "tables": ["t01", "t02", "v01", "v02"], "views": ["v01", "v02"]}, | ||
{"catalog": "hive_catalog", "database": "flink04", "tables": ["t11", "t12", ], "views": []}, | ||
# {"catalog": "", "database": "", "tables": [], "views": []}, | ||
], | ||
"current_catalog": "my_hive", | ||
"current_database": "flink01", | ||
} | ||
|
||
|
||
class TestTmp(unittest.TestCase): | ||
def test1(self): | ||
router = GwRouter(test_config) | ||
v1_ep = f"{test_config.get('host_port')}/v1" | ||
router.start() | ||
|
||
print(f"======= get {v1_ep}/api_version") | ||
r = requests.get(f"{v1_ep}/api_version") | ||
# print(r.status_code) | ||
self.assertEqual(200, r.status_code) | ||
print(r.text) | ||
|
||
# session 测试 ========= | ||
# session 测试 ========= | ||
# session 测试 ========= | ||
print(f"======= post {v1_ep}/sessions") | ||
r = requests.post(f"{v1_ep}/sessions", '{"sessionName" : "hehe"}') | ||
# print(r.status_code) | ||
print(r.text) | ||
self.assertEqual(200, r.status_code) | ||
session_handle = r.json()['sessionHandle'] | ||
|
||
print(f"======= get {v1_ep}/sessions/{session_handle}") | ||
r = requests.get(f"{v1_ep}/sessions/{session_handle}") | ||
# print(r.status_code) | ||
self.assertEqual(200, r.status_code) | ||
print(r.text) | ||
|
||
print(f"======= post {v1_ep}/sessions/{session_handle}/statements") | ||
data = {"statement": "select 1 as id", "executionTimeout": 100} | ||
# r = requests.post(url=f"{v1_ep}/sessions/{session_handle}/statements", data=json.dumps(data)) | ||
r = requests.post( | ||
url=f"{v1_ep}/sessions/{session_handle}/statements", | ||
data=json.dumps(data), | ||
headers={ | ||
"Content-Type": "application/json", | ||
}, | ||
) | ||
# print(r.status_code) | ||
print(r.text) | ||
operation_handle = r.json()['operationHandle'] | ||
print(router.all_statements()) | ||
|
||
router.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
|
||
from dbt.adapters.flink.handler import FlinkCursor | ||
from tests.sqlgateway.mock.mock_client import MockFlinkSqlGatewayClient | ||
|
||
|
||
class TestTmp(unittest.TestCase): | ||
|
||
def test_tmp(self): | ||
session = MockFlinkSqlGatewayClient.create_session( | ||
host="127.0.0.1", | ||
port=8083, | ||
session_name="some_session", | ||
) | ||
cursor = FlinkCursor(session) | ||
sql = "select * /** fetch_max(10) fetch_mode('streaming') fetch_timeout_ms(5000) */ from input2" | ||
cursor.execute(sql) | ||
# check sql received | ||
stats = MockFlinkSqlGatewayClient.all_statements(session) | ||
self.assertTrue("SET 'execution.runtime-mode' = 'batch'", stats[0]) | ||
self.assertTrue(sql, stats[1]) |