From f9082573da7ce2a50d5d75c8456dbd2d352d4a08 Mon Sep 17 00:00:00 2001 From: Dmitry Vedenko Date: Fri, 16 Feb 2024 12:48:32 +0300 Subject: [PATCH] Use SafeConnection instead of Connection --- .../sync/CloudProjectsDatabase.cpp | 49 ++-- .../sync/CloudProjectsDatabase.h | 8 +- .../sync/RemoteProjectSnapshot.cpp | 268 ++++++++---------- .../sync/RemoteProjectSnapshot.h | 2 - 4 files changed, 153 insertions(+), 174 deletions(-) diff --git a/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.cpp b/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.cpp index 16a4bb3616b9..3ec87873e3d3 100644 --- a/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.cpp +++ b/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.cpp @@ -73,17 +73,18 @@ CloudProjectsDatabase& CloudProjectsDatabase::Get() bool CloudProjectsDatabase::IsOpen() const { - return mConnection.IsOpen(); + return !!mConnection; } -sqlite::Connection& CloudProjectsDatabase::GetConnection() +sqlite::SafeConnection::Lock CloudProjectsDatabase::GetConnection() { if (!mConnection) OpenConnection(); - return mConnection; + // It is safe to call lock on a null connection + return sqlite::SafeConnection::Lock { mConnection }; } -const sqlite::Connection& CloudProjectsDatabase::GetConnection() const +const sqlite::SafeConnection::Lock CloudProjectsDatabase::GetConnection() const { return const_cast(this)->GetConnection(); } @@ -91,12 +92,12 @@ const sqlite::Connection& CloudProjectsDatabase::GetConnection() const std::optional CloudProjectsDatabase::GetProjectData(const std::string_view& projectId) const { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return {}; - auto statement = connection.CreateStatement( + auto statement = connection->CreateStatement( "SELECT project_id, snapshot_id, saves_count, last_audio_preview_save, local_path, last_modified, last_read, sync_status FROM projects WHERE project_id = ? LIMIT 1"); if (!statement) @@ -108,12 +109,12 @@ CloudProjectsDatabase::GetProjectData(const std::string_view& projectId) const std::optional CloudProjectsDatabase::GetProjectDataForPath( const std::string& projectFilePath) const { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return {}; - auto statement = connection.CreateStatement( + auto statement = connection->CreateStatement( "SELECT project_id, snapshot_id, saves_count, last_audio_preview_save, local_path, last_modified, last_read, sync_status FROM projects WHERE local_path = ? LIMIT 1"); if (!statement) @@ -125,12 +126,12 @@ std::optional CloudProjectsDatabase::GetProjectDataForPath( bool CloudProjectsDatabase::MarkProjectAsSynced( const std::string_view& projectId, const std::string_view& snapshotId) { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return false; - auto statement = connection.CreateStatement ("UPDATE projects SET sync_status = ? WHERE project_id = ? AND snapshot_id = ?"); + auto statement = connection->CreateStatement ("UPDATE projects SET sync_status = ? WHERE project_id = ? AND snapshot_id = ?"); if (!statement) return false; @@ -146,16 +147,16 @@ bool CloudProjectsDatabase::MarkProjectAsSynced( void CloudProjectsDatabase::UpdateProjectBlockList( const std::string_view& projectId, const SampleBlockIDSet& blockSet) { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return; - auto inProjectSet = connection.CreateScalarFunction( + auto inProjectSet = connection->CreateScalarFunction( "inProjectSet", [&blockSet](int64_t blockIndex) { return blockSet.find(blockIndex) != blockSet.end(); }); - auto statement = connection.CreateStatement( + auto statement = connection->CreateStatement( "DELETE FROM block_hashes WHERE project_id = ? AND NOT inProjectSet(block_id)"); auto result = statement->Prepare(projectId).Run(); @@ -169,12 +170,12 @@ void CloudProjectsDatabase::UpdateProjectBlockList( std::optional CloudProjectsDatabase::GetBlockHash( const std::string_view& projectId, int64_t blockId) const { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return {}; - auto statement = connection.CreateStatement ("SELECT hash FROM block_hashes WHERE project_id = ? AND block_id = ? LIMIT 1"); + auto statement = connection->CreateStatement ("SELECT hash FROM block_hashes WHERE project_id = ? AND block_id = ? LIMIT 1"); if (!statement) return {}; @@ -198,17 +199,17 @@ void CloudProjectsDatabase::UpdateBlockHashes( const std::string_view& projectId, const std::vector>& hashes) { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return; const int localVar {}; - auto transaction = connection.BeginTransaction( + auto transaction = connection->BeginTransaction( std::string("UpdateBlockHashes_") + std::to_string(reinterpret_cast(&localVar))); - auto statement = connection.CreateStatement ( + auto statement = connection->CreateStatement ( "INSERT OR REPLACE INTO block_hashes (project_id, block_id, hash) VALUES (?, ?, ?)"); for (const auto& [blockId, hash] : hashes) @@ -220,12 +221,12 @@ void CloudProjectsDatabase::UpdateBlockHashes( bool CloudProjectsDatabase::UpdateProjectData( const DBProjectData& projectData) { - auto& connection = GetConnection(); + auto connection = GetConnection(); if (!connection) return false; - auto statement = connection.CreateStatement ( + auto statement = connection->CreateStatement ( "INSERT OR REPLACE INTO projects (project_id, snapshot_id, saves_count, last_audio_preview_save, local_path, last_modified, last_read, sync_status) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"); if (!statement) @@ -292,14 +293,12 @@ bool CloudProjectsDatabase::OpenConnection() const auto configDir = FileNames::ConfigDir(); const auto configPath = configDir + "/audiocom_sync.db"; - auto db = sqlite::Connection::Open(audacity::ToUTF8(configPath)); + mConnection = sqlite::SafeConnection::Open(audacity::ToUTF8(configPath)); - if (!db) + if (!mConnection) return false; - mConnection = std::move(*db); - - auto result = mConnection.Execute(createTableQuery); + auto result = mConnection->Acquire()->Execute(createTableQuery); if (!result) { diff --git a/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.h b/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.h index e615d49319ad..58d70a701ab0 100644 --- a/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.h +++ b/libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.h @@ -13,7 +13,7 @@ #include #include -#include "sqlite/Connection.h" +#include "sqlite/SafeConnection.h" #include "CloudSyncUtils.h" @@ -47,8 +47,8 @@ class CloudProjectsDatabase final bool IsOpen() const; - sqlite::Connection& GetConnection(); - const sqlite::Connection& GetConnection() const; + sqlite::SafeConnection::Lock GetConnection(); + const sqlite::SafeConnection::Lock GetConnection() const; std::optional GetProjectData(const std::string_view& projectId) const; std::optional GetProjectDataForPath(const std::string& projectPath) const; @@ -67,7 +67,7 @@ class CloudProjectsDatabase final private: std::optional DoGetProjectData(sqlite::RunResult result) const; bool OpenConnection(); - sqlite::Connection mConnection; + std::shared_ptr mConnection; }; } // namespace cloud::audiocom::sync diff --git a/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.cpp b/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.cpp index 165cb5ecc971..4c5b5286888e 100644 --- a/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.cpp +++ b/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.cpp @@ -38,16 +38,13 @@ RemoteProjectSnapshot::RemoteProjectSnapshot( , mPath { std::move(path) } , mCallback { std::move(callback) } { - { - auto writeLock = std::lock_guard { mDbWriteMutex }; - auto& db = CloudProjectsDatabase::Get().GetConnection(); - auto attachStmt = db.CreateStatement("ATTACH DATABASE ? AS ?"); - auto result = attachStmt->Prepare(mPath, mSnapshotDBName).Run(); - mAttached = result.IsOk(); + auto db = CloudProjectsDatabase::Get().GetConnection(); + auto attachStmt = db->CreateStatement("ATTACH DATABASE ? AS ?"); + auto result = attachStmt->Prepare(mPath, mSnapshotDBName).Run(); + mAttached = result.IsOk(); - if (!mAttached) - return; - } + if (!mAttached) + return; auto knownBlocks = CalculateKnownBlocks(); @@ -66,10 +63,7 @@ RemoteProjectSnapshot::RemoteProjectSnapshot( } } - { - auto writeLock = std::lock_guard { mDbWriteMutex }; - MarkProjectInDB(false); - } + MarkProjectInDB(false); mMissingBlocks = mSnapshotInfo.Blocks.size() - knownBlocks.size(); mRequests.reserve(1 + mMissingBlocks); @@ -106,9 +100,8 @@ RemoteProjectSnapshot::~RemoteProjectSnapshot() if (mAttached) { - auto writeLock = std::lock_guard { mDbWriteMutex }; - auto& db = CloudProjectsDatabase::Get().GetConnection(); - auto detachStmt = db.CreateStatement("DETACH DATABASE ?"); + auto db = CloudProjectsDatabase::Get().GetConnection(); + auto detachStmt = db->CreateStatement("DETACH DATABASE ?"); detachStmt->Prepare(mSnapshotDBName).Run(); } } @@ -159,13 +152,13 @@ RemoteProjectSnapshot::CalculateKnownBlocks() const for (const auto& block : mSnapshotInfo.Blocks) remoteBlocks.insert(ToUpper(block.Hash)); - auto& db = CloudProjectsDatabase::Get().GetConnection(); + auto db = CloudProjectsDatabase::Get().GetConnection(); - auto fn = db.CreateScalarFunction( + auto fn = db->CreateScalarFunction( "inRemoteBlocks", [&remoteBlocks](const std::string& hash) { return remoteBlocks.find(hash) != remoteBlocks.end(); }); - auto statement = db.CreateStatement( + auto statement = db->CreateStatement( "SELECT hash FROM block_hashes WHERE project_id = ? AND inRemoteBlocks(hash) AND block_id IN (SELECT blockid FROM " + mSnapshotDBName + ".sampleblocks)"); @@ -274,76 +267,73 @@ void RemoteProjectSnapshot::OnProjectBlobDownloaded( OnFailure({ ResponseResultCode::UnexpectedResponse, {} }); return; } - { - auto writeLock = std::lock_guard { mDbWriteMutex }; - auto& db = CloudProjectsDatabase::Get().GetConnection(); - auto transaction = db.BeginTransaction("p_" + mProjectInfo.Id); + auto db = CloudProjectsDatabase::Get().GetConnection(); + auto transaction = db->BeginTransaction("p_" + mProjectInfo.Id); - auto updateProjectStatement = db.CreateStatement( - "INSERT INTO " + mSnapshotDBName + - ".project (id, dict, doc) VALUES (1, ?1, ?2) " - "ON CONFLICT(id) DO UPDATE SET dict = ?1, doc = ?2"); + auto updateProjectStatement = db->CreateStatement( + "INSERT INTO " + mSnapshotDBName + + ".project (id, dict, doc) VALUES (1, ?1, ?2) " + "ON CONFLICT(id) DO UPDATE SET dict = ?1, doc = ?2"); - if (!updateProjectStatement) - { - OnFailure({ ResponseResultCode::InternalClientError, - audacity::ToUTF8(updateProjectStatement.GetError() - .GetErrorString() - .Translation()) }); - return; - } + if (!updateProjectStatement) + { + OnFailure({ ResponseResultCode::InternalClientError, + audacity::ToUTF8(updateProjectStatement.GetError() + .GetErrorString() + .Translation()) }); + return; + } - auto& preparedUpdateProjectStatement = updateProjectStatement->Prepare(); + auto& preparedUpdateProjectStatement = updateProjectStatement->Prepare(); - preparedUpdateProjectStatement.Bind( - 1, data.data() + sizeof(uint64_t), dictSize, false); + preparedUpdateProjectStatement.Bind( + 1, data.data() + sizeof(uint64_t), dictSize, false); - preparedUpdateProjectStatement.Bind( - 2, data.data() + sizeof(uint64_t) + dictSize, - data.size() - sizeof(uint64_t) - dictSize, false); + preparedUpdateProjectStatement.Bind( + 2, data.data() + sizeof(uint64_t) + dictSize, + data.size() - sizeof(uint64_t) - dictSize, false); - auto result = preparedUpdateProjectStatement.Run(); + auto result = preparedUpdateProjectStatement.Run(); - if (!result.IsOk()) - { - OnFailure( - { ResponseResultCode::InternalClientError, - audacity::ToUTF8( - result.GetErrors().front().GetErrorString().Translation()) }); + if (!result.IsOk()) + { + OnFailure( + { ResponseResultCode::InternalClientError, + audacity::ToUTF8( + result.GetErrors().front().GetErrorString().Translation()) }); - return; - } + return; + } - auto deleteAutosaveStatement = db.CreateStatement( - "DELETE FROM " + mSnapshotDBName + ".autosave WHERE id = 1"); + auto deleteAutosaveStatement = db->CreateStatement( + "DELETE FROM " + mSnapshotDBName + ".autosave WHERE id = 1"); - if (!deleteAutosaveStatement) - { - OnFailure({ ResponseResultCode::InternalClientError, - audacity::ToUTF8(deleteAutosaveStatement.GetError() - .GetErrorString() - .Translation()) }); - return; - } + if (!deleteAutosaveStatement) + { + OnFailure({ ResponseResultCode::InternalClientError, + audacity::ToUTF8(deleteAutosaveStatement.GetError() + .GetErrorString() + .Translation()) }); + return; + } - result = deleteAutosaveStatement->Prepare().Run(); + result = deleteAutosaveStatement->Prepare().Run(); - if (!result.IsOk()) - { - OnFailure( - { ResponseResultCode::InternalClientError, - audacity::ToUTF8( - result.GetErrors().front().GetErrorString().Translation()) }); - return; - } + if (!result.IsOk()) + { + OnFailure( + { ResponseResultCode::InternalClientError, + audacity::ToUTF8( + result.GetErrors().front().GetErrorString().Translation()) }); + return; + } - if (auto error = transaction.Commit(); error.IsError()) - { - OnFailure({ ResponseResultCode::InternalClientError, - audacity::ToUTF8(error.GetErrorString().Translation()) }); - return; - } + if (auto error = transaction.Commit(); error.IsError()) + { + OnFailure({ ResponseResultCode::InternalClientError, + audacity::ToUTF8(error.GetErrorString().Translation()) }); + return; } mProjectDownloaded.store(true, std::memory_order_release); @@ -365,76 +355,73 @@ void RemoteProjectSnapshot::OnBlockDownloaded( .Translation()) }); return; } - { - auto writeLock = std::lock_guard { mDbWriteMutex }; - auto& db = CloudProjectsDatabase::Get().GetConnection(); - auto transaction = db.BeginTransaction("b_" + blockHash); + auto db = CloudProjectsDatabase::Get().GetConnection(); + auto transaction = db->BeginTransaction("b_" + blockHash); - auto hashesStatement = db.CreateStatement( - "INSERT INTO block_hashes (project_id, block_id, hash) VALUES (?1, ?2, ?3) " - "ON CONFLICT(project_id, block_id) DO UPDATE SET hash = ?3"); + auto hashesStatement = db->CreateStatement( + "INSERT INTO block_hashes (project_id, block_id, hash) VALUES (?1, ?2, ?3) " + "ON CONFLICT(project_id, block_id) DO UPDATE SET hash = ?3"); - auto result = hashesStatement - ->Prepare(mProjectInfo.Id, blockData->BlockId, blockHash) - .Run(); + auto result = + hashesStatement->Prepare(mProjectInfo.Id, blockData->BlockId, blockHash) + .Run(); - if (!result.IsOk()) - { - OnFailure( - { ResponseResultCode::InternalClientError, - audacity::ToUTF8( - result.GetErrors().front().GetErrorString().Translation()) }); - return; - } + if (!result.IsOk()) + { + OnFailure( + { ResponseResultCode::InternalClientError, + audacity::ToUTF8( + result.GetErrors().front().GetErrorString().Translation()) }); + return; + } - auto blockStatement = db.CreateStatement( - "INSERT INTO " + mSnapshotDBName + - ".sampleblocks (blockid, sampleformat, summin, summax, sumrms, summary256, summary64k, samples) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) " - "ON CONFLICT(blockid) DO UPDATE SET sampleformat = ?2, summin = ?3, summax = ?4, sumrms = ?5, summary256 = ?6, summary64k = ?7, samples = ?8"); + auto blockStatement = db->CreateStatement( + "INSERT INTO " + mSnapshotDBName + + ".sampleblocks (blockid, sampleformat, summin, summax, sumrms, summary256, summary64k, samples) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) " + "ON CONFLICT(blockid) DO UPDATE SET sampleformat = ?2, summin = ?3, summax = ?4, sumrms = ?5, summary256 = ?6, summary64k = ?7, samples = ?8"); - if (!blockStatement) - { - OnFailure( - { ResponseResultCode::InternalClientError, - audacity::ToUTF8( - blockStatement.GetError().GetErrorString().Translation()) }); - return; - } + if (!blockStatement) + { + OnFailure( + { ResponseResultCode::InternalClientError, + audacity::ToUTF8( + blockStatement.GetError().GetErrorString().Translation()) }); + return; + } - auto& preparedStatement = blockStatement->Prepare(); - - preparedStatement.Bind(1, blockData->BlockId); - preparedStatement.Bind(2, static_cast(blockData->Format)); - preparedStatement.Bind(3, blockData->BlockMinMaxRMS.Min); - preparedStatement.Bind(4, blockData->BlockMinMaxRMS.Max); - preparedStatement.Bind(5, blockData->BlockMinMaxRMS.RMS); - preparedStatement.Bind( - 6, blockData->Summary256.data(), - blockData->Summary256.size() * sizeof(MinMaxRMS), false); - preparedStatement.Bind( - 7, blockData->Summary64k.data(), - blockData->Summary64k.size() * sizeof(MinMaxRMS), false); - preparedStatement.Bind( - 8, blockData->Data.data(), blockData->Data.size(), false); - - result = preparedStatement.Run(); - - if (!result.IsOk()) - { - OnFailure( - { ResponseResultCode::InternalClientError, - audacity::ToUTF8( - result.GetErrors().front().GetErrorString().Translation()) }); - return; - } + auto& preparedStatement = blockStatement->Prepare(); + + preparedStatement.Bind(1, blockData->BlockId); + preparedStatement.Bind(2, static_cast(blockData->Format)); + preparedStatement.Bind(3, blockData->BlockMinMaxRMS.Min); + preparedStatement.Bind(4, blockData->BlockMinMaxRMS.Max); + preparedStatement.Bind(5, blockData->BlockMinMaxRMS.RMS); + preparedStatement.Bind( + 6, blockData->Summary256.data(), + blockData->Summary256.size() * sizeof(MinMaxRMS), false); + preparedStatement.Bind( + 7, blockData->Summary64k.data(), + blockData->Summary64k.size() * sizeof(MinMaxRMS), false); + preparedStatement.Bind( + 8, blockData->Data.data(), blockData->Data.size(), false); + + result = preparedStatement.Run(); + + if (!result.IsOk()) + { + OnFailure( + { ResponseResultCode::InternalClientError, + audacity::ToUTF8( + result.GetErrors().front().GetErrorString().Translation()) }); + return; + } - if (auto error = transaction.Commit(); error.IsError()) - { - OnFailure({ ResponseResultCode::InternalClientError, - audacity::ToUTF8(error.GetErrorString().Translation()) }); - return; - } + if (auto error = transaction.Commit(); error.IsError()) + { + OnFailure({ ResponseResultCode::InternalClientError, + audacity::ToUTF8(error.GetErrorString().Translation()) }); + return; } mDownloadedBlocks.fetch_add(1, std::memory_order_acq_rel); @@ -501,12 +488,7 @@ void RemoteProjectSnapshot::ReportProgress() const auto completed = blocksDownloaded == mMissingBlocks && projectDownloaded; - // This happens under the write lock - if (completed) - { - auto writeLock = std::lock_guard { mDbWriteMutex }; - MarkProjectInDB(true); - } + MarkProjectInDB(true); mCallback({ {}, blocksDownloaded, mMissingBlocks, projectDownloaded }); } diff --git a/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.h b/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.h index 136f5088b358..4eeb73193907 100644 --- a/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.h +++ b/libraries/lib-cloud-audiocom/sync/RemoteProjectSnapshot.h @@ -101,8 +101,6 @@ class RemoteProjectSnapshot final const std::string mPath; RemoteProjectSnapshotStateCallback mCallback; - std::mutex mDbWriteMutex; - std::thread mRequestsThread; std::mutex mRequestsMutex; std::condition_variable mRequestsCV;