diff --git a/.github/workflows/django.yml b/.github/workflows/django.yml index b6bbd4d8c4..0722acc2d8 100644 --- a/.github/workflows/django.yml +++ b/.github/workflows/django.yml @@ -17,7 +17,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.8, 3.9, "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] # https://github.com/actions/example-services/tree/master/.github/workflows services: @@ -70,7 +70,7 @@ jobs: - name: Install Dependencies run: | - sudo apt-get update && sudo apt-get install libsasl2-dev libldap2-dev libssl-dev unixodbc unixodbc-dev + sudo apt-get update && sudo apt-get install libsasl2-dev libkrb5-dev libldap2-dev libssl-dev unixodbc unixodbc-dev python -m pip install --upgrade pip pip install codecov coverage flake8 -r requirements.txt diff --git a/README.md b/README.md index dff7b545a5..f67414bf0c 100644 --- a/README.md +++ b/README.md @@ -22,17 +22,18 @@ 功能清单 ==== -| 数据库 | 查询 | 审核 | 执行 | 备份 | 数据字典 | 慢日志 | 会话管理 | 账号管理 | 参数管理 | 数据归档 | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| MySQL | √ | √ | √ | √ | √ | √ | √ | √ | √ | √ | -| MsSQL | √ | × | √ | × | √ | × | × | × | × | × | -| Redis | √ | × | √ | × | × | × | × | × | × | × | -| PgSQL | √ | × | √ | × | × | × | × | × | × | × | -| Oracle | √ | √ | √ | √ | √ | × | √ | × | × | × | -| MongoDB | √ | √ | √ | × | × | × | √ | √ | × | × | -| Phoenix | √ | × | √ | × | × | × | × | × | × | × | -| ODPS | √ | × | × | × | × | × | × | × | × | × | +| 数据库 | 查询 | 审核 | 执行 | 备份 | 数据字典 | 慢日志 | 会话管理 | 账号管理 | 参数管理 | 数据归档 | +|------------| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| MySQL | √ | √ | √ | √ | √ | √ | √ | √ | √ | √ | +| MsSQL | √ | × | √ | × | √ | × | × | × | × | × | +| Redis | √ | × | √ | × | × | × | × | × | × | × | +| PgSQL | √ | × | √ | × | × | × | × | × | × | × | +| Oracle | √ | √ | √ | √ | √ | × | √ | × | × | × | +| MongoDB | √ | √ | √ | × | × | × | √ | √ | × | × | +| Phoenix | √ | × | √ | × | × | × | × | × | × | × | +| ODPS | √ | × | × | × | × | × | × | × | × | × | | ClickHouse | √ | √ | √ | × | × | × | × | × | × | × | +| Cassandra | √ | × | √ | × | × | × | × | × | × | × | diff --git a/archery/settings.py b/archery/settings.py index 0a01726696..7e69957027 100644 --- a/archery/settings.py +++ b/archery/settings.py @@ -19,7 +19,7 @@ environ.Env.read_env(os.path.join(BASE_DIR, ".env")) env = environ.Env( DEBUG=(bool, False), - ALLOWED_HOSTS=(List[str], ["*"]), + ALLOWED_HOSTS=(list, ["*"]), SECRET_KEY=(str, "hfusaf2m4ot#7)fkw#di2bu6(cv0@opwmafx5n#6=3d%x^hpl6"), DATABASE_URL=(str, "mysql://root:@127.0.0.1:3306/archery"), CACHE_URL=(str, "redis://127.0.0.1:6379/0"), @@ -38,6 +38,21 @@ Q_CLUISTER_SYNC=(bool, False), # qcluster 同步模式, debug 时可以调整为 True # CSRF_TRUSTED_ORIGINS=subdomain.example.com,subdomain.example2.com subdomain.example.com CSRF_TRUSTED_ORIGINS=(list, []), + ENABLED_ENGINES=( + list, + [ + "mysql", + "clickhouse", + "goinception", + "mssql", + "redis", + "pqsql", + "oracle", + "mongo", + "phoenix", + "odps", + ], + ), ) # SECURITY WARNING: keep the secret key used in production secret! @@ -57,6 +72,21 @@ # 请求限制 DATA_UPLOAD_MAX_MEMORY_SIZE = 15728640 +AVAILABLE_ENGINES = { + "mysql": {"path": "sql.engines.mysql:MysqlEngine"}, + "cassandra": {"path": "sql.engines.cassandra:CassandraEngine"}, + "clickhouse": {"path": "sql.engines.clickhouse:ClickHouseEngine"}, + "goinception": {"path": "sql.engines.goinception:GoInceptionEngine"}, + "mssql": {"path": "sql.engines.mssql:MssqlEngine"}, + "redis": {"path": "sql.engines.redis:RedisEngine"}, + "pqsql": {"path": "sql.engines.pgsql:PgSQLEngine"}, + "oracle": {"path": "sql.engines.oracle:OracleEngine"}, + "mongo": {"path": "sql.engines.mongo:MongoEngine"}, + "phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"}, + "odps": {"path": "sql.engines.odps:ODPSEngine"}, +} +ENABLED_ENGINES = env("ENABLED_ENGINES") + # Application definition INSTALLED_APPS = ( "django.contrib.admin", @@ -245,7 +275,6 @@ ENABLE_OIDC = env("ENABLE_OIDC", False) if ENABLE_OIDC: INSTALLED_APPS += ("mozilla_django_oidc",) - MIDDLEWARE += ("mozilla_django_oidc.middleware.SessionRefresh",) AUTHENTICATION_BACKENDS = ( "common.authenticate.oidc_auth.OIDCAuthenticationBackend", "django.contrib.auth.backends.ModelBackend", diff --git a/downloads/dictionary/.gitignore b/downloads/dictionary/.gitignore new file mode 100644 index 0000000000..0b84df0f02 --- /dev/null +++ b/downloads/dictionary/.gitignore @@ -0,0 +1 @@ +*.html \ No newline at end of file diff --git a/downloads/dictionary/.gitkeep b/downloads/dictionary/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/downloads/dictionary/test_instance_test_archery.html b/downloads/dictionary/test_instance_test_archery.html deleted file mode 100644 index 900ff1f413..0000000000 --- a/downloads/dictionary/test_instance_test_archery.html +++ /dev/null @@ -1,18 +0,0 @@ - - -
生成时间:2023-01-31 14:41:33
- - - diff --git a/requirements.txt b/requirements.txt index cd54d81312..0857a3d03b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ pyecharts==1.9.1 aliyun-python-sdk-rds==2.1.1 cx-Oracle==7.3.0 supervisor==4.1.0 -phoenixdb==0.7 +phoenixdb==1.2.1 django-mirage-field==1.4.0 schema-sync==0.9.7 parsedatetime==2.4 @@ -38,4 +38,5 @@ django-environ==0.8.1 alibabacloud_dysmsapi20170525==2.0.9 tencentcloud-sdk-python==3.0.656 mozilla-django-oidc==3.0.0 -django-auth-dingding==0.0.2 \ No newline at end of file +django-auth-dingding==0.0.2 +cassandra-driver diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 4a702d8075..809d65df78 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -90,6 +90,16 @@ def table_info(request): ) +def get_export_full_path(base_dir: str, instance_name: str, db_name: str) -> str: + """validate if the instance_name and db_name provided is secure""" + fullpath = os.path.normpath( + os.path.join(base_dir, f"{instance_name}_{db_name}.html") + ) + if not fullpath.startswith(base_dir): + return "" + return fullpath + + @permission_required("sql.data_dictionary_export", raise_exception=True) def export(request): """导出数据字典""" @@ -111,10 +121,10 @@ def export(request): elif request.user.is_superuser: dbs = query_engine.get_all_databases().rows else: - return JsonResponse({"status": 1, "msg": f"仅管理员可以导出整个实例的字典信息!", "data": []}) + return JsonResponse({"status": 1, "msg": "仅管理员可以导出整个实例的字典信息!", "data": []}) # 获取数据,存入目录 - path = os.path.join(settings.BASE_DIR, "downloads/dictionary") + path = os.path.join(settings.BASE_DIR, "downloads", "dictionary") os.makedirs(path, exist_ok=True) for db in dbs: table_metas = query_engine.get_tables_metas_data(db_name=db) @@ -126,12 +136,18 @@ def export(request): data = loader.render_to_string( template_name="dictionaryexport.html", context=context, request=request ) - with open(f"{path}/{instance_name}_{db}.html", "w") as f: - f.write(data) + fullpath = get_export_full_path(path, instance_name, db) + if not fullpath: + return JsonResponse({"status": 1, "msg": "实例名或db名不合法", "data": []}) + with open(fullpath, "w", encoding="utf-8") as fp: + fp.write(data) # 关闭连接 query_engine.close() if db_name: - response = FileResponse(open(f"{path}/{instance_name}_{db_name}.html", "rb")) + fullpath = get_export_full_path(path, instance_name, db) + if not fullpath: + return JsonResponse({"status": 1, "msg": "实例名或db名不合法", "data": []}) + response = FileResponse(open(fullpath, "rb")) response["Content-Type"] = "application/octet-stream" response[ "Content-Disposition" diff --git a/sql/engines/Readme.md b/sql/engines/Readme.md new file mode 100644 index 0000000000..fb6a4ddfa0 --- /dev/null +++ b/sql/engines/Readme.md @@ -0,0 +1,49 @@ +# Engine 说明 + +## Cassandra +当前连接时, 使用参数基本为写死参数, 具体可以参照代码. + +如果需要覆盖, 可以自行继承 + +具体方法为: +1. 新增一个文件夹`extras`在根目录, 和`sql`, `sql_api`等文件夹平级 可以docker 打包时加入, 也可以使用卷挂载的方式 +2. 新增一个文件, `mycassandra.py` +```python +from sql.engines.cassandra import CassandraEngine + +class MyCassandraEngine(CassandraEngine): + def get_connection(self, db_name=None): + db_name = db_name or self.db_name + if self.conn: + if db_name: + self.conn.execute(f"use {db_name}") + return self.conn + hosts = self.host.split(",") + # 在这里更改你获取 session 的方式 + auth_provider = PlainTextAuthProvider( + username=self.user, password=self.password + ) + cluster = Cluster(hosts, port=self.port, auth_provider=auth_provider, + load_balancing_policy=RoundRobinPolicy(), protocol_version=5) + self.conn = cluster.connect(keyspace=db_name) + # 下面这一句最好是不要动. + self.conn.row_factory = tuple_factory + return self.conn +``` +3. 修改settings , 加载你刚写的 engine +```python +AVAILABLE_ENGINES = { + "mysql": {"path": "sql.engines.mysql:MysqlEngine"}, + # 这里改成你的 engine + "cassandra": {"path": "extras.mycassandra:MyCassandraEngine"}, + "clickhouse": {"path": "sql.engines.clickhouse:ClickHouseEngine"}, + "goinception": {"path": "sql.engines.goinception:GoInceptionEngine"}, + "mssql": {"path": "sql.engines.mssql:MssqlEngine"}, + "redis": {"path": "sql.engines.redis:RedisEngine"}, + "pqsql": {"path": "sql.engines.pgsql:PgSQLEngine"}, + "oracle": {"path": "sql.engines.oracle:OracleEngine"}, + "mongo": {"path": "sql.engines.mongo:MongoEngine"}, + "phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"}, + "odps": {"path": "sql.engines.odps:ODPSEngine"}, +} +``` \ No newline at end of file diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index 508a8dd9b1..c3c6f7ddd9 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -1,6 +1,8 @@ """engine base库, 包含一个``EngineBase`` class和一个get_engine函数""" +import importlib from sql.engines.models import ResultSet, ReviewSet from sql.utils.ssh_tunnel import SSHConnection +from django.conf import settings class EngineBase: @@ -8,6 +10,9 @@ class EngineBase: test_query = None + name = "Base" + info = "base engine" + def __init__(self, instance=None): self.conn = None self.thread_id = None @@ -77,16 +82,6 @@ def test_connection(self): """测试实例链接是否正常""" return self.query(sql=self.test_query) - @property - def name(self): - """返回engine名称""" - return "base" - - @property - def info(self): - """返回引擎简介""" - return "Base engine" - def escape_string(self, value: str) -> str: """参数转义""" return value @@ -179,7 +174,7 @@ def query( limit_num=0, close_conn=True, parameters=None, - **kwargs + **kwargs, ): """实际查询 返回一个ResultSet""" return ResultSet() @@ -213,6 +208,22 @@ def set_variable(self, variable_name, variable_value): return ResultSet() +def get_engine_map(): + available_engines = settings.AVAILABLE_ENGINES + enabled_engines = {} + for e in settings.ENABLED_ENGINES: + config = available_engines.get(e) + if not config: + raise ValueError(f"invalid engine {e}, not found in engine map") + module, o = config["path"].split(":") + engine = getattr(importlib.import_module(module), o) + enabled_engines[e] = engine + return enabled_engines + + +engine_map = get_engine_map() + + def get_engine(instance=None): # pragma: no cover """获取数据库操作engine""" if instance.db_type == "mysql": @@ -222,44 +233,9 @@ def get_engine(instance=None): # pragma: no cover from .cloud.aliyun_rds import AliyunRDS return AliyunRDS(instance=instance) - from .mysql import MysqlEngine - - return MysqlEngine(instance=instance) - elif instance.db_type == "mssql": - from .mssql import MssqlEngine - - return MssqlEngine(instance=instance) - elif instance.db_type == "redis": - from .redis import RedisEngine - - return RedisEngine(instance=instance) - elif instance.db_type == "pgsql": - from .pgsql import PgSQLEngine - - return PgSQLEngine(instance=instance) - elif instance.db_type == "oracle": - from .oracle import OracleEngine - - return OracleEngine(instance=instance) - elif instance.db_type == "mongo": - from .mongo import MongoEngine - - return MongoEngine(instance=instance) - elif instance.db_type == "goinception": - from .goinception import GoInceptionEngine - - return GoInceptionEngine(instance=instance) - elif instance.db_type == "phoenix": - from .phoenix import PhoenixEngine - - return PhoenixEngine(instance=instance) - - elif instance.db_type == "odps": - from .odps import ODPSEngine - - return ODPSEngine(instance=instance) - - elif instance.db_type == "clickhouse": - from .clickhouse import ClickHouseEngine - - return ClickHouseEngine(instance=instance) + engine = engine_map.get(instance.db_type) + if not engine: + raise ValueError( + f"engine {instance.db_type} not enabled or not supported, please contact admin" + ) + return engine(instance=instance) diff --git a/sql/engines/cassandra.py b/sql/engines/cassandra.py new file mode 100644 index 0000000000..5c1ba05331 --- /dev/null +++ b/sql/engines/cassandra.py @@ -0,0 +1,257 @@ +import re + +import logging +import traceback + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider +from cassandra.query import tuple_factory +from cassandra.policies import RoundRobinPolicy + +import sqlparse + +from . import EngineBase +from .models import ResultSet, ReviewSet, ReviewResult + +from sql.models import SqlWorkflow + +logger = logging.getLogger("default") + + +def split_sql(db_name=None, sql=""): + """切分语句,追加到检测结果中,默认全部检测通过""" + sql = sqlparse.format(sql, strip_comments=True) + sql_result = [] + if db_name: + sql_result += [f"""USE {db_name}"""] + sql_result += sqlparse.split(sql) + return sql_result + + +def dummy_audit(full_sql: str, sql_list) -> ReviewSet: + check_result = ReviewSet(full_sql=full_sql) + rowid = 1 + for statement in sql_list: + check_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + rowid += 1 + return check_result + + +class CassandraEngine(EngineBase): + name = "Cassandra" + info = "Cassandra engine" + + def get_connection(self, db_name=None): + db_name = db_name or self.db_name + if self.conn: + if db_name: + self.conn.execute(f"use {db_name}") + return self.conn + auth_provider = PlainTextAuthProvider( + username=self.user, password=self.password + ) + hosts = self.host.split(",") + cluster = Cluster( + hosts, + port=self.port, + auth_provider=auth_provider, + load_balancing_policy=RoundRobinPolicy(), + protocol_version=5, + ) + self.conn = cluster.connect(keyspace=db_name) + self.conn.row_factory = tuple_factory + return self.conn + + def close(self): + if self.conn: + self.conn.shutdown() + self.conn = None + + def test_connection(self): + result = self.get_all_databases() + self.close() + return result + + def escape_string(self, value: str) -> str: + return re.sub(r"[; ]", "", value) + + def get_all_databases(self, **kwargs): + """ + 获取所有的 keyspace/database + :return: + """ + result = self.query(sql="SELECT keyspace_name FROM system_schema.keyspaces;") + result.rows = [x[0] for x in result.rows] + return result + + def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): + """获取所有列, 返回一个ResultSet""" + sql = "select column_name, type from columns where keyspace_name=%s and table_name=%s" + result = self.query( + db_name="system_schema", sql=sql, parameters=(db_name, tb_name) + ) + return result + + def describe_table(self, db_name, tb_name, **kwargs): + sql = f"describe table {tb_name}" + result = self.query(db_name=db_name, sql=sql) + result.column_list = ["table", "create table"] + filtered_rows = [] + for r in result.rows: + filtered_rows.append((r[2], r[3])) + result.rows = filtered_rows + return result + + def query_check(self, db_name=None, sql="", limit_num: int = 100): + """提交查询前的检查""" + # 查询语句的检查、注释去除、切分 + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + # 删除注释语句,进行语法判断,执行第一条有效sql + try: + sql = sqlparse.format(sql, strip_comments=True) + sql = sqlparse.split(sql)[0] + result["filtered_sql"] = sql.strip() + except IndexError: + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" + if re.match(r"^select|^describe", sql, re.I) is None: + result["bad_query"] = True + result["msg"] = "不支持的查询语法类型!" + if "*" in sql: + result["has_star"] = True + result["msg"] = "SQL语句中含有 * " + return result + + def filter_sql(self, sql="", limit_num=0) -> str: + # 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n + sql = sql.rstrip(";").strip() + if re.match(r"^select", sql, re.I): + # LIMIT N + limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I) + if limit_n.search(sql): + sql_limit = limit_n.search(sql).group(1) + limit_num = min(int(limit_num), int(sql_limit)) + sql = limit_n.sub(f"limit {limit_num};", sql) + else: + sql = f"{sql} limit {limit_num};" + else: + sql = f"{sql};" + return sql + + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): + """返回 ResultSet""" + result_set = ResultSet(full_sql=sql) + try: + conn = self.get_connection(db_name=db_name) + rows = conn.execute(sql, parameters=parameters) + result_set.column_list = rows.column_names + result_set.rows = rows.all() + result_set.affected_rows = len(result_set.rows) + if limit_num > 0: + result_set.rows = result_set.rows[0:limit_num] + result_set.affected_rows = min(limit_num, result_set.affected_rows) + except Exception as e: + logger.warning( + f"{self.name} query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}" + ) + result_set.error = str(e) + if close_conn: + self.close() + return result_set + + def get_all_tables(self, db_name, **kwargs): + sql = "SELECT table_name FROM system_schema.tables WHERE keyspace_name = %s;" + parameters = [db_name] + result = self.query(db_name=db_name, sql=sql, parameters=parameters) + tb_list = [row[0] for row in result.rows] + result.rows = tb_list + return result + + def query_masking(self, db_name=None, sql="", resultset=None): + """不做脱敏""" + return resultset + + def execute_check(self, db_name=None, sql=""): + """上线单执行前的检查, 返回Review set""" + sql_result = split_sql(db_name, sql) + return dummy_audit(sql, sql_result) + + def execute(self, db_name=None, sql="", close_conn=True, parameters=None): + """执行sql语句 返回 Review set""" + execute_result = ReviewSet(full_sql=sql) + conn = self.get_connection(db_name=db_name) + sql_result = split_sql(db_name, sql) + rowid = 1 + for statement in sql_result: + try: + conn.execute(statement) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + except Exception as e: + logger.warning( + f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}" + ) + execute_result.error = str(e) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + break + rowid += 1 + if execute_result.error: + for statement in sql_result[rowid:]: + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage="前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) + rowid += 1 + if close_conn: + self.close() + return execute_result + + def execute_workflow(self, workflow: SqlWorkflow): + """执行上线单,返回Review set""" + return self.execute( + db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content + ) diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 33dedfdd31..69dfd78772 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -42,13 +42,8 @@ def get_connection(self, db_name=None): ) return self.conn - @property - def name(self): - return "ClickHouse" - - @property - def info(self): - return "ClickHouse engine" + name = "ClickHouse" + info = "ClickHouse engine" def escape_string(self, value: str) -> str: """字符串参数转义""" diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py index d291cccf28..551c0eea47 100644 --- a/sql/engines/goinception.py +++ b/sql/engines/goinception.py @@ -18,13 +18,9 @@ class GoInceptionEngine(EngineBase): test_query = "INCEPTION GET VARIABLES" - @property - def name(self): - return "GoInception" + name = "GoInception" - @property - def info(self): - return "GoInception engine" + info = "GoInception engine" def get_connection(self, db_name=None): if self.conn: diff --git a/sql/engines/mongo.py b/sql/engines/mongo.py index 7a7c2aa844..764d084a5a 100644 --- a/sql/engines/mongo.py +++ b/sql/engines/mongo.py @@ -3,7 +3,6 @@ import pymongo import logging import traceback -import json import subprocess import simplejson as json import datetime @@ -800,13 +799,9 @@ def close(self): self.conn.close() self.conn = None - @property - def name(self): # pragma: no cover - return "Mongo" + name = "Mongo" - @property - def info(self): # pragma: no cover - return "Mongo engine" + info = "Mongo engine" def get_roles(self): sql_get_roles = "db.system.roles.find({},{_id:1})" @@ -1266,18 +1261,20 @@ def kill_op(self, opids): result = ResultSet() try: conn = self.get_connection() - db = conn.admin - for opid in opids: - conn.admin.command({"killOp": 1, "op": opid}) except Exception as e: - try: - sql = {"killOp": 1, "op": _opid} - except: - sql = {"killOp": 1, "op": ""} - logger.warning( - f"mongodb语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}" - ) + logger.error(f"{self.name} 连接失败, error: {str(e)}") result.error = str(e) + return result + for opid in opids: + try: + conn.admin.command({"killOp": 1, "op": opid}) + except Exception as e: + sql = {"killOp": 1, "op": opid} + logger.warning( + f"{self.name}语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}" + ) + result.error = str(e) + return result def get_all_databases_summary(self): """实例数据库管理功能,获取实例所有的数据库描述信息""" diff --git a/sql/engines/mssql.py b/sql/engines/mssql.py index 9a56a0dadf..716a600917 100644 --- a/sql/engines/mssql.py +++ b/sql/engines/mssql.py @@ -31,13 +31,9 @@ def get_connection(self, db_name=None): self.conn = pyodbc.connect(connstr) return self.conn - @property - def name(self): - return "MsSQL" + name = "MsSQL" - @property - def info(self): - return "MsSQL engine" + info = "MsSQL engine" def get_all_databases(self): """获取数据库列表, 返回一个ResultSet""" diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index 05252cc81a..7ab5cfe18c 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -91,13 +91,9 @@ def get_connection(self, db_name=None): self.thread_id = self.conn.thread_id() return self.conn - @property - def name(self): - return "MySQL" + name = "MySQL" - @property - def info(self): - return "MySQL engine" + info = "MySQL engine" def escape_string(self, value: str) -> str: """字符串参数转义""" diff --git a/sql/engines/odps.py b/sql/engines/odps.py index b4e986b5e1..e6ef6a4cc8 100644 --- a/sql/engines/odps.py +++ b/sql/engines/odps.py @@ -29,13 +29,9 @@ def get_connection(self, db_name=None): return self.conn - @property - def name(self): - return "ODPS" + name = "ODPS" - @property - def info(self): - return "ODPS engine" + info = "ODPS engine" def get_all_databases(self): """获取数据库列表, 返回一个ResultSet diff --git a/sql/engines/oracle.py b/sql/engines/oracle.py index 31cabbbbd3..6f25bbe71b 100644 --- a/sql/engines/oracle.py +++ b/sql/engines/oracle.py @@ -28,8 +28,9 @@ class OracleEngine(EngineBase): def __init__(self, instance=None): super(OracleEngine, self).__init__(instance=instance) - self.service_name = instance.service_name - self.sid = instance.sid + if instance: + self.service_name = instance.service_name + self.sid = instance.sid def get_connection(self, db_name=None): if self.conn: @@ -50,13 +51,9 @@ def get_connection(self, db_name=None): raise ValueError("sid 和 dsn 均未填写, 请联系管理页补充该实例配置.") return self.conn - @property - def name(self): - return "Oracle" + name = "Oracle" - @property - def info(self): - return "Oracle engine" + info = "Oracle engine" @property def auto_backup(self): diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py index 2c303a1644..7dcb1a3ae2 100644 --- a/sql/engines/pgsql.py +++ b/sql/engines/pgsql.py @@ -40,13 +40,9 @@ def get_connection(self, db_name=None): ) return self.conn - @property - def name(self): - return "PgSQL" + name = "PgSQL" - @property - def info(self): - return "PgSQL engine" + info = "PgSQL engine" def get_all_databases(self): """ diff --git a/sql/engines/phoenix.py b/sql/engines/phoenix.py index 6383c2537e..fc7ac592d1 100644 --- a/sql/engines/phoenix.py +++ b/sql/engines/phoenix.py @@ -14,6 +14,9 @@ class PhoenixEngine(EngineBase): test_query = "SELECT 1" + name = "phoenix" + info = "phoenix engine" + def get_connection(self, db_name=None): if self.conn: return self.conn diff --git a/sql/engines/redis.py b/sql/engines/redis.py index e8b8d6410e..168e982b56 100644 --- a/sql/engines/redis.py +++ b/sql/engines/redis.py @@ -47,13 +47,9 @@ def get_connection(self, db_name=None): ssl=self.is_ssl, ) - @property - def name(self): - return "Redis" + name = "Redis" - @property - def info(self): - return "Redis engine" + info = "Redis engine" def test_connection(self): return self.get_all_databases() diff --git a/sql/engines/test_cassandra.py b/sql/engines/test_cassandra.py new file mode 100644 index 0000000000..e0025868ee --- /dev/null +++ b/sql/engines/test_cassandra.py @@ -0,0 +1,207 @@ +import unittest +from unittest.mock import patch, Mock + +from django.test import TestCase +from sql.models import Instance +from sql.engines.cassandra import CassandraEngine, split_sql +from sql.engines.models import ResultSet + +# 启用后, 会运行全部测试, 包括一些集成测试 +integration_test_enabled = False +integration_test_host = "localhost" + + +class CassandraEngineTest(TestCase): + def setUp(self) -> None: + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="cassandra", + host="localhost", + port=9200, + user="cassandra", + password="cassandra", + db_name="some_db", + ) + self.engine = CassandraEngine(instance=self.ins) + + def tearDown(self) -> None: + self.ins.delete() + + @patch("sql.engines.cassandra.Cluster.connect") + def test_get_connection(self, mock_connect): + _ = self.engine.get_connection() + mock_connect.assert_called_once() + + @patch("sql.engines.cassandra.CassandraEngine.get_connection") + def test_query(self, mock_get_connection): + test_sql = """select 123""" + self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet) + + def test_query_check(self): + test_sql = """select 123; -- this is comment + select 456;""" + + result_sql = "select 123;" + + check_result = self.engine.query_check(sql=test_sql) + + self.assertIsInstance(check_result, dict) + self.assertEqual(False, check_result.get("bad_query")) + self.assertEqual(result_sql, check_result.get("filtered_sql")) + + def test_query_check_error(self): + test_sql = """drop table table_a""" + + check_result = self.engine.query_check(sql=test_sql) + + self.assertIsInstance(check_result, dict) + self.assertEqual(True, check_result.get("bad_query")) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_get_all_databases(self, mock_query): + mock_query.return_value = ResultSet(rows=[("some_db",)]) + + result = self.engine.get_all_databases() + + self.assertIsInstance(result, ResultSet) + self.assertEqual(result.rows, ["some_db"]) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_get_all_tables(self, mock_query): + # 下面是查表示例返回结果 + mock_query.return_value = ResultSet(rows=[("u",), ("v",), ("w",)]) + + table_list = self.engine.get_all_tables("some_db") + + self.assertEqual(table_list.rows, ["u", "v", "w"]) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_describe_table(self, mock_query): + mock_query.return_value = ResultSet() + self.engine.describe_table("some_db", "some_table") + mock_query.assert_called_once_with( + db_name="some_db", sql="describe table some_table" + ) + + @patch("sql.engines.cassandra.CassandraEngine.query") + def test_get_all_columns_by_tb(self, mock_query): + mock_query.return_value = ResultSet( + rows=[("name", "text")], column_list=["column_name", "type"] + ) + + result = self.engine.get_all_columns_by_tb("some_db", "some_table") + self.assertEqual(result.rows, [("name", "text")]) + self.assertEqual(result.column_list, ["column_name", "type"]) + + def test_split(self): + sql = """CREATE TABLE emp( + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" + sql_result = split_sql(db_name="test_db", sql=sql) + self.assertEqual(sql_result[0], "USE test_db") + + def test_execute_check(self): + sql = """CREATE TABLE emp( + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" + check_result = self.engine.execute_check(db_name="test_db", sql=sql) + self.assertEqual(check_result.full_sql, sql) + self.assertEqual(check_result.rows[1].stagestatus, "Audit completed") + + @patch("sql.engines.cassandra.CassandraEngine.get_connection") + def test_execute(self, mock_connection): + mock_execute = Mock() + mock_connection.return_value.execute = mock_execute + sql = """CREATE TABLE emp( + emp_id int PRIMARY KEY, + emp_name text, + emp_city text, + emp_sal varint, + emp_phone varint + );""" + execute_result = self.engine.execute(db_name="test_db", sql=sql) + self.assertEqual(execute_result.rows[1].stagestatus, "Execute Successfully") + mock_execute.assert_called() + + # exception + mock_execute.side_effect = ValueError("foo") + mock_execute.reset_mock(return_value=True) + execute_result = self.engine.execute(db_name="test_db", sql=sql) + self.assertEqual(execute_result.rows[0].stagestatus, "Execute Failed") + self.assertEqual(execute_result.rows[1].stagestatus, "Execute Failed") + self.assertEqual(execute_result.rows[0].errormessage, "异常信息:foo") + self.assertEqual(execute_result.rows[1].errormessage, "前序语句失败, 未执行") + mock_execute.assert_called() + + def test_filter_sql(self): + sql_without_limit = "select name from user_info;" + self.assertEqual( + self.engine.filter_sql(sql_without_limit, limit_num=100), + "select name from user_info limit 100;", + ) + sql_with_normal_limit = "select name from user_info limit 1;" + self.assertEqual( + self.engine.filter_sql(sql_with_normal_limit, limit_num=100), + "select name from user_info limit 1;", + ) + sql_with_high_limit = "select name from user_info limit 1000;" + self.assertEqual( + self.engine.filter_sql(sql_with_high_limit, limit_num=100), + "select name from user_info limit 100;", + ) + + +@unittest.skipIf( + not integration_test_enabled, "cassandra integration test is not enabled" +) +class CassandraIntegrationTest(TestCase): + def setUp(self): + self.instance = Instance.objects.create( + instance_name="int_ins", + type="slave", + db_type="cassandra", + host=integration_test_host, + port=9042, + user="cassandra", + password="cassandra", + db_name="", + ) + self.engine = CassandraEngine(instance=self.instance) + + self.keyspace = "test" + self.table = "test_table" + # 新建 keyspace + self.engine.execute( + sql=f"create keyspace {self.keyspace} with replication = " + "{'class': 'org.apache.cassandra.locator.SimpleStrategy', " + "'replication_factor': '1'};" + ) + # 建表 + self.engine.execute( + db_name=self.keyspace, + sql=f"""create table if not exists {self.table}( name text primary key );""", + ) + + def tearDown(self): + self.engine.execute(sql="drop keyspace test;") + + def test_integrate_query(self): + self.engine.execute( + db_name=self.keyspace, + sql=f"insert into {self.table} (name) values ('test')", + ) + + result = self.engine.query( + db_name=self.keyspace, sql=f"select * from {self.table}" + ) + + self.assertEqual(result.rows[0][0], "test") diff --git a/sql/instance.py b/sql/instance.py index e7a3105f9a..3d2fa5785f 100644 --- a/sql/instance.py +++ b/sql/instance.py @@ -377,7 +377,7 @@ def describe(request): except Exception as msg: result["status"] = 1 result["msg"] = str(msg) - if result["data"]["error"]: + if result["data"].get("error"): result["status"] = 1 result["msg"] = result["data"]["error"] return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/sql/models.py b/sql/models.py index 98f32af8a2..0547d0cb64 100755 --- a/sql/models.py +++ b/sql/models.py @@ -127,6 +127,7 @@ class Meta: ("odps", "ODPS"), ("clickhouse", "ClickHouse"), ("goinception", "goInception"), + ("cassandra", "Cassandra"), ) diff --git a/sql/templates/instance.html b/sql/templates/instance.html index 1a1940b6f0..c61f56d888 100644 --- a/sql/templates/instance.html +++ b/sql/templates/instance.html @@ -13,15 +13,9 @@