Skip to content

Commit

Permalink
Merge pull request #3340 from uktrade/fix/codeql-sql-query-built-from…
Browse files Browse the repository at this point in the history
…-user-submitted-code

Fix/codeql sql query built from user submitted code
  • Loading branch information
JamesPRobinson authored Oct 31, 2024
2 parents 53cb7a8 + 43004d4 commit 47c96de
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
extension-pkg-whitelist=pglast

# Specify a score threshold to be exceeded before program exits with error.
fail-under=10
Expand Down
54 changes: 45 additions & 9 deletions dataworkspace/dataworkspace/apps/explorer/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,33 @@ def cleanup_temporary_query_tables():
server_db_user = DATABASES_DATA[query_log.connection]["USER"]
db_role = f"{USER_SCHEMA_STEM}{db_role_schema_suffix_for_user(query_log.run_by_user)}"
table_schema_and_name = tempory_query_table_name(query_log.run_by_user, query_log.id)

with cache.lock(
f'database-grant--{DATABASES_DATA[query_log.connection]["NAME"]}--{db_role}--v4',
blocking_timeout=3,
timeout=180,
):
with connections[query_log.connection].cursor() as cursor:
logger.info("Dropping temporary query table %s", table_schema_and_name)
cursor.execute(f"GRANT {db_role} TO {server_db_user}")
cursor.execute(f"DROP TABLE IF EXISTS {table_schema_and_name}")
cursor.execute(f"REVOKE {db_role} FROM {server_db_user}")
output_table_schema, output_table_name = table_schema_and_name.split(".")
cursor.execute(
psycopg2.sql.SQL("GRANT {role} TO {user}").format(
role=psycopg2.sql.Identifier(db_role),
user=psycopg2.sql.Identifier(server_db_user),
),
)
cursor.execute(
psycopg2.sql.SQL("DROP TABLE IF EXISTS {table_schema_name}").format(
table_schema_name=psycopg2.sql.Identifier(
output_table_schema, output_table_name
)
)
)
cursor.execute(
psycopg2.sql.SQL("REVOKE {role} FROM {user}").format(
role=psycopg2.sql.Identifier(db_role),
user=psycopg2.sql.Identifier(server_db_user),
)
)


def _prefix_column(index, column):
Expand Down Expand Up @@ -143,7 +159,7 @@ def _run_query(conn, query_log, page, limit, timeout, output_table):
start_time = time.time()
sql = query_log.sql.rstrip().rstrip(";")
try:
cursor.execute(f"SET statement_timeout = {timeout}")
cursor.execute("SET statement_timeout = %s", (timeout,))

if sql.strip().upper().startswith("EXPLAIN"):
cursor.execute(
Expand All @@ -164,7 +180,11 @@ def _run_query(conn, query_log, page, limit, timeout, output_table):
# It adds a prefix of col_x_ to duplicated column returned from the query and
# these prefixed column names are used to create a table containing the
# query results. The prefixes are removed when the results are returned.
cursor.execute(f"SELECT * FROM ({sql}) sq LIMIT 0")
cursor.execute(
psycopg2.sql.SQL("SELECT * FROM ({user_query}) sq LIMIT 0").format(
user_query=psycopg2.sql.SQL(sql)
)
)
column_names = list(zip(*cursor.description))[0]
duplicated_column_names = set(c for c in column_names if column_names.count(c) > 1)
prefixed_sql_columns = [
Expand All @@ -174,7 +194,14 @@ def _run_query(conn, query_log, page, limit, timeout, output_table):
)
for i, col in enumerate(cursor.description, 1)
]
cursor.execute(f'CREATE TABLE {output_table} ({", ".join(prefixed_sql_columns)})')
cols_formatted = ", ".join(prefixed_sql_columns)
output_table_schema, output_table_name = output_table.split(".")
cursor.execute(
psycopg2.sql.SQL("CREATE TABLE {output_table} ({cols_formatted})").format(
output_table=psycopg2.sql.Identifier(output_table_schema, output_table_name),
cols_formatted=psycopg2.sql.SQL(cols_formatted),
)
)
limit_clause = ""
if limit is not None:
limit_clause = f"LIMIT {limit}"
Expand All @@ -184,9 +211,18 @@ def _run_query(conn, query_log, page, limit, timeout, output_table):
offset = f" OFFSET {(page - 1) * limit}"

