Skip to content

Commit

Permalink
more tests in cassandra
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoQuote committed Aug 15, 2023
1 parent 073b189 commit 2edfdc3
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 38 deletions.
49 changes: 49 additions & 0 deletions sql/engines/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Engine 说明

## Cassandra
当前连接时, 使用参数基本为写死参数, 具体可以参照代码.

如果需要覆盖, 可以自行继承

具体方法为:
1. 新增一个文件夹`extras`在根目录, 和`sql`, `sql_api`等文件夹平级 可以docker 打包时加入, 也可以使用卷挂载的方式
2. 新增一个文件, `mycassandra.py`
```python
from sql.engines.cassandra import CassandraEngine

class MyCassandraEngine(CassandraEngine):
def get_connection(self, db_name=None):
db_name = db_name or self.db_name
if self.conn:
if db_name:
self.conn.execute(f"use {db_name}")
return self.conn
hosts = self.host.split(",")
# 在这里更改你获取 session 的方式
auth_provider = PlainTextAuthProvider(
username=self.user, password=self.password
)
cluster = Cluster(hosts, port=self.port, auth_provider=auth_provider,
load_balancing_policy=RoundRobinPolicy(), protocol_version=5)
self.conn = cluster.connect(keyspace=db_name)
# 下面这一句最好是不要动.
self.conn.row_factory = tuple_factory
return self.conn
```
3. 修改settings , 加载你刚写的 engine
```python
AVAILABLE_ENGINES = {
"mysql": {"path": "sql.engines.mysql:MysqlEngine"},
# 这里改成你的 engine
"cassandra": {"path": "extras.mycassandra:MyCassandraEngine"},
"clickhouse": {"path": "sql.engines.clickhouse:ClickHouseEngine"},
"goinception": {"path": "sql.engines.goinception:GoInceptionEngine"},
"mssql": {"path": "sql.engines.mssql:MssqlEngine"},
"redis": {"path": "sql.engines.redis:RedisEngine"},
"pqsql": {"path": "sql.engines.pgsql:PgSQLEngine"},
"oracle": {"path": "sql.engines.oracle:OracleEngine"},
"mongo": {"path": "sql.engines.mongo:MongoEngine"},
"phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"},
"odps": {"path": "sql.engines.odps:ODPSEngine"},
}
```
79 changes: 52 additions & 27 deletions sql/engines/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import tuple_factory
from cassandra.policies import RoundRobinPolicy

import sqlparse

Expand All @@ -19,15 +20,33 @@

def split_sql(db_name=None, sql=""):
"""切分语句,追加到检测结果中,默认全部检测通过"""
sql = sql.split("\n")
sql = filter(None, sql)
sql = sqlparse.format(sql, strip_comments=True)
sql_result = []
if db_name:
sql_result += [f"""USE {db_name}"""]
sql_result += sql
sql_result += sqlparse.split(sql)
return sql_result


