From b498572d254eeaad062f8317f6ac1b95a22609da Mon Sep 17 00:00:00 2001 From: qizhicheng Date: Tue, 8 Aug 2023 19:27:46 +0800 Subject: [PATCH 1/5] add cassandra draft MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add cassandra integration test 重构 engine 的代码, 改为动态加载 --- .github/workflows/django.yml | 4 +- archery/settings.py | 33 ++- .../test_instance_test_archery.html | 8 +- requirements.txt | 5 +- sql/data_dictionary.py | 24 +- sql/engines/__init__.py | 80 +++--- sql/engines/cassandra.py | 236 ++++++++++++++++++ sql/engines/clickhouse.py | 9 +- sql/engines/goinception.py | 8 +- sql/engines/mongo.py | 31 ++- sql/engines/mssql.py | 8 +- sql/engines/mysql.py | 8 +- sql/engines/odps.py | 8 +- sql/engines/oracle.py | 13 +- sql/engines/pgsql.py | 8 +- sql/engines/phoenix.py | 3 + sql/engines/redis.py | 8 +- sql/engines/test_cassandra.py | 144 +++++++++++ sql/instance.py | 2 +- sql/models.py | 1 + sql/templates/instance.html | 12 +- sql/templates/queryapplylist.html | 14 +- sql/templates/sqlquery.html | 114 +++------ sql/templates/sqlsubmit.html | 13 +- sql/tests.py | 8 +- sql/views.py | 11 +- src/docker/Dockerfile | 2 +- 27 files changed, 562 insertions(+), 253 deletions(-) create mode 100644 sql/engines/cassandra.py create mode 100644 sql/engines/test_cassandra.py diff --git a/.github/workflows/django.yml b/.github/workflows/django.yml index b6bbd4d8c4..0722acc2d8 100644 --- a/.github/workflows/django.yml +++ b/.github/workflows/django.yml @@ -17,7 +17,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.8, 3.9, "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] # https://github.com/actions/example-services/tree/master/.github/workflows services: @@ -70,7 +70,7 @@ jobs: - name: Install Dependencies run: | - sudo apt-get update && sudo apt-get install libsasl2-dev libldap2-dev libssl-dev unixodbc unixodbc-dev + sudo apt-get update && sudo apt-get install libsasl2-dev libkrb5-dev libldap2-dev libssl-dev unixodbc unixodbc-dev python -m pip install --upgrade pip pip install codecov coverage flake8 -r requirements.txt diff --git a/archery/settings.py b/archery/settings.py index 0a01726696..7e69957027 100644 --- a/archery/settings.py +++ b/archery/settings.py @@ -19,7 +19,7 @@ environ.Env.read_env(os.path.join(BASE_DIR, ".env")) env = environ.Env( DEBUG=(bool, False), - ALLOWED_HOSTS=(List[str], ["*"]), + ALLOWED_HOSTS=(list, ["*"]), SECRET_KEY=(str, "hfusaf2m4ot#7)fkw#di2bu6(cv0@opwmafx5n#6=3d%x^hpl6"), DATABASE_URL=(str, "mysql://root:@127.0.0.1:3306/archery"), CACHE_URL=(str, "redis://127.0.0.1:6379/0"), @@ -38,6 +38,21 @@ Q_CLUISTER_SYNC=(bool, False), # qcluster 同步模式, debug 时可以调整为 True # CSRF_TRUSTED_ORIGINS=subdomain.example.com,subdomain.example2.com subdomain.example.com CSRF_TRUSTED_ORIGINS=(list, []), + ENABLED_ENGINES=( + list, + [ + "mysql", + "clickhouse", + "goinception", + "mssql", + "redis", + "pqsql", + "oracle", + "mongo", + "phoenix", + "odps", + ], + ), ) # SECURITY WARNING: keep the secret key used in production secret! @@ -57,6 +72,21 @@ # 请求限制 DATA_UPLOAD_MAX_MEMORY_SIZE = 15728640 +AVAILABLE_ENGINES = { + "mysql": {"path": "sql.engines.mysql:MysqlEngine"}, + "cassandra": {"path": "sql.engines.cassandra:CassandraEngine"}, + "clickhouse": {"path": "sql.engines.clickhouse:ClickHouseEngine"}, + "goinception": {"path": "sql.engines.goinception:GoInceptionEngine"}, + "mssql": {"path": "sql.engines.mssql:MssqlEngine"}, + "redis": {"path": "sql.engines.redis:RedisEngine"}, + "pqsql": {"path": "sql.engines.pgsql:PgSQLEngine"}, + "oracle": {"path": "sql.engines.oracle:OracleEngine"}, + "mongo": {"path": "sql.engines.mongo:MongoEngine"}, + "phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"}, + "odps": {"path": "sql.engines.odps:ODPSEngine"}, +} +ENABLED_ENGINES = env("ENABLED_ENGINES") + # Application definition INSTALLED_APPS = ( "django.contrib.admin", @@ -245,7 +275,6 @@ ENABLE_OIDC = env("ENABLE_OIDC", False) if ENABLE_OIDC: INSTALLED_APPS += ("mozilla_django_oidc",) - MIDDLEWARE += ("mozilla_django_oidc.middleware.SessionRefresh",) AUTHENTICATION_BACKENDS = ( "common.authenticate.oidc_auth.OIDCAuthenticationBackend", "django.contrib.auth.backends.ModelBackend", diff --git a/downloads/dictionary/test_instance_test_archery.html b/downloads/dictionary/test_instance_test_archery.html index 900ff1f413..30a5966fb5 100644 --- a/downloads/dictionary/test_instance_test_archery.html +++ b/downloads/dictionary/test_instance_test_archery.html @@ -1,8 +1,8 @@ - 数据库表结构说明文档 + ݿṹ˵ĵ -

test_archery 数据字典 (共 0 个表)

-

生成时间:2023-01-31 14:41:33

+

test_archery ֵ ( 0 )

+

ʱ䣺2023-08-14 18:44:43

