Skip to content

Commit

Permalink
Merge pull request #215 from eclecticiq/sqlite-multithreading
Browse files Browse the repository at this point in the history
Fix bug in multithreaded use of sqlite
  • Loading branch information
erwin-eiq authored Dec 27, 2021
2 parents 6282305 + f5f49a2 commit 13ddcbd
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 96 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

0.2.3 (2021-12-22)
------------------
* Fix bug in multithreaded use of sqlite (`#210 <https://github.com/eclecticiq/OpenTAXII/issues/114>`_ thanks `@rohits144 <https://github.com/rohits144>`_ for the report)

0.2.2 (2021-11-05)
------------------
* Fix readthedocs build
Expand Down
2 changes: 1 addition & 1 deletion opentaxii/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
This module defines the package version for use in __init__.py and setup.py.
"""

__version__ = '0.2.2'
__version__ = '0.2.3'
12 changes: 3 additions & 9 deletions opentaxii/sqldb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ class SQLAlchemyDB(object):
'''

def __init__(self, db_connection, base_model, session_options=None, **kwargs):

if isinstance(db_connection, str):
self.engine = engine.create_engine(db_connection, convert_unicode=True, **kwargs)
self.connection = self.engine.connect()
else:
self.connection = db_connection

self.engine = engine.create_engine(db_connection, convert_unicode=True, **kwargs)
self.Query = orm.Query
self.session = self.create_scoped_session(session_options)
self.Model = self.extend_base_model(base_model)
Expand All @@ -57,10 +51,10 @@ def create_scoped_session(self, options=None):
self.create_session(options), scopefunc=scopefunc)

def create_session(self, options):
return orm.sessionmaker(bind=self.connection, **options)
return orm.sessionmaker(bind=self.engine, **options)

def create_all_tables(self):
self.metadata.create_all(bind=self.connection)
self.metadata.create_all(bind=self.engine)

def init_app(self, app):
@app.teardown_appcontext
Expand Down
15 changes: 7 additions & 8 deletions opentaxii/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import sys
import logging
import structlog
import importlib
import base64
import binascii
import importlib
import logging
import sys

import structlog
from six.moves import urllib

from .entities import Account
from .taxii.entities import (
CollectionEntity, deserialize_content_bindings)
from .taxii.converters import dict_to_service_entity
from .exceptions import InvalidAuthHeader
from .taxii.converters import dict_to_service_entity
from .taxii.entities import CollectionEntity, deserialize_content_bindings

log = structlog.getLogger(__name__)

Expand Down Expand Up @@ -254,7 +253,7 @@ def sync_collections(server, collections, force_deletion=False):
else:
collection = existing_by_name[name]
collection.available = False
manager.update_collection(cobj)
manager.update_collection(collection)
disabled_counter += 1
log.info("sync_collections.disabled", name=name)
log.info(
Expand Down
137 changes: 60 additions & 77 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from tempfile import mkstemp

import pytest
from opentaxii.config import ServerConfig
Expand All @@ -7,7 +8,6 @@
from opentaxii.server import TAXIIServer
from opentaxii.taxii.converters import dict_to_service_entity
from opentaxii.utils import configure_logging
from sqlalchemy import create_engine, event

from fixtures import DOMAIN, SERVICES

Expand All @@ -16,44 +16,27 @@

@pytest.fixture(scope="session")
def dbconn():
yield "sqlite://"

filehandle, filename = mkstemp(suffix='.db')
os.close(filehandle)
try:
yield f"sqlite:///{filename}"
finally:
try:
os.remove(filename)
except FileNotFoundError:
pass

elif DBTYPE in ("mysql", "mariadb"):
import MySQLdb
from sqlalchemy.orm import sessionmaker

Session = sessionmaker()

@pytest.fixture(scope="session")
def dbconn():
# drop db if exists to provide clean state at beginning
dbname = "test"
if DBTYPE == "mysql":
port = 3306
elif DBTYPE == "mariadb":
port = 3307
connection_kwargs = {
"host": "127.0.0.1",
"user": "root",
"passwd": "",
"port": port,
}
mysql_conn: MySQLdb.Connection = MySQLdb.connect(**connection_kwargs)
mysql_conn.query(f"DROP DATABASE IF EXISTS {dbname}")
mysql_conn.query(
f"CREATE DATABASE {dbname} "
f"DEFAULT CHARACTER SET utf8 "
f"DEFAULT COLLATE utf8_general_ci"
)
mysql_conn.close()
engine = create_engine(
f"mysql+mysqldb://root:@127.0.0.1:{port}/test?charset=utf8",
convert_unicode=True,
)
connection = engine.connect()
yield connection
connection.close()
yield f"mysql+mysqldb://root:@127.0.0.1:{port}/test?charset=utf8"


elif DBTYPE == "postgres":
Expand All @@ -64,37 +47,10 @@ def dbconn():

compat.register()
import psycopg2
from sqlalchemy.orm import sessionmaker

Session = sessionmaker()

@pytest.fixture(scope="session")
def dbconn():
# drop public schema to provide clean state at beginning
dbname = "test"
pg_conn = psycopg2.connect(
dbname=dbname,
user="test",
password="test",
host="127.0.0.1",
port="5432",
)
pg_cur = pg_conn.cursor()
pg_cur.execute(
"DROP SCHEMA public CASCADE;"
"CREATE SCHEMA public;"
"GRANT ALL ON SCHEMA public TO test;"
"GRANT ALL ON SCHEMA public TO public;"
)
pg_cur.close()
pg_conn.close()
engine = create_engine(
f"postgresql+psycopg2://test:test@127.0.0.1:5432/{dbname}",
convert_unicode=True,
)
connection = engine.connect()
yield connection
connection.close()
yield "postgresql+psycopg2://test:test@127.0.0.1:5432/test"


else:
Expand Down Expand Up @@ -137,32 +93,59 @@ def anonymous_user():
release_context()


@pytest.fixture()
def app(dbconn):
if DBTYPE != "sqlite":
# run non-sqlite tests in nested transaction/savepoint setup to ensure atomic tests
transaction = dbconn.begin()
session = Session(bind=dbconn)
session.begin_nested()

@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:

# ensure that state is expired the way
# session.commit() at the top level normally does
# (optional step)
session.expire_all()
def clean_db(dbconn):
# drop and recreate db to provide clean state at beginning
if DBTYPE == "sqlite":
filename = dbconn[len("sqlite:///"):]
os.remove(filename)
elif DBTYPE == "postgres":
with psycopg2.connect(
dbname="test",
user="test",
password="test",
host="127.0.0.1",
port="5432",
) as pg_conn:
with pg_conn.cursor() as pg_cur:
pg_cur.execute(
"DROP SCHEMA public CASCADE;"
"CREATE SCHEMA public;"
"GRANT ALL ON SCHEMA public TO test;"
"GRANT ALL ON SCHEMA public TO public;"
)
pg_conn.close() # pypy + psycopg2cffi needs this, doesn't close conn otherwise
elif DBTYPE in ("mysql", "mariadb"):
dbname = "test"
if DBTYPE == "mysql":
port = 3306
elif DBTYPE == "mariadb":
port = 3307
connection_kwargs = {
"host": "127.0.0.1",
"user": "root",
"passwd": "",
"port": port,
}
with MySQLdb.connect(**connection_kwargs) as mysql_conn:
mysql_conn.query(f"DROP DATABASE IF EXISTS {dbname}")
mysql_conn.query(
f"CREATE DATABASE {dbname} "
f"DEFAULT CHARACTER SET utf8 "
f"DEFAULT COLLATE utf8_general_ci"
)

session.begin_nested()

context.server = TAXIIServer(prepare_test_config(dbconn))
@pytest.fixture()
def app(dbconn):
clean_db(dbconn)
server = TAXIIServer(prepare_test_config(dbconn))
context.server = server
app = create_app(context.server)
app.config["TESTING"] = True
yield app
if DBTYPE != "sqlite":
session.close()
transaction.rollback()
for part in [server.auth, server.persistence]:
part.api.db.session.commit()
part.api.db.engine.dispose()


@pytest.fixture()
Expand Down
15 changes: 14 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import concurrent.futures

import pytest
from opentaxii.taxii.converters import dict_to_service_entity

from fixtures import DOMAIN
Expand Down Expand Up @@ -53,3 +54,15 @@ def test_services_configured(server):
assert len(with_paths) == len(INTERNAL_SERVICES)
assert all([
p.address.startswith(DOMAIN) for p in with_paths])


def test_multithreaded_access(server):

def testfunc():
server.get_services()
server.persistence.api.db.session.commit()

with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
results = [executor.submit(testfunc) for _ in range(2)]
for result in concurrent.futures.as_completed(results):
assert not result.exception(timeout=5)

0 comments on commit 13ddcbd

Please sign in to comment.