diff --git a/cve_bin_tool/cvedb.py b/cve_bin_tool/cvedb.py index 0c529b27df..e4488a6822 100644 --- a/cve_bin_tool/cvedb.py +++ b/cve_bin_tool/cvedb.py @@ -43,6 +43,15 @@ DBNAME = "cve.db" OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb" +# Define a list of valid table names +VALID_TABLE_NAMES = [ + "cve_severity", + "cve_range", + "cve_exploited", + "cve_metrics", + "metrics", +] + class CVEDB: """ @@ -253,8 +262,17 @@ def latest_schema( self.LOGGER.debug("Check database is using latest schema") cursor = self.db_open_and_get_cursor() - schema_check = f"SELECT * FROM {table_name} WHERE 1=0" # nosec - result = cursor.execute(schema_check) + + # Use the validation function to check the table name + if table_name not in VALID_TABLE_NAMES: + raise ValueError(f"Invalid table name: {table_name}") + + # Construct the query using the table name directly + # Since table_name is predefined and not from user input, there's no risk of SQL injection + query = f"SELECT * FROM {table_name} WHERE 1=0" # nosec + cursor.execute(query) + result = cursor.fetchall() + schema_latest = False if not cursor: @@ -373,6 +391,9 @@ def table_schemas(self): metrics_table, ) + def is_valid_table_name(self, table_name: str) -> bool: + return table_name in VALID_TABLE_NAMES + def init_database(self) -> None: """Initialize db tables used for storing cve/version data.""" @@ -862,8 +883,12 @@ def dict_factory(self, cursor, row): def get_all_records_in_table(self, table_name): """Return JSON of all records in a database table.""" + # Use the validation function to check the table name + if not self.is_valid_table_name(table_name): + raise ValueError(f"Invalid table name: {table_name}") cursor = self.db_open_and_get_cursor() cursor.row_factory = self.dict_factory + # Can't use parameterized query because they don't work on table names cursor.execute(f"SELECT * FROM '{table_name}' ") # nosec # fetchall as result results = cursor.fetchall()