diff --git a/browser/ai_chat/ai_chat_service_factory.cc b/browser/ai_chat/ai_chat_service_factory.cc index 03f7829b4d09..c7a904132aac 100644 --- a/browser/ai_chat/ai_chat_service_factory.cc +++ b/browser/ai_chat/ai_chat_service_factory.cc @@ -77,9 +77,10 @@ AIChatServiceFactory::BuildServiceInstanceForBrowserContext( (g_brave_browser_process->process_misc_metrics()) ? g_brave_browser_process->process_misc_metrics()->ai_chat_metrics() : nullptr, + g_browser_process->os_crypt_async(), context->GetDefaultStoragePartition() ->GetURLLoaderFactoryForBrowserProcess(), - version_info::GetChannelString(chrome::GetChannel())); + version_info::GetChannelString(chrome::GetChannel()), context->GetPath()); } } // namespace ai_chat diff --git a/browser/ai_chat/android/ai_chat_utils_android.cc b/browser/ai_chat/android/ai_chat_utils_android.cc index 6d6cc4f4e6b5..e7fd12bbe6fa 100644 --- a/browser/ai_chat/android/ai_chat_utils_android.cc +++ b/browser/ai_chat/android/ai_chat_utils_android.cc @@ -51,7 +51,7 @@ static void JNI_BraveLeoUtils_OpenLeoQuery( // Send the query conversation->MaybeUnlinkAssociatedContent(); mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, base::android::ConvertJavaStringToUTF8(query), std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false); diff --git a/browser/browsing_data/brave_browsing_data_remover_delegate.cc b/browser/browsing_data/brave_browsing_data_remover_delegate.cc index 456bd0bee7d4..4a048f98a418 100644 --- a/browser/browsing_data/brave_browsing_data_remover_delegate.cc +++ b/browser/browsing_data/brave_browsing_data_remover_delegate.cc @@ -20,6 +20,8 @@ #include "components/content_settings/core/browser/host_content_settings_map.h" #if BUILDFLAG(ENABLE_AI_CHAT) +#include "brave/browser/ai_chat/ai_chat_service_factory.h" +#include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #endif @@ -102,7 +104,8 @@ void BraveBrowsingDataRemoverDelegate::ClearShieldsSettings( #if BUILDFLAG(ENABLE_AI_CHAT) void BraveBrowsingDataRemoverDelegate::ClearAiChatHistory(base::Time begin_time, base::Time end_time) { - // Handler for the Brave Leo History clearing. - // It is prepared for future implementation. + // Delete all Leo data + ai_chat::AIChatServiceFactory::GetForBrowserContext(profile_) + ->ClearAllHistory(); } #endif // BUILDFLAG(ENABLE_AI_CHAT) diff --git a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc index e4e43d3dd93c..33b433e567d0 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc +++ b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc @@ -6,7 +6,9 @@ #include "brave/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h" #include +#include #include +#include #include "brave/browser/ai_chat/ai_chat_service_factory.h" #include "brave/browser/ui/side_panel/ai_chat/ai_chat_side_panel_utils.h" diff --git a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc index 0697b5a369a7..c104e7462910 100644 --- a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc +++ b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc @@ -10,10 +10,12 @@ #include #include "base/containers/contains.h" +#include "brave/browser/ai_chat/ai_chat_service_factory.h" #include "brave/browser/brave_browser_process.h" #include "brave/browser/misc_metrics/process_misc_metrics.h" #include "brave/browser/ui/sidebar/sidebar_service_factory.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" +#include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "brave/components/sidebar/browser/sidebar_item.h" #include "brave/components/sidebar/browser/sidebar_service.h" @@ -141,13 +143,13 @@ void BraveLeoAssistantHandler::HandleGetLeoIconVisibility( void BraveLeoAssistantHandler::HandleResetLeoData( const base::Value::List& args) { - auto* service = sidebar::SidebarServiceFactory::GetForProfile(profile_); + auto* sidebar_service = + sidebar::SidebarServiceFactory::GetForProfile(profile_); + + ShowLeoAssistantIconVisibleIfNot(sidebar_service); - ShowLeoAssistantIconVisibleIfNot(service); - profile_->GetPrefs()->ClearPref(ai_chat::prefs::kLastAcceptedDisclaimer); - g_brave_browser_process->process_misc_metrics() - ->ai_chat_metrics() - ->RecordReset(); + ai_chat::AIChatServiceFactory::GetForBrowserContext(profile_) + ->ClearAllHistory(); AllowJavascript(); } diff --git a/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc b/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc index 2397611db5a0..b9e174024b5a 100644 --- a/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc +++ b/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc @@ -84,7 +84,7 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) { // Send the query to the AIChat's backend. ai_chat::mojom::ConversationTurnPtr turn = ai_chat::mojom::ConversationTurn::New( - ai_chat::mojom::CharacterType::HUMAN, + std::nullopt, ai_chat::mojom::CharacterType::HUMAN, ai_chat::mojom::ActionType::QUERY, ai_chat::mojom::ConversationTurnVisibility::VISIBLE, base::UTF16ToUTF8(query) /* text */, std::nullopt /* selected_text */, diff --git a/components/ai_chat/core/browser/BUILD.gn b/components/ai_chat/core/browser/BUILD.gn index 845f3a5f8bf5..fc6a5960e025 100644 --- a/components/ai_chat/core/browser/BUILD.gn +++ b/components/ai_chat/core/browser/BUILD.gn @@ -17,8 +17,6 @@ static_library("browser") { "ai_chat_database.h", "ai_chat_feedback_api.cc", "ai_chat_feedback_api.h", - "ai_chat_storage_service.cc", - "ai_chat_storage_service.h", "ai_chat_metrics.cc", "ai_chat_metrics.h", "ai_chat_service.cc", @@ -89,6 +87,8 @@ static_library("browser") { "//components/component_updater", "//components/component_updater:component_updater_paths", "//components/keyed_service/core", + "//components/os_crypt/async/browser", + "//components/os_crypt/async/common", "//components/os_crypt/sync:os_crypt", "//components/prefs", "//components/user_prefs", @@ -164,6 +164,7 @@ if (!is_ios) { "//brave/components/skus/common:mojom", "//components/component_updater:component_updater_paths", "//components/component_updater:test_support", + "//components/os_crypt/async/browser:test_support", "//components/os_crypt/sync:test_support", "//components/prefs:test_support", "//components/sync_preferences:test_support", @@ -172,6 +173,7 @@ if (!is_ios) { "//services/data_decoder/public/cpp:test_support", "//services/network:test_support", "//services/network/public/cpp:cpp", + "//sql:test_support", "//testing/gtest:gtest", ] } @@ -184,12 +186,18 @@ source_set("test_support") { "engine/mock_engine_consumer.h", "engine/mock_remote_completion_client.cc", "engine/mock_remote_completion_client.h", + "mock_conversation_handler_observer.cc", + "mock_conversation_handler_observer.h", + "test_utils.cc", + "test_utils.h", ] deps = [ "//brave/components/ai_chat/core/browser", + "//brave/components/ai_chat/core/common/mojom", "//services/network/public/cpp", "//testing/gmock", + "//testing/gtest", ] } diff --git a/components/ai_chat/core/browser/DEPS b/components/ai_chat/core/browser/DEPS index 108a66fbb918..4473b9aafe45 100644 --- a/components/ai_chat/core/browser/DEPS +++ b/components/ai_chat/core/browser/DEPS @@ -1,7 +1,9 @@ include_rules = [ + "+components/os_crypt/async", "+services/data_decoder/public", "+services/network/public", "+services/network/test", + "+sql", "+third_party/skia/include", "+third_party/re2/src/re2", "+third_party/tflite", diff --git a/components/ai_chat/core/browser/ai_chat_database.cc b/components/ai_chat/core/browser/ai_chat_database.cc index 1445293f4f2b..114f495ce3fb 100644 --- a/components/ai_chat/core/browser/ai_chat_database.cc +++ b/components/ai_chat/core/browser/ai_chat_database.cc @@ -5,41 +5,72 @@ #include "brave/components/ai_chat/core/browser/ai_chat_database.h" +#include +#include +#include #include +#include "base/check.h" #include "base/strings/string_split.h" #include "base/time/time.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "components/os_crypt/async/common/encryptor.h" #include "sql/statement.h" #include "sql/transaction.h" namespace { -// Do not use To/FromInternalValues in base::Time -// https://bugs.chromium.org/p/chromium/issues/detail?id=634507#c23 -base::TimeDelta SerializeTimeToDelta(const base::Time& time) { - return time.ToDeltaSinceWindowsEpoch(); -} +constexpr char kSearchQueriesSeparator[] = "|||"; -base::Time DeserializeTime(const int64_t& serialized_time) { - return base::Time() + base::Microseconds(serialized_time); +std::optional GetOptionalString(sql::Statement& statement, + int index) { + if (statement.GetColumnType(index) == sql::ColumnType::kNull) { + return std::nullopt; + } + return std::make_optional(statement.ColumnString(index)); } } // namespace namespace ai_chat { -AIChatDatabase::AIChatDatabase() - : db_({.page_size = 4096, .cache_size = 1000}) {} + +AIChatDatabase::AIChatDatabase(const base::FilePath& storage_dir, + os_crypt_async::Encryptor encryptor) + : db_({.page_size = 4096, .cache_size = 1000}), + encryptor_(std::move(encryptor)) { + is_initialized_ = Init(storage_dir); + if (!is_initialized_) { + LOG(ERROR) << "Failed to initialize AI Chat database"; + } +} AIChatDatabase::~AIChatDatabase() = default; +bool AIChatDatabase::IsInitialized() const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + return is_initialized_; +} + bool AIChatDatabase::Init(const base::FilePath& db_file_path) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!GetDB().Open(db_file_path)) { return false; } - if (!CreateConversationTable() || !CreateConversationEntryTable() || - !CreateConversationEntryTextTable() || !CreateSearchQueriesTable()) { - DVLOG(0) << "Failure to create tables\n"; + sql::Transaction transaction(&GetDB()); + if (!transaction.Begin()) { + return false; + } + + if (!CreateConversationTable() || !CreateAssociatedContentTable() || + !CreateConversationEntryTable() || !CreateConversationEntryTextTable() || + !CreateSearchQueriesTable()) { + DVLOG(0) << "Failure to create tables"; + return false; + } + + if (!transaction.Commit()) { return false; } @@ -47,203 +78,477 @@ bool AIChatDatabase::Init(const base::FilePath& db_file_path) { } std::vector AIChatDatabase::GetAllConversations() { - sql::Statement statement(GetDB().GetUniqueStatement( - "SELECT conversation.*, entries.date " - "FROM conversation" - " LEFT JOIN (" - "SELECT conversation_id, date" - " FROM conversation_entry" - " GROUP BY conversation_id" - " ORDER BY date DESC" - " LIMIT 1" - ") AS entries" - " ON conversation.id = entries.conversation_id")); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + static const std::string kQuery = + "SELECT conversation.uuid, conversation.title," + " MAX(conversation_entry.date) AS date" + " FROM conversation" + " LEFT JOIN conversation_entry" + " ON conversation_entry.conversation_uuid = conversation.uuid" + " GROUP BY conversation.uuid" + " ORDER BY date DESC"; + sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE, kQuery)); + CHECK(statement.is_valid()); std::vector conversation_list; while (statement.Step()) { + DVLOG(1) << __func__ << " got a result"; mojom::ConversationPtr conversation = mojom::Conversation::New(); - int index = 0; - conversation->id = statement.ColumnInt64(index++); - conversation->title = statement.ColumnString(index++); - conversation->page_url = GURL(statement.ColumnString(index++)); - conversation->page_contents = statement.ColumnString(index++); - conversation->date = DeserializeTime(statement.ColumnInt64(index++)); + conversation->uuid = statement.ColumnString(0); + conversation->title = + DecryptOptionalColumnToString(statement, 1).value_or(""); + conversation->updated_time = statement.ColumnTime(2); + conversation->has_content = true; + + static const std::string kAssociatedContentQuery = + "SELECT id, title, url, content_type, content_used_percentage," + " is_content_refined" + " FROM associated_content" + " WHERE conversation_uuid=?"; + sql::Statement associated_content_statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kAssociatedContentQuery)); + CHECK(associated_content_statement.is_valid()); + associated_content_statement.BindString(0, conversation->uuid); + + conversation->associated_content = mojom::SiteInfo::New(); + // Only accepts a single row, for now + if (associated_content_statement.Step()) { + DVLOG(1) << __func__ << " got associated content"; + + conversation->associated_content->id = + associated_content_statement.ColumnInt64(0); + conversation->associated_content->title = + DecryptOptionalColumnToString(associated_content_statement, 1); + auto url_raw = + DecryptOptionalColumnToString(associated_content_statement, 2); + if (url_raw.has_value()) { + conversation->associated_content->url = GURL(url_raw.value()); + } + conversation->associated_content->content_type = + static_cast( + associated_content_statement.ColumnInt(3)); + conversation->associated_content->content_used_percentage = + associated_content_statement.ColumnInt(4); + conversation->associated_content->is_content_refined = + associated_content_statement.ColumnBool(5); + conversation->associated_content->is_content_association_possible = true; + } else { + conversation->associated_content->is_content_association_possible = false; + } + conversation_list.emplace_back(std::move(conversation)); } return conversation_list; } -std::vector AIChatDatabase::GetConversationEntries( - int64_t conversation_id) { - sql::Statement statement(GetDB().GetUniqueStatement( - "SELECT conversation_entry.id, conversation_entry.date, " - "conversation_entry.character_type, conversation_entry.action_type, " - "conversation_entry.selected_text, text.id, text.date, text.text, " - "text.conversation_entry_id, " - "GROUP_CONCAT(DISTINCT search_queries.query) AS search_queries" +mojom::ConversationArchivePtr AIChatDatabase::GetConversationData( + std::string_view conversation_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + return mojom::ConversationArchive::New( + GetConversationEntries(conversation_uuid), + GetArchiveContentsForConversation(conversation_uuid)); +} + +std::vector AIChatDatabase::GetConversationEntries( + std::string_view conversation_uuid) { + static const std::string kEntriesQuery = + "SELECT uuid, date, entry_text, character_type, editing_entry_uuid, " + "action_type, selected_text" " FROM conversation_entry" - " JOIN conversation_entry_text as text" - " JOIN search_queries" - " ON conversation_entry.id = text.conversation_entry_id" - " WHERE conversation_entry.conversation_id=?" - " GROUP BY conversation_entry.id, text.id" - " ORDER BY conversation_entry.date ASC")); + " WHERE conversation_uuid=?" + " ORDER BY date ASC"; + sql::Statement statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kEntriesQuery)); + CHECK(statement.is_valid()); - statement.BindInt64(0, conversation_id); + statement.BindString(0, conversation_uuid); - std::vector history; + DVLOG(4) << __func__ << " for " << conversation_uuid; - int64_t last_conversation_entry_id = -1; + std::vector history; + // Map of editing entry id to the edit entry + std::map> edits; while (statement.Step()) { - mojom::ConversationEntryTextPtr entry_text = - mojom::ConversationEntryText::New(); - entry_text->id = statement.ColumnInt64(5); - entry_text->date = DeserializeTime(statement.ColumnInt64(6)); - entry_text->text = statement.ColumnString(7); - - int64_t current_conversation_id = statement.ColumnInt64(0); - - if (last_conversation_entry_id == current_conversation_id) { - // Update the last entry if it is the same conversation entry - history.back()->texts.emplace_back(std::move(entry_text)); - } else { - mojom::ConversationEntryPtr entry = mojom::ConversationEntry::New(); - entry->id = statement.ColumnInt64(0); - entry->date = DeserializeTime(statement.ColumnInt64(1)); - entry->character_type = - static_cast(statement.ColumnInt(2)); - entry->action_type = - static_cast(statement.ColumnInt(3)); - entry->selected_text = statement.ColumnString(4); - - // Parse search queries - std::string search_queries_all = statement.ColumnString(9); - if (!search_queries_all.empty()) { - entry->events = std::vector(); - - std::vector search_queries = - base::SplitString(search_queries_all, ",", base::TRIM_WHITESPACE, - base::SPLIT_WANT_NONEMPTY); - entry->events->emplace_back( - mojom::ConversationEntryEvent::NewSearchQueriesEvent( - mojom::SearchQueriesEvent::New(std::move(search_queries)))); + // basic metadata + std::string entry_uuid = statement.ColumnString(0); + DVLOG(4) << "Found entry row for conversation " << conversation_uuid + << " with id " << entry_uuid; + auto date = statement.ColumnTime(1); + auto text = DecryptOptionalColumnToString(statement, 2).value_or(""); + auto character_type = + static_cast(statement.ColumnInt(3)); + auto editing_entry_id = GetOptionalString(statement, 4); + auto action_type = static_cast(statement.ColumnInt(5)); + auto selected_text = DecryptOptionalColumnToString(statement, 6); + + auto entry = mojom::ConversationTurn::New( + entry_uuid, character_type, action_type, + mojom::ConversationTurnVisibility::VISIBLE, text, selected_text, + std::nullopt, date, std::nullopt, false); + + // events + struct Event { + int event_order; + mojom::ConversationEntryEventPtr event; + }; + std::vector events; + + // Completion events + { + sql::Statement event_statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, + "SELECT event_order, text" + " FROM conversation_entry_event_completion" + " WHERE conversation_entry_uuid=?" + " ORDER BY event_order ASC")); + event_statement.BindString(0, entry_uuid); + + while (event_statement.Step()) { + int event_order = event_statement.ColumnInt(0); + std::string completion = DecryptColumnToString(event_statement, 1); + events.emplace_back(Event{ + event_order, mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(completion))}); } + } - // Add the text to the new entry - entry->texts = std::vector(); - entry->texts.emplace_back(std::move(entry_text)); + // Search Query events + { + sql::Statement event_statement(GetDB().GetUniqueStatement( + "SELECT event_order, queries" + " FROM conversation_entry_event_search_queries" + " WHERE conversation_entry_uuid=?" + " ORDER BY event_order ASC")); + event_statement.BindString(0, entry_uuid); + + while (event_statement.Step()) { + int event_order = event_statement.ColumnInt(0); + auto queries_data = DecryptColumnToString(event_statement, 1); + std::vector queries = + base::SplitString(queries_data, kSearchQueriesSeparator, + base::WhitespaceHandling::TRIM_WHITESPACE, + base::SplitResult::SPLIT_WANT_NONEMPTY); + events.emplace_back(Event{ + event_order, mojom::ConversationEntryEvent::NewSearchQueriesEvent( + mojom::SearchQueriesEvent::New(queries))}); + } + } - last_conversation_entry_id = entry->id; + // insert events in order + if (!events.empty()) { + base::ranges::sort(events, [](const Event& a, const Event& b) { + return a.event_order < b.event_order; + }); + entry->events = std::vector{}; + for (auto& event : events) { + entry->events->emplace_back(std::move(event.event)); + } + } - // Add the new entry to the history + // root entry or edited entry + if (editing_entry_id.has_value()) { + DVLOG(4) << "Collected edit entry for " << editing_entry_id.value() + << " with id " << entry_uuid; + edits[editing_entry_id.value()].emplace_back(std::move(entry)); + } else { + DVLOG(4) << "Collected entry for " << entry_uuid; history.emplace_back(std::move(entry)); } } + // Reconstruct edits + for (auto& entry : history) { + CHECK(entry->uuid.has_value()); + auto id = entry->uuid.value(); + if (edits.count(id)) { + entry->edits = std::vector{}; + for (auto& edit : edits[id]) { + entry->edits->emplace_back(std::move(edit)); + } + } + } + return history; } -int64_t AIChatDatabase::AddConversation(mojom::ConversationPtr conversation) { - sql::Statement statement(GetDB().GetUniqueStatement( - "INSERT INTO conversation(id, title, " - "page_url, page_contents) VALUES(NULL, ?, ?, ?)")); +std::vector +AIChatDatabase::GetArchiveContentsForConversation( + std::string_view conversation_uuid) { + static const std::string kQuery = + "SELECT id, last_contents" + " FROM associated_content" + " WHERE conversation_uuid=?" + " AND last_contents IS NOT NULL" + " ORDER BY id ASC"; + sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE, kQuery)); CHECK(statement.is_valid()); + statement.BindString(0, conversation_uuid); + std::vector archive_contents; + // We only support a single entry until ConversationHandler supports multiple + // associated contents. + if (statement.Step()) { + auto content = mojom::ContentArchive::New( + statement.ColumnInt(0), DecryptColumnToString(statement, 1)); + archive_contents.emplace_back(std::move(content)); + } + return archive_contents; +} - statement.BindString(0, conversation->title); - if (conversation->page_url.has_value()) { - statement.BindString(1, conversation->page_url->spec()); - } else { - statement.BindNull(1); +AIChatDatabase::AddConversationResult AIChatDatabase::AddConversation( + mojom::ConversationPtr conversation, + std::optional contents, + mojom::ConversationTurnPtr first_entry) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK(!conversation->uuid.empty()); + CHECK(first_entry); + + sql::Transaction transaction(&GetDB()); + CHECK(GetDB().is_open()); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin"; + return AddConversationResult{}; } - if (conversation->page_contents.has_value()) { - statement.BindString(2, conversation->page_contents.value()); + static const std::string kInsertConversationQuery = + "INSERT INTO conversation(uuid, title) " + "VALUES(?, ?)"; + sql::Statement statement( + GetDB().GetUniqueStatement(kInsertConversationQuery)); + CHECK(statement.is_valid()); + + statement.BindString(0, conversation->uuid); + + if (conversation->title.empty()) { + statement.BindNull(1); } else { - statement.BindNull(2); + if (!BindAndEncryptString(statement, 1, conversation->title)) { + return AddConversationResult{}; + } } if (!statement.Run()) { - DVLOG(0) << "Failed to execute 'conversation' insert statement" + DVLOG(0) << "Failed to execute 'conversation' insert statement: " + << db_.GetErrorMessage(); + return AddConversationResult{}; + } + + AddConversationResult result; + + if (conversation->associated_content->is_content_association_possible) { + DVLOG(2) << "Adding associated content for conversation " + << conversation->uuid << " with url " + << conversation->associated_content->url->spec(); + result.associated_content_id = AddOrUpdateAssociatedContent( + conversation->uuid, std::move(conversation->associated_content), + contents); + if (result.associated_content_id == -1) { + return result; + } + } + + if (!AddConversationEntry(conversation->uuid, std::move(first_entry))) { + return AddConversationResult{}; + } + + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " << db_.GetErrorMessage(); - return INT64_C(-1); + return AddConversationResult{}; } - return GetDB().GetLastInsertRowId(); + return result; } -int64_t AIChatDatabase::AddConversationEntry( - int64_t conversation_id, - mojom::ConversationEntryPtr entry) { - sql::Transaction transaction(&GetDB()); +int32_t AIChatDatabase::AddOrUpdateAssociatedContent( + std::string_view conversation_uuid, + mojom::SiteInfoPtr associated_content, + std::optional contents) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + // TODO(petemill): handle multiple associated content per conversation + CHECK(!conversation_uuid.empty()); + + sql::Statement statement; + int index; + if (associated_content->id.value_or(-1) != -1) { + DVLOG(4) << "Updating associated content for conversation " + << conversation_uuid << " with id " + << associated_content->id.value(); + static const std::string kUpdateAssociatedContentQuery = + "UPDATE associated_content" + " SET title = ?," + " url = ?," + " content_type = ?," + " last_contents = ?," + " content_used_percentage = ?," + " is_content_refined = ?" + " WHERE id=?"; + statement.Assign(GetDB().GetUniqueStatement(kUpdateAssociatedContentQuery)); + statement.BindInt(6, associated_content->id.value()); + index = 0; + } else { + DVLOG(4) << "Inserting associated content for conversation " + << conversation_uuid; + static const std::string kInsertAssociatedContentQuery = + "INSERT INTO associated_content(conversation_uuid, title, url," + " content_type, last_contents, content_used_percentage," + " is_content_refined)" + " VALUES(?, ?, ?, ?, ?, ?, ?) "; + statement.Assign(GetDB().GetUniqueStatement(kInsertAssociatedContentQuery)); + statement.BindString(0, conversation_uuid); + index = 1; + } + CHECK(statement.is_valid()); + BindAndEncryptOptionalString(statement, index, associated_content->title); + index++; + BindAndEncryptOptionalString(statement, index, + associated_content->url->spec()); + index++; + statement.BindInt(index, + static_cast(associated_content->content_type)); + index++; + BindAndEncryptOptionalString(statement, index, contents); + index++; + statement.BindInt(index, associated_content->content_used_percentage); + index++; + statement.BindBool(index, associated_content->is_content_refined); - CHECK(GetDB().is_open()); + if (!statement.Run()) { + DVLOG(0) + << "Failed to execute 'associated_content' insert or update statement: " + << db_.GetErrorMessage(); + return INT32_C(-1); + } - if (!transaction.Begin()) { - DVLOG(0) << "Transaction cannot begin\n"; - return INT64_C(-1); + if (!associated_content->id.has_value()) { + return GetDB().GetLastInsertRowId(); } + return associated_content->id.value(); +} +bool AIChatDatabase::AddConversationEntry( + std::string_view conversation_uuid, + mojom::ConversationTurnPtr entry, + std::optional editing_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK(!conversation_uuid.empty()); + CHECK(entry->uuid.has_value() && !entry->uuid->empty()); + // Verify the conversation exists + static const std::string kGetConversationIdQuery = + "SELECT uuid FROM conversation WHERE uuid=?"; sql::Statement get_conversation_id_statement( - GetDB().GetUniqueStatement("SELECT id FROM conversation" - " WHERE id=?")); + GetDB().GetCachedStatement(SQL_FROM_HERE, kGetConversationIdQuery)); CHECK(get_conversation_id_statement.is_valid()); - - get_conversation_id_statement.BindInt64(0, conversation_id); - + get_conversation_id_statement.BindString(0, conversation_uuid); if (!get_conversation_id_statement.Step()) { DVLOG(0) << "ID not found in 'conversation' table"; - return INT64_C(-1); + return false; + } + + sql::Transaction transaction(&GetDB()); + CHECK(GetDB().is_open()); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin"; + return false; } - sql::Statement insert_conversation_entry_statement(GetDB().GetUniqueStatement( - "INSERT INTO conversation_entry(id, date, " - "character_type, action_type, selected_text, " - "conversation_id) VALUES(NULL, ?, ?, ?, ?, ?)")); + sql::Statement insert_conversation_entry_statement; + + if (editing_id.has_value()) { + static const std::string kInsertEditingConversationEntryQuery = + "INSERT INTO conversation_entry(editing_entry_uuid, uuid," + " conversation_uuid, date, entry_text," + " character_type, action_type, selected_text)" + " VALUES(?, ?, ?, ?, ?, ?, ?, ?)"; + insert_conversation_entry_statement.Assign( + GetDB().GetUniqueStatement(kInsertEditingConversationEntryQuery)); + } else { + static const std::string kInsertConversationEntryQuery = + "INSERT INTO conversation_entry(uuid, conversation_uuid, date," + " entry_text, character_type, action_type, selected_text)" + " VALUES(?, ?, ?, ?, ?, ?, ?)"; + insert_conversation_entry_statement.Assign( + GetDB().GetUniqueStatement(kInsertConversationEntryQuery)); + } CHECK(insert_conversation_entry_statement.is_valid()); - // ConversationEntry's date should always match first text's date int index = 0; - insert_conversation_entry_statement.BindTimeDelta( - index++, SerializeTimeToDelta(entry->texts[0]->date)); + if (editing_id.has_value()) { + insert_conversation_entry_statement.BindString(index++, editing_id.value()); + } + insert_conversation_entry_statement.BindString(index++, entry->uuid.value()); + insert_conversation_entry_statement.BindString(index++, conversation_uuid); + insert_conversation_entry_statement.BindTime(index++, entry->created_time); + BindAndEncryptOptionalString(insert_conversation_entry_statement, index++, + entry->text); insert_conversation_entry_statement.BindInt( index++, static_cast(entry->character_type)); insert_conversation_entry_statement.BindInt( index++, static_cast(entry->action_type)); - - if (entry->selected_text.has_value()) { - insert_conversation_entry_statement.BindString( - index++, entry->selected_text.value()); - } else { - insert_conversation_entry_statement.BindNull(index++); - } - - insert_conversation_entry_statement.BindInt64(index++, conversation_id); + BindAndEncryptOptionalString(insert_conversation_entry_statement, index++, + entry->selected_text); if (!insert_conversation_entry_statement.Run()) { DVLOG(0) << "Failed to execute 'conversation_entry' insert statement: " << db_.GetErrorMessage(); - return INT64_C(-1); - } - - int64_t conversation_entry_row_id = GetDB().GetLastInsertRowId(); - - for (mojom::ConversationEntryTextPtr& text : entry->texts) { - // Add texts - AddConversationEntryText(conversation_entry_row_id, std::move(text)); + return false; } if (entry->events.has_value()) { - for (const mojom::ConversationEntryEventPtr& event : *entry->events) { - if (event->is_search_queries_event()) { - std::vector queries = - event->get_search_queries_event()->search_queries; - for (const std::string& query : queries) { - // Add search queries - AddSearchQuery(conversation_id, conversation_entry_row_id, query); + for (size_t i = 0; i < entry->events->size(); i++) { + const mojom::ConversationEntryEventPtr& event = entry->events->at(i); + switch (event->which()) { + case mojom::ConversationEntryEvent::Tag::kCompletionEvent: { + sql::Statement event_statement(GetDB().GetCachedStatement( + SQL_FROM_HERE, + "INSERT INTO conversation_entry_event_completion" + " (event_order, text, conversation_entry_uuid)" + " VALUES(?, ?, ?)")); + CHECK(event_statement.is_valid()); + event_statement.BindInt(0, static_cast(i)); + if (!BindAndEncryptString( + event_statement, 1, + event->get_completion_event()->completion)) { + return false; + } + event_statement.BindString(2, entry->uuid.value()); + event_statement.Run(); + break; + } + case mojom::ConversationEntryEvent::Tag::kSearchQueriesEvent: { + sql::Statement event_statement(GetDB().GetCachedStatement( + SQL_FROM_HERE, + "INSERT INTO conversation_entry_event_search_queries" + " (event_order, queries, conversation_entry_uuid)" + " VALUES(?, ?, ?)")); + CHECK(event_statement.is_valid()); + + std::string queries_data = base::JoinString( + event->get_search_queries_event()->search_queries, + kSearchQueriesSeparator); + + event_statement.BindInt(0, static_cast(i)); + if (!BindAndEncryptString(event_statement, 1, queries_data)) { + return false; + } + event_statement.BindString(2, entry->uuid.value()); + event_statement.Run(); + break; } + default: { + break; + } + } + } + } + + if (entry->edits.has_value()) { + for (auto& edit : entry->edits.value()) { + if (!AddConversationEntry(conversation_uuid, std::move(edit), + entry->uuid.value())) { + return false; } } } @@ -251,163 +556,359 @@ int64_t AIChatDatabase::AddConversationEntry( if (!transaction.Commit()) { DVLOG(0) << "Transaction commit failed with reason: " << db_.GetErrorMessage(); - return INT64_C(-1); + return false; } - return conversation_entry_row_id; + return true; } -int64_t AIChatDatabase::AddConversationEntryText( - int64_t conversation_entry_id, - mojom::ConversationEntryTextPtr entry_text) { - sql::Statement get_conversation_entry_statement( - GetDB().GetUniqueStatement("SELECT id FROM conversation_entry" - " WHERE id=?")); - CHECK(get_conversation_entry_statement.is_valid()); - get_conversation_entry_statement.BindInt64(0, conversation_entry_id); - - if (!get_conversation_entry_statement.Step()) { - DVLOG(0) << "ID not found in 'conversation entry' table" - << db_.GetErrorMessage(); - return INT64_C(-1); - } +bool AIChatDatabase::UpdateConversationTitle(std::string conversation_uuid, + std::string title) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + static const std::string kUpdateConversationTitleQuery = + "UPDATE conversation SET title=? WHERE uuid=?"; + sql::Statement statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kUpdateConversationTitleQuery)); + CHECK(statement.is_valid()); - sql::Statement insert_text_statement( - GetDB().GetUniqueStatement("INSERT INTO " - "conversation_entry_text(" - "id, date, text, conversation_entry_id" - ") VALUES(NULL, ?, ?, ?)")); - CHECK(insert_text_statement.is_valid()); + if (!BindAndEncryptString(statement, 0, title)) { + return false; + } + statement.BindString(1, conversation_uuid); - insert_text_statement.BindTimeDelta(0, - SerializeTimeToDelta(entry_text->date)); - insert_text_statement.BindString(1, entry_text->text); - insert_text_statement.BindInt64(2, conversation_entry_id); + return statement.Run(); +} - if (!insert_text_statement.Run()) { - DVLOG(0) << "Failed to execute 'conversation_entry_text' insert statement: " - << db_.GetErrorMessage(); - return INT64_C(-1); +bool AIChatDatabase::DeleteConversation(std::string_view conversation_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + sql::Transaction transaction(&db_); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin\n"; + return false; } - return GetDB().GetLastInsertRowId(); -} + // Delete all conversation entries + static const std::string kSelectConversationEntryQuery = + "SELECT uuid FROM conversation_entry WHERE conversation_uuid=?"; + sql::Statement select_conversation_entry_statement( + GetDB().GetUniqueStatement(kSelectConversationEntryQuery)); + CHECK(select_conversation_entry_statement.is_valid()); + select_conversation_entry_statement.BindString(0, conversation_uuid); -int64_t AIChatDatabase::AddSearchQuery(int64_t conversation_id, - int64_t conversation_entry_id, - const std::string& query) { - sql::Statement insert_search_query_statement(GetDB().GetUniqueStatement( - "INSERT INTO search_queries" - "(id, query, conversation_id, conversation_entry_id)" - "VALUES(NULL, ?, ?, ?)")); - CHECK(insert_search_query_statement.is_valid()); + // Delete all conversation entry events + while (select_conversation_entry_statement.Step()) { + std::string conversation_entry_uuid = + select_conversation_entry_statement.ColumnString(0); + static const std::string kDeleteCompletionEventQuery = + "DELETE FROM conversation_entry_event_completion" + " WHERE conversation_entry_uuid=?"; + sql::Statement delete_completion_event_statement( + GetDB().GetUniqueStatement(kDeleteCompletionEventQuery)); + CHECK(delete_completion_event_statement.is_valid()); + delete_completion_event_statement.BindString(0, conversation_entry_uuid); + if (!delete_completion_event_statement.Run()) { + return false; + } - int index = 0; - insert_search_query_statement.BindString(index++, query); - insert_search_query_statement.BindInt64(index++, conversation_id); - insert_search_query_statement.BindInt64(index++, conversation_entry_id); + static const std::string kDeleteSearchQueriesEventQuery = + "DELETE FROM conversation_entry_event_search_queries " + " WHERE conversation_entry_uuid=?"; + sql::Statement delete_queries_event_statement( + GetDB().GetUniqueStatement(kDeleteSearchQueriesEventQuery)); + CHECK(delete_queries_event_statement.is_valid()); + delete_queries_event_statement.BindString(0, conversation_entry_uuid); + if (!delete_queries_event_statement.Run()) { + return false; + } - if (!insert_search_query_statement.Run()) { - DVLOG(0) << "Failed to execute 'search_queries' insert statement: " - << db_.GetErrorMessage(); - return INT64_C(-1); + static const std::string kDeleteEntryQuery = + "DELETE FROM conversation_entry WHERE uuid=?"; + sql::Statement delete_conversation_entry_statement( + GetDB().GetUniqueStatement(kDeleteEntryQuery)); + CHECK(delete_conversation_entry_statement.is_valid()); + delete_conversation_entry_statement.BindString(0, conversation_entry_uuid); + if (!delete_conversation_entry_statement.Run()) { + return false; + } } - return GetDB().GetLastInsertRowId(); -} - -bool AIChatDatabase::DeleteConversation(int64_t conversation_id) { - sql::Transaction transaction(&db_); - if (!transaction.Begin()) { - DVLOG(0) << "Transaction cannot begin\n"; + // Delete the conversation metadata + static const std::string kDeleteAssociatedContentQuery = + "DELETE FROM associated_content WHERE conversation_uuid=?"; + sql::Statement delete_associated_content_statement( + GetDB().GetUniqueStatement(kDeleteAssociatedContentQuery)); + CHECK(delete_associated_content_statement.is_valid()); + delete_associated_content_statement.BindString(0, conversation_uuid); + if (!delete_associated_content_statement.Run()) { return false; } + static const std::string kDeleteConversationQuery = + "DELETE FROM conversation WHERE uuid=?"; sql::Statement delete_conversation_statement( - GetDB().GetUniqueStatement("DELETE FROM conversation" - "WHERE id=?")); - delete_conversation_statement.BindInt64(0, conversation_id); - + GetDB().GetUniqueStatement(kDeleteConversationQuery)); + CHECK(delete_conversation_statement.is_valid()); + delete_conversation_statement.BindString(0, conversation_uuid); if (!delete_conversation_statement.Run()) { return false; } - sql::Statement select_conversation_entry_statement( - GetDB().GetUniqueStatement("SELECT id FROM conversation_entry" - "WHERE conversation_id=?")); + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return false; + } + return true; +} - // We remove all conversation text associated to |conversation_id| - while (select_conversation_entry_statement.Step()) { - sql::Statement delete_conversation_text_statement( - GetDB().GetUniqueStatement("DELETE FROM conversation_entry_text" - "WHERE conversation_entry_id=?")); - delete_conversation_text_statement.BindInt64( - 0, select_conversation_entry_statement.ColumnInt64(0)); +bool AIChatDatabase::DeleteConversationEntry( + std::string conversation_entry_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + sql::Transaction transaction(&db_); + CHECK(!conversation_entry_uuid.empty()); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin\n"; + return false; + } + // Delete from conversation_entry_event_completion + { + sql::Statement delete_statement(GetDB().GetUniqueStatement( + "DELETE FROM conversation_entry_event_completion WHERE " + "conversation_entry_uuid=?")); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + DLOG(ERROR) + << "Failed to delete from conversation_entry_event_completion " + "for id: " + << conversation_entry_uuid; + return false; + } + } - if (!delete_conversation_text_statement.Run()) { + // Delete from conversation_entry_event_search_queries + { + static const std::string kQuery = + "DELETE FROM conversation_entry_event_search_queries WHERE " + "conversation_entry_uuid=?"; + sql::Statement delete_statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(delete_statement.is_valid()); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + DLOG(ERROR) << "Failed to delete from " + "conversation_entry_event_search_queries for conversation " + "entry uuid: " + << conversation_entry_uuid; return false; } } - sql::Statement delete_conversation_entry_statement( - GetDB().GetUniqueStatement("DELETE FROM conversation_entry" - "WHERE conversation_entry_id=?")); - delete_conversation_entry_statement.BindInt64(0, conversation_id); + // Delete edits + { + static const std::string kQuery = + "DELETE FROM conversation_entry WHERE editing_entry_uuid = ?"; + sql::Statement delete_statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(delete_statement.is_valid()); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + DLOG(ERROR) << "Failed to delete from conversation_entry for " + "conversation entry uuid: " + << conversation_entry_uuid; + return false; + } + } - // At last we remove all conversation entries associated to |conversation_id| - if (!delete_conversation_entry_statement.Run()) { - return false; + // Delete from conversation_entry + { + static const std::string kQuery = + "DELETE FROM conversation_entry WHERE uuid=?"; + sql::Statement delete_statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(delete_statement.is_valid()); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + LOG(ERROR) << "Failed to delete from conversation_entry for id: " + << conversation_entry_uuid; + return false; + } } - transaction.Commit(); + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return false; + } return true; } -bool AIChatDatabase::DropAllTables() { - return GetDB().Execute("DROP TABLE conversation") && - GetDB().Execute("DROP TABLE conversation_entry") && - GetDB().Execute("DROP TABLE conversation_entry_text"); +bool AIChatDatabase::DeleteAllData() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + static const std::string kDropConversationTableQuery = + "DROP TABLE conversation"; + static const std::string kDropAssociatedContentTableQuery = + "DROP TABLE associated_content"; + static const std::string kDropConversationEntryTableQuery = + "DROP TABLE conversation_entry"; + static const std::string kDropCompletionEventTableQuery = + "DROP TABLE conversation_entry_event_completion"; + static const std::string kDropSearchQueriesEventTableQuery = + "DROP TABLE conversation_entry_event_search_queries"; + + return GetDB().Execute(kDropConversationTableQuery) && + GetDB().Execute(kDropAssociatedContentTableQuery) && + GetDB().Execute(kDropConversationEntryTableQuery) && + GetDB().Execute(kDropCompletionEventTableQuery) && + GetDB().Execute(kDropSearchQueriesEventTableQuery) && + CreateConversationTable() && CreateAssociatedContentTable() && + CreateConversationEntryTable() && CreateConversationEntryTextTable() && + CreateSearchQueriesTable(); } sql::Database& AIChatDatabase::GetDB() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return db_; } +std::string AIChatDatabase::DecryptColumnToString(sql::Statement& statement, + int index) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto decrypted_value = encryptor_.DecryptData(statement.ColumnBlob(index)); + if (!decrypted_value) { + DVLOG(0) << "Failed to decrypt value"; + return ""; + } + return *decrypted_value; +} + +std::optional AIChatDatabase::DecryptOptionalColumnToString( + sql::Statement& statement, + int index) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto column_type = statement.GetColumnType(index); + // Don't allow non-BLOB types + if (column_type != sql::ColumnType::kBlob) { + return std::nullopt; + } + auto decrypted_value = encryptor_.DecryptData(statement.ColumnBlob(index)); + if (!decrypted_value) { + DVLOG(0) << "Failed to decrypt value"; + return std::nullopt; + } + return std::make_optional(*decrypted_value); +} + +void AIChatDatabase::BindAndEncryptOptionalString( + sql::Statement& statement, + int index, + const std::optional& value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (value.has_value() && !value.value().empty()) { + auto encrypted_value = encryptor_.EncryptString(value.value()); + if (!encrypted_value) { + DVLOG(0) << "Failed to encrypt value"; + statement.BindNull(index); + return; + } + statement.BindBlob(index, *encrypted_value); + } else { + statement.BindNull(index); + } +} + +bool AIChatDatabase::BindAndEncryptString(sql::Statement& statement, + int index, + const std::string& value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto encrypted_value = encryptor_.EncryptString(value); + if (!encrypted_value) { + DVLOG(0) << "Failed to encrypt value"; + return false; + } + statement.BindBlob(index, *encrypted_value); + return true; +} + bool AIChatDatabase::CreateConversationTable() { - return GetDB().Execute( + static const std::string kCreateConversationTableQuery = "CREATE TABLE IF NOT EXISTS conversation(" - "id INTEGER PRIMARY KEY," - "title TEXT," - "page_url TEXT," - "page_contents TEXT)"); + "uuid TEXT PRIMARY KEY NOT NULL," + "title BLOB)"; + CHECK(GetDB().IsSQLValid(kCreateConversationTableQuery)); + return GetDB().Execute(kCreateConversationTableQuery); } +// AssociatedContent is 1:many with Conversation for future-proofing when +// we support multiple associated contents per conversation. +bool AIChatDatabase::CreateAssociatedContentTable() { + static const std::string kCreateAssociatedContentTableQuery = + "CREATE TABLE IF NOT EXISTS associated_content(" + "id INTEGER PRIMARY KEY NOT NULL," + "conversation_uuid TEXT NOT NULL," + // Encrypted title string + "title BLOB NOT NULL," + // Encrypted url string + "url BLOB NOT NULL," + // Stores SiteInfo.IsVideo. Future-proofed for multiple content types + // 0 for regular content + // 1 for video. + "content_type INTEGER NOT NULL," + // Encrypted string value of the content, so that conversations can be + // continued. + "last_contents BLOB," + // Don't need REAL for content_used_percentage since + // we're never using decimal values. + "content_used_percentage INTEGER NOT NULL," + "is_content_refined INTEGER NOT NULL)"; + CHECK(GetDB().IsSQLValid(kCreateAssociatedContentTableQuery)); + return GetDB().Execute(kCreateAssociatedContentTableQuery); +} + +// AKA ConversationTurn in mojom bool AIChatDatabase::CreateConversationEntryTable() { - return GetDB().Execute( + static const std::string kCreateConversationEntryTableQuery = "CREATE TABLE IF NOT EXISTS conversation_entry(" - "id INTEGER PRIMARY KEY," + "uuid TEXT PRIMARY KEY NOT NULL," + "conversation_uuid STRING NOT NULL," "date INTEGER NOT NULL," + // Encrypted text string + // TODO(petemill): move to event only + "entry_text BLOB," "character_type INTEGER NOT NULL," - "action_type INTEGER NOT NULL," - "selected_text TEXT," - "conversation_id INTEGER NOT NULL)"); + // editing_entry points to the ConversationEntry row that is being edited. + // Edits can be sorted by date. + "editing_entry_uuid TEXT," + "action_type INTEGER," + // Encrypted selected text + "selected_text BLOB)"; + // TODO(petemill): Forking can be achieved by associating each + // ConversationEntry with a parent ConversationEntry. + // TODO(petemill): Store a model name with each entry to know when + // a model was changed for a conversation, or for forking-by-model features. + CHECK(GetDB().IsSQLValid(kCreateConversationEntryTableQuery)); + return GetDB().Execute(kCreateConversationEntryTableQuery); } bool AIChatDatabase::CreateConversationEntryTextTable() { - return GetDB().Execute( - "CREATE TABLE IF NOT EXISTS conversation_entry_text(" - "id INTEGER PRIMARY KEY," - "date INTEGER NOT NULL," - "text TEXT NOT NULL," - "conversation_entry_id INTEGER NOT NULL)"); + static const std::string kCreateConversationEntryTextTableQuery = + "CREATE TABLE IF NOT EXISTS conversation_entry_event_completion(" + "conversation_entry_uuid INTEGER NOT NULL," + "event_order INTEGER NOT NULL," + // encrypted event text string + "text BLOB NOT NULL," + "PRIMARY KEY(conversation_entry_uuid, event_order)" + ")"; + CHECK(GetDB().IsSQLValid(kCreateConversationEntryTextTableQuery)); + return GetDB().Execute(kCreateConversationEntryTextTableQuery); } bool AIChatDatabase::CreateSearchQueriesTable() { - return GetDB().Execute( - "CREATE TABLE IF NOT EXISTS search_queries(" - "id INTEGER PRIMARY KEY," - "query TEXT NOT NULL," - "conversation_id INTEGER NOT NULL," - "conversation_entry_id INTEGER NOT NULL)"); + static const std::string kCreateSearchQueriesTableQuery = + "CREATE TABLE IF NOT EXISTS conversation_entry_event_search_queries(" + "conversation_entry_uuid INTEGER NOT NULL," + "event_order INTEGER NOT NULL," + // encrypted delimited search query strings + "queries BLOB NOT NULL," + "PRIMARY KEY(conversation_entry_uuid, event_order)" + ")"; + CHECK(GetDB().IsSQLValid(kCreateSearchQueriesTableQuery)); + return GetDB().Execute(kCreateSearchQueriesTableQuery); } } // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_database.h b/components/ai_chat/core/browser/ai_chat_database.h index 38c9931ac597..fa466f28f5f8 100644 --- a/components/ai_chat/core/browser/ai_chat_database.h +++ b/components/ai_chat/core/browser/ai_chat_database.h @@ -6,11 +6,15 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ +#include +#include #include -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "base/sequence_checker.h" +#include "base/thread_annotations.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "components/os_crypt/async/common/encryptor.h" #include "sql/database.h" -#include "sql/init_status.h" namespace sql { class Database; @@ -18,38 +22,100 @@ class Database; namespace ai_chat { +// Persists AI Chat conversations and associated content. Conversations are +// mainly formed of their conversation entries. Edits to conversation entries +// should be handled with removal and re-adding so that other classes can make +// decisions about how it affects the rest of history. +// All data should be stored encrypted. class AIChatDatabase { public: - AIChatDatabase(); + struct AddConversationResult { + int32_t associated_content_id = -1; + }; + + AIChatDatabase(const base::FilePath& storage_dir, + os_crypt_async::Encryptor encryptor); AIChatDatabase(const AIChatDatabase&) = delete; AIChatDatabase& operator=(const AIChatDatabase&) = delete; ~AIChatDatabase(); - bool Init(const base::FilePath& db_file_path); + bool IsInitialized() const; + // Gets lightweight metadata for all conversations. No high-memory-consuming + // data is returned. std::vector GetAllConversations(); - std::vector GetConversationEntries( - int64_t conversation_id); - int64_t AddConversation(mojom::ConversationPtr conversation); - int64_t AddConversationEntry(int64_t conversation_id, - mojom::ConversationEntryPtr entry); - int64_t AddConversationEntryText(int64_t conversation_entry_id, - mojom::ConversationEntryTextPtr entry_text); - int64_t AddSearchQuery(int64_t conversation_id, - int64_t conversation_entry_id, - const std::string& query); - bool DeleteConversation(int64_t conversation_id); - bool DropAllTables(); + + // Gets all data needed to rehydrate a conversation + mojom::ConversationArchivePtr GetConversationData( + std::string_view conversation_uuid); + + // Returns new ID for the provided entry and any provided associated content + AddConversationResult AddConversation(mojom::ConversationPtr conversation, + std::optional contents, + mojom::ConversationTurnPtr first_entry); + + // Update any properties of associated content metadata or full-text content + int32_t AddOrUpdateAssociatedContent(std::string_view conversation_uuid, + mojom::SiteInfoPtr associated_content, + std::optional content); + + // Adds a new conversation entry to the conversation with the provided UUID + bool AddConversationEntry( + std::string_view conversation_uuid, + mojom::ConversationTurnPtr entry, + std::optional editing_id = std::nullopt); + + // Updates the title of the conversation with the provided UUID + bool UpdateConversationTitle(std::string conversation_uuid, + std::string title); + + // Deletes the conversation with the provided UUID + bool DeleteConversation(std::string_view conversation_uuid); + + // Deletes the conversation entry with the provided ID and all associated + // edits and events. + bool DeleteConversationEntry(std::string conversation_entry_uuid); + + // Drops all data and tables in the database, and re-creates empty tables + bool DeleteAllData(); private: + friend class AIChatDatabaseTest; + sql::Database& GetDB(); + bool Init(const base::FilePath& db_file_path); + + std::vector GetConversationEntries( + std::string_view conversation_id); + std::vector GetArchiveContentsForConversation( + std::string_view conversation_uuid); + + std::string DecryptColumnToString(sql::Statement& statement, int index); + std::optional DecryptOptionalColumnToString( + sql::Statement& statement, + int index); + void BindAndEncryptOptionalString(sql::Statement& statement, + int index, + const std::optional& value); + bool BindAndEncryptString(sql::Statement& statement, + int index, + const std::string& value); + bool CreateConversationTable(); + bool CreateAssociatedContentTable(); bool CreateConversationEntryTable(); bool CreateConversationEntryTextTable(); bool CreateSearchQueriesTable(); - sql::Database db_; + bool is_initialized_ GUARDED_BY_CONTEXT(sequence_checker_) = false; + + // The underlying SQL database + sql::Database db_ GUARDED_BY_CONTEXT(sequence_checker_); + os_crypt_async::Encryptor encryptor_ GUARDED_BY_CONTEXT(sequence_checker_); + + // Verifies that all operations happen on the same sequence. + SEQUENCE_CHECKER(sequence_checker_); }; } // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_database_unittest.cc b/components/ai_chat/core/browser/ai_chat_database_unittest.cc index b16e3b6f2f42..86c0537f0522 100644 --- a/components/ai_chat/core/browser/ai_chat_database_unittest.cc +++ b/components/ai_chat/core/browser/ai_chat_database_unittest.cc @@ -1,188 +1,330 @@ -/* Copyright (c) 2023 The Brave Authors. All rights reserved. - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at https://mozilla.org/MPL/2.0/. */ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. #include "brave/components/ai_chat/core/browser/ai_chat_database.h" #include #include +#include +#include #include #include +#include "base/containers/flat_tree.h" #include "base/files/file_path.h" #include "base/files/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/path_service.h" +#include "base/run_loop.h" +#include "base/strings/strcat.h" +#include "base/strings/stringprintf.h" +#include "base/test/bind.h" +#include "base/test/task_environment.h" +#include "base/time/time.h" +#include "base/uuid.h" +#include "brave/components/ai_chat/core/browser/test_utils.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "components/os_crypt/async/browser/test_utils.h" #include "sql/test/test_helpers.h" #include "testing/gtest/include/gtest/gtest.h" -namespace { -int64_t GetInternalValue(const base::Time& time) { - return time.ToDeltaSinceWindowsEpoch().InMicroseconds(); -} -} // namespace - namespace ai_chat { - -class AIChatDatabaseTest : public testing::Test { +class AIChatDatabaseTest : public testing::Test, + public testing::WithParamInterface { public: AIChatDatabaseTest() = default; void SetUp() override { - ASSERT_TRUE(temp_directory_.CreateUniqueTempDir()); - // We take ownership of the path for adhoc inspection of the spitted - // db. This is useful to test sql queries. If you do this, make sure to - // comment out the base::DeletePathRecursively line. - path_ = temp_directory_.Take(); - db_ = std::make_unique(); - ASSERT_TRUE(db_->Init(db_file_path())); + CHECK(temp_directory_.CreateUniqueTempDir()); + os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting( + /*is_sync_for_unittests=*/true); + + // Create database when os_crypt is ready + base::RunLoop run_loop; + encryptor_ready_subscription_ = + os_crypt_->GetInstance(base::BindLambdaForTesting( + [&](os_crypt_async::Encryptor encryptor, bool success) { + ASSERT_TRUE(success); + db_ = std::make_unique(db_file_path(), + std::move(encryptor)); + run_loop.Quit(); + })); + run_loop.Run(); + + ASSERT_TRUE(IsOpen()); + + if (GetParam()) { + db_->DeleteAllData(); + } } void TearDown() override { db_.reset(); - ASSERT_TRUE(base::DeletePathRecursively(path_)); + CHECK(temp_directory_.Delete()); } - base::FilePath db_file_path() { return path_.AppendASCII("ai_chat"); } + bool IsOpen() { return db_->GetDB().is_open(); } + + base::FilePath db_file_path() { + return temp_directory_.GetPath().AppendASCII("ai_chat"); + } protected: + base::test::TaskEnvironment task_environment_{ + base::test::TaskEnvironment::TimeSource::MOCK_TIME}; base::ScopedTempDir temp_directory_; + std::unique_ptr os_crypt_; + base::CallbackListSubscription encryptor_ready_subscription_; std::unique_ptr db_; base::FilePath path_; }; -TEST_F(AIChatDatabaseTest, AddConversations) { - int64_t first_id = db_->AddConversation(mojom::Conversation::New( - INT64_C(1), base::Time::Now(), "Initial conversation", GURL(), - "Sample page content")); - EXPECT_EQ(first_id, INT64_C(1)); +INSTANTIATE_TEST_SUITE_P( + , + AIChatDatabaseTest, + // Run all tests with the initial schema and the schema created + // after calling DeleteAllData, to verify the schemas are the same + // and no tables are missing or different. + ::testing::Bool(), + [](const testing::TestParamInfo& info) { + return base::StringPrintf("DropTablesFirst_%s", + info.param ? "Yes" : "No"); + }); + +// Functions tested: +// - AddConversation +// - GetAllConversations +// - GetConversationData +// - AddConversationEntry +// - DeleteConversationEntry +// - DeleteConversation +TEST_P(AIChatDatabaseTest, AddAndGetConversationAndEntries) { + auto now = base::Time::Now(); + + // Do this for both associated content and non-associated content + for (bool has_content : {true, false}) { + std::string uuid = has_content ? "first" : "second"; + // Create the conversation metadata which gets persisted + // when the first entry is asked to be persisted. + // Put an incorrect time value to show that the time from the + // mojom::Conversation is not persisted and instead is taken from the most + // recent entry. + GURL page_url = GURL("https://example.com/page"); + std::string expected_contents = "Page contents"; + mojom::SiteInfoPtr associated_content = + has_content + ? mojom::SiteInfo::New( + std::nullopt, mojom::ContentType::PageContent, "page title", + page_url.host(), page_url, 62, true, true) + : mojom::SiteInfo::New( + std::nullopt, mojom::ContentType::PageContent, std::nullopt, + std::nullopt, std::nullopt, 0, false, false); + mojom::ConversationPtr metadata = + mojom::Conversation::New(uuid, "title", now - base::Hours(2), true, + std::move(associated_content)); + + // Persist the first entry (and get the response ready) + auto history = CreateSampleChatHistory(1u); + + auto ids = db_->AddConversation( + metadata->Clone(), + has_content ? std::make_optional(expected_contents) : std::nullopt, + history[0]->Clone()); + + // Verify database thinks it was successful + if (has_content) { + EXPECT_NE(ids.associated_content_id, -1); + } else { + EXPECT_EQ(ids.associated_content_id, -1); + } + + // Test getting the conversation metadata + std::vector conversations = + db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), has_content ? 1u : 2u); + auto& conversation = has_content ? conversations[0] : conversations[1]; + ExpectConversationEquals(FROM_HERE, conversation, metadata); + EXPECT_EQ(conversation->updated_time, history.front()->created_time); + + // Persist the response entry + EXPECT_TRUE(db_->AddConversationEntry(uuid, history[1]->Clone())); + + // Test getting the conversation entries + mojom::ConversationArchivePtr result = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result->entries, history); + EXPECT_EQ(result->associated_content.size(), has_content ? 1u : 0u); + if (has_content) { + EXPECT_EQ(result->associated_content[0]->content_id, + ids.associated_content_id); + EXPECT_EQ(result->associated_content[0]->content, expected_contents); + } + + // Add another pair of entries + auto next_history = CreateSampleChatHistory(1u, 1); + EXPECT_TRUE(db_->AddConversationEntry(uuid, next_history[0]->Clone())); + EXPECT_TRUE(db_->AddConversationEntry(uuid, next_history[1]->Clone())); + + // Verify all entries are returned + mojom::ConversationArchivePtr result_2 = db_->GetConversationData(uuid); + for (auto& entry : next_history) { + history.push_back(std::move(entry)); + } + ExpectConversationHistoryEquals(FROM_HERE, result_2->entries, history); + + // Edits (delete, re-add and check edit re-construction) + + // Delete the last response + EXPECT_TRUE( + db_->DeleteConversationEntry(result_2->entries.back()->uuid.value())); + + // Verify the last entry is gone + history.erase(history.end() - 1); + mojom::ConversationArchivePtr result_3 = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result_3->entries, history); + + // Add an edit to the last query + { + auto& last_query = result_3->entries.back(); + last_query->edits = std::vector{}; + last_query->edits->emplace_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + mojom::ConversationTurnVisibility::VISIBLE, "edited query 1", + std::nullopt, std::nullopt, base::Time::Now() + base::Minutes(121), + std::nullopt, false)); + EXPECT_TRUE(db_->DeleteConversationEntry(last_query->uuid.value())); + EXPECT_TRUE(db_->AddConversationEntry(uuid, last_query->Clone())); + } + // Verify the edit is persisted + mojom::ConversationArchivePtr result_4 = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result_4->entries, + result_3->entries); + + // Add another edit to test multiple edits for the same turn + { + auto& last_query = result_4->entries.back(); + last_query->edits->emplace_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + mojom::ConversationTurnVisibility::VISIBLE, "edited query 2", + std::nullopt, std::nullopt, base::Time::Now() + base::Minutes(122), + std::nullopt, false)); + EXPECT_TRUE(db_->DeleteConversationEntry(last_query->uuid.value())); + EXPECT_TRUE(db_->AddConversationEntry(uuid, last_query->Clone())); + } + // Verify multiple edits are persisted + mojom::ConversationArchivePtr result_5 = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result_5->entries, + result_4->entries); + } + + // Test deleting conversation (after loop so that we can test conversation + // entry selection with multiple conversations in the database). + EXPECT_TRUE(db_->DeleteConversation("second")); + // Verify no data for deleted conversation + mojom::ConversationArchivePtr conversation_data = + db_->GetConversationData("second"); + EXPECT_EQ(conversation_data->entries.size(), 0u); + EXPECT_EQ(conversation_data->associated_content.size(), 0u); + // Verify deleted conversation metadata not returned std::vector conversations = db_->GetAllConversations(); - EXPECT_EQ(conversations.size(), UINT64_C(1)); - int64_t second_id = db_->AddConversation(mojom::Conversation::New( - INT64_C(2), base::Time::Now(), "Another conversation", GURL(), - "Another sample page content")); - EXPECT_EQ(second_id, INT64_C(2)); - std::vector conversations_updated = - db_->GetAllConversations(); - EXPECT_EQ(conversations_updated.size(), UINT64_C(2)); + EXPECT_EQ(conversations.size(), 1u); + EXPECT_EQ(conversations[0]->uuid, "first"); + // Verify there's still data for other conversations + mojom::ConversationArchivePtr conversation_data_2 = + db_->GetConversationData("first"); + EXPECT_GT(conversation_data_2->entries.size(), 0u); + EXPECT_EQ(conversation_data_2->associated_content.size(), 1u); + // Delete last conversation + EXPECT_TRUE(db_->DeleteConversation("first")); + conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 0u); } -TEST_F(AIChatDatabaseTest, AddConversationEntries) { - static const int64_t kConversationId = INT64_C(1); - static const int64_t kConversationEntryId = INT64_C(1); - - static const GURL kURL = GURL("https://brave.com/"); - static const char kConversationTitle[] = "Summarizing Brave"; - static const char kPageContents[] = "Brave is a web browser."; - static const char kFirstResponse[] = "This is a generated response"; - static const char kSecondResponse[] = "This is a re-generated response"; - - static const base::Time kFirstTextCreatedAt(base::Time::Now()); - static const base::Time kSecondTextCreatedAt(base::Time::Now() + - base::Minutes(5)); - - // Add conversation - int64_t conversation_id = db_->AddConversation( - mojom::Conversation::New(kConversationId, kFirstTextCreatedAt, - kConversationTitle, kURL, kPageContents)); - EXPECT_EQ(conversation_id, kConversationId); - - // Add first conversation entry - { - std::vector texts; - texts.emplace_back(mojom::ConversationEntryText::New(1, kFirstTextCreatedAt, - kFirstResponse)); - - std::vector search_queries{"brave", "browser", "web"}; - std::vector events; - events.emplace_back(mojom::ConversationEntryEvent::NewSearchQueriesEvent( - mojom::SearchQueriesEvent::New(std::move(search_queries)))); - - int64_t entry_id = db_->AddConversationEntry( - conversation_id, - mojom::ConversationEntry::New( - INT64_C(-1), kFirstTextCreatedAt, mojom::CharacterType::ASSISTANT, - mojom::ActionType::UNSPECIFIED, std::nullopt /* selected_text */, - std::move(texts), std::move(events))); - EXPECT_EQ(entry_id, kConversationEntryId); - } +TEST_P(AIChatDatabaseTest, UpdateConversationTitle) { + std::vector initial_titles = {"first title", ""}; + for (const auto& initial_title : initial_titles) { + const std::string uuid = + base::StrCat({"for_conversation_title_", initial_title}); + const std::string updated_title = "updated title"; + mojom::ConversationPtr metadata = mojom::Conversation::New( + uuid, initial_title, base::Time::Now(), true, + mojom::SiteInfo::New(std::nullopt, mojom::ContentType::PageContent, + std::nullopt, std::nullopt, std::nullopt, 0, false, + false)); + + // Persist the first entry (and get the response ready) + auto history = CreateSampleChatHistory(1u); - // Get and verify conversations - { + db_->AddConversation(metadata->Clone(), std::nullopt, history[0]->Clone()); + + // Verify initial title std::vector conversations = db_->GetAllConversations(); - EXPECT_FALSE(conversations.empty()); - EXPECT_EQ(conversations[0]->id, INT64_C(1)); - EXPECT_EQ(conversations[0]->title, kConversationTitle); - EXPECT_EQ(conversations[0]->page_url->spec(), kURL.spec()); - EXPECT_EQ(conversations[0]->page_contents, kPageContents); - // A conversation is created with the first entry's date - EXPECT_EQ(GetInternalValue(conversations[0]->date), - GetInternalValue(kFirstTextCreatedAt)); - } + // get this conversation + auto* conversation = GetConversation(FROM_HERE, conversations, uuid); + EXPECT_EQ(conversation->title, initial_title); - // Get and verify conversation entries - { - std::vector entries = - db_->GetConversationEntries(conversation_id); - EXPECT_EQ(UINT64_C(1), entries.size()); - EXPECT_EQ(entries[0]->character_type, mojom::CharacterType::ASSISTANT); - EXPECT_EQ(GetInternalValue(entries[0]->date), - GetInternalValue(kFirstTextCreatedAt)); - EXPECT_EQ(UINT64_C(1), entries[0]->texts.size()); - EXPECT_EQ(entries[0]->texts[0]->text, kFirstResponse); - EXPECT_EQ(entries[0]->action_type, mojom::ActionType::UNSPECIFIED); - EXPECT_TRUE(entries[0]->selected_text.value().empty()); - - // Verify search queries - EXPECT_TRUE(entries.front()->events.has_value()); - EXPECT_EQ(UINT64_C(3), entries[0] - ->events.value()[0] - ->get_search_queries_event() - ->search_queries.size()); + // Update title + EXPECT_TRUE(db_->UpdateConversationTitle(uuid, updated_title)); + // Verify + conversations = db_->GetAllConversations(); + conversation = GetConversation(FROM_HERE, conversations, uuid); + EXPECT_EQ(conversation->title, updated_title); } +} - // Add another text to the first entry - db_->AddConversationEntryText( - UINT64_C(1), mojom::ConversationEntryText::New( - UINT64_C(1), kSecondTextCreatedAt, kSecondResponse)); +TEST_P(AIChatDatabaseTest, AddOrUpdateAssociatedContent) { + std::string uuid = "for_associated_content"; + GURL page_url = GURL("https://example.com/page"); + mojom::ConversationPtr metadata = mojom::Conversation::New( + uuid, "title", base::Time::Now() - base::Hours(2), true, + mojom::SiteInfo::New(std::nullopt, mojom::ContentType::PageContent, + "page title", page_url.host(), page_url, 62, true, + true)); - { - auto entries = db_->GetConversationEntries(conversation_id); - EXPECT_EQ(UINT64_C(2), entries[0]->texts.size()); - EXPECT_EQ(entries[0]->texts[1]->text, kSecondResponse); - } + auto history = CreateSampleChatHistory(1u); - // Add another entry - { - base::Time created_at(base::Time::Now()); - std::string text("Is this a question?"); - std::string selected_text("The brown fox jumps over the lazy dog"); + std::string expected_contents = "First contents"; + AIChatDatabase::AddConversationResult ids = db_->AddConversation( + metadata->Clone(), std::make_optional(expected_contents), + history[0]->Clone()); - std::vector texts; - texts.emplace_back(mojom::ConversationEntryText::New(1, created_at, text)); + EXPECT_GT(ids.associated_content_id, -1); - int64_t entry_id = db_->AddConversationEntry( - conversation_id, - mojom::ConversationEntry::New( - INT64_C(-1), created_at, mojom::CharacterType::HUMAN, - mojom::ActionType::CASUALIZE, selected_text, std::move(texts), - std::nullopt /* search_query */)); + metadata->associated_content->id = ids.associated_content_id; - EXPECT_EQ(entry_id, 2); - } + // Verify data is persisted + mojom::ConversationArchivePtr result = db_->GetConversationData(uuid); + EXPECT_EQ(result->associated_content.size(), 1u); + EXPECT_EQ(result->associated_content[0]->content_id, + ids.associated_content_id); + EXPECT_EQ(result->associated_content[0]->content, expected_contents); + auto conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 1u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata); - // Get and verify conversation entries - { - const std::vector& entries = - db_->GetConversationEntries(conversation_id); - EXPECT_EQ(UINT64_C(2), entries.size()); - EXPECT_EQ(entries[1]->selected_text.value(), - "The brown fox jumps over the lazy dog"); - } + // Change data and call AddOrUpdateAssociatedContent + expected_contents = "Second contents"; + metadata->associated_content->content_used_percentage = 50; + metadata->associated_content->is_content_refined = false; + db_->AddOrUpdateAssociatedContent(uuid, metadata->associated_content->Clone(), + std::make_optional(expected_contents)); + // Verify data is changed + result = db_->GetConversationData(uuid); + EXPECT_EQ(result->associated_content.size(), 1u); + EXPECT_EQ(result->associated_content[0]->content_id, + metadata->associated_content->id.value()); + EXPECT_EQ(result->associated_content[0]->content, expected_contents); + conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 1u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata); } } // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_service.cc b/components/ai_chat/core/browser/ai_chat_service.cc index cba85d97a89b..73833cbe6e86 100644 --- a/components/ai_chat/core/browser/ai_chat_service.cc +++ b/components/ai_chat/core/browser/ai_chat_service.cc @@ -11,12 +11,17 @@ #include #include +#include "base/functional/callback_helpers.h" +#include "base/notreached.h" #include "base/task/sequenced_task_runner.h" #include "base/task/single_thread_task_runner.h" #include "base/task/task_runner.h" +#include "base/task/task_traits.h" +#include "base/task/thread_pool.h" #include "base/time/time.h" #include "base/uuid.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" #include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/model_service.h" @@ -25,11 +30,15 @@ #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" +#include "components/os_crypt/async/browser/os_crypt_async.h" #include "components/prefs/pref_service.h" namespace ai_chat { namespace { +constexpr base::FilePath::StringPieceType kDBFileName = + FILE_PATH_LITERAL("AIChat"); + static const auto kAllowedSchemes = base::MakeFixedFlatSet( {url::kHttpsScheme, url::kHttpScheme, url::kFileScheme, url::kDataScheme}); @@ -40,8 +49,10 @@ AIChatService::AIChatService( std::unique_ptr ai_chat_credential_manager, PrefService* profile_prefs, AIChatMetrics* ai_chat_metrics, + os_crypt_async::OSCryptAsync* os_crypt_async, scoped_refptr url_loader_factory, - std::string_view channel_string) + std::string_view channel_string, + base::FilePath profile_path) : model_service_(model_service), profile_prefs_(profile_prefs), ai_chat_metrics_(ai_chat_metrics), @@ -56,6 +67,15 @@ AIChatService::AIChatService( prefs::kLastAcceptedDisclaimer, base::BindRepeating(&AIChatService::OnUserOptedIn, weak_ptr_factory_.GetWeakPtr())); + + if (features::IsAIChatHistoryEnabled()) { + DVLOG(0) << "Initializing OS Crypt Async"; + encryptor_ready_subscription_ = os_crypt_async->GetInstance( + base::BindOnce(&AIChatService::OnOsCryptAsyncReady, + weak_ptr_factory_.GetWeakPtr(), profile_path)); + // Don't init DB until oscrypt is ready - we don't want to use the DB + // if we can't use encryption. + } } AIChatService::~AIChatService() = default; @@ -73,6 +93,7 @@ void AIChatService::Bind(mojo::PendingReceiver receiver) { void AIChatService::Shutdown() { // Disconnect remotes receivers_.ClearWithReason(0, "Shutting down"); + db_task_runner_.reset(); } ConversationHandler* AIChatService::CreateConversation() { @@ -80,8 +101,12 @@ ConversationHandler* AIChatService::CreateConversation() { std::string conversation_uuid = uuid.AsLowercaseString(); // Create the conversation metadata { - mojom::ConversationPtr conversation = mojom::Conversation::New( - conversation_uuid, "", "", base::Time::Now(), false); + mojom::SiteInfoPtr non_content = mojom::SiteInfo::New( + std::nullopt, mojom::ContentType::PageContent, std::nullopt, + std::nullopt, std::nullopt, 0, false, false); + mojom::ConversationPtr conversation = + mojom::Conversation::New(conversation_uuid, "", base::Time::Now(), + false, std::move(non_content)); conversations_.insert_or_assign(conversation_uuid, std::move(conversation)); } mojom::Conversation* conversation = @@ -110,14 +135,62 @@ ConversationHandler* AIChatService::CreateConversation() { } ConversationHandler* AIChatService::GetConversation( - const std::string& conversation_uuid) { - auto conversation_handler_it = conversation_handlers_.find(conversation_uuid); + std::string_view conversation_uuid) { + auto conversation_handler_it = + conversation_handlers_.find(conversation_uuid.data()); if (conversation_handler_it == conversation_handlers_.end()) { return nullptr; } return conversation_handler_it->second.get(); } +void AIChatService::GetConversation( + std::string_view conversation_uuid, + base::OnceCallback callback) { + if (ConversationHandler* cached_conversation = + GetConversation(conversation_uuid)) { + DVLOG(4) << __func__ << " found cached conversation for " + << conversation_uuid; + std::move(callback).Run(cached_conversation); + return; + } + // Load from database + if (!ai_chat_db_ || !conversations_.contains(conversation_uuid.data())) { + LOG(ERROR) << "Asked for conversation that does not exist: " + << conversation_uuid; + std::move(callback).Run(nullptr); + return; + } + CHECK(ai_chat_db_); + mojom::ConversationPtr& metadata = + conversations_.at(conversation_uuid.data()); + // Get archive content and conversation entries + ai_chat_db_.AsyncCall(&AIChatDatabase::GetConversationData) + .WithArgs(metadata->uuid) + .Then(base::BindOnce(&AIChatService::OnConversationDataReceived, + weak_ptr_factory_.GetWeakPtr(), metadata->uuid, + std::move(callback))); +} + +void AIChatService::OnConversationDataReceived( + std::string conversation_uuid, + base::OnceCallback callback, + mojom::ConversationArchivePtr data) { + DVLOG(4) << __func__ << " for " << conversation_uuid + << " with data: " << data->entries.size() << " entries and " + << data->associated_content.size() << " contents"; + mojom::Conversation* conversation = + conversations_.at(conversation_uuid.data()).get(); + std::unique_ptr conversation_handler = + std::make_unique( + conversation, this, model_service_, credential_manager_.get(), + feedback_api_.get(), url_loader_factory_, std::move(data)); + conversation_observations_.AddObservation(conversation_handler.get()); + conversation_handlers_.insert_or_assign(conversation_uuid.data(), + std::move(conversation_handler)); + std::move(callback).Run(GetConversation(conversation_uuid)); +} + ConversationHandler* AIChatService::GetOrCreateConversationHandlerForContent( int associated_content_id, base::WeakPtr @@ -155,6 +228,88 @@ ConversationHandler* AIChatService::CreateConversationHandlerForContent( return conversation; } +void AIChatService::ClearAllHistory() { + // Delete opt-in state + profile_prefs_->ClearPref(prefs::kLastAcceptedDisclaimer); + + // Delete in-memory data + conversation_observations_.RemoveAllObservations(); + conversation_handlers_.clear(); + conversations_.clear(); + content_conversations_.clear(); + OnConversationListChanged(); + + // Delete database data + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(&AIChatDatabase::DeleteAllData) + .Then(base::BindOnce([](bool success) { + if (!success) { + LOG(ERROR) << "Failed to delete all AIChatService database data"; + } + })); + } + if (ai_chat_metrics_ != nullptr) { + ai_chat_metrics_->RecordReset(); + } +} + +void AIChatService::OnOsCryptAsyncReady(const base::FilePath& storage_dir, + os_crypt_async::Encryptor encryptor, + bool success) { + CHECK(features::IsAIChatHistoryEnabled()); + if (!success) { + LOG(ERROR) << "Failed to initialize AIChat DB due to OSCrypt failure"; + return; + } + + ai_chat_db_ = base::SequenceBound( + GetDBTaskRunner(), storage_dir.Append(kDBFileName), std::move(encryptor)); + + ai_chat_db_.AsyncCall(&AIChatDatabase::IsInitialized) + .Then(base::BindOnce(&AIChatService::OnDBInit, + weak_ptr_factory_.GetWeakPtr())); +} + +void AIChatService::OnDBInit(bool success) { + CHECK(ai_chat_db_); + if (!success) { + LOG(ERROR) << "Failed to initialize AIChat DB"; + // Ensure any pending tasks are cancelled. Sometimes, especially during unit + // tests that are very quick to run, initilization might fail because of + // shutdown but the next task in the sequence still executes, causing a + // crash because of bad database state. + ai_chat_db_.Reset(); + return; + } + DVLOG(0) << "AIChat DB initialized"; + // Load all conversations + ai_chat_db_.AsyncCall(&AIChatDatabase::GetAllConversations) + .Then(base::BindOnce(&AIChatService::OnAllConversationsRetrieved, + weak_ptr_factory_.GetWeakPtr())); +} + +void AIChatService::OnAllConversationsRetrieved( + std::vector conversations) { + for (auto& conversation : conversations) { + DVLOG(2) << "Loaded conversation " << conversation->uuid + << " with details: " << "\n has content: " + << conversation->has_content + << "\n last updated: " << conversation->updated_time + << "\n title: " << conversation->title; + if (conversations_.contains(conversation->uuid)) { + // This can occur during testing when conversations are created and + // persisted before the database is initially queried. We don't want to + // overwrite the existing in-memory items as they are likely + // being referenced by ConversationHandler instances. + continue; + } + conversations_.insert_or_assign(conversation->uuid, + std::move(conversation)); + } + DVLOG(1) << "Loaded " << conversations.size() << " conversations."; + OnConversationListChanged(); +} + void AIChatService::MaybeAssociateContentWithConversation( ConversationHandler* conversation, int associated_content_id, @@ -224,12 +379,10 @@ void AIChatService::DismissPremiumPrompt() { void AIChatService::DeleteConversation(const std::string& id) { ConversationHandler* conversation_handler = conversation_handlers_.at(id).get(); - if (!conversation_handler) { - return; + if (conversation_handlers_.contains(id)) { + conversation_observations_.RemoveObservation(conversation_handler); + conversation_handlers_.erase(id); } - // conversation_handler->OnConversationDeleted(); - conversation_observations_.RemoveObservation(conversation_handler); - conversation_handlers_.erase(id); conversations_.erase(id); DVLOG(1) << "Erased conversation due to deletion request (" << id << "). Now have " << conversations_.size() @@ -237,6 +390,22 @@ void AIChatService::DeleteConversation(const std::string& id) { << conversation_handlers_.size() << " ConversationHandler instances."; OnConversationListChanged(); + // Update database + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(&AIChatDatabase::DeleteConversation) + .WithArgs(id) + .Then(base::BindOnce( + [](std::string conversation_uuid, bool success) { + if (!success) { + LOG(ERROR) << "Failed to delete conversation from database " + << conversation_uuid; + } else { + DVLOG(1) << "Deleted conversation from database: " + << conversation_uuid; + } + }, + std::move(id))); + } } void AIChatService::RenameConversation(const std::string& id, @@ -274,66 +443,165 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback, void AIChatService::MaybeEraseConversation( ConversationHandler* conversation_handler) { if (!conversation_handler->IsAnyClientConnected() && - (!features::IsAIChatHistoryEnabled() || - !conversation_handler->HasAnyHistory())) { - // Can erase because no active UI and no history, so it's - // not a real / persistable conversation + !conversation_handler->IsRequestInProgress()) { + // Can erase handler because no active UI + bool has_history = conversation_handler->HasAnyHistory(); auto uuid = conversation_handler->get_conversation_uuid(); conversation_observations_.RemoveObservation(conversation_handler); conversation_handlers_.erase(uuid); - conversations_.erase(uuid); - std::erase_if(content_conversations_, - [&uuid](const auto& kv) { return kv.second == uuid; }); - DVLOG(1) << "Erased conversation (" << uuid << "). Now have " + DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have " << conversations_.size() << " Conversation metadata items and " << conversation_handlers_.size() << " ConversationHandler instances."; - OnConversationListChanged(); + if (!features::IsAIChatHistoryEnabled() || !has_history) { + // Can erase because no active UI and no history, so it's + // not a real / persistable conversation + conversations_.erase(uuid); + std::erase_if(content_conversations_, + [&uuid](const auto& kv) { return kv.second == uuid; }); + DVLOG(1) << "Erased conversation (" << uuid << "). Now have " + << conversations_.size() << " Conversation metadata items and " + << conversation_handlers_.size() + << " ConversationHandler instances."; + OnConversationListChanged(); + } + } else { + DVLOG(4) << "Not unloading conversation (" + << conversation_handler->get_conversation_uuid() + << ") from memory. Has active clients: " + << (conversation_handler->IsAnyClientConnected() ? "yes" : "no") + << " Request is in progress: " + << (conversation_handler->IsRequestInProgress() ? "yes" : "no"); + } +} + +base::SequencedTaskRunner* AIChatService::GetDBTaskRunner() { + if (!db_task_runner_) { + db_task_runner_ = base::ThreadPool::CreateSequencedTaskRunner( + {base::MayBlock(), base::WithBaseSyncPrimitives(), + base::TaskPriority::BEST_EFFORT, + base::TaskShutdownBehavior::BLOCK_SHUTDOWN}); + } + + return db_task_runner_.get(); +} + +void AIChatService::OnRequestInProgressChanged(ConversationHandler* handler, + bool in_progress) { + // We don't unload a conversation if it has a request in progress, so check + // again when that changes. + if (!in_progress) { + MaybeEraseConversation(handler); } } -void AIChatService::OnConversationEntriesChanged( +void AIChatService::OnConversationEntryAdded( ConversationHandler* handler, - std::vector entries) { + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value) { auto conversation_it = conversations_.find(handler->get_conversation_uuid()); CHECK(conversation_it != conversations_.end()); - auto& conversation = conversation_it->second; - if (!entries.empty()) { - bool notify = false; - // First time a new entry has appeared - if (!conversation->has_content && entries.size() == 1) { - // This conversation is visible once the first response begins - conversation->has_content = true; - notify = true; - if (ai_chat_metrics_ != nullptr) { - if (entries.size() == 1) { - ai_chat_metrics_->RecordNewChat(); - } - // Each time the user submits an entry - // TODO(petemill): Verify this won't get called multiple times for the - // same entry. - if (entries.back()->character_type == mojom::CharacterType::HUMAN) { - ai_chat_metrics_->RecordNewPrompt(); - } - if (entries.size() >= 2) { - if (conversation->summary.size() < 70) { - for (const auto& entry : entries) { - if (entry->character_type == mojom::CharacterType::ASSISTANT && - !entry->text.empty()) { - conversation->summary = entry->text.substr(0, 70); - notify = true; - break; - } - } - } - } - if (notify) { - OnConversationListChanged(); + mojom::ConversationPtr& conversation = conversation_it->second; + + auto& history = handler->GetConversationHistory(); + + if (!conversation->has_content && history.size() > 0) { + HandleFirstEntry(handler, entry, associated_content_value, conversation); + if (ai_chat_metrics_ != nullptr) { + if (handler->GetConversationHistory().size() == 1) { + ai_chat_metrics_->RecordNewChat(); + } + if (entry->character_type == mojom::CharacterType::HUMAN) { + ai_chat_metrics_->RecordNewPrompt(); + } + } + } else { + HandleNewEntry(handler, entry, associated_content_value, conversation); + } + + conversation->has_content = true; + conversation->updated_time = entry->created_time; + OnConversationListChanged(); +} + +void AIChatService::HandleFirstEntry( + ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation) { + DVLOG(1) << __func__ << " Conversation " << conversation->uuid + << " being persisted for first time."; + CHECK(entry->uuid.has_value()); + // We can persist the conversation metadata for the first time as well as the + // entry. + auto on_conversation_added = base::BindOnce( + [](mojom::Conversation* metadata, + AIChatDatabase::AddConversationResult db_result) { + if (metadata && db_result.associated_content_id != -1) { + metadata->associated_content->id = db_result.associated_content_id; } + }, + conversation.get()); + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(&AIChatDatabase::AddConversation) + .WithArgs(conversation->Clone(), + std::optional(associated_content_value), + entry->Clone()) + .Then(std::move(on_conversation_added)); + } +} + +void AIChatService::HandleNewEntry( + ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation) { + CHECK(entry->uuid.has_value()); + DVLOG(1) << __func__ << " Conversation " << conversation->uuid + << " persisting new entry. Count of entries: " + << handler->GetConversationHistory().size(); + + // Persist the new entry and update the associated content data, if present + if (ai_chat_db_) { + auto on_entry_added = base::BindOnce([](bool success) { + if (!success) { + DVLOG(0) << "Failed to add conversation entry"; } + }); + ai_chat_db_.AsyncCall(&AIChatDatabase::AddConversationEntry) + .WithArgs(handler->get_conversation_uuid(), entry.Clone(), std::nullopt) + .Then(std::move(on_entry_added)); + + if (associated_content_value.has_value() && + conversation->associated_content->is_content_association_possible) { + auto on_content_updated = base::BindOnce( + [](mojom::Conversation* metadata, int32_t content_id) { + if (metadata && content_id != -1) { + metadata->associated_content->id = content_id; + } + }, + conversation.get()); + + ai_chat_db_.AsyncCall(&AIChatDatabase::AddOrUpdateAssociatedContent) + .WithArgs(conversation->uuid, + conversation->associated_content->Clone(), + std::optional(associated_content_value)) + .Then(std::move(on_content_updated)); } - // TODO(petemill): Persist the entries, but consider receiving finer grained - // entry update events. + } +} + +void AIChatService::OnConversationEntryRemoved(ConversationHandler* handler, + std::string entry_uuid) { + // Persist the removal + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(&AIChatDatabase::DeleteConversationEntry) + .WithArgs(entry_uuid) + .Then(base::BindOnce([](bool success) { + if (!success) { + DVLOG(0) << "Failed to delete conversation entry"; + } + })); } } @@ -349,7 +617,19 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler, CHECK(conversation_it != conversations_.end()); auto& conversation = conversation_it->second; conversation->title = title; + OnConversationListChanged(); + + // Persist the change + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(&AIChatDatabase::UpdateConversationTitle) + .WithArgs(handler->get_conversation_uuid(), std::move(title)) + .Then(base::BindOnce([](bool success) { + if (!success) { + DVLOG(0) << "Failed to update conversation title"; + } + })); + } } void AIChatService::GetVisibleConversations( @@ -365,12 +645,20 @@ void AIChatService::BindConversation( const std::string& uuid, mojo::PendingReceiver receiver, mojo::PendingRemote conversation_ui_handler) { - ConversationHandler* conversation = GetConversation(uuid); - if (!conversation) { - return; - } - CHECK(conversation) << "Asked to bind a conversation which doesn't exist"; - conversation->Bind(std::move(receiver), std::move(conversation_ui_handler)); + GetConversation( + std::move(uuid), + base::BindOnce( + [](mojo::PendingReceiver receiver, + mojo::PendingRemote conversation_ui_handler, + ConversationHandler* handler) { + if (!handler) { + DVLOG(0) << "Failed to get conversation for binding"; + return; + } + handler->Bind(std::move(receiver), + std::move(conversation_ui_handler)); + }, + std::move(receiver), std::move(conversation_ui_handler))); } void AIChatService::BindObserver( @@ -423,7 +711,7 @@ std::vector AIChatService::FilterVisibleConversations() { conversations.push_back(conversation.get()); } base::ranges::sort(conversations, std::greater<>(), - &mojom::Conversation::created_time); + &mojom::Conversation::updated_time); return conversations; } diff --git a/components/ai_chat/core/browser/ai_chat_service.h b/components/ai_chat/core/browser/ai_chat_service.h index fb4d16e80d31..cd9b0c7ed155 100644 --- a/components/ai_chat/core/browser/ai_chat_service.h +++ b/components/ai_chat/core/browser/ai_chat_service.h @@ -11,10 +11,15 @@ #include #include +#include "base/callback_list.h" +#include "base/functional/callback_forward.h" #include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" #include "base/scoped_multi_source_observation.h" +#include "base/task/sequenced_task_runner.h" +#include "base/threading/sequence_bound.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" #include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" @@ -28,6 +33,11 @@ #include "mojo/public/cpp/bindings/receiver_set.h" #include "services/network/public/cpp/shared_url_loader_factory.h" +namespace os_crypt_async { +class Encryptor; +class OSCryptAsync; +} // namespace os_crypt_async + namespace ai_chat { class ModelService; @@ -45,8 +55,10 @@ class AIChatService : public KeyedService, std::unique_ptr ai_chat_credential_manager, PrefService* profile_prefs, AIChatMetrics* ai_chat_metrics, + os_crypt_async::OSCryptAsync* os_crypt_async, scoped_refptr url_loader_factory, - std::string_view channel_string); + std::string_view channel_string, + base::FilePath profile_path); ~AIChatService() override; AIChatService(const AIChatService&) = delete; @@ -59,9 +71,14 @@ class AIChatService : public KeyedService, void Shutdown() override; // ConversationHandler::Observer - void OnConversationEntriesChanged( + void OnRequestInProgressChanged(ConversationHandler* handler, + bool in_progress) override; + void OnConversationEntryAdded( ConversationHandler* handler, - std::vector entries) override; + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value) override; + void OnConversationEntryRemoved(ConversationHandler* handler, + std::string entry_uuid) override; void OnClientConnectionChanged(ConversationHandler* handler) override; void OnConversationTitleChanged(ConversationHandler* handler, std::string title) override; @@ -69,7 +86,9 @@ class AIChatService : public KeyedService, // Adds new conversation and returns the handler ConversationHandler* CreateConversation(); - ConversationHandler* GetConversation(const std::string& uuid); + ConversationHandler* GetConversation(std::string_view uuid); + void GetConversation(std::string_view conversation_uuid, + base::OnceCallback); // Creates and owns a ConversationHandler if one hasn't been made for the // associated_content_id yet. |associated_content_id| should not be stored. It @@ -87,6 +106,9 @@ class AIChatService : public KeyedService, base::WeakPtr associated_content); + // Removes all in-memory and persisted data for all conversations + void ClearAllHistory(); + // mojom::Service void MarkAgreementAccepted() override; void GetVisibleConversations( @@ -120,11 +142,35 @@ class AIChatService : public KeyedService, size_t GetInMemoryConversationCountForTesting(); private: + // Called when the encryptor is ready. + void OnOsCryptAsyncReady(const base::FilePath& storage_dir, + os_crypt_async::Encryptor encryptor, + bool success); + + void OnDBInit(bool success); + void OnAllConversationsRetrieved( + std::vector conversations); + void OnConversationDataReceived( + std::string conversation_uuid, + base::OnceCallback callback, + mojom::ConversationArchivePtr data); + void MaybeAssociateContentWithConversation( ConversationHandler* conversation, int associated_content_id, base::WeakPtr associated_content); + void MaybeEraseConversation(ConversationHandler* conversation); + void HandleFirstEntry( + ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation); + void HandleNewEntry(ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation); + void OnUserOptedIn(); void OnSkusServiceReceived( SkusServiceGetter getter, @@ -134,7 +180,8 @@ class AIChatService : public KeyedService, void OnPremiumStatusReceived(GetPremiumStatusCallback callback, mojom::PremiumStatus status, mojom::PremiumInfoPtr info); - void MaybeEraseConversation(ConversationHandler* conversation); + + base::SequencedTaskRunner* GetDBTaskRunner(); raw_ptr model_service_; raw_ptr profile_prefs_; @@ -145,6 +192,9 @@ class AIChatService : public KeyedService, std::unique_ptr feedback_api_; std::unique_ptr credential_manager_; + base::SequenceBound ai_chat_db_; + scoped_refptr db_task_runner_; + // All conversation metadata. Mainly just titles and uuids. Key is uuid std::map conversations_; @@ -173,6 +223,9 @@ class AIChatService : public KeyedService, // often (whenever UI is focused). mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Unknown; + // Maintains the subscription for `OSCryptAsync` and cancels upon destruction. + base::CallbackListSubscription encryptor_ready_subscription_; + base::WeakPtrFactory weak_ptr_factory_{this}; }; diff --git a/components/ai_chat/core/browser/ai_chat_service_unittest.cc b/components/ai_chat/core/browser/ai_chat_service_unittest.cc index 0f6a9fc0d3bf..f958219e42e7 100644 --- a/components/ai_chat/core/browser/ai_chat_service_unittest.cc +++ b/components/ai_chat/core/browser/ai_chat_service_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "base/files/scoped_temp_dir.h" #include "base/functional/bind.h" #include "base/functional/callback.h" #include "base/functional/overloaded.h" @@ -30,10 +31,14 @@ #include "base/time/time.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h" +#include "brave/components/ai_chat/core/browser/test_utils.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" +#include "components/os_crypt/async/browser/os_crypt_async.h" +#include "components/os_crypt/async/browser/test_utils.h" #include "components/sync_preferences/testing_pref_service_syncable.h" #include "mojo/public/cpp/bindings/receiver.h" #include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h" @@ -54,40 +59,6 @@ namespace ai_chat { namespace { -std::vector CreateSampleHistory() { - auto created_time1 = base::Time::Now(); - std::vector history; - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, - mojom::ConversationTurnVisibility::VISIBLE, "prompt1", std::nullopt, - std::nullopt, created_time1, std::nullopt, false)); - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "answer1", std::nullopt, - std::nullopt, base::Time::Now(), std::nullopt, false)); - return history; -} - -class MockConversationHandlerObserver : public ConversationHandler::Observer { - public: - MockConversationHandlerObserver() = default; - ~MockConversationHandlerObserver() override = default; - - void Observe(ConversationHandler* conversation) { - conversation_observations_.AddObservation(conversation); - } - - MOCK_METHOD(void, - OnClientConnectionChanged, - (ConversationHandler*), - (override)); - - private: - base::ScopedMultiSourceObservation - conversation_observations_{this}; -}; - class MockAIChatCredentialManager : public AIChatCredentialManager { public: using AIChatCredentialManager::AIChatCredentialManager; @@ -169,6 +140,8 @@ class MockConversationHandlerClient : public mojom::ConversationUI { MOCK_METHOD(void, OnFaviconImageDataChanged, (), (override)); + MOCK_METHOD(void, OnConversationDeleted, (), (override)); + private: mojo::Receiver conversation_ui_receiver_{this}; mojo::Remote conversation_handler_remote_; @@ -221,13 +194,39 @@ class AIChatServiceUnitTest : public testing::Test, } void SetUp() override { + CHECK(temp_directory_.CreateUniqueTempDir()); + DVLOG(0) << "Temp directory: " << temp_directory_.GetPath().value(); prefs::RegisterProfilePrefs(prefs_.registry()); prefs::RegisterLocalStatePrefs(local_state_.registry()); ModelService::RegisterProfilePrefs(prefs_.registry()); + os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting( + /*is_sync_for_unittests=*/true); + shared_url_loader_factory_ = base::MakeRefCounted( &url_loader_factory_); + + model_service_ = std::make_unique(&prefs_); + + CreateService(); + + if (is_opted_in_) { + EmulateUserOptedIn(); + } else { + EmulateUserOptedOut(); + } + } + + void TearDown() override { + ai_chat_service_.reset(); + // Allow handles on the db to be released, otherwise for very quick + // tests, we get crashes on temp_directory_.Delete(). + task_environment_.RunUntilIdle(); + CHECK(temp_directory_.Delete()); + } + + void CreateService() { std::unique_ptr credential_manager = std::make_unique(base::NullCallback(), &local_state_); @@ -239,23 +238,23 @@ class AIChatServiceUnitTest : public testing::Test, std::move(premium_info)); }); - model_service_ = std::make_unique(&prefs_); - ai_chat_service_ = std::make_unique( model_service_.get(), std::move(credential_manager), &prefs_, nullptr, - shared_url_loader_factory_, ""); + os_crypt_.get(), shared_url_loader_factory_, "", + temp_directory_.GetPath()); client_ = std::make_unique>(ai_chat_service_.get()); + } - if (is_opted_in_) { - EmulateUserOptedIn(); - } else { - EmulateUserOptedOut(); - } + void ResetService() { + ai_chat_service_.reset(); + task_environment_.RunUntilIdle(); + CreateService(); } - void ExpectVisibleConversationsSize(size_t size) { + void ExpectVisibleConversationsSize(base::Location location, size_t size) { + SCOPED_TRACE(testing::Message() << location.ToString()); base::RunLoop run_loop; client_->service_remote()->GetVisibleConversations( base::BindLambdaForTesting( @@ -302,8 +301,6 @@ class AIChatServiceUnitTest : public testing::Test, void EmulateUserOptedOut() { ::ai_chat::SetUserOptedIn(&prefs_, false); } - void TearDown() override { ai_chat_service_.reset(); } - protected: base::test::TaskEnvironment task_environment_; std::unique_ptr ai_chat_service_; @@ -311,6 +308,7 @@ class AIChatServiceUnitTest : public testing::Test, std::unique_ptr> client_; sync_preferences::TestingPrefServiceSyncable prefs_; sync_preferences::TestingPrefServiceSyncable local_state_; + std::unique_ptr os_crypt_; network::TestURLLoaderFactory url_loader_factory_; scoped_refptr shared_url_loader_factory_; data_decoder::test::InProcessDataDecoder in_process_data_decoder_; @@ -318,6 +316,7 @@ class AIChatServiceUnitTest : public testing::Test, private: base::test::ScopedFeatureList scoped_feature_list_; + base::ScopedTempDir temp_directory_; }; INSTANTIATE_TEST_SUITE_P( @@ -339,7 +338,7 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_NoMessages) { ConversationHandler* conversation_handler2 = CreateConversation(); // Shouldn't "display" any conversations without messages - ExpectVisibleConversationsSize(0); + ExpectVisibleConversationsSize(FROM_HERE, 0); // Before connecting any clients to the conversations, none should be deleted EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 2u); @@ -369,12 +368,12 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) { .Times(testing::AtLeast(1)); ConversationHandler* conversation_handler1 = CreateConversation(); - conversation_handler1->SetChatHistoryForTesting(CreateSampleHistory()); + conversation_handler1->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); ConversationHandler* conversation_handler2 = CreateConversation(); - conversation_handler2->SetChatHistoryForTesting(CreateSampleHistory()); + conversation_handler2->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); - ExpectVisibleConversationsSize(2u); + ExpectVisibleConversationsSize(FROM_HERE, 2u); // Before connecting any clients to the conversations, none should be deleted EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 2u); @@ -382,19 +381,18 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) { // Connect a client then disconnect auto client1 = CreateConversationClient(conversation_handler1); DisconnectConversationClient(client1.get()); - // Only 1 should be deleted, or none if we're preserving history - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), - IsAIChatHistoryEnabled() ? 2u : 1u); + // Only 1 should be deleted, whether we preserve history or not (is preserved + // in the database). + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); - ExpectVisibleConversationsSize(IsAIChatHistoryEnabled() ? 2u : 1u); + ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 2u : 1u); // Connect a client then disconnect auto client2 = CreateConversationClient(conversation_handler2); DisconnectConversationClient(client2.get()); - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), - IsAIChatHistoryEnabled() ? 2u : 0u); + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u); - ExpectVisibleConversationsSize(IsAIChatHistoryEnabled() ? 2u : 0u); + ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 2u : 0u); testing::Mock::VerifyAndClearExpectations(client_.get()); task_environment_.RunUntilIdle(); @@ -492,4 +490,46 @@ TEST_P(AIChatServiceUnitTest, run_loop.Run(); } +TEST_P(AIChatServiceUnitTest, GetConversation_AfterRestart) { + auto history = CreateSampleChatHistory(1u); + std::string uuid; + { + ConversationHandler* conversation_handler = CreateConversation(); + uuid = conversation_handler->get_conversation_uuid(); + auto client = CreateConversationClient(conversation_handler); + conversation_handler->SetChatHistoryForTesting(CloneHistory(history)); + ExpectVisibleConversationsSize(FROM_HERE, 1); + DisconnectConversationClient(client.get()); + } + ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 1 : 0); + + // Allow entries to finish being persisted before restarting service + task_environment_.RunUntilIdle(); + DVLOG(0) << "Restarting service"; + ResetService(); + + if (IsAIChatHistoryEnabled()) { + EXPECT_CALL(*client_, OnConversationListChanged(testing::SizeIs(1))) + .Times(testing::AtLeast(1)); + } else { + EXPECT_CALL(*client_, OnConversationListChanged).Times(0); + } + task_environment_.RunUntilIdle(); + // Can get conversation data + if (IsAIChatHistoryEnabled()) { + base::RunLoop run_loop; + ai_chat_service_->GetConversation( + uuid, base::BindLambdaForTesting( + [&](ConversationHandler* conversation_handler) { + EXPECT_TRUE(conversation_handler); + ExpectConversationHistoryEquals( + FROM_HERE, + conversation_handler->GetConversationHistory(), + history); + run_loop.Quit(); + })); + run_loop.Run(); + } +} + } // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_storage_service.cc b/components/ai_chat/core/browser/ai_chat_storage_service.cc deleted file mode 100644 index c213e1e2a37c..000000000000 --- a/components/ai_chat/core/browser/ai_chat_storage_service.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright (c) 2023 The Brave Authors. All rights reserved. - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at https://mozilla.org/MPL/2.0/. */ - -#include "brave/components/ai_chat/core/browser/ai_chat_storage_service.h" - -#include -#include - -#include "base/task/thread_pool.h" -#include "content/public/browser/browser_context.h" - -namespace ai_chat { -namespace { -constexpr base::FilePath::StringPieceType kBaseDirName = - FILE_PATH_LITERAL("AIChat"); -} // namespace - -AIChatStorageService::AIChatStorageService(content::BrowserContext* context) - : base_dir_(context->GetPath().Append(kBaseDirName)) { - ai_chat_db_ = base::SequenceBound(GetTaskRunner()); - - auto on_response = base::BindOnce( - [](bool success) { DVLOG(1) << "AIChatDB Init: " << success; }); - - ai_chat_db_.AsyncCall(&AIChatDatabase::Init) - .WithArgs(base_dir_) - .Then(std::move(on_response)); -} - -AIChatStorageService::~AIChatStorageService() = default; - -base::SequencedTaskRunner* AIChatStorageService::GetTaskRunner() { - if (!task_runner_) { - task_runner_ = base::ThreadPool::CreateSequencedTaskRunner( - {base::MayBlock(), base::WithBaseSyncPrimitives(), - base::TaskPriority::BEST_EFFORT, - base::TaskShutdownBehavior::BLOCK_SHUTDOWN}); - } - - return task_runner_.get(); -} - -void AIChatStorageService::SyncConversation(mojom::ConversationPtr conversation, - ConversationCallback callback) { - auto on_added = [](mojom::ConversationPtr conversation, - ConversationCallback callback, int64_t id) { - conversation->id = id; - std::move(callback).Run(std::move(conversation)); - }; - - ai_chat_db_.AsyncCall(&AIChatDatabase::AddConversation) - .WithArgs(conversation.Clone()) - .Then(base::BindOnce(on_added, std::move(conversation), - std::move(callback))); -} - -void AIChatStorageService::SyncConversationTurn(int64_t conversation_id, - mojom::ConversationTurnPtr turn) { - NOTIMPLEMENTED(); -} - -void AIChatStorageService::GetConversationForGURL(const GURL& gurl, - ConversationCallback callback) { - auto on_get = [](const GURL& gurl, ConversationCallback callback, - std::vector conversations) { - auto it = std::find_if(conversations.begin(), conversations.end(), - [&gurl](const auto& conversation) { - return conversation->page_url == gurl.spec(); - }); - if (it != conversations.end()) { - std::move(callback).Run(std::move(*it)); - } else { - std::move(callback).Run(std::nullopt); - } - }; - - ai_chat_db_.AsyncCall(&AIChatDatabase::GetAllConversations) - .Then(base::BindOnce(on_get, gurl, std::move(callback))); -} - -void AIChatStorageService::Shutdown() { - task_runner_.reset(); -} - -} // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_storage_service.h b/components/ai_chat/core/browser/ai_chat_storage_service.h deleted file mode 100644 index 28e6ca464665..000000000000 --- a/components/ai_chat/core/browser/ai_chat_storage_service.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2023 The Brave Authors. All rights reserved. - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at https://mozilla.org/MPL/2.0/. */ - -#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_STORAGE_SERVICE_H_ -#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_STORAGE_SERVICE_H_ - -#include "base/files/file_path.h" -#include "base/memory/scoped_refptr.h" -#include "base/memory/weak_ptr.h" -#include "base/threading/sequence_bound.h" -#include "brave/components/ai_chat/core/browser/ai_chat_database.h" -#include "components/keyed_service/core/keyed_service.h" - -namespace base { -class SequencedTaskRunner; -} // namespace base - -namespace content { -class BrowserContext; -} // namespace content - -namespace ai_chat { -using ConversationCallback = base::OnceCallback conversation)>; -class AIChatStorageService : public KeyedService { - public: - explicit AIChatStorageService(content::BrowserContext* context); - ~AIChatStorageService() override; - AIChatStorageService(const AIChatStorageService&) = delete; - AIChatStorageService& operator=(const AIChatStorageService&) = delete; - - void SyncConversation(mojom::ConversationPtr conversation, - ConversationCallback callback); - void SyncConversationTurn(int64_t conversation_id, - mojom::ConversationTurnPtr turn); - void GetConversationForGURL(const GURL& gurl, ConversationCallback callback); - - private: - base::SequencedTaskRunner* GetTaskRunner(); - - // KeyedService overrides: - void Shutdown() override; - - const base::FilePath base_dir_; - - scoped_refptr task_runner_; - - base::SequenceBound ai_chat_db_; - - base::WeakPtrFactory weak_ptr_factory_{this}; -}; - -} // namespace ai_chat - -#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_STORAGE_SERVICE_H_ diff --git a/components/ai_chat/core/browser/conversation_handler.cc b/components/ai_chat/core/browser/conversation_handler.cc index 27e78a5bd7f5..bd3c3cf6b93f 100644 --- a/components/ai_chat/core/browser/conversation_handler.cc +++ b/components/ai_chat/core/browser/conversation_handler.cc @@ -10,12 +10,15 @@ #include #include +#include "base/containers/flat_tree.h" +#include "base/containers/span.h" #include "base/files/file_path.h" #include "base/memory/weak_ptr.h" #include "base/strings/utf_string_conversions.h" #include "base/task/sequenced_task_runner.h" #include "base/task/thread_pool.h" #include "base/types/expected.h" +#include "base/uuid.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" @@ -134,12 +137,28 @@ void AssociatedContentDelegate::OnTextEmbedderInitialized(bool initialized) { } ConversationHandler::ConversationHandler( - const mojom::Conversation* conversation, + mojom::Conversation* conversation, AIChatService* ai_chat_service, ModelService* model_service, AIChatCredentialManager* credential_manager, AIChatFeedbackAPI* feedback_api, scoped_refptr url_loader_factory) + : ConversationHandler(conversation, + ai_chat_service, + model_service, + credential_manager, + feedback_api, + url_loader_factory, + std::nullopt) {} + +ConversationHandler::ConversationHandler( + mojom::Conversation* conversation, + AIChatService* ai_chat_service, + ModelService* model_service, + AIChatCredentialManager* credential_manager, + AIChatFeedbackAPI* feedback_api, + scoped_refptr url_loader_factory, + std::optional initial_state) : metadata_(conversation), ai_chat_service_(ai_chat_service), model_service_(model_service), @@ -156,6 +175,25 @@ ConversationHandler::ConversationHandler( models_observer_.Observe(model_service_.get()); // TODO(petemill): differ based on premium status, if different ChangeModel(model_service->GetDefaultModelKey()); + + if (initial_state.has_value() && !initial_state.value()->entries.empty()) { + // We only support single associated content for now + mojom::ConversationArchivePtr conversation_data = + std::move(initial_state.value()); + if (!conversation_data->associated_content.empty()) { + CHECK(metadata_->associated_content->id.has_value()); + CHECK_EQ(conversation_data->associated_content[0]->content_id, + metadata_->associated_content->id.value()); + bool is_video = (metadata_->associated_content->content_type == + mojom::ContentType::VideoTranscript); + SetArchiveContent(conversation_data->associated_content[0]->content, + is_video); + } + DVLOG(1) << "Restoring associated content for conversation " + << metadata_->uuid << " with " + << conversation_data->entries.size(); + chat_history_ = std::move(conversation_data->entries); + } } ConversationHandler::~ConversationHandler() { @@ -203,6 +241,10 @@ bool ConversationHandler::HasAnyHistory() { }); } +bool ConversationHandler::IsRequestInProgress() { + return is_request_in_progress_; +} + void ConversationHandler::OnConversationDeleted() { for (auto& client : conversation_ui_handlers_) { client->OnConversationDeleted(); @@ -240,9 +282,7 @@ void ConversationHandler::InitEngine() { if (is_request_in_progress_) { // Pending requests have been deleted along with the model engine is_request_in_progress_ = false; - for (auto& client : conversation_ui_handlers_) { - client->OnAPIRequestInProgress(is_request_in_progress_); - } + OnAPIRequestInProgressChanged(); } // When the model changes, the content truncation might be different, @@ -262,22 +302,31 @@ void ConversationHandler::OnAssociatedContentDestroyed( // using our current cached content. associated_content_delegate_ = nullptr; if (!chat_history_.empty() && should_send_page_contents_ && - associated_content_info_ && associated_content_info_->url.has_value()) { + metadata_->associated_content && + metadata_->associated_content->url.has_value()) { // Get the latest version of article text and // associated_content_info_ if this chat has history and was connected to - // the associated conversation, then construct a "content archive" - // implementation of AssociatedContentDelegate with a duplicate of the - // article text. - auto archive_content = std::make_unique( - associated_content_info_->url.value_or(GURL()), last_text_content, - base::UTF8ToUTF16(associated_content_info_->title.value_or("")), - is_video); - associated_content_delegate_ = archive_content->GetWeakPtr(); - archive_content_ = std::move(archive_content); + // the associated conversation, then store the content so the conversation + // can continue. + SetArchiveContent(std::move(last_text_content), is_video); } OnAssociatedContentInfoChanged(); } +void ConversationHandler::SetArchiveContent(std::string text_content, + bool is_video) { + // Construct a "content archive" implementation of AssociatedContentDelegate + // with a duplicate of the article text. + auto archive_content = std::make_unique( + metadata_->associated_content->url.value_or(GURL()), + std::move(text_content), + base::UTF8ToUTF16(metadata_->associated_content->title.value_or("")), + is_video); + associated_content_delegate_ = archive_content->GetWeakPtr(); + archive_content_ = std::move(archive_content); + should_send_page_contents_ = true; +} + void ConversationHandler::SetAssociatedContentDelegate( base::WeakPtr delegate) { // If this conversation is allowed to fetch content, this is the delegate @@ -355,7 +404,7 @@ void ConversationHandler::GetState(GetStateCallback callback) { mojom::ConversationStatePtr state = mojom::ConversationState::New( metadata_->uuid, is_request_in_progress_, std::move(models_copy), model_key, suggestions_, suggestion_generation_status_, - associated_content_info_->Clone(), should_send_page_contents_, + metadata_->associated_content->Clone(), should_send_page_contents_, current_error_); std::move(callback).Run(std::move(state)); @@ -468,7 +517,7 @@ void ConversationHandler::SubmitHumanConversationEntry( << "than a single human conversation turn at a time."; mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED, + std::nullopt, CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED, mojom::ConversationTurnVisibility::VISIBLE, input, std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false); SubmitHumanConversationEntry(std::move(turn)); @@ -612,6 +661,7 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index, } auto edited_turn = mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), turn->character_type, turn->action_type, turn->visibility, trimmed_input, std::nullopt /* selected_text */, std::move(events), base::Time::Now(), std::nullopt /* edits */, false); @@ -624,7 +674,9 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index, } turn->edits->emplace_back(std::move(edited_turn)); - OnHistoryUpdate(); + OnConversationEntryRemoved(turn->uuid); + OnConversationEntryAdded(turn); + return; } @@ -643,18 +695,26 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index, // editable human turns in our current implementation, just use std::nullopt // here directly to be more explicit and avoid confusion. auto edited_turn = mojom::ConversationTurn::New( - turn->character_type, turn->action_type, turn->visibility, - sanitized_input, std::nullopt /* selected_text */, - std::nullopt /* events */, base::Time::Now(), std::nullopt /* edits */, - false); + base::Uuid::GenerateRandomV4().AsLowercaseString(), turn->character_type, + turn->action_type, turn->visibility, sanitized_input, + std::nullopt /* selected_text */, std::nullopt /* events */, + base::Time::Now(), std::nullopt /* edits */, false); if (!turn->edits) { turn->edits.emplace(); } + // Erase all turns after the edited turn and notify observers + std::vector> erased_turn_ids; + base::ranges::transform( + chat_history_.begin() + turn_index, chat_history_.end(), + std::back_inserter(erased_turn_ids), + [](mojom::ConversationTurnPtr& turn) { return turn->uuid; }); turn->edits->emplace_back(std::move(edited_turn)); - auto new_turn = std::move(chat_history_.at(turn_index)); chat_history_.erase(chat_history_.begin() + turn_index, chat_history_.end()); - OnHistoryUpdate(); + + for (auto& uuid : erased_turn_ids) { + OnConversationEntryRemoved(uuid); + } SubmitHumanConversationEntry(std::move(new_turn)); } @@ -666,7 +726,7 @@ void ConversationHandler::SubmitSummarizationRequest() { << "This conversation request should send page contents"; mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_PAGE, + std::nullopt, CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_PAGE, mojom::ConversationTurnVisibility::VISIBLE, l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE), std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false); @@ -744,7 +804,7 @@ void ConversationHandler::PerformQuestionGeneration( void ConversationHandler::GetAssociatedContentInfo( GetAssociatedContentInfoCallback callback) { BuildAssociatedContentInfo(); - std::move(callback).Run(associated_content_info_->Clone(), + std::move(callback).Run(metadata_->associated_content->Clone(), should_send_page_contents_); } @@ -793,7 +853,7 @@ void ConversationHandler::ClearErrorAndGetFailedMessage( mojom::ConversationTurnPtr turn = std::move(*chat_history_.end()); chat_history_.erase(chat_history_.end()); - OnHistoryUpdate(); + OnConversationEntryRemoved(turn->uuid); std::move(callback).Run(std::move(turn)); } @@ -809,7 +869,7 @@ void ConversationHandler::SubmitSelectedTextWithQuestion( const std::string& question, mojom::ActionType action_type) { mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, action_type, + std::nullopt, CharacterType::HUMAN, action_type, mojom::ConversationTurnVisibility::VISIBLE, question, selected_text, std::nullopt, base::Time::Now(), std::nullopt, false); @@ -848,7 +908,7 @@ void ConversationHandler::AddSubmitSelectedTextError( } const std::string& question = GetActionTypeQuestion(action_type); mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, action_type, + std::nullopt, CharacterType::HUMAN, action_type, mojom::ConversationTurnVisibility::VISIBLE, question, selected_text, std::nullopt, base::Time::Now(), std::nullopt, false); AddToConversationHistory(std::move(turn)); @@ -873,10 +933,13 @@ void ConversationHandler::AddToConversationHistory( return; } + if (!turn->uuid.has_value()) { + turn->uuid = base::Uuid::GenerateRandomV4().AsLowercaseString(); + } + chat_history_.push_back(std::move(turn)); - OnHistoryUpdate(); - OnConversationEntriesChanged(); + OnConversationEntryAdded(chat_history_.back()); } void ConversationHandler::PerformAssistantGeneration( @@ -937,6 +1000,7 @@ void ConversationHandler::UpdateOrCreateLastAssistantEntry( if (chat_history_.empty() || chat_history_.back()->character_type != CharacterType::ASSISTANT) { mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, std::vector{}, base::Time::Now(), @@ -956,7 +1020,7 @@ void ConversationHandler::UpdateOrCreateLastAssistantEntry( } // Optimize by merging with previous completion events if delta updates - // are supported or otherwise replacing the previous event. + // are supported or otherwise replacing faddedthe previous event. if (entry->events->size() > 0) { auto& last_event = entry->events->back(); if (last_event->is_completion_event()) { @@ -1070,22 +1134,26 @@ void ConversationHandler::OnGetStagedEntriesFromContent( return; } - // Add the query & summary pairs to the conversation history and call - // OnHistoryUpdate to update UI. + // Add the query & summary pairs to the conversation history and notify + // observers. for (const auto& entry : *entries) { chat_history_.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, entry.query, std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); + OnConversationEntryAdded(chat_history_.back()); + std::vector events; events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New(entry.summary))); chat_history_.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, entry.summary, std::nullopt, std::move(events), base::Time::Now(), std::nullopt, true)); + OnConversationEntryAdded(chat_history_.back()); } - OnHistoryUpdate(); } void ConversationHandler::GeneratePageContent(GetPageContentCallback callback) { @@ -1100,21 +1168,34 @@ void ConversationHandler::GeneratePageContent(GetPageContentCallback callback) { DCHECK(ai_chat_service_->HasUserOptedIn()) << "UI shouldn't allow operations before user has accepted agreement"; + // Keep hold of the current content so we can check if it changed + is_content_different_ = false; + std::string current_content = + std::string(associated_content_delegate_->GetCachedTextContent()); associated_content_delegate_->GetContent( base::BindOnce(&ConversationHandler::OnGeneratePageContentComplete, - weak_ptr_factory_.GetWeakPtr(), std::move(callback))); + weak_ptr_factory_.GetWeakPtr(), std::move(callback), + std::move(current_content))); } void ConversationHandler::OnGeneratePageContentComplete( GetPageContentCallback callback, + std::string previous_content, std::string contents_text, bool is_video, std::string invalidation_token) { engine_->SanitizeInput(contents_text); + is_content_different_ = contents_text != previous_content; + + metadata_->associated_content->content_type = + is_video ? mojom::ContentType::VideoTranscript + : mojom::ContentType::PageContent; + std::move(callback).Run(contents_text, is_video, invalidation_token); - // Content-used percentage might have changed + // Content-used percentage and is_video might have changed in addition to + // content_type. OnAssociatedContentInfoChanged(); } @@ -1160,11 +1241,13 @@ void ConversationHandler::OnEngineCompletionComplete( UpdateOrCreateLastAssistantEntry( mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New(*result))); - OnConversationEntriesChanged(); + OnConversationEntryAdded(chat_history_.back()); } else { auto& last_entry = chat_history_.back(); if (last_entry->character_type != mojom::CharacterType::ASSISTANT) { SetAPIError(mojom::APIError::ConnectionIssue); + } else { + OnConversationEntryAdded(chat_history_.back()); } } MaybePopPendingRequests(); @@ -1255,20 +1338,53 @@ void ConversationHandler::OnHistoryUpdate() { for (auto& client : conversation_ui_handlers_) { client->OnConversationHistoryUpdate(); } - OnConversationEntriesChanged(); } -void ConversationHandler::OnConversationEntriesChanged() { +void ConversationHandler::OnConversationEntryRemoved( + std::optional entry_uuid) { + OnHistoryUpdate(); + if (!entry_uuid.has_value()) { + return; + } for (auto& observer : observers_) { - // TODO(petemill): only tell observers about complete turns. This is - // expensive to do for every event generated by in-progress turns, - // and consumers likely only need complete ones (e.g. database save). - std::vector history; - for (const auto& turn : chat_history_) { - history.emplace_back(turn->Clone()); + observer.OnConversationEntryRemoved(this, entry_uuid.value()); + } +} + +void ConversationHandler::OnConversationEntryAdded( + mojom::ConversationTurnPtr& entry) { + // Only notify about staged entries once we have the first staged entry + if (entry->from_brave_search_SERP) { + OnHistoryUpdate(); + return; + } + std::optional associated_content_value; + if (is_content_different_ && associated_content_delegate_) { + associated_content_value = + associated_content_delegate_->GetCachedTextContent(); + } + // If this is the first entry that isn't staged, notify about all previous + // staged entries + if (!entry->from_brave_search_SERP && + base::ranges::all_of(chat_history_, + [&entry](mojom::ConversationTurnPtr& history_entry) { + return history_entry == entry || + history_entry->from_brave_search_SERP; + })) { + // Notify every item in chat history + for (auto& observer : observers_) { + for (auto& history_entry : chat_history_) { + observer.OnConversationEntryAdded(this, history_entry, + associated_content_value); + } } - observer.OnConversationEntriesChanged(this, std::move(history)); + OnHistoryUpdate(); + return; } + for (auto& observer : observers_) { + observer.OnConversationEntryAdded(this, entry, associated_content_value); + } + OnHistoryUpdate(); } int ConversationHandler::GetContentUsedPercentage() { @@ -1299,31 +1415,30 @@ bool ConversationHandler::IsContentAssociationPossible() { } void ConversationHandler::BuildAssociatedContentInfo() { - // Save in class instance so that we have a cache for when live - // AssociatedContentDelegate disconnects. Only modify in this function. - associated_content_info_ = mojom::SiteInfo::New(); + // Only modify associated content metadata here if (associated_content_delegate_) { - associated_content_info_->title = + metadata_->associated_content->title = base::UTF16ToUTF8(associated_content_delegate_->GetTitle()); const GURL url = associated_content_delegate_->GetURL(); - if (url.SchemeIsHTTPOrHTTPS()) { - associated_content_info_->hostname = url.host(); - associated_content_info_->url = url; - } - associated_content_info_->content_used_percentage = + metadata_->associated_content->hostname = url.host(); + metadata_->associated_content->url = url; + metadata_->associated_content->content_used_percentage = GetContentUsedPercentage(); - associated_content_info_->is_content_refined = is_content_refined_; - associated_content_info_->is_content_association_possible = true; + metadata_->associated_content->is_content_refined = is_content_refined_; + metadata_->associated_content->is_content_association_possible = true; } else { - associated_content_info_->is_content_association_possible = false; + metadata_->associated_content->title = std::nullopt; + metadata_->associated_content->hostname = std::nullopt; + metadata_->associated_content->url = std::nullopt; + metadata_->associated_content->is_content_association_possible = false; } } void ConversationHandler::OnAssociatedContentInfoChanged() { BuildAssociatedContentInfo(); for (auto& client : conversation_ui_handlers_) { - client->OnAssociatedContentInfoChanged(associated_content_info_->Clone(), - should_send_page_contents_); + client->OnAssociatedContentInfoChanged( + metadata_->associated_content->Clone(), should_send_page_contents_); } } @@ -1364,6 +1479,9 @@ void ConversationHandler::OnAPIRequestInProgressChanged() { for (auto& client : conversation_ui_handlers_) { client->OnAPIRequestInProgress(is_request_in_progress_); } + for (auto& observer : observers_) { + observer.OnRequestInProgressChanged(this, is_request_in_progress_); + } } } // namespace ai_chat diff --git a/components/ai_chat/core/browser/conversation_handler.h b/components/ai_chat/core/browser/conversation_handler.h index 1566f260028f..625788f11745 100644 --- a/components/ai_chat/core/browser/conversation_handler.h +++ b/components/ai_chat/core/browser/conversation_handler.h @@ -21,6 +21,7 @@ #include "brave/components/ai_chat/core/browser/text_embedder.h" #include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "mojo/public/cpp/bindings/pending_receiver.h" #include "mojo/public/cpp/bindings/receiver_set.h" @@ -128,9 +129,16 @@ class ConversationHandler : public mojom::ConversationHandler, ~Observer() override {} // Called when the conversation history changess - virtual void OnConversationEntriesChanged( + virtual void OnRequestInProgressChanged(ConversationHandler* handler, + bool in_progress) {} + virtual void OnConversationEntryAdded( ConversationHandler* handler, - std::vector entries) {} + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value) {} + virtual void OnConversationEntryRemoved(ConversationHandler* handler, + std::string turn_uuid) {} + virtual void OnConversationEntryUpdated(ConversationHandler* handler, + mojom::ConversationTurnPtr entry) {} // Called when a mojo client connects or disconnects virtual void OnClientConnectionChanged(ConversationHandler* handler) {} @@ -139,13 +147,22 @@ class ConversationHandler : public mojom::ConversationHandler, }; ConversationHandler( - const mojom::Conversation* conversation, + mojom::Conversation* conversation, AIChatService* ai_chat_service, ModelService* model_service, AIChatCredentialManager* credential_manager, AIChatFeedbackAPI* feedback_api, scoped_refptr url_loader_factory); + ConversationHandler( + mojom::Conversation* conversation, + AIChatService* ai_chat_service, + ModelService* model_service, + AIChatCredentialManager* credential_manager, + AIChatFeedbackAPI* feedback_api, + scoped_refptr url_loader_factory, + std::optional initial_state); + ~ConversationHandler() override; ConversationHandler(const ConversationHandler&) = delete; ConversationHandler& operator=(const ConversationHandler&) = delete; @@ -159,7 +176,7 @@ class ConversationHandler : public mojom::ConversationHandler, bool IsAnyClientConnected(); bool HasAnyHistory(); - void OnConversationDeleted(); + bool IsRequestInProgress(); // Called when the associated content is destroyed or navigated away. If // it's a navigation, the AssociatedContentDelegate will set itself to a new @@ -234,7 +251,9 @@ class ConversationHandler : public mojom::ConversationHandler, void SetChatHistoryForTesting( std::vector history) { chat_history_ = std::move(history); - OnHistoryUpdate(); + for (auto& entry : chat_history_) { + OnConversationEntryAdded(entry); + } } AssociatedContentDelegate* GetAssociatedContentDelegateForTesting() { @@ -288,11 +307,11 @@ class ConversationHandler : public mojom::ConversationHandler, const std::optional>& entries); void GeneratePageContent(GetPageContentCallback callback); - void SetPageContent(std::string contents_text, - bool is_video, - std::string invalidation_token); + + void SetArchiveContent(std::string text_content, bool is_video); void OnGeneratePageContentComplete(GetPageContentCallback callback, + std::string previous_content, std::string contents_text, bool is_video, std::string invalidation_token); @@ -309,10 +328,12 @@ class ConversationHandler : public mojom::ConversationHandler, EngineConsumer::SuggestedQuestionResult result); void OnModelDataChanged(); + void OnConversationDeleted(); void OnHistoryUpdate(); + void OnConversationEntryAdded(mojom::ConversationTurnPtr& entry); + void OnConversationEntryRemoved(std::optional turn_id); void OnSuggestedQuestionsChanged(); void OnAssociatedContentInfoChanged(); - void OnConversationEntriesChanged(); void OnClientConnectionChanged(); void OnConversationTitleChanged(std::string title); void OnConversationUIConnectionChanged(mojo::RemoteSetElementId id); @@ -331,7 +352,6 @@ class ConversationHandler : public mojom::ConversationHandler, // Is a conversation engine request in progress (does not include // non-conversation engine requests. bool is_request_in_progress_ = false; - mojom::SiteInfoPtr associated_content_info_ = nullptr; // TODO(petemill): Tracking whether the UI is open // for a conversation might not be neccessary anymore as there @@ -356,6 +376,9 @@ class ConversationHandler : public mojom::ConversationHandler, // different pages. bool should_send_page_contents_ = false; bool is_content_refined_ = false; + // When this is true, the most recent content retrieval was different to the + // previous one. + bool is_content_different_ = false; bool is_print_preview_fallback_requested_ = false; @@ -363,7 +386,7 @@ class ConversationHandler : public mojom::ConversationHandler, mojom::APIError current_error_ = mojom::APIError::None; // Data store UUID for conversation - raw_ptr metadata_; + raw_ptr metadata_; raw_ptr ai_chat_service_; raw_ptr model_service_; raw_ptr credential_manager_; diff --git a/components/ai_chat/core/browser/conversation_handler_unittest.cc b/components/ai_chat/core/browser/conversation_handler_unittest.cc index c35f6c64d572..d942ecd5037a 100644 --- a/components/ai_chat/core/browser/conversation_handler_unittest.cc +++ b/components/ai_chat/core/browser/conversation_handler_unittest.cc @@ -12,6 +12,7 @@ #include #include +#include "base/files/scoped_temp_dir.h" #include "base/functional/bind.h" #include "base/functional/callback.h" #include "base/functional/overloaded.h" @@ -29,14 +30,19 @@ #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/engine/mock_engine_consumer.h" +#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h" +#include "brave/components/ai_chat/core/browser/test_utils.h" #include "brave/components/ai_chat/core/browser/text_embedder.h" #include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "components/grit/brave_components_strings.h" +#include "components/os_crypt/async/browser/os_crypt_async.h" +#include "components/os_crypt/async/browser/test_utils.h" #include "components/sync_preferences/testing_pref_service_syncable.h" #include "mojo/public/cpp/bindings/receiver.h" #include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h" @@ -59,17 +65,6 @@ namespace ai_chat { namespace { -bool CompareConversationTurn(const mojom::ConversationTurnPtr& a, - const mojom::ConversationTurnPtr& b) { - if (!a || !b) { - return a == b; // Both should be null or neither - } - return a->action_type == b->action_type && - a->character_type == b->character_type && - a->selected_text == b->selected_text && a->text == b->text && - a->visibility == b->visibility; -} - class MockAIChatCredentialManager : public AIChatCredentialManager { public: using AIChatCredentialManager::AIChatCredentialManager; @@ -117,6 +112,8 @@ class MockConversationHandlerClient : public mojom::ConversationUI { MOCK_METHOD(void, OnFaviconImageDataChanged, (), (override)); + MOCK_METHOD(void, OnConversationDeleted, (), (override)); + private: mojo::Receiver conversation_ui_receiver_{this}; mojo::Remote conversation_handler_; @@ -174,10 +171,14 @@ class MockTextEmbedder : public TextEmbedder { class ConversationHandlerUnitTest : public testing::Test { public: void SetUp() override { + ASSERT_TRUE(temp_directory_.CreateUniqueTempDir()); prefs::RegisterProfilePrefs(prefs_.registry()); prefs::RegisterLocalStatePrefs(local_state_.registry()); ModelService::RegisterProfilePrefs(prefs_.registry()); + os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting( + /*is_sync_for_unittests=*/true); + shared_url_loader_factory_ = base::MakeRefCounted( &url_loader_factory_); @@ -196,9 +197,14 @@ class ConversationHandlerUnitTest : public testing::Test { ai_chat_service_ = std::make_unique( model_service_.get(), std::move(credential_manager), &prefs_, nullptr, - shared_url_loader_factory_, ""); + os_crypt_.get(), shared_url_loader_factory_, "", + temp_directory_.GetPath()); - conversation_ = mojom::Conversation::New("uuid", "title", false); + mojom::SiteInfoPtr non_content = mojom::SiteInfo::New( + std::nullopt, mojom::ContentType::PageContent, std::nullopt, + std::nullopt, std::nullopt, 0, false, false); + conversation_ = mojom::Conversation::New("uuid", "title", base::Time::Now(), + false, std::move(non_content)); conversation_handler_ = std::make_unique( conversation_.get(), ai_chat_service_.get(), model_service_.get(), @@ -264,6 +270,7 @@ class ConversationHandlerUnitTest : public testing::Test { std::unique_ptr model_service_; sync_preferences::TestingPrefServiceSyncable prefs_; sync_preferences::TestingPrefServiceSyncable local_state_; + std::unique_ptr os_crypt_; network::TestURLLoaderFactory url_loader_factory_; scoped_refptr shared_url_loader_factory_; data_decoder::test::InProcessDataDecoder in_process_data_decoder_; @@ -272,6 +279,9 @@ class ConversationHandlerUnitTest : public testing::Test { std::unique_ptr> associated_content_; bool is_opted_in_ = true; bool has_associated_content_ = true; + + private: + base::ScopedTempDir temp_directory_; }; class ConversationHandlerUnitTest_OptedOut @@ -342,6 +352,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { std::string selected_text = "I have spoken."; std::string expected_turn_text = l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT); + const std::string expected_response = "This is the way."; // Expect the ConversationHandler to call the engine with the selected text // and the action's expanded text. @@ -353,7 +364,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { .WillOnce(::testing::DoAll( base::test::RunOnceCallback<4>( mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("This is the way."))), + mojom::CompletionEvent::New(expected_response))), base::test::RunOnceCallback<5>(base::ok("")))); EXPECT_FALSE(conversation_handler_->HasAnyHistory()); @@ -374,8 +385,8 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { NiceMock client(conversation_handler_.get()); EXPECT_CALL(client, OnAPIRequestInProgress(true)).Times(1); - // Human and AI entries. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + // Human, AI entries and content event for AI response. + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(3); // Fired from OnEngineCompletionComplete. EXPECT_CALL(client, OnAPIRequestInProgress(false)).Times(1); // Ensure everything is sanitized @@ -383,7 +394,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { EXPECT_CALL(*engine, SanitizeInput(StrEq(expected_turn_text))); conversation_handler_->SubmitSelectedText( - "I have spoken.", mojom::ActionType::SUMMARIZE_SELECTED_TEXT); + selected_text, mojom::ActionType::SUMMARIZE_SELECTED_TEXT); task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(&client); @@ -405,19 +416,22 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { EXPECT_TRUE(conversation_handler_->HasAnyHistory()); const auto& history = conversation_handler_->GetConversationHistory(); std::vector expected_history; + expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, - mojom::ConversationTurnVisibility::VISIBLE, - l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT), - "I have spoken.", std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + mojom::ConversationTurnVisibility::VISIBLE, expected_turn_text, + selected_text, std::nullopt, base::Time::Now(), std::nullopt, false)); + + std::vector response_events; + response_events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(expected_response))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "This is the way.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); - EXPECT_EQ(history.size(), expected_history.size()); - for (size_t i = 0; i < history.size(); i++) { - EXPECT_TRUE(CompareConversationTurn(history[i], expected_history[i])); - } + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + expected_response, std::nullopt, std::move(response_events), + base::Time::Now(), std::nullopt, false)); + ExpectConversationHistoryEquals(FROM_HERE, history, expected_history, false); } TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { @@ -431,6 +445,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { std::string selected_text = "I have spoken again."; std::string expected_turn_text = l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT); + std::string expected_response = "This is the way."; EXPECT_CALL(*engine, GenerateAssistantResponse(false, StrEq(page_content), LastTurnHasSelectedText(selected_text), @@ -439,7 +454,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { .WillOnce(::testing::DoAll( base::test::RunOnceCallback<4>( mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("This is the way."))), + mojom::CompletionEvent::New(expected_response))), base::test::RunOnceCallback<5>(base::ok("")))); ON_CALL(*associated_content_, GetURL) @@ -456,8 +471,8 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { NiceMock client(conversation_handler_.get()); EXPECT_CALL(client, OnAPIRequestInProgress(true)).Times(1); - // Human and AI entries. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + // Human and AI entries, and content event for AI response. + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(3); // Fired from OnEngineCompletionComplete. EXPECT_CALL(client, OnAPIRequestInProgress(false)).Times(1); // Ensure everything is sanitized @@ -468,7 +483,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { EXPECT_CALL(*engine, GenerateQuestionSuggestions).Times(0); conversation_handler_->SubmitSelectedText( - "I have spoken again.", mojom::ActionType::SUMMARIZE_SELECTED_TEXT); + selected_text, mojom::ActionType::SUMMARIZE_SELECTED_TEXT); task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(&client); @@ -493,18 +508,21 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { const auto& history2 = conversation_handler_->GetConversationHistory(); std::vector expected_history2; expected_history2.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, expected_turn_text, - "I have spoken again.", std::nullopt, base::Time::Now(), std::nullopt, - false)); + selected_text, std::nullopt, base::Time::Now(), std::nullopt, false)); + + std::vector response_events; + response_events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(expected_response))); expected_history2.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "This is the way.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); - EXPECT_EQ(history2.size(), expected_history2.size()); - for (size_t i = 0; i < history2.size(); i++) { - EXPECT_TRUE(CompareConversationTurn(history2[i], expected_history2[i])); - } + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + expected_response, std::nullopt, std::move(response_events), + base::Time::Now(), std::nullopt, false)); + ExpectConversationHistoryEquals(FROM_HERE, history2, expected_history2, + false); } TEST_F(ConversationHandlerUnitTest, UpdateOrCreateLastAssistantEntry_Delta) { @@ -751,111 +769,119 @@ TEST_F(ConversationHandlerUnitTest, ModifyConversation) { MockEngineConsumer* engine = static_cast( conversation_handler_->GetEngineForTesting()); - // Setup history for testing. - auto created_time1 = base::Time::Now(); - std::vector history; - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, - mojom::ConversationTurnVisibility::VISIBLE, "prompt1", std::nullopt, - std::nullopt, created_time1, std::nullopt, false)); - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "answer1", std::nullopt, - std::nullopt, base::Time::Now(), std::nullopt, false)); - conversation_handler_->SetChatHistoryForTesting(std::move(history)); - + // Setup history for testing. Items have IDs so we can test removal + // notifications to an observer. + std::vector history = CreateSampleChatHistory(1); + EXPECT_FALSE(history[0]->edits); + conversation_handler_->SetChatHistoryForTesting(CloneHistory(history)); + mojom::ConversationEntryEventPtr expected_new_completion_event = + mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New("new answer")); // Modify an entry for the first time. EXPECT_CALL(*engine, GenerateAssistantResponse(false, StrEq(""), LastTurnHasText("prompt2"), StrEq("prompt2"), _, _)) // Mock the response from the engine - .WillOnce(::testing::DoAll( - base::test::RunOnceCallback<4>( - mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("new answer"))), - base::test::RunOnceCallback<5>(base::ok("")))); + .WillOnce(::testing::DoAll(base::test::RunOnceCallback<4>( + expected_new_completion_event->Clone()), + base::test::RunOnceCallback<5>(base::ok("")))); + testing::NiceMock observer; + // Verify both entries are removed + EXPECT_CALL(observer, OnConversationEntryRemoved(conversation_handler_.get(), + history[0]->uuid.value())) + .Times(1); + EXPECT_CALL(observer, OnConversationEntryRemoved(conversation_handler_.get(), + history[1]->uuid.value())) + .Times(1); + // Verify edited entry is added as well as the new response + EXPECT_CALL(observer, + OnConversationEntryAdded(conversation_handler_.get(), _, _)) + .Times(2); + observer.Observe(conversation_handler_.get()); + + // Make a first edit conversation_handler_->ModifyConversation(0, "prompt2"); + testing::Mock::VerifyAndClearExpectations(&observer); + + // Create the entries events in the way we're expecting to look + // post-modification. + auto first_edit_expected_history = CloneHistory(history); + auto first_edit = history[0]->Clone(); + first_edit->uuid = "ignore_me"; + first_edit->selected_text = std::nullopt; + first_edit->text = "prompt2"; + + first_edit_expected_history[0]->edits.emplace(); + first_edit_expected_history[0]->edits->push_back(first_edit->Clone()); + + first_edit_expected_history[1]->text = "new answer"; + first_edit_expected_history[1]->events.emplace(); + first_edit_expected_history[1]->events->emplace_back( + expected_new_completion_event->Clone()); + + // Verify the first entry still has original details const auto& conversation_history = conversation_handler_->GetConversationHistory(); - ASSERT_EQ(conversation_history.size(), 2u); - EXPECT_EQ(conversation_history[0]->text, "prompt1"); - EXPECT_EQ(conversation_history[0]->created_time, created_time1); - ASSERT_TRUE(conversation_history[0]->edits); - ASSERT_EQ(conversation_history[0]->edits->size(), 1u); - EXPECT_EQ(conversation_history[0]->edits->at(0)->text, "prompt2"); - EXPECT_NE(conversation_history[0]->edits->at(0)->created_time, created_time1); - EXPECT_FALSE(conversation_history[0]->edits->at(0)->edits); - - EXPECT_EQ(conversation_history[1]->text, "new answer"); + ExpectConversationHistoryEquals(FROM_HERE, conversation_history, + first_edit_expected_history, false); + // Create time shouldn't be changed + EXPECT_EQ(conversation_history[0]->created_time, history[0]->created_time); auto created_time2 = conversation_history[0]->edits->at(0)->created_time; + // New edit should have a different created time + EXPECT_NE(created_time2, history[0]->created_time); // Modify the same entry again. EXPECT_CALL(*engine, GenerateAssistantResponse(false, StrEq(""), LastTurnHasText("prompt3"), StrEq("prompt3"), _, _)) // Mock the response from the engine - .WillOnce(::testing::DoAll( - base::test::RunOnceCallback<4>( - mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("new answer"))), - base::test::RunOnceCallback<5>(base::ok("")))); + .WillOnce(::testing::DoAll(base::test::RunOnceCallback<4>( + expected_new_completion_event->Clone()), + base::test::RunOnceCallback<5>(base::ok("")))); conversation_handler_->ModifyConversation(0, "prompt3"); - ASSERT_EQ(conversation_history.size(), 2u); - EXPECT_EQ(conversation_history[0]->text, "prompt1"); - EXPECT_EQ(conversation_history[0]->created_time, created_time1); - ASSERT_TRUE(conversation_history[0]->edits); - ASSERT_EQ(conversation_history[0]->edits->size(), 2u); - EXPECT_EQ(conversation_history[0]->edits->at(0)->text, "prompt2"); - EXPECT_EQ(conversation_history[0]->edits->at(0)->created_time, created_time2); - EXPECT_FALSE(conversation_history[0]->edits->at(0)->edits); + auto second_edit_expected_history = CloneHistory(first_edit_expected_history); + auto second_edit = first_edit->Clone(); + second_edit->text = "prompt3"; + second_edit_expected_history[0]->edits->emplace_back(second_edit->Clone()); - EXPECT_EQ(conversation_history[0]->edits->at(1)->text, "prompt3"); - EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, created_time1); + ExpectConversationHistoryEquals(FROM_HERE, conversation_history, + second_edit_expected_history, false); + // Create time shouldn't be changed + EXPECT_EQ(conversation_history[0]->created_time, history[0]->created_time); + // New edit should have a different create time + EXPECT_EQ(conversation_history[0]->edits->at(0)->created_time, created_time2); + EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, + conversation_history[0]->created_time); EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, created_time2); - EXPECT_FALSE(conversation_history[0]->edits->at(1)->edits); - EXPECT_EQ(conversation_history[1]->text, "new answer"); - - // Modify server response should have text and completion event updated in + // Modifying server response should have text and completion event updated in // the entry of edits. // Engine should not be called for an assistant edit EXPECT_CALL(*engine, GenerateAssistantResponse(_, _, _, _, _, _)).Times(0); conversation_handler_->ModifyConversation(1, " answer2 "); - ASSERT_EQ(conversation_history.size(), 2u); - EXPECT_EQ(conversation_history[1]->text, "new answer"); - ASSERT_TRUE(conversation_history[1]->edits); - ASSERT_EQ(conversation_history[1]->edits->size(), 1u); - EXPECT_EQ(conversation_history[1]->edits->at(0)->text, "answer2"); + auto third_edit_expected_history = CloneHistory(second_edit_expected_history); + + auto response_edit = third_edit_expected_history[1]->Clone(); + response_edit->uuid = "ignore_me"; + response_edit->text = "answer2"; // trimmed + response_edit->events->at(0) = + mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New("answer2")); + + third_edit_expected_history[1]->edits.emplace(); + third_edit_expected_history[1]->edits->emplace_back(response_edit->Clone()); + + ExpectConversationHistoryEquals(FROM_HERE, conversation_history, + third_edit_expected_history, false); + + // Edit time should be set differently EXPECT_NE(conversation_history[1]->edits->at(0)->created_time, conversation_history[1]->created_time); - - ASSERT_TRUE(conversation_history[1]->events); - ASSERT_EQ(conversation_history[1]->events->size(), 1u); - // Verify the original is left unchanged - ASSERT_TRUE(conversation_history[1]->events->at(0)->is_completion_event()); - EXPECT_EQ(conversation_history[1] - ->events->at(0) - ->get_completion_event() - ->completion, - "new answer"); - - ASSERT_TRUE(conversation_history[1]->edits->at(0)->events); - ASSERT_EQ(conversation_history[1]->edits->at(0)->events->size(), 1u); - ASSERT_TRUE(conversation_history[1] - ->edits->at(0) - ->events->at(0) - ->is_completion_event()); - EXPECT_EQ(conversation_history[1] - ->edits->at(0) - ->events->at(0) - ->get_completion_event() - ->completion, - "answer2"); } TEST_F(ConversationHandlerUnitTest, @@ -863,11 +889,20 @@ TEST_F(ConversationHandlerUnitTest, // Fetch with result should update the conversation history and call // OnConversationHistoryUpdate on observers. SetAssociatedContentStagedEntries(/*empty=*/false); + + // Shouldn't get any notification of real entries added + NiceMock observer; + observer.Observe(conversation_handler_.get()); + EXPECT_CALL(observer, OnConversationEntryAdded).Times(0); + // Client connecting will trigger content staging EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(1); NiceMock client(conversation_handler_.get()); - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(1); EXPECT_TRUE(conversation_handler_->IsAnyClientConnected()); + + // History update notification once for each entry + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + conversation_handler_->GetAssociatedContentInfo(base::BindLambdaForTesting( [&](mojom::SiteInfoPtr site_info, bool should_send_page_contents) { EXPECT_TRUE(should_send_page_contents); @@ -875,25 +910,28 @@ TEST_F(ConversationHandlerUnitTest, task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(associated_content_.get()); + testing::Mock::VerifyAndClearExpectations(&observer); testing::Mock::VerifyAndClearExpectations(&client); auto& history = conversation_handler_->GetConversationHistory(); std::vector expected_history; expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "query", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); std::vector events; events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("summary"))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "summary", std::nullopt, - std::move(events), base::Time::Now(), std::nullopt, true)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "summary", std::nullopt, std::move(events), base::Time::Now(), + std::nullopt, true)); ASSERT_EQ(history.size(), expected_history.size()); for (size_t i = 0; i < history.size(); i++) { expected_history[i]->created_time = history[i]->created_time; - EXPECT_EQ(history[i], expected_history[i]); + ExpectConversationEntryEquals(FROM_HERE, history[i], expected_history[i], + false); } // HasAnyHistory should still return false since all entries are staged EXPECT_FALSE(conversation_handler_->HasAnyHistory()); @@ -920,9 +958,12 @@ TEST_F(ConversationHandlerUnitTest, // OnConversationHistoryUpdate on observers. SetAssociatedContentStagedEntries(/*empty=*/false, /*multi=*/true); // Client connecting will trigger content staging + testing::NiceMock observer; + observer.Observe(conversation_handler_.get()); + EXPECT_CALL(observer, OnConversationEntryAdded).Times(0); EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(1); NiceMock client(conversation_handler_.get()); - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(1); + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(4); EXPECT_TRUE(conversation_handler_->IsAnyClientConnected()); conversation_handler_->GetAssociatedContentInfo(base::BindLambdaForTesting( [&](mojom::SiteInfoPtr site_info, bool should_send_page_contents) { @@ -931,56 +972,72 @@ TEST_F(ConversationHandlerUnitTest, task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(associated_content_.get()); + testing::Mock::VerifyAndClearExpectations(&observer); testing::Mock::VerifyAndClearExpectations(&client); auto& history = conversation_handler_->GetConversationHistory(); std::vector expected_history; expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "query", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); std::vector events; events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("summary"))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "summary", std::nullopt, - std::move(events), base::Time::Now(), std::nullopt, true)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "summary", std::nullopt, std::move(events), base::Time::Now(), + std::nullopt, true)); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "query2", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); std::vector events2; events2.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("summary2"))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "summary2", std::nullopt, - std::move(events2), base::Time::Now(), std::nullopt, true)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "summary2", std::nullopt, std::move(events2), base::Time::Now(), + std::nullopt, true)); ASSERT_EQ(history.size(), expected_history.size()); for (size_t i = 0; i < history.size(); i++) { expected_history[i]->created_time = history[i]->created_time; - EXPECT_EQ(history[i], expected_history[i]); + ExpectConversationEntryEquals(FROM_HERE, history[i], expected_history[i], + false); } // HasAnyHistory should still return false since all entries are staged EXPECT_FALSE(conversation_handler_->HasAnyHistory()); - // Verify turning off content association clears the conversation history. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(1); - // Shouldn't ask for staged entries if user doesn't want to be associated - // with content. This verifies that even with existing staged entries, - // MaybeFetchOrClearContentStagedConversation will always early return. - EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(0); + // Verify adding an actual conversation entry causes all entries to be + // notified and HasAnyHistory to return true. + // Modify an entry for the first time. + MockEngineConsumer* engine = static_cast( + conversation_handler_->GetEngineForTesting()); + EXPECT_CALL(*associated_content_, GetContent) + .WillOnce(base::test::RunOnceCallback<0>("page content", false, "")); + EXPECT_CALL(*engine, GenerateAssistantResponse) + // Mock the response from the engine + .WillOnce(::testing::DoAll( + base::test::RunOnceCallback<4>( + mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New("new answer"))), + base::test::RunOnceCallback<5>(base::ok("")))); - conversation_handler_->SetShouldSendPageContents(false); + EXPECT_CALL(observer, OnConversationEntryAdded).Times(6); + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(3); + + conversation_handler_->SubmitHumanConversationEntry("query3"); task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(&client); + testing::Mock::VerifyAndClearExpectations(&observer); testing::Mock::VerifyAndClearExpectations(associated_content_.get()); - EXPECT_TRUE(conversation_handler_->GetConversationHistory().empty()); + EXPECT_TRUE(conversation_handler_->HasAnyHistory()); } TEST_F(ConversationHandlerUnitTest, diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc index 9c3f354d3b73..8146bc91aef0 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc @@ -71,14 +71,16 @@ class EngineConsumerClaudeUnitTest : public testing::Test { TEST_F(EngineConsumerClaudeUnitTest, TestGenerateAssistantResponse) { std::vector history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, "Which show is this catchphrase from?", "I have spoken.", std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); auto* mock_remote_completion_client = GetMockRemoteCompletionClient(); std::string prompt_before_time_and_date = "\n\nHuman: Here is the text of a web page in tags:\n\nThis " @@ -236,10 +238,10 @@ TEST_F(EngineConsumerClaudeUnitTest, TestGenerateAssistantResponse) { // Test with page content refine event. { mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back( mojom::ConversationEntryEvent::NewPageContentRefineEvent( mojom::PageContentRefineEvent::New())); @@ -276,10 +278,10 @@ TEST_F(EngineConsumerClaudeUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(mock_remote_completion_client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc index fd74e5f5b108..61b20f86c7b1 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc @@ -211,16 +211,17 @@ TEST_F(EngineConsumerConversationAPIUnitTest, // selected text but with page association. EngineConsumer::ConversationHistory history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Which show is this catchphrase from?", "I have spoken.", std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::RESPONSE, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, "Is it related to a broader series?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -288,7 +289,7 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_ModifyReply) { // Tests events building from history with modified agent reply. EngineConsumer::ConversationHistory history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Which show is 'This is the way' from?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -309,18 +310,19 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_ModifyReply) { modified_events.push_back(modified_completion_event.Clone()); auto edit = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::move(modified_events), base::Time::Now(), std::nullopt, - false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::move(modified_events), + base::Time::Now(), std::nullopt, false); std::vector edits; edits.push_back(std::move(edit)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "Mandalorian.", std::nullopt, - std::move(events), base::Time::Now(), std::move(edits), false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "Mandalorian.", std::nullopt, std::move(events), base::Time::Now(), + std::move(edits), false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Is it related to a broader series?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -377,10 +379,10 @@ TEST_F(EngineConsumerConversationAPIUnitTest, history.push_back(std::move(turn)); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back( mojom::ConversationEntryEvent::NewPageContentRefineEvent( mojom::PageContentRefineEvent::New())); @@ -409,10 +411,10 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_EarlyReturn) { testing::Mock::VerifyAndClearExpectations(mock_api_client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc index 180b40fe8f05..7ac1d2e1d17b 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc @@ -69,14 +69,16 @@ class EngineConsumerLlamaUnitTest : public testing::Test { TEST_F(EngineConsumerLlamaUnitTest, TestGenerateAssistantResponse) { EngineConsumer::ConversationHistory history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, "Which show is this catchphrase from?", "This is the way.", std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); auto* mock_remote_completion_client = static_cast(engine_->GetAPIForTesting()); std::string prompt_before_time_and_date = @@ -239,10 +241,10 @@ TEST_F(EngineConsumerLlamaUnitTest, TestGenerateAssistantResponse) { // Test with page content refine event. { mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back( mojom::ConversationEntryEvent::NewPageContentRefineEvent( mojom::PageContentRefineEvent::New())); @@ -279,10 +281,10 @@ TEST_F(EngineConsumerLlamaUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(mock_remote_completion_client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc index 38a2466c8a36..7df2ca03f63c 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc @@ -203,14 +203,16 @@ TEST_F(EngineConsumerOAIUnitTest, TestGenerateAssistantResponse) { std::string assistant_input = "This is mandalorian."; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, human_input, selected_text, std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, assistant_input, std::nullopt, - std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + assistant_input, std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); std::string date_and_time_string = base::UTF16ToUTF8(TimeFormatFriendlyDateAndTime(base::Time::Now())); @@ -306,10 +308,10 @@ TEST_F(EngineConsumerOAIUnitTest, TestGenerateAssistantResponse) { // Test with page content refine event. { mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back( mojom::ConversationEntryEvent::NewPageContentRefineEvent( mojom::PageContentRefineEvent::New())); @@ -349,10 +351,10 @@ TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/test_utils.cc b/components/ai_chat/core/browser/engine/test_utils.cc index a1ace67692b0..e9fb2472bcd9 100644 --- a/components/ai_chat/core/browser/engine/test_utils.cc +++ b/components/ai_chat/core/browser/engine/test_utils.cc @@ -16,7 +16,7 @@ namespace ai_chat { std::vector GetHistoryWithModifiedReply() { std::vector history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Which show is 'This is the way' from?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -34,18 +34,19 @@ std::vector GetHistoryWithModifiedReply() { mojom::CompletionEvent::New("The Mandalorian"))); auto edit = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::move(modified_events), base::Time::Now(), std::nullopt, - false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::move(modified_events), + base::Time::Now(), std::nullopt, false); std::vector edits; edits.push_back(std::move(edit)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "Mandalorian.", std::nullopt, - std::move(events), base::Time::Now(), std::move(edits), false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "Mandalorian.", std::nullopt, std::move(events), base::Time::Now(), + std::move(edits), false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Is it related to a broader series?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); diff --git a/components/ai_chat/core/browser/mock_conversation_handler_observer.cc b/components/ai_chat/core/browser/mock_conversation_handler_observer.cc new file mode 100644 index 000000000000..2b0577906eea --- /dev/null +++ b/components/ai_chat/core/browser/mock_conversation_handler_observer.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h" + +namespace ai_chat { + +MockConversationHandlerObserver::MockConversationHandlerObserver() = default; +MockConversationHandlerObserver::~MockConversationHandlerObserver() = default; + +void MockConversationHandlerObserver::Observe( + ConversationHandler* conversation) { + conversation_observations_.AddObservation(conversation); +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/mock_conversation_handler_observer.h b/components/ai_chat/core/browser/mock_conversation_handler_observer.h new file mode 100644 index 000000000000..b944d74382f3 --- /dev/null +++ b/components/ai_chat/core/browser/mock_conversation_handler_observer.h @@ -0,0 +1,66 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MOCK_CONVERSATION_HANDLER_OBSERVER_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MOCK_CONVERSATION_HANDLER_OBSERVER_H_ + +#include +#include + +#include "base/scoped_multi_source_observation.h" +#include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace ai_chat { + +class MockConversationHandlerObserver : public ConversationHandler::Observer { + public: + MockConversationHandlerObserver(); + ~MockConversationHandlerObserver() override; + + void Observe(ConversationHandler* conversation); + + MOCK_METHOD(void, + OnRequestInProgressChanged, + (ConversationHandler * handler, bool in_progress), + (override)); + + MOCK_METHOD(void, + OnConversationEntryAdded, + (ConversationHandler * handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value), + (override)); + + MOCK_METHOD(void, + OnConversationEntryRemoved, + (ConversationHandler * handler, std::string turn_uuid), + (override)); + + MOCK_METHOD(void, + OnConversationEntryUpdated, + (ConversationHandler * handler, mojom::ConversationTurnPtr entry), + (override)); + + MOCK_METHOD(void, + OnClientConnectionChanged, + (ConversationHandler*), + (override)); + + MOCK_METHOD(void, + OnConversationTitleChanged, + (ConversationHandler * handler, std::string title), + (override)); + + private: + base::ScopedMultiSourceObservation + conversation_observations_{this}; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MOCK_CONVERSATION_HANDLER_OBSERVER_H_ diff --git a/components/ai_chat/core/browser/test_utils.cc b/components/ai_chat/core/browser/test_utils.cc new file mode 100644 index 000000000000..5480bf3ec9a3 --- /dev/null +++ b/components/ai_chat/core/browser/test_utils.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/core/browser/test_utils.h" + +#include +#include + +#include "base/strings/strcat.h" +#include "base/time/time.h" +#include "base/uuid.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace ai_chat { + +namespace { + +std::string MessageConversationEntryEvents( + const mojom::ConversationTurnPtr& entry) { + std::string message = "Entry has the following events:"; + if (!entry->events.has_value()) { + message = base::StrCat({message, "\nNo events"}); + return message; + } + for (const auto& event : entry->events.value()) { + switch (event->which()) { + case mojom::ConversationEntryEvent::Tag::kCompletionEvent: { + message = base::StrCat({message, "\n - completion: ", + event->get_completion_event()->completion}); + break; + } + case mojom::ConversationEntryEvent::Tag::kSearchQueriesEvent: { + message = base::StrCat({message, "\n - search event"}); + break; + } + case mojom::ConversationEntryEvent::Tag::kConversationTitleEvent: { + message = base::StrCat({message, "\n - title: ", + event->get_conversation_title_event()->title}); + break; + } + case mojom::ConversationEntryEvent::Tag::kPageContentRefineEvent: { + message = base::StrCat({message, "\n - content refine event"}); + break; + } + default: + message = base::StrCat({message, "\n - unknown event"}); + } + } + return message; +} + +} // namespace + +void ExpectConversationEquals(base::Location location, + const mojom::ConversationPtr& a, + const mojom::ConversationPtr& b) { + SCOPED_TRACE(testing::Message() << location.ToString()); + if (!a || !b) { + EXPECT_EQ(a, b); // Both should be null or neither + return; + } + EXPECT_EQ(a->uuid, b->uuid); + EXPECT_EQ(a->title, b->title); + EXPECT_EQ(a->has_content, b->has_content); + + // associated content + ExpectAssociatedContentEquals(FROM_HERE, a->associated_content, + b->associated_content); +} + +void ExpectAssociatedContentEquals(base::Location location, + const mojom::SiteInfoPtr& a, + const mojom::SiteInfoPtr& b) { + SCOPED_TRACE(testing::Message() << location.ToString()); + if (!a || !b) { + EXPECT_EQ(a, b); // Both should be null or neither + return; + } + EXPECT_EQ(a->title, b->title); + EXPECT_EQ(a->url, b->url); + EXPECT_EQ(a->content_type, b->content_type); + EXPECT_EQ(a->content_used_percentage, b->content_used_percentage); + EXPECT_EQ(a->is_content_refined, b->is_content_refined); + EXPECT_EQ(a->is_content_association_possible, + b->is_content_association_possible); +} + +void ExpectConversationHistoryEquals( + base::Location location, + const std::vector& a, + const std::vector& b, + bool compare_uuid) { + SCOPED_TRACE(testing::Message() << location.ToString()); + EXPECT_EQ(a.size(), b.size()); + for (auto i = 0u; i < a.size(); i++) { + SCOPED_TRACE(testing::Message() << "Comparing entries at index " << i); + ExpectConversationEntryEquals(FROM_HERE, a.at(i), b.at(i), compare_uuid); + } +} + +void ExpectConversationEntryEquals(base::Location location, + const mojom::ConversationTurnPtr& a, + const mojom::ConversationTurnPtr& b, + bool compare_uuid) { + SCOPED_TRACE(testing::Message() << location.ToString()); + if (!a || !b) { + EXPECT_EQ(a, b); // Both should be null or neither + return; + } + + if (compare_uuid) { + EXPECT_EQ(a->uuid.value_or("default"), b->uuid.value_or("default")); + } + + EXPECT_EQ(a->action_type, b->action_type); + EXPECT_EQ(a->character_type, b->character_type); + EXPECT_EQ(a->selected_text, b->selected_text); + EXPECT_EQ(a->text, b->text); + EXPECT_EQ(a->visibility, b->visibility); + + // compare events + EXPECT_EQ(a->events.has_value(), b->events.has_value()); + if (a->events.has_value()) { + EXPECT_EQ(a->events->size(), b->events->size()) + << "\nEvents for a. " << MessageConversationEntryEvents(a) + << "\nEvents for b. " << MessageConversationEntryEvents(b); + for (auto i = 0u; i < a->events->size(); i++) { + SCOPED_TRACE(testing::Message() << "Comparing events at index " << i); + auto& a_event = a->events->at(i); + auto& b_event = b->events->at(i); + EXPECT_EQ(a_event->which(), b_event->which()); + switch (a_event->which()) { + case mojom::ConversationEntryEvent::Tag::kCompletionEvent: { + EXPECT_EQ(a_event->get_completion_event()->completion, + b_event->get_completion_event()->completion); + break; + } + case mojom::ConversationEntryEvent::Tag::kSearchQueriesEvent: { + EXPECT_EQ(a_event->get_search_queries_event()->search_queries, + b_event->get_search_queries_event()->search_queries); + break; + } + default: + NOTREACHED() + << "Unexpected event type for comparison. Only know about " + "event types which are not discarded."; + } + } + } + + // compare edits + EXPECT_EQ(a->edits.has_value(), b->edits.has_value()); + if (a->edits.has_value()) { + EXPECT_EQ(a->edits->size(), b->edits->size()); + for (auto i = 0u; i < a->edits->size(); i++) { + SCOPED_TRACE(testing::Message() << "Comparing edits at index " << i); + auto& a_edit = a->edits->at(i); + auto& b_edit = b->edits->at(i); + ExpectConversationEntryEquals(FROM_HERE, a_edit, b_edit, compare_uuid); + } + } +} + +mojom::Conversation* GetConversation( + base::Location location, + const std::vector& conversations, + std::string uuid) { + SCOPED_TRACE(testing::Message() << location.ToString()); + auto it = std::find_if(conversations.begin(), conversations.end(), + [&uuid](const mojom::ConversationPtr& conversation) { + return conversation->uuid == uuid; + }); + EXPECT_NE(it, conversations.end()); + return it->get(); +} + +std::vector CreateSampleChatHistory( + size_t num_query_pairs, + uint32_t future_hours) { + std::vector history; + base::Time now = base::Time::Now(); + for (size_t i = 0; i < num_query_pairs; i++) { + // query + history.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + mojom::ConversationTurnVisibility::VISIBLE, + base::StrCat({"query", base::NumberToString(i)}), std::nullopt, + std::nullopt, now + base::Seconds(i * 60) + base::Hours(future_hours), + std::nullopt, false)); + // response + std::vector events; + events.emplace_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(base::StrCat( + {"This is a generated response ", base::NumberToString(i)})))); + events.emplace_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(base::StrCat( + {"and this is more response", base::NumberToString(i)})))); + events.emplace_back(mojom::ConversationEntryEvent::NewSearchQueriesEvent( + mojom::SearchQueriesEvent::New(std::vector{ + base::StrCat({"Something to search for", base::NumberToString(i)}), + base::StrCat({"Another search query", base::NumberToString(i)})}))); + history.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, + mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, + std::move(events), + now + base::Seconds((i * 60) + 30) + base::Hours(future_hours), + std::nullopt, false)); + } + return history; +} + +std::vector CloneHistory( + std::vector& history) { + std::vector cloned_history; + for (const auto& turn : history) { + cloned_history.push_back(turn->Clone()); + } + return cloned_history; +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/test_utils.h b/components/ai_chat/core/browser/test_utils.h new file mode 100644 index 000000000000..3126def1fb78 --- /dev/null +++ b/components/ai_chat/core/browser/test_utils.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEST_UTILS_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEST_UTILS_H_ + +#include +#include + +#include "base/location.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" + +namespace ai_chat { + +void ExpectConversationEquals(base::Location location, + const mojom::ConversationPtr& a, + const mojom::ConversationPtr& b); + +void ExpectAssociatedContentEquals(base::Location location, + const mojom::SiteInfoPtr& a, + const mojom::SiteInfoPtr& b); + +void ExpectConversationEntryEquals(base::Location location, + const mojom::ConversationTurnPtr& a, + const mojom::ConversationTurnPtr& b, + bool compare_uuid = true); + +void ExpectConversationHistoryEquals( + base::Location location, + const std::vector& a, + const std::vector& b, + bool compare_uuid = true); + +mojom::Conversation* GetConversation( + base::Location location, + const std::vector& conversations, + std::string uuid); + +std::vector CreateSampleChatHistory( + size_t num_query_pairs, + uint32_t future_hours = 0); + +std::vector CloneHistory( + std::vector& history); + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEST_UTILS_H_ diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index 9126e17d0a7c..3ce48e5718a3 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -70,19 +70,57 @@ enum SuggestionGenerationStatus { HasGenerated, }; +// Type of content that is extracted +enum ContentType { + PageContent, + VideoTranscript, +}; + +// A piece of content associated with a conversation. +// TODO(petemill): Rename to AssociatedContent and have hostname/url/title be +// part of a detail property (which would have different structure possibilities +// depending on content type. struct SiteInfo { - // Title is present if the site has associated page content + int32? id; + ContentType content_type; + // Web page specific fields, if available and allowed string? title; - // Indicates if the URL scheme is of a type that allows for content association. - bool is_content_association_possible; - // Current tab's url and hostname, if http(s) string? hostname; url.mojom.Url? url; + // Percentage of the content that has been utilized by remote LLM (0-100) int32 content_used_percentage; // Indicates content has been refined through extracting most relevant parts // from long page content bool is_content_refined; + + // Indicates if the URL scheme is of a type that allows for content association. + // TODO(petemill): SiteInfo should just not exist if content-associated is + // not possible. + bool is_content_association_possible; +}; + +struct Conversation { + string uuid; + // Set by the LLM or the user + string title; + // First few characters of content, used if title isn't set + // or where more context is needed. + mojo_base.mojom.Time updated_time; + // If there are entries and the conversation should be selectable + bool has_content; + + SiteInfo associated_content; +}; + +struct ContentArchive { + int32 content_id; + string content; +}; + +struct ConversationArchive { + array entries; + array associated_content; }; enum ActionType { @@ -149,6 +187,9 @@ union ConversationEntryEvent { // The selected_text attribute contains what user selects in the page when // calling from the context menu. struct ConversationTurn { + // Populated if owned by a conversation + string? uuid; + CharacterType character_type; ActionType action_type; ConversationTurnVisibility visibility; @@ -240,18 +281,6 @@ struct ActionGroup { array entries; }; -struct Conversation { - string uuid; - // Set by the LLM or the user - string title; - // First few characters of content, used if title isn't set - // or where more context is needed. - string summary; - mojo_base.mojom.Time created_time; - // If there are entries and the conversation should be selectable - bool has_content; -}; - interface Service { // User opts-in to the feature at a profile level MarkAgreementAccepted(); @@ -437,28 +466,3 @@ interface ConversationUI { interface IAPSubscription { GetPurchaseTokenOrderId() => (string token, string order_id); }; - -// ids below are generated by sqlite -struct Conversation { - int64 id; - mojo_base.mojom.Time date; - string title; - url.mojom.Url? page_url; - string? page_contents; -}; - -struct ConversationEntry { - int64 id; - mojo_base.mojom.Time date; - CharacterType character_type; - ActionType action_type; - string? selected_text; - array texts; - array? events; -}; - -struct ConversationEntryText { - int64 id; - mojo_base.mojom.Time date; - string text; -}; diff --git a/components/ai_chat/resources/page/components/header/index.tsx b/components/ai_chat/resources/page/components/header/index.tsx index 6a97e5763251..778ad603b7f4 100644 --- a/components/ai_chat/resources/page/components/header/index.tsx +++ b/components/ai_chat/resources/page/components/header/index.tsx @@ -23,7 +23,6 @@ const Logo = ({ isPremium }: { isPremium: boolean }) =>
const getTitle = (activeConversation?: Conversation) => activeConversation?.title - || activeConversation?.summary || getLocale('conversationListUntitled') diff --git a/components/ai_chat/resources/page/components/sidebar_nav/index.tsx b/components/ai_chat/resources/page/components/sidebar_nav/index.tsx index 198b360325b6..63490436472d 100644 --- a/components/ai_chat/resources/page/components/sidebar_nav/index.tsx +++ b/components/ai_chat/resources/page/components/sidebar_nav/index.tsx @@ -168,7 +168,7 @@ export default function SidebarNav(props: SidebarNavProps) {
) : ( aiChatContext.setEditingConversationId(item.uuid)} onDelete={() => getAPI().Service.deleteConversation(item.uuid)} diff --git a/components/ai_chat/resources/page/stories/components_panel.tsx b/components/ai_chat/resources/page/stories/components_panel.tsx index 0e4bbd6eda08..ec97c1c84404 100644 --- a/components/ai_chat/resources/page/stories/components_panel.tsx +++ b/components/ai_chat/resources/page/stories/components_panel.tsx @@ -63,6 +63,7 @@ function getPageContentRefineEvent(): mojom.ConversationEntryEvent { const HISTORY: mojom.ConversationTurn[] = [ { + uuid: undefined, text: 'Summarize this page', characterType: mojom.CharacterType.HUMAN, visibility: mojom.ConversationTurnVisibility.VISIBLE, @@ -74,6 +75,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, visibility: mojom.ConversationTurnVisibility.VISIBLE, @@ -85,6 +87,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: 'What is pointer compression?\n...and how does it work?\n - tell me something interesting', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.QUERY, @@ -96,6 +99,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -107,6 +111,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -118,6 +123,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: 'What is taylor series?', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.QUERY, @@ -129,6 +135,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -140,6 +147,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: 'Write a hello world program in c++', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.QUERY, @@ -151,6 +159,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -162,6 +171,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: 'Summarize this excerpt', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.SUMMARIZE_SELECTED_TEXT, @@ -173,6 +183,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -184,6 +195,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: 'Shorten this selected text', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.SHORTEN, @@ -195,6 +207,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -206,12 +219,14 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: 'Will an LTT store backpack fit in a Tesla Model Y frunk?', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.SHORTEN, visibility: mojom.ConversationTurnVisibility.VISIBLE, selectedText: '', edits: [{ + uuid: undefined, text: 'Will it fit in a Tesla Model Y frunk?', characterType: mojom.CharacterType.HUMAN, actionType: mojom.ActionType.SHORTEN, @@ -227,6 +242,7 @@ const HISTORY: mojom.ConversationTurn[] = [ fromBraveSearchSERP: false }, { + uuid: undefined, text: '', characterType: mojom.CharacterType.ASSISTANT, actionType: mojom.ActionType.UNSPECIFIED, @@ -310,6 +326,8 @@ const SAMPLE_QUESTIONS = [ ] const SITE_INFO: mojom.SiteInfo = { + id: undefined, + contentType: mojom.ContentType.PageContent, title: 'Tiny Tweaks to Neurons Can Rewire Animal Motion', contentUsedPercentage: 40, isContentAssociationPossible: true, diff --git a/ios/browser/api/ai_chat/ai_chat.mm b/ios/browser/api/ai_chat/ai_chat.mm index 20c931903b4d..69f39c2481a8 100644 --- a/ios/browser/api/ai_chat/ai_chat.mm +++ b/ios/browser/api/ai_chat/ai_chat.mm @@ -108,7 +108,7 @@ - (void)changeModel:(NSString*)modelKey { - (void)submitHumanConversationEntry:(NSString*)text { current_conversation_->SubmitHumanConversationEntry( ai_chat::mojom::ConversationTurn::New( - ai_chat::mojom::CharacterType::HUMAN, + std::nullopt, ai_chat::mojom::CharacterType::HUMAN, ai_chat::mojom::ActionType::UNSPECIFIED, ai_chat::mojom::ConversationTurnVisibility::VISIBLE, base::SysNSStringToUTF8(text), std::nullopt, std::nullopt, diff --git a/ios/browser/api/ai_chat/ai_chat_service_factory.mm b/ios/browser/api/ai_chat/ai_chat_service_factory.mm index 323e685507d2..9cb981924284 100644 --- a/ios/browser/api/ai_chat/ai_chat_service_factory.mm +++ b/ios/browser/api/ai_chat/ai_chat_service_factory.mm @@ -68,8 +68,10 @@ return std::make_unique( model_service, std::move(credential_manager), user_prefs::UserPrefs::Get(context), ai_chat_metrics_.get(), + GetApplicationContext()->GetOSCryptAsync(), context->GetSharedURLLoaderFactory(), - version_info::GetChannelString(::GetChannel())); + version_info::GetChannelString(::GetChannel()), + ProfileIOS::FromBrowserState(context)->GetStatePath()); } web::BrowserState* AIChatServiceFactory::GetBrowserStateToUse(