diff --git a/requirements.txt b/requirements.txt index cd54d81312..0857a3d03b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ pyecharts==1.9.1 aliyun-python-sdk-rds==2.1.1 cx-Oracle==7.3.0 supervisor==4.1.0 -phoenixdb==0.7 +phoenixdb==1.2.1 django-mirage-field==1.4.0 schema-sync==0.9.7 parsedatetime==2.4 @@ -38,4 +38,5 @@ django-environ==0.8.1 alibabacloud_dysmsapi20170525==2.0.9 tencentcloud-sdk-python==3.0.656 mozilla-django-oidc==3.0.0 -django-auth-dingding==0.0.2 \ No newline at end of file +django-auth-dingding==0.0.2 +cassandra-driver diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 4a702d8075..942023ac94 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -90,6 +90,16 @@ def table_info(request): ) +def get_export_full_path(base_dir: str, instance_name: str, db_name: str) -> str: + """validate if the instance_name and db_name provided is secure""" + fullpath = os.path.normpath( + os.path.join(base_dir, f"{instance_name}_{db_name}.html") + ) + if not fullpath.startswith(base_dir): + return "" + return fullpath + + @permission_required("sql.data_dictionary_export", raise_exception=True) def export(request): """导出数据字典""" @@ -111,7 +121,7 @@ def export(request): elif request.user.is_superuser: dbs = query_engine.get_all_databases().rows else: - return JsonResponse({"status": 1, "msg": f"仅管理员可以导出整个实例的字典信息!", "data": []}) + return JsonResponse({"status": 1, "msg": "仅管理员可以导出整个实例的字典信息!", "data": []}) # 获取数据,存入目录 path = os.path.join(settings.BASE_DIR, "downloads/dictionary") @@ -126,12 +136,18 @@ def export(request): data = loader.render_to_string( template_name="dictionaryexport.html", context=context, request=request ) - with open(f"{path}/{instance_name}_{db}.html", "w") as f: - f.write(data) + fullpath = get_export_full_path(path, instance_name, db) + if not fullpath: + return JsonResponse({"status": 1, "msg": "实例名或db名不合法", "data": []}) + with open(fullpath, "w", encoding="utf-8") as fp: + fp.write(data) # 关闭连接 query_engine.close() if db_name: - response = FileResponse(open(f"{path}/{instance_name}_{db_name}.html", "rb")) + fullpath = get_export_full_path(path, instance_name, db) + if not fullpath: + return JsonResponse({"status": 1, "msg": "实例名或db名不合法", "data": []}) + response = FileResponse(open(fullpath, "rb")) response["Content-Type"] = "application/octet-stream" response[ "Content-Disposition" diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index 508a8dd9b1..c3c6f7ddd9 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -1,6 +1,8 @@ """engine base库, 包含一个``EngineBase`` class和一个get_engine函数""" +import importlib from sql.engines.models import ResultSet, ReviewSet from sql.utils.ssh_tunnel import SSHConnection +from django.conf import settings class EngineBase: @@ -8,6 +10,9 @@ class EngineBase: test_query = None + name = "Base" + info = "base engine" + def __init__(self, instance=None): self.conn = None self.thread_id = None @@ -77,16 +82,6 @@ def test_connection(self): """测试实例链接是否正常""" return self.query(sql=self.test_query) - @property - def name(self): - """返回engine名称""" - return "base" - - @property - def info(self): - """返回引擎简介""" - return "Base engine" - def escape_string(self, value: str) -> str: """参数转义""" return value @@ -179,7 +174,7 @@ def query( limit_num=0, close_conn=True, parameters=None, - **kwargs + **kwargs, ): """实际查询 返回一个ResultSet""" return ResultSet() @@ -213,6 +208,22 @@ def set_variable(self, variable_name, variable_value): return ResultSet() +def get_engine_map(): + available_engines = settings.AVAILABLE_ENGINES + enabled_engines = {} + for e in settings.ENABLED_ENGINES: + config = available_engines.get(e) + if not config: + raise ValueError(f"invalid engine {e}, not found in engine map") + module, o = config["path"].split(":") + engine = getattr(importlib.import_module(module), o) + enabled_engines[e] = engine + return enabled_engines + + +engine_map = get_engine_map() + + def get_engine(instance=None): # pragma: no cover """获取数据库操作engine""" if instance.db_type == "mysql": @@ -222,44 +233,9 @@ def get_engine(instance=None): # pragma: no cover from .cloud.aliyun_rds import AliyunRDS return AliyunRDS(instance=instance) - from .mysql import MysqlEngine - - return MysqlEngine(instance=instance) - elif instance.db_type == "mssql": - from .mssql import MssqlEngine - - return MssqlEngine(instance=instance) - elif instance.db_type == "redis": - from .redis import RedisEngine - - return RedisEngine(instance=instance) - elif instance.db_type == "pgsql": - from .pgsql import PgSQLEngine - - return PgSQLEngine(instance=instance) - elif instance.db_type == "oracle": - from .oracle import OracleEngine - - return OracleEngine(instance=instance) - elif instance.db_type == "mongo": - from .mongo import MongoEngine - - return MongoEngine(instance=instance) - elif instance.db_type == "goinception": - from .goinception import GoInceptionEngine - - return GoInceptionEngine(instance=instance) - elif instance.db_type == "phoenix": - from .phoenix import PhoenixEngine - - return PhoenixEngine(instance=instance) - - elif instance.db_type == "odps": - from .odps import ODPSEngine - - return ODPSEngine(instance=instance) - - elif instance.db_type == "clickhouse": - from .clickhouse import ClickHouseEngine - - return ClickHouseEngine(instance=instance) + engine = engine_map.get(instance.db_type) + if not engine: + raise ValueError( + f"engine {instance.db_type} not enabled or not supported, please contact admin" + ) + return engine(instance=instance) diff --git a/sql/engines/cassandra.py b/sql/engines/cassandra.py new file mode 100644 index 0000000000..03b913bab9 --- /dev/null +++ b/sql/engines/cassandra.py @@ -0,0 +1,236 @@ +import re + +import logging +import traceback + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider +from cassandra.query import tuple_factory + +import sqlparse + +from . import EngineBase +from .models import ResultSet, ReviewSet, ReviewResult + +from sql.models import SqlWorkflow + +logger = logging.getLogger("default") + + +def row_to_set(row) -> set: + pass + + +def split_sql(db_name=None, sql=""): + # 切分语句,追加到检测结果中,默认全部检测通过 + sql = sql.split("\n") + sql = filter(None, sql) + sql_result = [] + if db_name: + sql_result += [f"""USE {db_name}"""] + sql_result += sql + return sql_result + + +class CassandraEngine(EngineBase): + name = "Cassandra" + info = "Cassandra engine" + + def get_connection(self, db_name=None): + db_name = db_name or self.db_name + if self.conn: + if db_name: + self.conn.execute(f"use {db_name}") + return self.conn + auth_provider = PlainTextAuthProvider( + username=self.user, password=self.password + ) + cluster = Cluster([self.host], port=self.port, auth_provider=auth_provider) + self.conn = cluster.connect(keyspace=db_name) + self.conn.row_factory = tuple_factory + return self.conn + + def close(self): + if self.conn: + self.conn.shutdown() + self.conn = None + + def test_connection(self): + result = self.get_all_databases() + self.close() + return result + + def escape_string(self, value: str) -> str: + return re.sub(r"[; ]", "", value) + + def get_all_databases(self, **kwargs): + """ + 获取所有的 keyspace/database + :return: + """ + result = self.query(sql="SELECT keyspace_name FROM system_schema.keyspaces;") + result.rows = [x[0] for x in result.rows] + return result + + def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): + """获取所有列, 返回一个ResultSet""" + sql = "select column_name, type from columns where keyspace_name=%s and table_name=%s" + result = self.query( + db_name="system_schema", sql=sql, parameters=(db_name, tb_name) + ) + return result + + def describe_table(self, db_name, tb_name, **kwargs): + sql = f"describe table {tb_name}" + result = self.query(db_name=db_name, sql=sql) + result.column_list = ["table", "create table"] + filtered_rows = [] + for r in result.rows: + filtered_rows.append((r[2], r[3])) + result.rows = filtered_rows + return result + + def query_check(self, db_name=None, sql="", limit_num=0): + """提交查询前的检查""" + # 查询语句的检查、注释去除、切分 + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + # 删除注释语句,进行语法判断,执行第一条有效sql + try: + sql = sqlparse.format(sql, strip_comments=True) + sql = sqlparse.split(sql)[0] + result["filtered_sql"] = sql.strip() + except IndexError: + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" + if re.match(r"^select|^describe", sql, re.I) is None: + result["bad_query"] = True + result["msg"] = "不支持的查询语法类型!" + if "*" in sql: + result["has_star"] = True + result["msg"] = "SQL语句中含有 * " + return result + + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): + """返回 ResultSet""" + result_set = ResultSet(full_sql=sql) + try: + conn = self.get_connection(db_name=db_name) + rows = conn.execute(sql, parameters=parameters) + result_set.column_list = rows.column_names + result_set.rows = rows.all() + result_set.affected_rows = len(result_set.rows) + if limit_num > 0: + result_set.rows = result_set.rows[0:limit_num] + result_set.affected_rows = min(limit_num, result_set.affected_rows) + except Exception as e: + logger.warning( + f"{self.name} query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}" + ) + result_set.error = str(e) + return result_set + + def get_all_tables(self, db_name, **kwargs): + sql = "SELECT table_name FROM system_schema.tables WHERE keyspace_name = %s;" + parameters = [db_name] + result = self.query(db_name=db_name, sql=sql, parameters=parameters) + tb_list = [row[0] for row in result.rows] + result.rows = tb_list + return result + + def filter_sql(self, sql="", limit_num=0): + return sql.strip() + + def query_masking(self, db_name=None, sql="", resultset=None): + """不做脱敏""" + return resultset + + def execute_check(self, db_name=None, sql=""): + """上线单执行前的检查, 返回Review set""" + check_result = ReviewSet(full_sql=sql) + # 切分语句,追加到检测结果中,默认全部检测通过 + sql_result = split_sql(db_name, sql) + rowid = 1 + for statement in sql_result: + check_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + rowid += 1 + return check_result + + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): + """执行sql语句 返回 Review set""" + execute_result = ReviewSet(full_sql=sql) + conn = self.get_connection(db_name=db_name) + sql_result = split_sql(db_name, sql) + rowid = 1 + for statement in sql_result: + try: + conn.execute(statement) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + except Exception as e: + logger.warning( + f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}" + ) + execute_result.error = str(e) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + break + rowid += 1 + if execute_result.error: + for statement in split_sql[rowid:]: + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + rowid += 1 + if close_conn: + self.close() + return execute_result + + def execute_workflow(self, workflow: SqlWorkflow): + """执行上线单,返回Review set""" + return self.execute( + db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content + ) diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 33dedfdd31..69dfd78772 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -42,13 +42,8 @@ def get_connection(self, db_name=None): ) return self.conn - @property - def name(self): - return "ClickHouse" - - @property - def info(self): - return "ClickHouse engine" + name = "ClickHouse" + info = "ClickHouse engine" def escape_string(self, value: str) -> str: """字符串参数转义""" diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py index d291cccf28..551c0eea47 100644 --- a/sql/engines/goinception.py +++ b/sql/engines/goinception.py @@ -18,13 +18,9 @@ class GoInceptionEngine(EngineBase): test_query = "INCEPTION GET VARIABLES" - @property - def name(self): - return "GoInception" + name = "GoInception" - @property - def info(self): - return "GoInception engine" + info = "GoInception engine" def get_connection(self, db_name=None): if self.conn: diff --git a/sql/engines/mongo.py b/sql/engines/mongo.py index 7a7c2aa844..764d084a5a 100644 --- a/sql/engines/mongo.py +++ b/sql/engines/mongo.py @@ -3,7 +3,6 @@ import pymongo import logging import traceback -import json import subprocess import simplejson as json import datetime @@ -800,13 +799,9 @@ def close(self): self.conn.close() self.conn = None - @property - def name(self): # pragma: no cover - return "Mongo" + name = "Mongo" - @property - def info(self): # pragma: no cover - return "Mongo engine" + info = "Mongo engine" def get_roles(self): sql_get_roles = "db.system.roles.find({},{_id:1})" @@ -1266,18 +1261,20 @@ def kill_op(self, opids): result = ResultSet() try: conn = self.get_connection() - db = conn.admin - for opid in opids: - conn.admin.command({"killOp": 1, "op": opid}) except Exception as e: - try: - sql = {"killOp": 1, "op": _opid} - except: - sql = {"killOp": 1, "op": ""} - logger.warning( - f"mongodb语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}" - ) + logger.error(f"{self.name} 连接失败, error: {str(e)}") result.error = str(e) + return result + for opid in opids: + try: + conn.admin.command({"killOp": 1, "op": opid}) + except Exception as e: + sql = {"killOp": 1, "op": opid} + logger.warning( + f"{self.name}语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}" + ) + result.error = str(e) + return result def get_all_databases_summary(self): """实例数据库管理功能,获取实例所有的数据库描述信息""" diff --git a/sql/engines/mssql.py b/sql/engines/mssql.py index 9a56a0dadf..716a600917 100644 --- a/sql/engines/mssql.py +++ b/sql/engines/mssql.py @@ -31,13 +31,9 @@ def get_connection(self, db_name=None): self.conn = pyodbc.connect(connstr) return self.conn - @property - def name(self): - return "MsSQL" + name = "MsSQL" - @property - def info(self): - return "MsSQL engine" + info = "MsSQL engine" def get_all_databases(self): """获取数据库列表, 返回一个ResultSet""" diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index 05252cc81a..7ab5cfe18c 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -91,13 +91,9 @@ def get_connection(self, db_name=None): self.thread_id = self.conn.thread_id() return self.conn - @property - def name(self): - return "MySQL" + name = "MySQL" - @property - def info(self): - return "MySQL engine" + info = "MySQL engine" def escape_string(self, value: str) -> str: """字符串参数转义""" diff --git a/sql/engines/odps.py b/sql/engines/odps.py index b4e986b5e1..e6ef6a4cc8 100644 --- a/sql/engines/odps.py +++ b/sql/engines/odps.py @@ -29,13 +29,9 @@ def get_connection(self, db_name=None): return self.conn - @property - def name(self): - return "ODPS" + name = "ODPS" - @property - def info(self): - return "ODPS engine" + info = "ODPS engine" def get_all_databases(self): """获取数据库列表, 返回一个ResultSet diff --git a/sql/engines/oracle.py b/sql/engines/oracle.py index 31cabbbbd3..6f25bbe71b 100644 --- a/sql/engines/oracle.py +++ b/sql/engines/oracle.py @@ -28,8 +28,9 @@ class OracleEngine(EngineBase): def __init__(self, instance=None): super(OracleEngine, self).__init__(instance=instance) - self.service_name = instance.service_name - self.sid = instance.sid + if instance: + self.service_name = instance.service_name + self.sid = instance.sid def get_connection(self, db_name=None): if self.conn: @@ -50,13 +51,9 @@ def get_connection(self, db_name=None): raise ValueError("sid 和 dsn 均未填写, 请联系管理页补充该实例配置.") return self.conn - @property - def name(self): - return "Oracle" + name = "Oracle" - @property - def info(self): - return "Oracle engine" + info = "Oracle engine" @property def auto_backup(self): diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py index 2c303a1644..7dcb1a3ae2 100644 --- a/sql/engines/pgsql.py +++ b/sql/engines/pgsql.py @@ -40,13 +40,9 @@ def get_connection(self, db_name=None): ) return self.conn - @property - def name(self): - return "PgSQL" + name = "PgSQL" - @property - def info(self): - return "PgSQL engine" + info = "PgSQL engine" def get_all_databases(self): """ diff --git a/sql/engines/phoenix.py b/sql/engines/phoenix.py index 6383c2537e..fc7ac592d1 100644 --- a/sql/engines/phoenix.py +++ b/sql/engines/phoenix.py @@ -14,6 +14,9 @@ class PhoenixEngine(EngineBase): test_query = "SELECT 1" + name = "phoenix" + info = "phoenix engine" + def get_connection(self, db_name=None): if self.conn: return self.conn diff --git a/sql/engines/redis.py b/sql/engines/redis.py index e8b8d6410e..168e982b56 100644 --- a/sql/engines/redis.py +++ b/sql/engines/redis.py @@ -47,13 +47,9 @@ def get_connection(self, db_name=None): ssl=self.is_ssl, ) - @property - def name(self): - return "Redis" + name = "Redis" - @property - def info(self): - return "Redis engine" + info = "Redis engine" def test_connection(self): return self.get_all_databases() diff --git a/sql/engines/test_cassandra.py b/sql/engines/test_cassandra.py new file mode 100644 index 0000000000..0c04ba6b03 --- /dev/null +++ b/sql/engines/test_cassandra.py @@ -0,0 +1,144 @@ +import unittest +from unittest.mock import patch, Mock + +from django.test import TestCase +from sql.models import Instance +from sql.engines.cassandra import CassandraEngine +from sql.engines.models import ResultSet + +# 启用后, 会运行全部测试, 包括一些集成测试 +integration_test_enabled = False +integration_test_host = "localhost" + + +class CassandraEngineTest(TestCase): + def setUp(self) -> None: + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="cassandra", + host="localhost", + port=9200, + user="cassandra", + password="cassandra", + db_name="some_db", + ) + self.engine = CassandraEngine(instance=self.ins) + + def tearDown(self) -> None: + self.ins.delete() + if integration_test_enabled: + self.engine.execute(sql="drop keyspace test;") + + @patch("sql.engines.cassandra.Cluster.connect") + def test_get_connection(self, mock_connect): + _ = self.engine.get_connection() + mock_connect.assert_called_once() + + @patch("sql.engines.cassandra.CassandraEngine.get_connection") + def test_query(self, mock_get_connection): + test_sql = """select 123""" + self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet) + + def test_query_check(self): + test_sql = """select 123; -- this is comment + select 456;""" + + result_sql = "select 123;" + + check_result = self.engine.query_check(sql=test_sql) + + self.assertIsInstance(check_result, dict) + self.assertEqual(False, check_result.get("bad_query")) + self.assertEqual(result_sql, check_result.get("filtered_sql")) + + def test_query_check_error(self): + test_sql = """drop table table_a""" + + check_result = self.engine.query_check(sql=test_sql) + + self.assertIsInstance(check_result, dict) + self.assertEqual(True, check_result.get("bad_query")) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_get_all_databases(self, mock_query): + mock_query.return_value = ResultSet(rows=[("some_db",)]) + + result = self.engine.get_all_databases() + + self.assertIsInstance(result, ResultSet) + self.assertEqual(result.rows, ["some_db"]) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_get_all_tables(self, mock_query): + # 下面是查表示例返回结果 + mock_query.return_value = ResultSet(rows=[("u",), ("v",), ("w",)]) + + table_list = self.engine.get_all_tables("some_db") + + self.assertEqual(table_list.rows, ["u", "v", "w"]) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_describe_table(self, mock_query): + mock_query.return_value = ResultSet() + self.engine.describe_table("some_db", "some_table") + mock_query.assert_called_once_with( + db_name="some_db", sql="describe table some_table" + ) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_get_all_columns_by_tb(self, mock_query): + mock_query.return_value = ResultSet( + rows=[("name", "text")], column_list=["column_name", "type"] + ) + + result = self.engine.get_all_columns_by_tb("some_db", "some_table") + self.assertEqual(result.rows, [("name", "text")]) + self.assertEqual(result.column_list, ["column_name", "type"]) + + +@unittest.skipIf( + not integration_test_enabled, "cassandra integration test is not enabled" +) +class CassandraIntegrationTest(TestCase): + def setUp(self): + self.instance = Instance.objects.create( + instance_name="int_ins", + type="slave", + db_type="cassandra", + host=integration_test_host, + port=9042, + user="cassandra", + password="cassandra", + db_name="", + ) + self.engine = CassandraEngine(instance=self.instance) + + self.keyspace = "test" + self.table = "test_table" + # 新建 keyspace + self.engine.execute( + sql=f"create keyspace {self.keyspace} with replication = " + "{'class': 'org.apache.cassandra.locator.SimpleStrategy', " + "'replication_factor': '1'};" + ) + # 建表 + self.engine.execute( + db_name=self.keyspace, + sql=f"""create table if not exists {self.table}( name text primary key );""", + ) + + def tearDown(self): + self.engine.execute(sql="drop keyspace test;") + + def test_integrate_query(self): + self.engine.execute( + db_name=self.keyspace, + sql=f"insert into {self.table} (name) values ('test')", + ) + + result = self.engine.query( + db_name=self.keyspace, sql=f"select * from {self.table}" + ) + + self.assertEqual(result.rows[0][0], "test") diff --git a/sql/instance.py b/sql/instance.py index e7a3105f9a..3d2fa5785f 100644 --- a/sql/instance.py +++ b/sql/instance.py @@ -377,7 +377,7 @@ def describe(request): except Exception as msg: result["status"] = 1 result["msg"] = str(msg) - if result["data"]["error"]: + if result["data"].get("error"): result["status"] = 1 result["msg"] = result["data"]["error"] return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/sql/models.py b/sql/models.py index 98f32af8a2..0547d0cb64 100755 --- a/sql/models.py +++ b/sql/models.py @@ -127,6 +127,7 @@ class Meta: ("odps", "ODPS"), ("clickhouse", "ClickHouse"), ("goinception", "goInception"), + ("cassandra", "Cassandra"), ) diff --git a/sql/templates/instance.html b/sql/templates/instance.html index 1a1940b6f0..c61f56d888 100644 --- a/sql/templates/instance.html +++ b/sql/templates/instance.html @@ -13,15 +13,9 @@
diff --git a/sql/templates/queryapplylist.html b/sql/templates/queryapplylist.html index 6c8ac6e69e..979df8b22f 100644 --- a/sql/templates/queryapplylist.html +++ b/sql/templates/queryapplylist.html @@ -47,15 +47,9 @@ title="请选择实例:" data-live-search="true" required> // TODO 使用models中的choices 渲染 - - - - - - - - - + {% for name, engine in engines.items %} + + {% endfor %}
@@ -159,7 +153,7 @@ if (data.status === 0) { $("optgroup[id^='optgroup']").empty(); let result = data['data'] - const supportDb = ['mysql', 'mssql', 'redis', 'pgsql', 'oracle', 'mongo', 'phoenix','odps' ,'clickhouse'] + const supportDb = [ {% for name in engines.keys %}"{{ name }}", {% endfor %}] for (let i of result) { let instance = ""; if (supportDb.indexOf(i.db_type) !== -1) { diff --git a/sql/templates/sqlquery.html b/sql/templates/sqlquery.html index 8fa8a4370c..6fd519a930 100644 --- a/sql/templates/sqlquery.html +++ b/sql/templates/sqlquery.html @@ -71,15 +71,9 @@ data-live-search="true" data-live-search-placeholder="搜索您所在组的实例" title="请选择实例:" data-placeholder="请选择实例:" required> - - - - - - - - - + {% for name,engine in engines.items %} + + {% endfor %}
@@ -707,15 +701,19 @@ var query_str = target+'|'+result['full_sql']; //获取当前的标签页,如果当前不在执行结果页,则默认新增一个页面 var active_li_id = sessionStorage.getItem('active_li_id'); - var active_li_title = sessionStorage.getItem('active_li_title'); - + let active_li_title = sessionStorage.getItem('active_li_title'); + let tb_name if (data.status === 0) { // 查看表结构默认新增tab,相同表结构获取不新增 - if (data.is_describe || result['full_sql'].match(/^show\s+create\s+table\s+(.*)/)) { - if (data.is_describe) { - var tb_name = $("#table_name").val(); - } else { - var tb_name = result['full_sql'].match(/^show\s+create\s+table\s+(.*);/)[1]; + if (result['full_sql'].match(/^show\s+create\s+table\s+(.*)/)) { + if (data.is_describe===undefined || !data.is_describe) { + data.is_describe = true + tb_name = result['full_sql'].match(/^show\s+create\s+table\s+(.*);/)[1]; + } + } + if (data.is_describe) { + if (tb_name === undefined) { + tb_name = $("#table_name").val(); } if (tb_name !== active_li_title) { tab_add(tb_name); @@ -724,7 +722,7 @@ } // 执行结果页默认不新增 else if (active_li_title.match(/^执行结果\d$/)) { - var n = active_li_id.split("execute_result_tab")[1]; + let n = active_li_id.split("execute_result_tab")[1]; } else { tab_add(); n = sessionStorage.getItem('tab_num'); @@ -738,7 +736,7 @@ //显示查询结果 if (result['column_list']) { //异步获取要动态生成的列 - var columns = []; + let columns = []; $.each(result['column_list'], function (i, column) { var iswholeCol = true; if (column == "mongodballdata") { @@ -770,7 +768,7 @@ } }); }); - if (result['full_sql'].match(/^show\s+create\s+table/)) { + if (data.is_describe) { //初始化表结构显示 $("#" + ("query_result" + n)).bootstrapTable('destroy').bootstrapTable({ escape: false, @@ -1018,70 +1016,16 @@ //控制按钮和选择器显示 function optgroup_control(change) { var optgroup = $('#instance_name :selected').parent().attr('label'); - if (optgroup === "MySQL") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").hide(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', false); - $("#btn-explain").attr('disabled', false); - } else if (optgroup === "MsSQL") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").hide(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', false); - $("#btn-explain").attr('disabled', true); - } else if (optgroup === "Redis") { - if (change) { - $("#div-table_name").hide(); - $("#div-schema_name").hide(); - redis_help_tab_add(); - } - $("#btn-format").attr('disabled', true); - $("#btn-explain").attr('disabled', true); - } else if (optgroup === "PgSQL") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").show(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', false); - $("#btn-explain").attr('disabled', true); - } else if (optgroup === "Oracle") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").hide(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', false); - $("#btn-explain").attr('disabled', false); - } else if (optgroup === "Mongo") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").hide(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', true); - $("#btn-explain").attr('disabled', false); - } else if (optgroup === "Phoenix") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").hide(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', false); - $("#btn-explain").attr('disabled', true); - } else if (optgroup === "ClickHouse") { - if (change) { - $("#div-table_name").show(); - $("#div-schema_name").hide(); - redis_help_tab_remove(); - } - $("#btn-format").attr('disabled', false); - $("#btn-explain").attr('disabled', false); + if (change) { + $("#div-table_name").show(); + $("#div-schema_name").hide(); + } + $("#btn-format").attr('disabled', false); + $("#btn-explain").attr('disabled', false); + if (optgroup === "Redis") { + redis_help_tab_add(); + } else { + redis_help_tab_remove(); } } @@ -1334,10 +1278,12 @@ if (data.status === 0) { $("optgroup[id^='optgroup']").empty(); let result = data['data'] - const supportDb = ['mysql', 'mssql', 'redis', 'pgsql', 'oracle', 'mongo', 'phoenix', 'odps', 'clickhouse'] + const supportDb = [ {% for name in engines.keys %}"{{ name }}", {% endfor %}] for (let i of result) { let instance = ""; if (supportDb.indexOf(i.db_type) !== -1) { + console.log("get supported db") + console.log(i) $("#optgroup-" + i.db_type).append(instance); } } diff --git a/sql/templates/sqlsubmit.html b/sql/templates/sqlsubmit.html index 0462a7665e..71384044ad 100644 --- a/sql/templates/sqlsubmit.html +++ b/sql/templates/sqlsubmit.html @@ -55,14 +55,9 @@ data-name="实例" data-placeholder="请选择实例!" title="请选择实例" data-live-search="true" data-live-search-placeholder="搜索您所在组的实例" required> - - - - - - - - + {% for name, engine in engines.items %} + + {% endfor %}
@@ -634,7 +629,7 @@ if (data.status === 0) { $("optgroup[id^='optgroup']").empty(); let result = data['data'] - const supportDb = ['mysql', 'mssql', 'redis', 'pgsql', 'oracle', 'mongo', 'phoenix', 'clickhouse'] + const supportDb = [ {% for name in engines.keys %}"{{ name }}", {% endfor %}] for (let i of result) { let instance = ""; if (supportDb.indexOf(i.db_type) !== -1) { diff --git a/sql/tests.py b/sql/tests.py index 10b15839fc..dab9d41e3b 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -1,5 +1,5 @@ import json -import re +import unittest from datetime import timedelta, datetime, date from unittest.mock import MagicMock, patch, ANY from django.conf import settings @@ -31,6 +31,7 @@ WorkflowAuditSetting, ArchiveConfig, ) +from common.dashboard import ChartDao User = Users @@ -3180,6 +3181,11 @@ def test_export_db(self, _get_engine): 测试导出 :return: """ + + def dummy(s): + return s + + _get_engine.return_value.escape_string = dummy _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( ResultSet(rows=(("test1",), ("test2",))) ) diff --git a/sql/views.py b/sql/views.py index 4c494e69f3..13174db039 100644 --- a/sql/views.py +++ b/sql/views.py @@ -12,7 +12,7 @@ from archery import settings from common.config import SysConfig -from sql.engines import get_engine +from sql.engines import get_engine, engine_map from common.utils.permission import superuser_required from common.utils.convert import Convert from sql.utils.tasks import task_info @@ -175,6 +175,7 @@ def submit_sql(request): context = { "group_list": group_list, "enable_backup_switch": archer_config.get("enable_backup_switch"), + "engines": engine_map, } return render(request, "sqlsubmit.html", context) @@ -328,7 +329,9 @@ def sqlquery(request): ) can_download = 1 if user.has_perm("sql.query_download") or user.is_superuser else 0 return render( - request, "sqlquery.html", {"favorites": favorites, "can_download": can_download} + request, + "sqlquery.html", + {"favorites": favorites, "can_download": can_download, "engines": engine_map}, ) @@ -339,7 +342,7 @@ def queryapplylist(request): # 获取资源组 group_list = user_groups(user) - context = {"group_list": group_list} + context = {"group_list": group_list, "engines": engine_map} return render(request, "queryapplylist.html", context) @@ -403,7 +406,7 @@ def instance(request): """实例管理页面""" # 获取实例标签 tags = InstanceTag.objects.filter(active=True) - return render(request, "instance.html", {"tags": tags}) + return render(request, "instance.html", {"tags": tags, "engines": engine_map}) @permission_required("sql.menu_instance_account", raise_exception=True) diff --git a/src/docker/Dockerfile b/src/docker/Dockerfile index f8848cbe72..5c5b9b5e14 100644 --- a/src/docker/Dockerfile +++ b/src/docker/Dockerfile @@ -7,7 +7,7 @@ COPY . /opt/archery/ #archery RUN cd /opt \ - && yum -y install nginx \ + && yum -y install nginx krb5-devel \ && source /opt/venv4archery/bin/activate \ && pip3 install -r /opt/archery/requirements.txt \ && cp -f /opt/archery/src/docker/nginx.conf /etc/nginx/ \ From bb30ee54c9cd947957fb9a483767dc02dac44b2a Mon Sep 17 00:00:00 2001 From: Leo Q Date: Mon, 14 Aug 2023 14:49:40 +0000 Subject: [PATCH 2/5] add more cassandra tests --- sql/engines/cassandra.py | 6 +----- sql/engines/test_cassandra.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/engines/cassandra.py b/sql/engines/cassandra.py index 03b913bab9..9cc83160bd 100644 --- a/sql/engines/cassandra.py +++ b/sql/engines/cassandra.py @@ -17,12 +17,8 @@ logger = logging.getLogger("default") -def row_to_set(row) -> set: - pass - - def split_sql(db_name=None, sql=""): - # 切分语句,追加到检测结果中,默认全部检测通过 + """切分语句,追加到检测结果中,默认全部检测通过""" sql = sql.split("\n") sql = filter(None, sql) sql_result = [] diff --git a/sql/engines/test_cassandra.py b/sql/engines/test_cassandra.py index 0c04ba6b03..15abe01a6e 100644 --- a/sql/engines/test_cassandra.py +++ b/sql/engines/test_cassandra.py @@ -3,7 +3,7 @@ from django.test import TestCase from sql.models import Instance -from sql.engines.cassandra import CassandraEngine +from sql.engines.cassandra import CassandraEngine, split_sql from sql.engines.models import ResultSet # 启用后, 会运行全部测试, 包括一些集成测试 @@ -96,6 +96,17 @@ def test_get_all_columns_by_tb(self, mock_query): self.assertEqual(result.rows, [("name", "text")]) self.assertEqual(result.column_list, ["column_name", "type"]) + def test_split(self): + sql = """CREATE TABLE emp( + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" + sql_result = split_sql(db_name="test_db", sql=sql) + self.assertEqual(sql_result[0], "USE test_db") + @unittest.skipIf( not integration_test_enabled, "cassandra integration test is not enabled" From 16070211c4ac1c15566539049d789443e99126d3 Mon Sep 17 00:00:00 2001 From: qizhicheng Date: Tue, 15 Aug 2023 13:43:13 +0800 Subject: [PATCH 3/5] fix data dictionary tests --- downloads/dictionary/.gitignore | 1 + downloads/dictionary/.gitkeep | 0 .../test_instance_test_archery.html | 18 ---------------- sql/data_dictionary.py | 2 +- sql/tests.py | 21 +++++++++++++++++++ 5 files changed, 23 insertions(+), 19 deletions(-) create mode 100644 downloads/dictionary/.gitignore create mode 100644 downloads/dictionary/.gitkeep delete mode 100644 downloads/dictionary/test_instance_test_archery.html diff --git a/downloads/dictionary/.gitignore b/downloads/dictionary/.gitignore new file mode 100644 index 0000000000..0b84df0f02 --- /dev/null +++ b/downloads/dictionary/.gitignore @@ -0,0 +1 @@ +*.html \ No newline at end of file diff --git a/downloads/dictionary/.gitkeep b/downloads/dictionary/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/downloads/dictionary/test_instance_test_archery.html b/downloads/dictionary/test_instance_test_archery.html deleted file mode 100644 index 30a5966fb5..0000000000 --- a/downloads/dictionary/test_instance_test_archery.html +++ /dev/null @@ -1,18 +0,0 @@ - - - ݿṹ˵ĵ - - - -

test_archery ֵ ( 0 )

-

ʱ䣺2023-08-14 18:44:43

- - - diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 942023ac94..809d65df78 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -124,7 +124,7 @@ def export(request): return JsonResponse({"status": 1, "msg": "仅管理员可以导出整个实例的字典信息!", "data": []}) # 获取数据,存入目录 - path = os.path.join(settings.BASE_DIR, "downloads/dictionary") + path = os.path.join(settings.BASE_DIR, "downloads", "dictionary") os.makedirs(path, exist_ok=True) for db in dbs: table_metas = query_engine.get_tables_metas_data(db_name=db) diff --git a/sql/tests.py b/sql/tests.py index dab9d41e3b..a8a910a07b 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -3247,6 +3247,15 @@ def dummy(s): r = self.client.get(path="/data_dictionary/export/", data=data) self.assertEqual(r.status_code, 200) self.assertTrue(r.streaming) + + # 测试恶意请求 + data = { + "instance_name": self.ins.instance_name, + "db_name": "/../../../etc/passwd", + "db_type": "mysql", + } + r = self.client.get(path="/data_dictionary/export/", data=data) + self.assertEqual(r.json()['status'], 1) @patch("sql.data_dictionary.get_engine") def oracle_test_export_db(self, _get_engine): @@ -3297,6 +3306,10 @@ def test_export_instance(self, _get_engine): 测试导出 :return: """ + def dummy(s): + return s + + _get_engine.return_value.escape_string = dummy _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( ResultSet(rows=(("test1",), ("test2",))) ) @@ -3361,6 +3374,14 @@ def test_export_instance(self, _get_engine): "status": 0, }, ) + # 测试恶意请求 + data = { + "instance_name": self.ins.instance_name, + "db_name": "/../../../etc/passwd", + "db_type": "mysql", + } + r = self.client.get(path="/data_dictionary/export/", data=data) + self.assertEqual(r.json()['status'], 1) @patch("sql.data_dictionary.get_engine") def oracle_test_export_instance(self, _get_engine): From 56c3a3dfa5b31632593110cab76f1f65e0a61ecc Mon Sep 17 00:00:00 2001 From: qizhicheng Date: Tue, 15 Aug 2023 16:13:19 +0800 Subject: [PATCH 4/5] more tests in cassandra --- sql/engines/Readme.md | 49 ++++++++++++++++++++++ sql/engines/cassandra.py | 79 +++++++++++++++++++++++------------ sql/engines/test_cassandra.py | 68 ++++++++++++++++++++++++++---- sql/tests.py | 7 ++-- 4 files changed, 165 insertions(+), 38 deletions(-) create mode 100644 sql/engines/Readme.md diff --git a/sql/engines/Readme.md b/sql/engines/Readme.md new file mode 100644 index 0000000000..fb6a4ddfa0 --- /dev/null +++ b/sql/engines/Readme.md @@ -0,0 +1,49 @@ +# Engine 说明 + +## Cassandra +当前连接时, 使用参数基本为写死参数, 具体可以参照代码. + +如果需要覆盖, 可以自行继承 + +具体方法为: +1. 新增一个文件夹`extras`在根目录, 和`sql`, `sql_api`等文件夹平级 可以docker 打包时加入, 也可以使用卷挂载的方式 +2. 新增一个文件, `mycassandra.py` +```python +from sql.engines.cassandra import CassandraEngine + +class MyCassandraEngine(CassandraEngine): + def get_connection(self, db_name=None): + db_name = db_name or self.db_name + if self.conn: + if db_name: + self.conn.execute(f"use {db_name}") + return self.conn + hosts = self.host.split(",") + # 在这里更改你获取 session 的方式 + auth_provider = PlainTextAuthProvider( + username=self.user, password=self.password + ) + cluster = Cluster(hosts, port=self.port, auth_provider=auth_provider, + load_balancing_policy=RoundRobinPolicy(), protocol_version=5) + self.conn = cluster.connect(keyspace=db_name) + # 下面这一句最好是不要动. + self.conn.row_factory = tuple_factory + return self.conn +``` +3. 修改settings , 加载你刚写的 engine +```python +AVAILABLE_ENGINES = { + "mysql": {"path": "sql.engines.mysql:MysqlEngine"}, + # 这里改成你的 engine + "cassandra": {"path": "extras.mycassandra:MyCassandraEngine"}, + "clickhouse": {"path": "sql.engines.clickhouse:ClickHouseEngine"}, + "goinception": {"path": "sql.engines.goinception:GoInceptionEngine"}, + "mssql": {"path": "sql.engines.mssql:MssqlEngine"}, + "redis": {"path": "sql.engines.redis:RedisEngine"}, + "pqsql": {"path": "sql.engines.pgsql:PgSQLEngine"}, + "oracle": {"path": "sql.engines.oracle:OracleEngine"}, + "mongo": {"path": "sql.engines.mongo:MongoEngine"}, + "phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"}, + "odps": {"path": "sql.engines.odps:ODPSEngine"}, +} +``` \ No newline at end of file diff --git a/sql/engines/cassandra.py b/sql/engines/cassandra.py index 9cc83160bd..5c1ba05331 100644 --- a/sql/engines/cassandra.py +++ b/sql/engines/cassandra.py @@ -6,6 +6,7 @@ from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import tuple_factory +from cassandra.policies import RoundRobinPolicy import sqlparse @@ -19,15 +20,33 @@ def split_sql(db_name=None, sql=""): """切分语句,追加到检测结果中,默认全部检测通过""" - sql = sql.split("\n") - sql = filter(None, sql) + sql = sqlparse.format(sql, strip_comments=True) sql_result = [] if db_name: sql_result += [f"""USE {db_name}"""] - sql_result += sql + sql_result += sqlparse.split(sql) return sql_result +def dummy_audit(full_sql: str, sql_list) -> ReviewSet: + check_result = ReviewSet(full_sql=full_sql) + rowid = 1 + for statement in sql_list: + check_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + rowid += 1 + return check_result + + class CassandraEngine(EngineBase): name = "Cassandra" info = "Cassandra engine" @@ -41,7 +60,14 @@ def get_connection(self, db_name=None): auth_provider = PlainTextAuthProvider( username=self.user, password=self.password ) - cluster = Cluster([self.host], port=self.port, auth_provider=auth_provider) + hosts = self.host.split(",") + cluster = Cluster( + hosts, + port=self.port, + auth_provider=auth_provider, + load_balancing_policy=RoundRobinPolicy(), + protocol_version=5, + ) self.conn = cluster.connect(keyspace=db_name) self.conn.row_factory = tuple_factory return self.conn @@ -86,7 +112,7 @@ def describe_table(self, db_name, tb_name, **kwargs): result.rows = filtered_rows return result - def query_check(self, db_name=None, sql="", limit_num=0): + def query_check(self, db_name=None, sql="", limit_num: int = 100): """提交查询前的检查""" # 查询语句的检查、注释去除、切分 result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} @@ -106,6 +132,22 @@ def query_check(self, db_name=None, sql="", limit_num=0): result["msg"] = "SQL语句中含有 * " return result + def filter_sql(self, sql="", limit_num=0) -> str: + # 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n + sql = sql.rstrip(";").strip() + if re.match(r"^select", sql, re.I): + # LIMIT N + limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I) + if limit_n.search(sql): + sql_limit = limit_n.search(sql).group(1) + limit_num = min(int(limit_num), int(sql_limit)) + sql = limit_n.sub(f"limit {limit_num};", sql) + else: + sql = f"{sql} limit {limit_num};" + else: + sql = f"{sql};" + return sql + def query( self, db_name=None, @@ -131,6 +173,8 @@ def query( f"{self.name} query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}" ) result_set.error = str(e) + if close_conn: + self.close() return result_set def get_all_tables(self, db_name, **kwargs): @@ -141,33 +185,14 @@ def get_all_tables(self, db_name, **kwargs): result.rows = tb_list return result - def filter_sql(self, sql="", limit_num=0): - return sql.strip() - def query_masking(self, db_name=None, sql="", resultset=None): """不做脱敏""" return resultset def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" - check_result = ReviewSet(full_sql=sql) - # 切分语句,追加到检测结果中,默认全部检测通过 sql_result = split_sql(db_name, sql) - rowid = 1 - for statement in sql_result: - check_result.rows.append( - ReviewResult( - id=rowid, - errlevel=0, - stagestatus="Audit completed", - errormessage="None", - sql=statement, - affected_rows=0, - execute_time=0, - ) - ) - rowid += 1 - return check_result + return dummy_audit(sql, sql_result) def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """执行sql语句 返回 Review set""" @@ -208,13 +233,13 @@ def execute(self, db_name=None, sql="", close_conn=True, parameters=None): break rowid += 1 if execute_result.error: - for statement in split_sql[rowid:]: + for statement in sql_result[rowid:]: execute_result.rows.append( ReviewResult( id=rowid, errlevel=2, stagestatus="Execute Failed", - errormessage=f"前序语句失败, 未执行", + errormessage="前序语句失败, 未执行", sql=statement, affected_rows=0, execute_time=0, diff --git a/sql/engines/test_cassandra.py b/sql/engines/test_cassandra.py index 15abe01a6e..e0025868ee 100644 --- a/sql/engines/test_cassandra.py +++ b/sql/engines/test_cassandra.py @@ -27,8 +27,6 @@ def setUp(self) -> None: def tearDown(self) -> None: self.ins.delete() - if integration_test_enabled: - self.engine.execute(sql="drop keyspace test;") @patch("sql.engines.cassandra.Cluster.connect") def test_get_connection(self, mock_connect): @@ -98,15 +96,69 @@ def test_get_all_columns_by_tb(self, mock_query): def test_split(self): sql = """CREATE TABLE emp( - emp_id int PRIMARY KEY, - emp_name text, - emp_city text, - emp_sal varint, - emp_phone varint - );""" + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" sql_result = split_sql(db_name="test_db", sql=sql) self.assertEqual(sql_result[0], "USE test_db") + def test_execute_check(self): + sql = """CREATE TABLE emp( + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" + check_result = self.engine.execute_check(db_name="test_db", sql=sql) + self.assertEqual(check_result.full_sql, sql) + self.assertEqual(check_result.rows[1].stagestatus, "Audit completed") + + @patch("sql.engines.cassandra.CassandraEngine.get_connection") + def test_execute(self, mock_connection): + mock_execute = Mock() + mock_connection.return_value.execute = mock_execute + sql = """CREATE TABLE emp( + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" + execute_result = self.engine.execute(db_name="test_db", sql=sql) + self.assertEqual(execute_result.rows[1].stagestatus, "Execute Successfully") + mock_execute.assert_called() + + # exception + mock_execute.side_effect = ValueError("foo") + mock_execute.reset_mock(return_value=True) + execute_result = self.engine.execute(db_name="test_db", sql=sql) + self.assertEqual(execute_result.rows[0].stagestatus, "Execute Failed") + self.assertEqual(execute_result.rows[1].stagestatus, "Execute Failed") + self.assertEqual(execute_result.rows[0].errormessage, "异常信息:foo") + self.assertEqual(execute_result.rows[1].errormessage, "前序语句失败, 未执行") + mock_execute.assert_called() + + def test_filter_sql(self): + sql_without_limit = "select name from user_info;" + self.assertEqual( + self.engine.filter_sql(sql_without_limit, limit_num=100), + "select name from user_info limit 100;", + ) + sql_with_normal_limit = "select name from user_info limit 1;" + self.assertEqual( + self.engine.filter_sql(sql_with_normal_limit, limit_num=100), + "select name from user_info limit 1;", + ) + sql_with_high_limit = "select name from user_info limit 1000;" + self.assertEqual( + self.engine.filter_sql(sql_with_high_limit, limit_num=100), + "select name from user_info limit 100;", + ) + @unittest.skipIf( not integration_test_enabled, "cassandra integration test is not enabled" diff --git a/sql/tests.py b/sql/tests.py index a8a910a07b..a1ba6238af 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -3247,7 +3247,7 @@ def dummy(s): r = self.client.get(path="/data_dictionary/export/", data=data) self.assertEqual(r.status_code, 200) self.assertTrue(r.streaming) - + # 测试恶意请求 data = { "instance_name": self.ins.instance_name, @@ -3255,7 +3255,7 @@ def dummy(s): "db_type": "mysql", } r = self.client.get(path="/data_dictionary/export/", data=data) - self.assertEqual(r.json()['status'], 1) + self.assertEqual(r.json()["status"], 1) @patch("sql.data_dictionary.get_engine") def oracle_test_export_db(self, _get_engine): @@ -3306,6 +3306,7 @@ def test_export_instance(self, _get_engine): 测试导出 :return: """ + def dummy(s): return s @@ -3381,7 +3382,7 @@ def dummy(s): "db_type": "mysql", } r = self.client.get(path="/data_dictionary/export/", data=data) - self.assertEqual(r.json()['status'], 1) + self.assertEqual(r.json()["status"], 1) @patch("sql.data_dictionary.get_engine") def oracle_test_export_instance(self, _get_engine): From 620ff4a443096418d6fd6f2d9d5b609911e608c6 Mon Sep 17 00:00:00 2001 From: qizhicheng Date: Tue, 15 Aug 2023 17:44:13 +0800 Subject: [PATCH 5/5] add cassandra in readme --- README.md | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index dff7b545a5..f67414bf0c 100644 --- a/README.md +++ b/README.md @@ -22,17 +22,18 @@ 功能清单 ==== -| 数据库 | 查询 | 审核 | 执行 | 备份 | 数据字典 | 慢日志 | 会话管理 | 账号管理 | 参数管理 | 数据归档 | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| MySQL | √ | √ | √ | √ | √ | √ | √ | √ | √ | √ | -| MsSQL | √ | × | √ | × | √ | × | × | × | × | × | -| Redis | √ | × | √ | × | × | × | × | × | × | × | -| PgSQL | √ | × | √ | × | × | × | × | × | × | × | -| Oracle | √ | √ | √ | √ | √ | × | √ | × | × | × | -| MongoDB | √ | √ | √ | × | × | × | √ | √ | × | × | -| Phoenix | √ | × | √ | × | × | × | × | × | × | × | -| ODPS | √ | × | × | × | × | × | × | × | × | × | +| 数据库 | 查询 | 审核 | 执行 | 备份 | 数据字典 | 慢日志 | 会话管理 | 账号管理 | 参数管理 | 数据归档 | +|------------| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| MySQL | √ | √ | √ | √ | √ | √ | √ | √ | √ | √ | +| MsSQL | √ | × | √ | × | √ | × | × | × | × | × | +| Redis | √ | × | √ | × | × | × | × | × | × | × | +| PgSQL | √ | × | √ | × | × | × | × | × | × | × | +| Oracle | √ | √ | √ | √ | √ | × | √ | × | × | × | +| MongoDB | √ | √ | √ | × | × | × | √ | √ | × | × | +| Phoenix | √ | × | √ | × | × | × | × | × | × | × | +| ODPS | √ | × | × | × | × | × | × | × | × | × | | ClickHouse | √ | √ | √ | × | × | × | × | × | × | × | +| Cassandra | √ | × | √ | × | × | × | × | × | × | × |