def dummy_audit(full_sql: str, sql_list) -> ReviewSet:
check_result = ReviewSet(full_sql=full_sql)
rowid = 1
for statement in sql_list:
check_result.rows.append(
ReviewResult(
id=rowid,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
rowid += 1
return check_result


class CassandraEngine(EngineBase):
name = "Cassandra"
info = "Cassandra engine"
Expand All @@ -41,7 +60,14 @@ def get_connection(self, db_name=None):
auth_provider = PlainTextAuthProvider(
username=self.user, password=self.password
)
cluster = Cluster([self.host], port=self.port, auth_provider=auth_provider)
hosts = self.host.split(",")
cluster = Cluster(
hosts,
port=self.port,
auth_provider=auth_provider,
load_balancing_policy=RoundRobinPolicy(),
protocol_version=5,
)
self.conn = cluster.connect(keyspace=db_name)
self.conn.row_factory = tuple_factory
return self.conn
Expand Down Expand Up @@ -86,7 +112,7 @@ def describe_table(self, db_name, tb_name, **kwargs):
result.rows = filtered_rows
return result

def query_check(self, db_name=None, sql="", limit_num=0):
def query_check(self, db_name=None, sql="", limit_num: int = 100):
"""提交查询前的检查"""
# 查询语句的检查、注释去除、切分
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
Expand All @@ -106,6 +132,22 @@ def query_check(self, db_name=None, sql="", limit_num=0):
result["msg"] = "SQL语句中含有 * "
return result

def filter_sql(self, sql="", limit_num=0) -> str:
# 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
sql = sql.rstrip(";").strip()
if re.match(r"^select", sql, re.I):
# LIMIT N
limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
if limit_n.search(sql):
sql_limit = limit_n.search(sql).group(1)
limit_num = min(int(limit_num), int(sql_limit))
sql = limit_n.sub(f"limit {limit_num};", sql)
else:
sql = f"{sql} limit {limit_num};"
else:
sql = f"{sql};"
return sql

def query(
self,
db_name=None,
Expand All @@ -131,6 +173,8 @@ def query(
f"{self.name} query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}"
)
result_set.error = str(e)
if close_conn:
self.close()
return result_set

def get_all_tables(self, db_name, **kwargs):
Expand All @@ -141,33 +185,14 @@ def get_all_tables(self, db_name, **kwargs):
result.rows = tb_list
return result

def filter_sql(self, sql="", limit_num=0):
return sql.strip()

def query_masking(self, db_name=None, sql="", resultset=None):
"""不做脱敏"""
return resultset

def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
# 切分语句,追加到检测结果中,默认全部检测通过
sql_result = split_sql(db_name, sql)
rowid = 1
for statement in sql_result:
check_result.rows.append(
ReviewResult(
id=rowid,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
rowid += 1
return check_result
return dummy_audit(sql, sql_result)

def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
"""执行sql语句 返回 Review set"""
Expand Down Expand Up @@ -208,13 +233,13 @@ def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
break
rowid += 1
if execute_result.error:
for statement in split_sql[rowid:]:
for statement in sql_result[rowid:]:
execute_result.rows.append(
ReviewResult(
id=rowid,
errlevel=2,
stagestatus="Execute Failed",
errormessage=f"前序语句失败, 未执行",
errormessage="前序语句失败, 未执行",
sql=statement,
affected_rows=0,
execute_time=0,
Expand Down
68 changes: 60 additions & 8 deletions sql/engines/test_cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def setUp(self) -> None:

def tearDown(self) -> None:
self.ins.delete()
if integration_test_enabled:
self.engine.execute(sql="drop keyspace test;")

@patch("sql.engines.cassandra.Cluster.connect")
def test_get_connection(self, mock_connect):
Expand Down Expand Up @@ -98,15 +96,69 @@ def test_get_all_columns_by_tb(self, mock_query):

def test_split(self):
sql = """CREATE TABLE emp(
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
sql_result = split_sql(db_name="test_db", sql=sql)
self.assertEqual(sql_result[0], "USE test_db")

def test_execute_check(self):
sql = """CREATE TABLE emp(
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
check_result = self.engine.execute_check(db_name="test_db", sql=sql)
self.assertEqual(check_result.full_sql, sql)
self.assertEqual(check_result.rows[1].stagestatus, "Audit completed")

@patch("sql.engines.cassandra.CassandraEngine.get_connection")
def test_execute(self, mock_connection):
mock_execute = Mock()
mock_connection.return_value.execute = mock_execute
sql = """CREATE TABLE emp(
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
execute_result = self.engine.execute(db_name="test_db", sql=sql)
self.assertEqual(execute_result.rows[1].stagestatus, "Execute Successfully")
mock_execute.assert_called()

# exception
mock_execute.side_effect = ValueError("foo")
mock_execute.reset_mock(return_value=True)
execute_result = self.engine.execute(db_name="test_db", sql=sql)
self.assertEqual(execute_result.rows[0].stagestatus, "Execute Failed")
self.assertEqual(execute_result.rows[1].stagestatus, "Execute Failed")
self.assertEqual(execute_result.rows[0].errormessage, "异常信息:foo")
self.assertEqual(execute_result.rows[1].errormessage, "前序语句失败, 未执行")
mock_execute.assert_called()

def test_filter_sql(self):
sql_without_limit = "select name from user_info;"
self.assertEqual(
self.engine.filter_sql(sql_without_limit, limit_num=100),
"select name from user_info limit 100;",
)
sql_with_normal_limit = "select name from user_info limit 1;"
self.assertEqual(
self.engine.filter_sql(sql_with_normal_limit, limit_num=100),
"select name from user_info limit 1;",
)
sql_with_high_limit = "select name from user_info limit 1000;"
self.assertEqual(
self.engine.filter_sql(sql_with_high_limit, limit_num=100),
"select name from user_info limit 100;",
)


@unittest.skipIf(
not integration_test_enabled, "cassandra integration test is not enabled"
Expand Down
7 changes: 4 additions & 3 deletions sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3247,15 +3247,15 @@ def dummy(s):
r = self.client.get(path="/data_dictionary/export/", data=data)
self.assertEqual(r.status_code, 200)
self.assertTrue(r.streaming)

# 测试恶意请求
data = {
"instance_name": self.ins.instance_name,
"db_name": "/../../../etc/passwd",
"db_type": "mysql",
}
r = self.client.get(path="/data_dictionary/export/", data=data)
self.assertEqual(r.json()['status'], 1)
self.assertEqual(r.json()["status"], 1)

@patch("sql.data_dictionary.get_engine")
def oracle_test_export_db(self, _get_engine):
Expand Down Expand Up @@ -3306,6 +3306,7 @@ def test_export_instance(self, _get_engine):
测试导出
:return:
"""

def dummy(s):
return s

Expand Down Expand Up @@ -3381,7 +3382,7 @@ def dummy(s):
"db_type": "mysql",
}
r = self.client.get(path="/data_dictionary/export/", data=data)
self.assertEqual(r.json()['status'], 1)
self.assertEqual(r.json()["status"], 1)

@patch("sql.data_dictionary.get_engine")
def oracle_test_export_instance(self, _get_engine):
Expand Down

0 comments on commit 2edfdc3

Please sign in to comment.