diff --git a/components/ai_chat/core/browser/BUILD.gn b/components/ai_chat/core/browser/BUILD.gn index 00f196b2bac2..845f3a5f8bf5 100644 --- a/components/ai_chat/core/browser/BUILD.gn +++ b/components/ai_chat/core/browser/BUILD.gn @@ -13,8 +13,12 @@ static_library("browser") { sources = [ "ai_chat_credential_manager.cc", "ai_chat_credential_manager.h", + "ai_chat_database.cc", + "ai_chat_database.h", "ai_chat_feedback_api.cc", "ai_chat_feedback_api.h", + "ai_chat_storage_service.cc", + "ai_chat_storage_service.h", "ai_chat_metrics.cc", "ai_chat_metrics.h", "ai_chat_service.cc", @@ -92,6 +96,7 @@ static_library("browser") { "//services/data_decoder/public/cpp", "//services/network/public/cpp", "//services/service_manager/public/cpp", + "//sql:sql", "//third_party/abseil-cpp:absl", "//third_party/re2", "//ui/base", @@ -124,6 +129,7 @@ if (!is_ios) { testonly = true sources = [ "ai_chat_credential_manager_unittest.cc", + "ai_chat_database_unittest.cc", "ai_chat_metrics_unittest.cc", "ai_chat_service_unittest.cc", "associated_content_driver_unittest.cc", diff --git a/components/ai_chat/core/browser/ai_chat_database.cc b/components/ai_chat/core/browser/ai_chat_database.cc new file mode 100644 index 000000000000..1445293f4f2b --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_database.cc @@ -0,0 +1,413 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" + +#include + +#include "base/strings/string_split.h" +#include "base/time/time.h" +#include "sql/statement.h" +#include "sql/transaction.h" + +namespace { +// Do not use To/FromInternalValues in base::Time +// https://bugs.chromium.org/p/chromium/issues/detail?id=634507#c23 + +base::TimeDelta SerializeTimeToDelta(const base::Time& time) { + return time.ToDeltaSinceWindowsEpoch(); +} + +base::Time DeserializeTime(const int64_t& serialized_time) { + return base::Time() + base::Microseconds(serialized_time); +} + +} // namespace + +namespace ai_chat { +AIChatDatabase::AIChatDatabase() + : db_({.page_size = 4096, .cache_size = 1000}) {} + +AIChatDatabase::~AIChatDatabase() = default; + +bool AIChatDatabase::Init(const base::FilePath& db_file_path) { + if (!GetDB().Open(db_file_path)) { + return false; + } + + if (!CreateConversationTable() || !CreateConversationEntryTable() || + !CreateConversationEntryTextTable() || !CreateSearchQueriesTable()) { + DVLOG(0) << "Failure to create tables\n"; + return false; + } + + return true; +} + +std::vector AIChatDatabase::GetAllConversations() { + sql::Statement statement(GetDB().GetUniqueStatement( + "SELECT conversation.*, entries.date " + "FROM conversation" + " LEFT JOIN (" + "SELECT conversation_id, date" + " FROM conversation_entry" + " GROUP BY conversation_id" + " ORDER BY date DESC" + " LIMIT 1" + ") AS entries" + " ON conversation.id = entries.conversation_id")); + + std::vector conversation_list; + + while (statement.Step()) { + mojom::ConversationPtr conversation = mojom::Conversation::New(); + int index = 0; + conversation->id = statement.ColumnInt64(index++); + conversation->title = statement.ColumnString(index++); + conversation->page_url = GURL(statement.ColumnString(index++)); + conversation->page_contents = statement.ColumnString(index++); + conversation->date = DeserializeTime(statement.ColumnInt64(index++)); + conversation_list.emplace_back(std::move(conversation)); + } + + return conversation_list; +} + +std::vector AIChatDatabase::GetConversationEntries( + int64_t conversation_id) { + sql::Statement statement(GetDB().GetUniqueStatement( + "SELECT conversation_entry.id, conversation_entry.date, " + "conversation_entry.character_type, conversation_entry.action_type, " + "conversation_entry.selected_text, text.id, text.date, text.text, " + "text.conversation_entry_id, " + "GROUP_CONCAT(DISTINCT search_queries.query) AS search_queries" + " FROM conversation_entry" + " JOIN conversation_entry_text as text" + " JOIN search_queries" + " ON conversation_entry.id = text.conversation_entry_id" + " WHERE conversation_entry.conversation_id=?" + " GROUP BY conversation_entry.id, text.id" + " ORDER BY conversation_entry.date ASC")); + + statement.BindInt64(0, conversation_id); + + std::vector history; + + int64_t last_conversation_entry_id = -1; + + while (statement.Step()) { + mojom::ConversationEntryTextPtr entry_text = + mojom::ConversationEntryText::New(); + entry_text->id = statement.ColumnInt64(5); + entry_text->date = DeserializeTime(statement.ColumnInt64(6)); + entry_text->text = statement.ColumnString(7); + + int64_t current_conversation_id = statement.ColumnInt64(0); + + if (last_conversation_entry_id == current_conversation_id) { + // Update the last entry if it is the same conversation entry + history.back()->texts.emplace_back(std::move(entry_text)); + } else { + mojom::ConversationEntryPtr entry = mojom::ConversationEntry::New(); + entry->id = statement.ColumnInt64(0); + entry->date = DeserializeTime(statement.ColumnInt64(1)); + entry->character_type = + static_cast(statement.ColumnInt(2)); + entry->action_type = + static_cast(statement.ColumnInt(3)); + entry->selected_text = statement.ColumnString(4); + + // Parse search queries + std::string search_queries_all = statement.ColumnString(9); + if (!search_queries_all.empty()) { + entry->events = std::vector(); + + std::vector search_queries = + base::SplitString(search_queries_all, ",", base::TRIM_WHITESPACE, + base::SPLIT_WANT_NONEMPTY); + entry->events->emplace_back( + mojom::ConversationEntryEvent::NewSearchQueriesEvent( + mojom::SearchQueriesEvent::New(std::move(search_queries)))); + } + + // Add the text to the new entry + entry->texts = std::vector(); + entry->texts.emplace_back(std::move(entry_text)); + + last_conversation_entry_id = entry->id; + + // Add the new entry to the history + history.emplace_back(std::move(entry)); + } + } + + return history; +} + +int64_t AIChatDatabase::AddConversation(mojom::ConversationPtr conversation) { + sql::Statement statement(GetDB().GetUniqueStatement( + "INSERT INTO conversation(id, title, " + "page_url, page_contents) VALUES(NULL, ?, ?, ?)")); + CHECK(statement.is_valid()); + + statement.BindString(0, conversation->title); + if (conversation->page_url.has_value()) { + statement.BindString(1, conversation->page_url->spec()); + } else { + statement.BindNull(1); + } + + if (conversation->page_contents.has_value()) { + statement.BindString(2, conversation->page_contents.value()); + } else { + statement.BindNull(2); + } + + if (!statement.Run()) { + DVLOG(0) << "Failed to execute 'conversation' insert statement" + << db_.GetErrorMessage(); + return INT64_C(-1); + } + + return GetDB().GetLastInsertRowId(); +} + +int64_t AIChatDatabase::AddConversationEntry( + int64_t conversation_id, + mojom::ConversationEntryPtr entry) { + sql::Transaction transaction(&GetDB()); + + CHECK(GetDB().is_open()); + + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin\n"; + return INT64_C(-1); + } + + sql::Statement get_conversation_id_statement( + GetDB().GetUniqueStatement("SELECT id FROM conversation" + " WHERE id=?")); + CHECK(get_conversation_id_statement.is_valid()); + + get_conversation_id_statement.BindInt64(0, conversation_id); + + if (!get_conversation_id_statement.Step()) { + DVLOG(0) << "ID not found in 'conversation' table"; + return INT64_C(-1); + } + + sql::Statement insert_conversation_entry_statement(GetDB().GetUniqueStatement( + "INSERT INTO conversation_entry(id, date, " + "character_type, action_type, selected_text, " + "conversation_id) VALUES(NULL, ?, ?, ?, ?, ?)")); + CHECK(insert_conversation_entry_statement.is_valid()); + + // ConversationEntry's date should always match first text's date + int index = 0; + insert_conversation_entry_statement.BindTimeDelta( + index++, SerializeTimeToDelta(entry->texts[0]->date)); + insert_conversation_entry_statement.BindInt( + index++, static_cast(entry->character_type)); + insert_conversation_entry_statement.BindInt( + index++, static_cast(entry->action_type)); + + if (entry->selected_text.has_value()) { + insert_conversation_entry_statement.BindString( + index++, entry->selected_text.value()); + } else { + insert_conversation_entry_statement.BindNull(index++); + } + + insert_conversation_entry_statement.BindInt64(index++, conversation_id); + + if (!insert_conversation_entry_statement.Run()) { + DVLOG(0) << "Failed to execute 'conversation_entry' insert statement: " + << db_.GetErrorMessage(); + return INT64_C(-1); + } + + int64_t conversation_entry_row_id = GetDB().GetLastInsertRowId(); + + for (mojom::ConversationEntryTextPtr& text : entry->texts) { + // Add texts + AddConversationEntryText(conversation_entry_row_id, std::move(text)); + } + + if (entry->events.has_value()) { + for (const mojom::ConversationEntryEventPtr& event : *entry->events) { + if (event->is_search_queries_event()) { + std::vector queries = + event->get_search_queries_event()->search_queries; + for (const std::string& query : queries) { + // Add search queries + AddSearchQuery(conversation_id, conversation_entry_row_id, query); + } + } + } + } + + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return INT64_C(-1); + } + + return conversation_entry_row_id; +} + +int64_t AIChatDatabase::AddConversationEntryText( + int64_t conversation_entry_id, + mojom::ConversationEntryTextPtr entry_text) { + sql::Statement get_conversation_entry_statement( + GetDB().GetUniqueStatement("SELECT id FROM conversation_entry" + " WHERE id=?")); + CHECK(get_conversation_entry_statement.is_valid()); + get_conversation_entry_statement.BindInt64(0, conversation_entry_id); + + if (!get_conversation_entry_statement.Step()) { + DVLOG(0) << "ID not found in 'conversation entry' table" + << db_.GetErrorMessage(); + return INT64_C(-1); + } + + sql::Statement insert_text_statement( + GetDB().GetUniqueStatement("INSERT INTO " + "conversation_entry_text(" + "id, date, text, conversation_entry_id" + ") VALUES(NULL, ?, ?, ?)")); + CHECK(insert_text_statement.is_valid()); + + insert_text_statement.BindTimeDelta(0, + SerializeTimeToDelta(entry_text->date)); + insert_text_statement.BindString(1, entry_text->text); + insert_text_statement.BindInt64(2, conversation_entry_id); + + if (!insert_text_statement.Run()) { + DVLOG(0) << "Failed to execute 'conversation_entry_text' insert statement: " + << db_.GetErrorMessage(); + return INT64_C(-1); + } + + return GetDB().GetLastInsertRowId(); +} + +int64_t AIChatDatabase::AddSearchQuery(int64_t conversation_id, + int64_t conversation_entry_id, + const std::string& query) { + sql::Statement insert_search_query_statement(GetDB().GetUniqueStatement( + "INSERT INTO search_queries" + "(id, query, conversation_id, conversation_entry_id)" + "VALUES(NULL, ?, ?, ?)")); + CHECK(insert_search_query_statement.is_valid()); + + int index = 0; + insert_search_query_statement.BindString(index++, query); + insert_search_query_statement.BindInt64(index++, conversation_id); + insert_search_query_statement.BindInt64(index++, conversation_entry_id); + + if (!insert_search_query_statement.Run()) { + DVLOG(0) << "Failed to execute 'search_queries' insert statement: " + << db_.GetErrorMessage(); + return INT64_C(-1); + } + + return GetDB().GetLastInsertRowId(); +} + +bool AIChatDatabase::DeleteConversation(int64_t conversation_id) { + sql::Transaction transaction(&db_); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin\n"; + return false; + } + + sql::Statement delete_conversation_statement( + GetDB().GetUniqueStatement("DELETE FROM conversation" + "WHERE id=?")); + delete_conversation_statement.BindInt64(0, conversation_id); + + if (!delete_conversation_statement.Run()) { + return false; + } + + sql::Statement select_conversation_entry_statement( + GetDB().GetUniqueStatement("SELECT id FROM conversation_entry" + "WHERE conversation_id=?")); + + // We remove all conversation text associated to |conversation_id| + while (select_conversation_entry_statement.Step()) { + sql::Statement delete_conversation_text_statement( + GetDB().GetUniqueStatement("DELETE FROM conversation_entry_text" + "WHERE conversation_entry_id=?")); + delete_conversation_text_statement.BindInt64( + 0, select_conversation_entry_statement.ColumnInt64(0)); + + if (!delete_conversation_text_statement.Run()) { + return false; + } + } + + sql::Statement delete_conversation_entry_statement( + GetDB().GetUniqueStatement("DELETE FROM conversation_entry" + "WHERE conversation_entry_id=?")); + delete_conversation_entry_statement.BindInt64(0, conversation_id); + + // At last we remove all conversation entries associated to |conversation_id| + if (!delete_conversation_entry_statement.Run()) { + return false; + } + + transaction.Commit(); + return true; +} + +bool AIChatDatabase::DropAllTables() { + return GetDB().Execute("DROP TABLE conversation") && + GetDB().Execute("DROP TABLE conversation_entry") && + GetDB().Execute("DROP TABLE conversation_entry_text"); +} + +sql::Database& AIChatDatabase::GetDB() { + return db_; +} + +bool AIChatDatabase::CreateConversationTable() { + return GetDB().Execute( + "CREATE TABLE IF NOT EXISTS conversation(" + "id INTEGER PRIMARY KEY," + "title TEXT," + "page_url TEXT," + "page_contents TEXT)"); +} + +bool AIChatDatabase::CreateConversationEntryTable() { + return GetDB().Execute( + "CREATE TABLE IF NOT EXISTS conversation_entry(" + "id INTEGER PRIMARY KEY," + "date INTEGER NOT NULL," + "character_type INTEGER NOT NULL," + "action_type INTEGER NOT NULL," + "selected_text TEXT," + "conversation_id INTEGER NOT NULL)"); +} + +bool AIChatDatabase::CreateConversationEntryTextTable() { + return GetDB().Execute( + "CREATE TABLE IF NOT EXISTS conversation_entry_text(" + "id INTEGER PRIMARY KEY," + "date INTEGER NOT NULL," + "text TEXT NOT NULL," + "conversation_entry_id INTEGER NOT NULL)"); +} + +bool AIChatDatabase::CreateSearchQueriesTable() { + return GetDB().Execute( + "CREATE TABLE IF NOT EXISTS search_queries(" + "id INTEGER PRIMARY KEY," + "query TEXT NOT NULL," + "conversation_id INTEGER NOT NULL," + "conversation_entry_id INTEGER NOT NULL)"); +} +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_database.h b/components/ai_chat/core/browser/ai_chat_database.h new file mode 100644 index 000000000000..38c9931ac597 --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_database.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ + +#include + +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "sql/database.h" +#include "sql/init_status.h" + +namespace sql { +class Database; +} // namespace sql + +namespace ai_chat { + +class AIChatDatabase { + public: + AIChatDatabase(); + AIChatDatabase(const AIChatDatabase&) = delete; + AIChatDatabase& operator=(const AIChatDatabase&) = delete; + ~AIChatDatabase(); + + bool Init(const base::FilePath& db_file_path); + + std::vector GetAllConversations(); + std::vector GetConversationEntries( + int64_t conversation_id); + int64_t AddConversation(mojom::ConversationPtr conversation); + int64_t AddConversationEntry(int64_t conversation_id, + mojom::ConversationEntryPtr entry); + int64_t AddConversationEntryText(int64_t conversation_entry_id, + mojom::ConversationEntryTextPtr entry_text); + int64_t AddSearchQuery(int64_t conversation_id, + int64_t conversation_entry_id, + const std::string& query); + bool DeleteConversation(int64_t conversation_id); + bool DropAllTables(); + + private: + sql::Database& GetDB(); + + bool CreateConversationTable(); + bool CreateConversationEntryTable(); + bool CreateConversationEntryTextTable(); + bool CreateSearchQueriesTable(); + + sql::Database db_; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ diff --git a/components/ai_chat/core/browser/ai_chat_database_unittest.cc b/components/ai_chat/core/browser/ai_chat_database_unittest.cc new file mode 100644 index 000000000000..b16e3b6f2f42 --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_database_unittest.cc @@ -0,0 +1,188 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" + +#include + +#include +#include +#include + +#include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/files/scoped_temp_dir.h" +#include "base/path_service.h" +#include "sql/test/test_helpers.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace { +int64_t GetInternalValue(const base::Time& time) { + return time.ToDeltaSinceWindowsEpoch().InMicroseconds(); +} +} // namespace + +namespace ai_chat { + +class AIChatDatabaseTest : public testing::Test { + public: + AIChatDatabaseTest() = default; + + void SetUp() override { + ASSERT_TRUE(temp_directory_.CreateUniqueTempDir()); + // We take ownership of the path for adhoc inspection of the spitted + // db. This is useful to test sql queries. If you do this, make sure to + // comment out the base::DeletePathRecursively line. + path_ = temp_directory_.Take(); + db_ = std::make_unique(); + ASSERT_TRUE(db_->Init(db_file_path())); + } + + void TearDown() override { + db_.reset(); + ASSERT_TRUE(base::DeletePathRecursively(path_)); + } + + base::FilePath db_file_path() { return path_.AppendASCII("ai_chat"); } + + protected: + base::ScopedTempDir temp_directory_; + std::unique_ptr db_; + base::FilePath path_; +}; + +TEST_F(AIChatDatabaseTest, AddConversations) { + int64_t first_id = db_->AddConversation(mojom::Conversation::New( + INT64_C(1), base::Time::Now(), "Initial conversation", GURL(), + "Sample page content")); + EXPECT_EQ(first_id, INT64_C(1)); + std::vector conversations = + db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), UINT64_C(1)); + int64_t second_id = db_->AddConversation(mojom::Conversation::New( + INT64_C(2), base::Time::Now(), "Another conversation", GURL(), + "Another sample page content")); + EXPECT_EQ(second_id, INT64_C(2)); + std::vector conversations_updated = + db_->GetAllConversations(); + EXPECT_EQ(conversations_updated.size(), UINT64_C(2)); +} + +TEST_F(AIChatDatabaseTest, AddConversationEntries) { + static const int64_t kConversationId = INT64_C(1); + static const int64_t kConversationEntryId = INT64_C(1); + + static const GURL kURL = GURL("https://brave.com/"); + static const char kConversationTitle[] = "Summarizing Brave"; + static const char kPageContents[] = "Brave is a web browser."; + static const char kFirstResponse[] = "This is a generated response"; + static const char kSecondResponse[] = "This is a re-generated response"; + + static const base::Time kFirstTextCreatedAt(base::Time::Now()); + static const base::Time kSecondTextCreatedAt(base::Time::Now() + + base::Minutes(5)); + + // Add conversation + int64_t conversation_id = db_->AddConversation( + mojom::Conversation::New(kConversationId, kFirstTextCreatedAt, + kConversationTitle, kURL, kPageContents)); + EXPECT_EQ(conversation_id, kConversationId); + + // Add first conversation entry + { + std::vector texts; + texts.emplace_back(mojom::ConversationEntryText::New(1, kFirstTextCreatedAt, + kFirstResponse)); + + std::vector search_queries{"brave", "browser", "web"}; + std::vector events; + events.emplace_back(mojom::ConversationEntryEvent::NewSearchQueriesEvent( + mojom::SearchQueriesEvent::New(std::move(search_queries)))); + + int64_t entry_id = db_->AddConversationEntry( + conversation_id, + mojom::ConversationEntry::New( + INT64_C(-1), kFirstTextCreatedAt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::UNSPECIFIED, std::nullopt /* selected_text */, + std::move(texts), std::move(events))); + EXPECT_EQ(entry_id, kConversationEntryId); + } + + // Get and verify conversations + { + std::vector conversations = + db_->GetAllConversations(); + EXPECT_FALSE(conversations.empty()); + EXPECT_EQ(conversations[0]->id, INT64_C(1)); + EXPECT_EQ(conversations[0]->title, kConversationTitle); + EXPECT_EQ(conversations[0]->page_url->spec(), kURL.spec()); + EXPECT_EQ(conversations[0]->page_contents, kPageContents); + // A conversation is created with the first entry's date + EXPECT_EQ(GetInternalValue(conversations[0]->date), + GetInternalValue(kFirstTextCreatedAt)); + } + + // Get and verify conversation entries + { + std::vector entries = + db_->GetConversationEntries(conversation_id); + EXPECT_EQ(UINT64_C(1), entries.size()); + EXPECT_EQ(entries[0]->character_type, mojom::CharacterType::ASSISTANT); + EXPECT_EQ(GetInternalValue(entries[0]->date), + GetInternalValue(kFirstTextCreatedAt)); + EXPECT_EQ(UINT64_C(1), entries[0]->texts.size()); + EXPECT_EQ(entries[0]->texts[0]->text, kFirstResponse); + EXPECT_EQ(entries[0]->action_type, mojom::ActionType::UNSPECIFIED); + EXPECT_TRUE(entries[0]->selected_text.value().empty()); + + // Verify search queries + EXPECT_TRUE(entries.front()->events.has_value()); + EXPECT_EQ(UINT64_C(3), entries[0] + ->events.value()[0] + ->get_search_queries_event() + ->search_queries.size()); + } + + // Add another text to the first entry + db_->AddConversationEntryText( + UINT64_C(1), mojom::ConversationEntryText::New( + UINT64_C(1), kSecondTextCreatedAt, kSecondResponse)); + + { + auto entries = db_->GetConversationEntries(conversation_id); + EXPECT_EQ(UINT64_C(2), entries[0]->texts.size()); + EXPECT_EQ(entries[0]->texts[1]->text, kSecondResponse); + } + + // Add another entry + { + base::Time created_at(base::Time::Now()); + std::string text("Is this a question?"); + std::string selected_text("The brown fox jumps over the lazy dog"); + + std::vector texts; + texts.emplace_back(mojom::ConversationEntryText::New(1, created_at, text)); + + int64_t entry_id = db_->AddConversationEntry( + conversation_id, + mojom::ConversationEntry::New( + INT64_C(-1), created_at, mojom::CharacterType::HUMAN, + mojom::ActionType::CASUALIZE, selected_text, std::move(texts), + std::nullopt /* search_query */)); + + EXPECT_EQ(entry_id, 2); + } + + // Get and verify conversation entries + { + const std::vector& entries = + db_->GetConversationEntries(conversation_id); + EXPECT_EQ(UINT64_C(2), entries.size()); + EXPECT_EQ(entries[1]->selected_text.value(), + "The brown fox jumps over the lazy dog"); + } +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_storage_service.cc b/components/ai_chat/core/browser/ai_chat_storage_service.cc new file mode 100644 index 000000000000..c213e1e2a37c --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_storage_service.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/ai_chat/core/browser/ai_chat_storage_service.h" + +#include +#include + +#include "base/task/thread_pool.h" +#include "content/public/browser/browser_context.h" + +namespace ai_chat { +namespace { +constexpr base::FilePath::StringPieceType kBaseDirName = + FILE_PATH_LITERAL("AIChat"); +} // namespace + +AIChatStorageService::AIChatStorageService(content::BrowserContext* context) + : base_dir_(context->GetPath().Append(kBaseDirName)) { + ai_chat_db_ = base::SequenceBound(GetTaskRunner()); + + auto on_response = base::BindOnce( + [](bool success) { DVLOG(1) << "AIChatDB Init: " << success; }); + + ai_chat_db_.AsyncCall(&AIChatDatabase::Init) + .WithArgs(base_dir_) + .Then(std::move(on_response)); +} + +AIChatStorageService::~AIChatStorageService() = default; + +base::SequencedTaskRunner* AIChatStorageService::GetTaskRunner() { + if (!task_runner_) { + task_runner_ = base::ThreadPool::CreateSequencedTaskRunner( + {base::MayBlock(), base::WithBaseSyncPrimitives(), + base::TaskPriority::BEST_EFFORT, + base::TaskShutdownBehavior::BLOCK_SHUTDOWN}); + } + + return task_runner_.get(); +} + +void AIChatStorageService::SyncConversation(mojom::ConversationPtr conversation, + ConversationCallback callback) { + auto on_added = [](mojom::ConversationPtr conversation, + ConversationCallback callback, int64_t id) { + conversation->id = id; + std::move(callback).Run(std::move(conversation)); + }; + + ai_chat_db_.AsyncCall(&AIChatDatabase::AddConversation) + .WithArgs(conversation.Clone()) + .Then(base::BindOnce(on_added, std::move(conversation), + std::move(callback))); +} + +void AIChatStorageService::SyncConversationTurn(int64_t conversation_id, + mojom::ConversationTurnPtr turn) { + NOTIMPLEMENTED(); +} + +void AIChatStorageService::GetConversationForGURL(const GURL& gurl, + ConversationCallback callback) { + auto on_get = [](const GURL& gurl, ConversationCallback callback, + std::vector conversations) { + auto it = std::find_if(conversations.begin(), conversations.end(), + [&gurl](const auto& conversation) { + return conversation->page_url == gurl.spec(); + }); + if (it != conversations.end()) { + std::move(callback).Run(std::move(*it)); + } else { + std::move(callback).Run(std::nullopt); + } + }; + + ai_chat_db_.AsyncCall(&AIChatDatabase::GetAllConversations) + .Then(base::BindOnce(on_get, gurl, std::move(callback))); +} + +void AIChatStorageService::Shutdown() { + task_runner_.reset(); +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_storage_service.h b/components/ai_chat/core/browser/ai_chat_storage_service.h new file mode 100644 index 000000000000..28e6ca464665 --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_storage_service.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_STORAGE_SERVICE_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_STORAGE_SERVICE_H_ + +#include "base/files/file_path.h" +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/threading/sequence_bound.h" +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" +#include "components/keyed_service/core/keyed_service.h" + +namespace base { +class SequencedTaskRunner; +} // namespace base + +namespace content { +class BrowserContext; +} // namespace content + +namespace ai_chat { +using ConversationCallback = base::OnceCallback conversation)>; +class AIChatStorageService : public KeyedService { + public: + explicit AIChatStorageService(content::BrowserContext* context); + ~AIChatStorageService() override; + AIChatStorageService(const AIChatStorageService&) = delete; + AIChatStorageService& operator=(const AIChatStorageService&) = delete; + + void SyncConversation(mojom::ConversationPtr conversation, + ConversationCallback callback); + void SyncConversationTurn(int64_t conversation_id, + mojom::ConversationTurnPtr turn); + void GetConversationForGURL(const GURL& gurl, ConversationCallback callback); + + private: + base::SequencedTaskRunner* GetTaskRunner(); + + // KeyedService overrides: + void Shutdown() override; + + const base::FilePath base_dir_; + + scoped_refptr task_runner_; + + base::SequenceBound ai_chat_db_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_STORAGE_SERVICE_H_ diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index 814340753d47..9126e17d0a7c 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -7,6 +7,7 @@ module ai_chat.mojom; import "mojo/public/mojom/base/time.mojom"; import "url/mojom/url.mojom"; +import "mojo/public/mojom/base/time.mojom"; enum CharacterType { HUMAN, @@ -436,3 +437,28 @@ interface ConversationUI { interface IAPSubscription { GetPurchaseTokenOrderId() => (string token, string order_id); }; + +// ids below are generated by sqlite +struct Conversation { + int64 id; + mojo_base.mojom.Time date; + string title; + url.mojom.Url? page_url; + string? page_contents; +}; + +struct ConversationEntry { + int64 id; + mojo_base.mojom.Time date; + CharacterType character_type; + ActionType action_type; + string? selected_text; + array texts; + array? events; +}; + +struct ConversationEntryText { + int64 id; + mojo_base.mojom.Time date; + string text; +};