diff --git a/CHANGES.rst b/CHANGES.rst index 910e2b06..fc819a87 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,10 @@ Changelog ========= +0.2.3 (2021-12-22) +------------------ +* Fix bug in multithreaded use of sqlite (`#210 `_ thanks `@rohits144 `_ for the report) + 0.2.2 (2021-11-05) ------------------ * Fix readthedocs build diff --git a/opentaxii/_version.py b/opentaxii/_version.py index 8cdfc355..4af68473 100644 --- a/opentaxii/_version.py +++ b/opentaxii/_version.py @@ -3,4 +3,4 @@ This module defines the package version for use in __init__.py and setup.py. """ -__version__ = '0.2.2' +__version__ = '0.2.3' diff --git a/opentaxii/sqldb_helper.py b/opentaxii/sqldb_helper.py index ff6ff612..90d0c02a 100644 --- a/opentaxii/sqldb_helper.py +++ b/opentaxii/sqldb_helper.py @@ -24,13 +24,7 @@ class SQLAlchemyDB(object): ''' def __init__(self, db_connection, base_model, session_options=None, **kwargs): - - if isinstance(db_connection, str): - self.engine = engine.create_engine(db_connection, convert_unicode=True, **kwargs) - self.connection = self.engine.connect() - else: - self.connection = db_connection - + self.engine = engine.create_engine(db_connection, convert_unicode=True, **kwargs) self.Query = orm.Query self.session = self.create_scoped_session(session_options) self.Model = self.extend_base_model(base_model) @@ -57,10 +51,10 @@ def create_scoped_session(self, options=None): self.create_session(options), scopefunc=scopefunc) def create_session(self, options): - return orm.sessionmaker(bind=self.connection, **options) + return orm.sessionmaker(bind=self.engine, **options) def create_all_tables(self): - self.metadata.create_all(bind=self.connection) + self.metadata.create_all(bind=self.engine) def init_app(self, app): @app.teardown_appcontext diff --git a/opentaxii/utils.py b/opentaxii/utils.py index bbf2ae75..219a4fe6 100644 --- a/opentaxii/utils.py +++ b/opentaxii/utils.py @@ -1,17 +1,16 @@ -import sys -import logging -import structlog -import importlib import base64 import binascii +import importlib +import logging +import sys +import structlog from six.moves import urllib from .entities import Account -from .taxii.entities import ( - CollectionEntity, deserialize_content_bindings) -from .taxii.converters import dict_to_service_entity from .exceptions import InvalidAuthHeader +from .taxii.converters import dict_to_service_entity +from .taxii.entities import CollectionEntity, deserialize_content_bindings log = structlog.getLogger(__name__) @@ -254,7 +253,7 @@ def sync_collections(server, collections, force_deletion=False): else: collection = existing_by_name[name] collection.available = False - manager.update_collection(cobj) + manager.update_collection(collection) disabled_counter += 1 log.info("sync_collections.disabled", name=name) log.info( diff --git a/tests/conftest.py b/tests/conftest.py index 30aaa8f6..e4a24d81 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from tempfile import mkstemp import pytest from opentaxii.config import ServerConfig @@ -7,7 +8,6 @@ from opentaxii.server import TAXIIServer from opentaxii.taxii.converters import dict_to_service_entity from opentaxii.utils import configure_logging -from sqlalchemy import create_engine, event from fixtures import DOMAIN, SERVICES @@ -16,44 +16,27 @@ @pytest.fixture(scope="session") def dbconn(): - yield "sqlite://" - + filehandle, filename = mkstemp(suffix='.db') + os.close(filehandle) + try: + yield f"sqlite:///{filename}" + finally: + try: + os.remove(filename) + except FileNotFoundError: + pass elif DBTYPE in ("mysql", "mariadb"): import MySQLdb - from sqlalchemy.orm import sessionmaker - - Session = sessionmaker() @pytest.fixture(scope="session") def dbconn(): # drop db if exists to provide clean state at beginning - dbname = "test" if DBTYPE == "mysql": port = 3306 elif DBTYPE == "mariadb": port = 3307 - connection_kwargs = { - "host": "127.0.0.1", - "user": "root", - "passwd": "", - "port": port, - } - mysql_conn: MySQLdb.Connection = MySQLdb.connect(**connection_kwargs) - mysql_conn.query(f"DROP DATABASE IF EXISTS {dbname}") - mysql_conn.query( - f"CREATE DATABASE {dbname} " - f"DEFAULT CHARACTER SET utf8 " - f"DEFAULT COLLATE utf8_general_ci" - ) - mysql_conn.close() - engine = create_engine( - f"mysql+mysqldb://root:@127.0.0.1:{port}/test?charset=utf8", - convert_unicode=True, - ) - connection = engine.connect() - yield connection - connection.close() + yield f"mysql+mysqldb://root:@127.0.0.1:{port}/test?charset=utf8" elif DBTYPE == "postgres": @@ -64,37 +47,10 @@ def dbconn(): compat.register() import psycopg2 - from sqlalchemy.orm import sessionmaker - - Session = sessionmaker() @pytest.fixture(scope="session") def dbconn(): - # drop public schema to provide clean state at beginning - dbname = "test" - pg_conn = psycopg2.connect( - dbname=dbname, - user="test", - password="test", - host="127.0.0.1", - port="5432", - ) - pg_cur = pg_conn.cursor() - pg_cur.execute( - "DROP SCHEMA public CASCADE;" - "CREATE SCHEMA public;" - "GRANT ALL ON SCHEMA public TO test;" - "GRANT ALL ON SCHEMA public TO public;" - ) - pg_cur.close() - pg_conn.close() - engine = create_engine( - f"postgresql+psycopg2://test:test@127.0.0.1:5432/{dbname}", - convert_unicode=True, - ) - connection = engine.connect() - yield connection - connection.close() + yield "postgresql+psycopg2://test:test@127.0.0.1:5432/test" else: @@ -137,32 +93,59 @@ def anonymous_user(): release_context() -@pytest.fixture() -def app(dbconn): - if DBTYPE != "sqlite": - # run non-sqlite tests in nested transaction/savepoint setup to ensure atomic tests - transaction = dbconn.begin() - session = Session(bind=dbconn) - session.begin_nested() - - @event.listens_for(session, "after_transaction_end") - def restart_savepoint(session, transaction): - if transaction.nested and not transaction._parent.nested: - - # ensure that state is expired the way - # session.commit() at the top level normally does - # (optional step) - session.expire_all() +def clean_db(dbconn): + # drop and recreate db to provide clean state at beginning + if DBTYPE == "sqlite": + filename = dbconn[len("sqlite:///"):] + os.remove(filename) + elif DBTYPE == "postgres": + with psycopg2.connect( + dbname="test", + user="test", + password="test", + host="127.0.0.1", + port="5432", + ) as pg_conn: + with pg_conn.cursor() as pg_cur: + pg_cur.execute( + "DROP SCHEMA public CASCADE;" + "CREATE SCHEMA public;" + "GRANT ALL ON SCHEMA public TO test;" + "GRANT ALL ON SCHEMA public TO public;" + ) + pg_conn.close() # pypy + psycopg2cffi needs this, doesn't close conn otherwise + elif DBTYPE in ("mysql", "mariadb"): + dbname = "test" + if DBTYPE == "mysql": + port = 3306 + elif DBTYPE == "mariadb": + port = 3307 + connection_kwargs = { + "host": "127.0.0.1", + "user": "root", + "passwd": "", + "port": port, + } + with MySQLdb.connect(**connection_kwargs) as mysql_conn: + mysql_conn.query(f"DROP DATABASE IF EXISTS {dbname}") + mysql_conn.query( + f"CREATE DATABASE {dbname} " + f"DEFAULT CHARACTER SET utf8 " + f"DEFAULT COLLATE utf8_general_ci" + ) - session.begin_nested() - context.server = TAXIIServer(prepare_test_config(dbconn)) +@pytest.fixture() +def app(dbconn): + clean_db(dbconn) + server = TAXIIServer(prepare_test_config(dbconn)) + context.server = server app = create_app(context.server) app.config["TESTING"] = True yield app - if DBTYPE != "sqlite": - session.close() - transaction.rollback() + for part in [server.auth, server.persistence]: + part.api.db.session.commit() + part.api.db.engine.dispose() @pytest.fixture() diff --git a/tests/test_server.py b/tests/test_server.py index 2f787a44..22f5bf17 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,6 @@ -import pytest +import concurrent.futures +import pytest from opentaxii.taxii.converters import dict_to_service_entity from fixtures import DOMAIN @@ -53,3 +54,15 @@ def test_services_configured(server): assert len(with_paths) == len(INTERNAL_SERVICES) assert all([ p.address.startswith(DOMAIN) for p in with_paths]) + + +def test_multithreaded_access(server): + + def testfunc(): + server.get_services() + server.persistence.api.db.session.commit() + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + results = [executor.submit(testfunc) for _ in range(2)] + for result in concurrent.futures.as_completed(results): + assert not result.exception(timeout=5)