Skip to content

Commit

Permalink
Adds a way to work with SQLite in a thread safe manner
Browse files Browse the repository at this point in the history
  • Loading branch information
crsib committed Feb 16, 2024
1 parent 5772219 commit deb96d6
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 18 deletions.
2 changes: 2 additions & 0 deletions libraries/lib-sqlite-helpers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ set( SOURCES
sqlite/Function.h
sqlite/Result.cpp
sqlite/Result.h
sqlite/SafeConnection.cpp
sqlite/SafeConnection.h
sqlite/SQLiteUtils.cpp
sqlite/SQLiteUtils.h
sqlite/Statement.cpp
Expand Down
8 changes: 0 additions & 8 deletions libraries/lib-sqlite-helpers/sqlite/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ Error Connection::TransactionHandler(
{
transaction.mCommitted = true;

std::lock_guard<std::mutex> lock(connection.mPendingTransactionsMutex);

connection.mPendingTransactions.erase(
std::remove_if(
connection.mPendingTransactions.begin(),
Expand All @@ -246,7 +244,6 @@ Error Connection::TransactionHandler(

if (operation == Transaction::TransactionOperation::BeginOp)
{
std::lock_guard<std::mutex> lock(connection.mPendingTransactionsMutex);
connection.mPendingTransactions.push_back(&transaction);
}

Expand Down Expand Up @@ -304,11 +301,6 @@ Error Connection::Close(bool force) noexcept
// rollback it.
std::vector<Transaction*> pendingTransactions;

{
std::lock_guard<std::mutex> lock(mPendingTransactionsMutex);
std::swap(mPendingTransactions, pendingTransactions);
}

for (auto* transaction : pendingTransactions)
if (auto err = transaction->Abort(); !force && err.IsError())
return err;
Expand Down
21 changes: 11 additions & 10 deletions libraries/lib-sqlite-helpers/sqlite/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#include <string>
#include <vector>

#include "Result.h"
#include "Transaction.h"
#include "Statement.h"
#include "Blob.h"
#include "Function.h"
#include "Result.h"
#include "Statement.h"
#include "Transaction.h"

struct sqlite3;

Expand Down Expand Up @@ -50,7 +50,7 @@ class SQLITE_HELPERS_API Connection final
const Connection& connection, OpenMode mode = OpenMode::ReadWriteCreate,
ThreadMode threadMode = ThreadMode::Serialized);

static Result<Connection> Reopen (
static Result<Connection> Reopen(
sqlite3* connection, OpenMode mode = OpenMode::ReadWriteCreate,
ThreadMode threadMode = ThreadMode::Serialized);

Expand All @@ -75,28 +75,29 @@ class SQLITE_HELPERS_API Connection final
Result<Statement> CreateStatement(std::string_view sql) const;

template<typename ScalarFunctionType>
ScalarFunction CreateScalarFunction(std::string name, ScalarFunctionType function)
ScalarFunction
CreateScalarFunction(std::string name, ScalarFunctionType function)
{
return ScalarFunction { mConnection, std::move(name),
std::move(function) };
}

template<typename StepFunctionType, typename FinalFunctionType>
AggregateFunction CreateAggregateFunction (
AggregateFunction CreateAggregateFunction(
std::string name, StepFunctionType stepFunction,
FinalFunctionType finalFunction)
{
return AggregateFunction { mConnection, std::move(name),
std::move(stepFunction),
std::move(finalFunction) };
std::move(stepFunction),
std::move(finalFunction) };
}

Result<Blob> OpenBlob(
const std::string& tableName, const std::string& columnName,
int64_t rowId, bool readOnly,
const std::string& databaseName = "main") const;

explicit operator sqlite3* () const noexcept;
explicit operator sqlite3*() const noexcept;
explicit operator bool() const noexcept;

Error Close(bool force) noexcept;
Expand All @@ -112,10 +113,10 @@ class SQLITE_HELPERS_API Connection final

sqlite3* mConnection {};

std::mutex mPendingTransactionsMutex {};
std::vector<Transaction*> mPendingTransactions {};

bool mInDestructor {};
bool mIsOwned {};
};

} // namespace sqlite
127 changes: 127 additions & 0 deletions libraries/lib-sqlite-helpers/sqlite/SafeConnection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* SPDX-License-Identifier: GPL-2.0-or-later
* SPDX-FileName: SafeConnection.cpp
* SPDX-FileContributor: Dmitry Vedenko
*/

#include "SafeConnection.h"

namespace sqlite
{
SafeConnection::SafeConnection(Tag, Connection connection)
: mConnection(std::move(connection))
{
}

std::shared_ptr<SafeConnection> SafeConnection::Open(
std::string_view path, OpenMode mode, ThreadMode threadMode,
Error* openError)
{
auto connection = Connection::Open(path, mode, threadMode);

if (!connection)
{
if (openError)
*openError = connection.GetError();

return {};
}

return std::make_shared<SafeConnection>(Tag {}, std::move(*connection));
}

std::shared_ptr<SafeConnection> SafeConnection::Reopen(
const Connection& prevConnection, OpenMode mode, ThreadMode threadMode,
Error* openError)
{
auto connection = Connection::Reopen(prevConnection, mode, threadMode);

if (!connection)
{
if (openError)
*openError = connection.GetError();

return {};
}

return std::make_shared<SafeConnection>(Tag {}, std::move(*connection));
}

std::shared_ptr<SafeConnection> SafeConnection::Reopen(
sqlite3* prevConnection, OpenMode mode, ThreadMode threadMode,
Error* openError)
{
auto connection = Connection::Reopen(prevConnection, mode, threadMode);

if (!connection)
{
if (openError)
*openError = connection.GetError();

return {};
}

return std::make_shared<SafeConnection>(Tag {}, std::move(*connection));
}

std::shared_ptr<SafeConnection> SafeConnection::Reopen(
SafeConnection& prevConnection, OpenMode mode, ThreadMode threadMode,
Error* openError)
{
auto connection =
Connection::Reopen(prevConnection.mConnection, mode, threadMode);

if (!connection)
{
if (openError)
*openError = connection.GetError();

return {};
}

return std::make_shared<SafeConnection>(Tag {}, std::move(*connection));
}

SafeConnection::Lock SafeConnection::Acquire() noexcept
{
return Lock { shared_from_this() };
}

SafeConnection::Lock::Lock(std::shared_ptr<SafeConnection> connection)
: mSafeConnection(std::move(connection))
{
if (mSafeConnection)
mLock = std::unique_lock { mSafeConnection->mConnectionMutex };
}

Connection* SafeConnection::Lock::operator->() noexcept
{
return &mSafeConnection->mConnection;
}

const Connection* SafeConnection::Lock::operator->() const noexcept
{
return &mSafeConnection->mConnection;
}

Connection& SafeConnection::Lock::operator*() noexcept
{
return mSafeConnection->mConnection;
}

const Connection& SafeConnection::Lock::operator*() const noexcept
{
return mSafeConnection->mConnection;
}

SafeConnection::Lock::operator bool() const noexcept
{
return IsValid();
}

bool SafeConnection::Lock::IsValid() const noexcept
{
return mSafeConnection != nullptr;
}

} // namespace sqlite
78 changes: 78 additions & 0 deletions libraries/lib-sqlite-helpers/sqlite/SafeConnection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* SPDX-License-Identifier: GPL-2.0-or-later
* SPDX-FileName: SafeConnection.h
* SPDX-FileContributor: Dmitry Vedenko
*/

#pragma once

#include <memory>
#include <mutex>

#include "Connection.h"

namespace sqlite
{
class SQLITE_HELPERS_API SafeConnection final :
public std::enable_shared_from_this<SafeConnection>
{
struct Tag final
{
};

using MutexType = std::recursive_mutex;

public:
SafeConnection(Tag, Connection connection);

static std::shared_ptr<SafeConnection> Open(
std::string_view path, OpenMode mode = OpenMode::ReadWriteCreate,
ThreadMode threadMode = ThreadMode::Serialized,
Error* openError = nullptr);

static std::shared_ptr<SafeConnection> Reopen(
const Connection& connection, OpenMode mode = OpenMode::ReadWriteCreate,
ThreadMode threadMode = ThreadMode::Serialized,
Error* openError = nullptr);

static std::shared_ptr<SafeConnection> Reopen(
sqlite3* connection, OpenMode mode = OpenMode::ReadWriteCreate,
ThreadMode threadMode = ThreadMode::Serialized,
Error* openError = nullptr);

static std::shared_ptr<SafeConnection> Reopen(
SafeConnection& connection, OpenMode mode = OpenMode::ReadWriteCreate,
ThreadMode threadMode = ThreadMode::Serialized,
Error* openError = nullptr);

struct SQLITE_HELPERS_API Lock final
{
explicit Lock(std::shared_ptr<SafeConnection> connection);
~Lock() = default;

Lock(const Lock&) = delete;
Lock& operator=(const Lock&) = delete;
Lock(Lock&&) = default;
Lock& operator=(Lock&&) = default;

Connection* operator->() noexcept;
const Connection* operator->() const noexcept;

Connection& operator*() noexcept;
const Connection& operator*() const noexcept;

explicit operator bool() const noexcept;
bool IsValid() const noexcept;

private:
std::shared_ptr<SafeConnection> mSafeConnection;
std::unique_lock<MutexType> mLock;
};

Lock Acquire() noexcept;

private:
Connection mConnection;
MutexType mConnectionMutex;
};
} // namespace sqlite

0 comments on commit deb96d6

Please sign in to comment.