cursor.execute(
f"INSERT INTO {output_table} SELECT * FROM ({sql}) sq {limit_clause}{offset}"
psycopg2.sql.SQL(
"INSERT INTO {output_table} SELECT * FROM ({sql}) sq {limit_clause}{offset}"
).format(
output_table=psycopg2.sql.Identifier(output_table_schema, output_table_name),
sql=psycopg2.sql.SQL(sql),
limit_clause=psycopg2.sql.SQL(limit_clause),
offset=psycopg2.sql.SQL(offset),
),
)
cursor.execute(
psycopg2.sql.SQL("SELECT COUNT(*) FROM ({sql}) sq").format(sql=psycopg2.sql.SQL(sql))
)
cursor.execute(f"SELECT COUNT(*) FROM ({sql}) sq")
except psycopg2.errors.QueryCanceled as e: # pylint: disable=no-member
logger.info("Query cancelled: %s", e)
return
Expand Down
93 changes: 71 additions & 22 deletions dataworkspace/dataworkspace/tests/explorer/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.core.serializers.json import DjangoJSONEncoder
from django.test import TestCase
from freezegun import freeze_time
from psycopg2.sql import Identifier, SQL
import pytest
import six

Expand Down Expand Up @@ -80,9 +81,25 @@ def test_cleanup_temporary_query_tables(self, mock_connections, mock_databases_d
cleanup_temporary_query_tables()

expected_calls = [
call("GRANT _user_12b9377c TO postgres"),
call(f"DROP TABLE IF EXISTS _user_12b9377c._data_explorer_tmp_query_{query_log_1.id}"),
call("REVOKE _user_12b9377c FROM postgres"),
call(
SQL("GRANT {user} TO {role}").format(
role=Identifier("postgres"),
user=Identifier("_user_12b9377c"),
),
),
call(
SQL("DROP TABLE IF EXISTS {schema_table}").format(
schema_table=Identifier(
"_user_12b9377c", f"_data_explorer_tmp_query_{query_log_1.id}"
)
)
),
call(
SQL("REVOKE {user} FROM {role}").format(
role=Identifier("postgres"),
user=Identifier("_user_12b9377c"),
),
),
]
mock_cursor.execute.assert_has_calls(expected_calls)

Expand Down Expand Up @@ -123,16 +140,27 @@ def test_submit_query_for_execution(self, mock_connection_settings, mock_schema_
query_log_id = QueryLog.objects.first().id

expected_calls = [
call("SET statement_timeout = 10000"),
call("SELECT * FROM (select * from foo) sq LIMIT 0"),
call("SET statement_timeout = %s", (10000,)),
call(SQL("SELECT * FROM ({query}) sq LIMIT 0").format(query=SQL("select * from foo"))),
call(
f'CREATE TABLE _user_12b9377c._data_explorer_tmp_query_{query_log_id} ("foo" integer, "bar" text)'
SQL("CREATE TABLE {schema_table} ({cols})").format(
schema_table=Identifier(
"_user_12b9377c", f"_data_explorer_tmp_query_{query_log_id}"
),
cols=SQL('"foo" integer, "bar" text'),
),
),
call(
f"INSERT INTO _user_12b9377c._data_explorer_tmp_query_{query_log_id}"
" SELECT * FROM (select * from foo) sq LIMIT 100"
SQL("INSERT INTO {schema_table} SELECT * FROM ({sql}) sq {limit}{offset}").format(
schema_table=Identifier(
"_user_12b9377c", f"_data_explorer_tmp_query_{query_log_id}"
),
sql=SQL("select * from foo"),
limit=SQL("LIMIT 100"),
offset=SQL(""),
),
),
call("SELECT COUNT(*) FROM (select * from foo) sq"),
call(SQL("SELECT COUNT(*) FROM ({sql}) sq").format(sql=SQL("select * from foo"))),
]
self.mock_cursor.execute.assert_has_calls(expected_calls)

Expand All @@ -152,16 +180,29 @@ def test_submit_query_for_execution_with_pagination(
query_log_id = QueryLog.objects.first().id

expected_calls = [
call("SET statement_timeout = 10000"),
call("SELECT * FROM (select * from foo) sq LIMIT 0"),
call("SET statement_timeout = %s", (10000,)),
call(SQL("SELECT * FROM ({query}) sq LIMIT 0").format(query=SQL("select * from foo"))),
call(
f'CREATE TABLE _user_12b9377c._data_explorer_tmp_query_{query_log_id} ("foo" integer, "bar" text)'
SQL("CREATE TABLE {schema_table} ({cols})").format(
schema_table=Identifier(
"_user_12b9377c", f"_data_explorer_tmp_query_{query_log_id}"
),
cols=SQL('"foo" integer, "bar" text'),
)
),
call(
f"INSERT INTO _user_12b9377c._data_explorer_tmp_query_{query_log_id}"
" SELECT * FROM (select * from foo) sq LIMIT 100 OFFSET 100"
SQL(
"INSERT INTO {schema_table} SELECT * FROM ({query}) sq {limit}{offset}"
).format(
schema_table=Identifier(
"_user_12b9377c", f"_data_explorer_tmp_query_{query_log_id}"
),
query=SQL("select * from foo"),
limit=SQL("LIMIT 100"),
offset=SQL(" OFFSET 100"),
)
),
call("SELECT COUNT(*) FROM (select * from foo) sq"),
call(SQL("SELECT COUNT(*) FROM ({query}) sq").format(query=SQL("select * from foo"))),
]
self.mock_cursor.execute.assert_has_calls(expected_calls)

Expand All @@ -181,17 +222,25 @@ def test_submit_query_for_execution_with_duplicated_column_names(
query_log_id = QueryLog.objects.first().id

expected_calls = [
call("SET statement_timeout = 10000"),
call("SELECT * FROM (select * from foo) sq LIMIT 0"),
call("SET statement_timeout = %s", (10000,)),
call(SQL("SELECT * FROM ({query}) sq LIMIT 0").format(query=SQL("select * from foo"))),
call(
f"CREATE TABLE _user_12b9377c._data_explorer_tmp_query_{query_log_id}"
' ("col_1_bar" integer, "col_2_bar" text)'
SQL("CREATE TABLE {table} ({cols})").format(
table=Identifier("_user_12b9377c", f"_data_explorer_tmp_query_{query_log_id}"),
cols=SQL('"col_1_bar" integer, "col_2_bar" text'),
)
),
call(
f"INSERT INTO _user_12b9377c._data_explorer_tmp_query_{query_log_id}"
" SELECT * FROM (select * from foo) sq LIMIT 100"
SQL("INSERT INTO {schema_table} SELECT * FROM ({sql}) sq {limit}{offset}").format(
schema_table=Identifier(
"_user_12b9377c", f"_data_explorer_tmp_query_{query_log_id}"
),
sql=SQL("select * from foo"),
limit=SQL("LIMIT 100"),
offset=SQL(""),
)
),
call("SELECT COUNT(*) FROM (select * from foo) sq"),
call(SQL("SELECT COUNT(*) FROM ({sql}) sq").format(sql=SQL("select * from foo"))),
]
self.mock_cursor.execute.assert_has_calls(expected_calls)

Expand Down

0 comments on commit 47c96de

Please sign in to comment.