Skip to content

Commit

Permalink
fix modifyconversation and some observations in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
petemill committed Oct 10, 2024
1 parent 93443fa commit 7f62f91
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 79 deletions.
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ source_set("test_support") {
"engine/mock_engine_consumer.h",
"engine/mock_remote_completion_client.cc",
"engine/mock_remote_completion_client.h",
"mock_conversation_handler_observer.cc",
"mock_conversation_handler_observer.h",
"test_utils.cc",
"test_utils.h",
]
Expand Down
21 changes: 1 addition & 20 deletions components/ai_chat/core/browser/ai_chat_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "base/time/time.h"
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
#include "brave/components/ai_chat/core/browser/conversation_handler.h"
#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h"
#include "brave/components/ai_chat/core/browser/test_utils.h"
#include "brave/components/ai_chat/core/browser/utils.h"
#include "brave/components/ai_chat/core/common/features.h"
Expand All @@ -56,26 +57,6 @@ namespace ai_chat {

namespace {

class MockConversationHandlerObserver : public ConversationHandler::Observer {
public:
MockConversationHandlerObserver() = default;
~MockConversationHandlerObserver() override = default;

void Observe(ConversationHandler* conversation) {
conversation_observations_.AddObservation(conversation);
}

MOCK_METHOD(void,
OnClientConnectionChanged,
(ConversationHandler*),
(override));

private:
base::ScopedMultiSourceObservation<ConversationHandler,
ConversationHandler::Observer>
conversation_observations_{this};
};

class MockAIChatCredentialManager : public AIChatCredentialManager {
public:
using AIChatCredentialManager::AIChatCredentialManager;
Expand Down
26 changes: 8 additions & 18 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,21 +700,23 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index,
if (!turn->edits) {
turn->edits.emplace();
}
turn->edits->emplace_back(std::move(edited_turn));

auto new_turn = std::move(chat_history_.at(turn_index));
// Erase all turns after the edited turn and notify observers
std::vector<std::optional<int32_t>> erased_turn_ids;
base::ranges::transform(
chat_history_.begin() + turn_index, chat_history_.end(),
std::back_inserter(erased_turn_ids),
[](mojom::ConversationTurnPtr& turn) { return turn->id; });
turn->edits->emplace_back(std::move(edited_turn));
auto new_turn = std::move(chat_history_.at(turn_index));
chat_history_.erase(chat_history_.begin() + turn_index, chat_history_.end());

for (auto& id : erased_turn_ids) {
OnConversationEntryRemoved(id.value());
if (id.has_value()) {
OnConversationEntryRemoved(id.value());
}
}

new_turn->id = std::nullopt;
SubmitHumanConversationEntry(std::move(new_turn));
}

Expand Down Expand Up @@ -1330,20 +1332,6 @@ void ConversationHandler::OnHistoryUpdate() {
for (auto& client : conversation_ui_handlers_) {
client->OnConversationHistoryUpdate();
}
OnConversationEntriesChanged();
}

void ConversationHandler::OnConversationEntriesChanged() {
for (auto& observer : observers_) {
// TODO(petemill): only tell observers about complete turns. This is
// expensive to do for every event generated by in-progress turns,
// and consumers likely only need complete ones (e.g. database save).
std::vector<mojom::ConversationTurnPtr> history;
for (const auto& turn : chat_history_) {
history.emplace_back(turn->Clone());
}
observer.OnConversationEntriesChanged(this, std::move(history));
}
}

void ConversationHandler::OnConversationEntryRemoved(
Expand All @@ -1369,6 +1357,8 @@ void ConversationHandler::OnConversationEntryAdded(
associated_content_value =
associated_content_delegate_->GetCachedTextContent();
}
// If this is the first entry that isn't staged, notify about all previous
// staged entries
if (!entry->from_brave_search_SERP &&
base::ranges::all_of(chat_history_,
[&entry](mojom::ConversationTurnPtr& history_entry) {
Expand Down
4 changes: 0 additions & 4 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ class ConversationHandler : public mojom::ConversationHandler,
// Called when the conversation history changess
virtual void OnRequestInProgressChanged(ConversationHandler* handler,
bool in_progress) {}
virtual void OnConversationEntriesChanged(
ConversationHandler* handler,
std::vector<mojom::ConversationTurnPtr> entries) {}
virtual void OnConversationEntryAdded(
ConversationHandler* handler,
mojom::ConversationTurnPtr& entry,
Expand Down Expand Up @@ -335,7 +332,6 @@ class ConversationHandler : public mojom::ConversationHandler,
void OnConversationEntryRemoved(std::optional<int> turn_id);
void OnSuggestedQuestionsChanged();
void OnAssociatedContentInfoChanged();
void OnConversationEntriesChanged();
void OnClientConnectionChanged();
void OnConversationTitleChanged(std::string title);
void OnConversationUIConnectionChanged(mojo::RemoteSetElementId id);
Expand Down
Loading

0 comments on commit 7f62f91

Please sign in to comment.