diff --git a/android/BUILD.gn b/android/BUILD.gn index 8fcd7824dfdb..e8c300338259 100644 --- a/android/BUILD.gn +++ b/android/BUILD.gn @@ -5,7 +5,7 @@ import("//brave/components/ai_chat/core/common/buildflags/buildflags.gni") import("//brave/components/p3a/buildflags.gni") -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//brave/components/webcompat_reporter/buildflags/buildflags.gni") import("//build/config/android/rules.gni") diff --git a/browser/brave_local_state_prefs.cc b/browser/brave_local_state_prefs.cc index c06b4a8e687a..2bd64b7a2d1c 100644 --- a/browser/brave_local_state_prefs.cc +++ b/browser/brave_local_state_prefs.cc @@ -37,7 +37,7 @@ #include "brave/components/p3a/star_randomness_meta.h" #include "brave/components/skus/browser/skus_utils.h" #include "brave/components/tor/buildflags/buildflags.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "build/build_config.h" #include "chrome/common/pref_names.h" #include "components/metrics/metrics_pref_names.h" diff --git a/browser/brave_profile_prefs.cc b/browser/brave_profile_prefs.cc index a87561c03f13..8aa1659e45d4 100644 --- a/browser/brave_profile_prefs.cc +++ b/browser/brave_profile_prefs.cc @@ -50,7 +50,7 @@ #include "brave/components/search_engines/brave_prepopulated_engines.h" #include "brave/components/speedreader/common/buildflags/buildflags.h" #include "brave/components/tor/buildflags/buildflags.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "build/build_config.h" #include "chrome/browser/prefetch/pref_names.h" #include "chrome/browser/prefs/session_startup_pref.h" diff --git a/browser/browser_context_keyed_service_factories.cc b/browser/browser_context_keyed_service_factories.cc index dd05c70346b2..b83eb6e3e30a 100644 --- a/browser/browser_context_keyed_service_factories.cc +++ b/browser/browser_context_keyed_service_factories.cc @@ -42,7 +42,7 @@ #include "brave/components/request_otr/common/buildflags/buildflags.h" #include "brave/components/speedreader/common/buildflags/buildflags.h" #include "brave/components/tor/buildflags/buildflags.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #if BUILDFLAG(ENABLE_BRAVE_VPN) #include "brave/browser/brave_vpn/brave_vpn_service_factory.h" diff --git a/browser/extensions/api/settings_private/brave_prefs_util.cc b/browser/extensions/api/settings_private/brave_prefs_util.cc index a5cebc019e78..119639f19050 100644 --- a/browser/extensions/api/settings_private/brave_prefs_util.cc +++ b/browser/extensions/api/settings_private/brave_prefs_util.cc @@ -23,7 +23,7 @@ #include "brave/components/request_otr/common/pref_names.h" #include "brave/components/speedreader/common/buildflags/buildflags.h" #include "brave/components/tor/buildflags/buildflags.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "chrome/browser/content_settings/cookie_settings_factory.h" #include "chrome/browser/content_settings/host_content_settings_map_factory.h" #include "chrome/browser/extensions/api/settings_private/prefs_util.h" diff --git a/browser/profiles/brave_profile_manager.cc b/browser/profiles/brave_profile_manager.cc index c545038c542a..fed51ca57853 100644 --- a/browser/profiles/brave_profile_manager.cc +++ b/browser/profiles/brave_profile_manager.cc @@ -26,7 +26,7 @@ #include "brave/components/ntp_background_images/common/pref_names.h" #include "brave/components/request_otr/common/buildflags/buildflags.h" #include "brave/components/tor/buildflags/buildflags.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "chrome/browser/browser_process.h" #include "chrome/browser/content_settings/host_content_settings_map_factory.h" #include "chrome/browser/profiles/profile_attributes_entry.h" diff --git a/browser/resources/settings/BUILD.gn b/browser/resources/settings/BUILD.gn index 2905c82c9241..f04eebb60c44 100644 --- a/browser/resources/settings/BUILD.gn +++ b/browser/resources/settings/BUILD.gn @@ -8,7 +8,7 @@ import("//brave/build/config.gni") import("//brave/components/brave_vpn/common/buildflags/buildflags.gni") import("//brave/components/brave_wayback_machine/buildflags/buildflags.gni") import("//brave/components/tor/buildflags/buildflags.gni") -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//brave/resources/brave_grit.gni") import("//chrome/common/features.gni") import("//extensions/buildflags/buildflags.gni") diff --git a/browser/search_engines/search_engine_tracker.cc b/browser/search_engines/search_engine_tracker.cc index 95246343cb19..693db739a231 100644 --- a/browser/search_engines/search_engine_tracker.cc +++ b/browser/search_engines/search_engine_tracker.cc @@ -248,7 +248,7 @@ void SearchEngineTracker::RecordWebDiscoveryEnabledP3A() { #endif #if BUILDFLAG(ENABLE_WEB_DISCOVERY_NATIVE) if (base::FeatureList::IsEnabled( - web_discovery::features::kWebDiscoveryNative)) { + web_discovery::features::kBraveWebDiscoveryNative)) { enabled = profile_prefs_->GetBoolean(web_discovery::kWebDiscoveryNativeEnabled); } diff --git a/browser/search_engines/search_engine_tracker.h b/browser/search_engines/search_engine_tracker.h index 25cc168e4572..cce08e4f513c 100644 --- a/browser/search_engines/search_engine_tracker.h +++ b/browser/search_engines/search_engine_tracker.h @@ -11,7 +11,7 @@ #include "base/memory/raw_ptr.h" #include "base/scoped_observation.h" #include "brave/components/time_period_storage/weekly_event_storage.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "components/keyed_service/content/browser_context_keyed_service_factory.h" #include "components/keyed_service/core/keyed_service.h" #include "components/prefs/pref_change_registrar.h" diff --git a/browser/sources.gni b/browser/sources.gni index ee58cb8034aa..b7fbdf0d1b77 100644 --- a/browser/sources.gni +++ b/browser/sources.gni @@ -52,7 +52,7 @@ import("//brave/components/brave_webtorrent/browser/buildflags/buildflags.gni") import("//brave/components/commander/common/buildflags/buildflags.gni") import("//brave/components/greaselion/browser/buildflags/buildflags.gni") import("//brave/components/tor/buildflags/buildflags.gni") -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//extensions/buildflags/buildflags.gni") brave_chrome_browser_visibility = [ @@ -215,8 +215,7 @@ brave_chrome_browser_deps = [ "//brave/components/speedreader/common/buildflags", "//brave/components/tor/buildflags", "//brave/components/version_info", - "//brave/components/web_discovery/common", - "//brave/components/web_discovery/common/buildflags", + "//brave/components/web_discovery/buildflags", "//brave/components/webcompat/content/browser", "//brave/components/webcompat/core/common", "//brave/services/network/public/cpp", @@ -365,6 +364,7 @@ if (enable_web_discovery_native) { brave_chrome_browser_deps += [ "//brave/browser/web_discovery", "//brave/components/web_discovery/browser", + "//brave/components/web_discovery/common", ] } diff --git a/browser/ui/BUILD.gn b/browser/ui/BUILD.gn index 36790fea8178..8d233f2a5a60 100644 --- a/browser/ui/BUILD.gn +++ b/browser/ui/BUILD.gn @@ -17,7 +17,7 @@ import("//brave/components/request_otr/common/buildflags/buildflags.gni") import("//brave/components/speedreader/common/buildflags/buildflags.gni") import("//brave/components/text_recognition/common/buildflags/buildflags.gni") import("//brave/components/tor/buildflags/buildflags.gni") -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//build/config/features.gni") import("//chrome/common/features.gni") import("//components/gcm_driver/config.gni") @@ -835,6 +835,7 @@ source_set("ui") { "//brave/components/tor/buildflags", "//brave/components/url_sanitizer/browser", "//brave/components/vector_icons", + "//brave/components/web_discovery/buildflags", "//brave/components/webui", "//chrome/app:command_ids", "//chrome/app/vector_icons:vector_icons", @@ -903,10 +904,7 @@ source_set("ui") { } if (enable_web_discovery_native) { - deps += [ - "//brave/components/web_discovery/common", - "//brave/components/web_discovery/common/buildflags", - ] + deps += [ "//brave/components/web_discovery/common" ] } if (is_linux) { diff --git a/browser/ui/webui/brave_settings_ui.cc b/browser/ui/webui/brave_settings_ui.cc index 88883cbb3718..de31ce02cfb4 100644 --- a/browser/ui/webui/brave_settings_ui.cc +++ b/browser/ui/webui/brave_settings_ui.cc @@ -45,8 +45,7 @@ #include "brave/components/speedreader/common/buildflags/buildflags.h" #include "brave/components/tor/buildflags/buildflags.h" #include "brave/components/version_info/version_info.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" -#include "brave/components/web_discovery/common/features.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "build/build_config.h" #include "chrome/browser/profiles/profile.h" #include "chrome/browser/ui/webui/settings/metrics_reporting_handler.h" @@ -93,6 +92,10 @@ #include "brave/components/playlist/common/features.h" #endif +#if BUILDFLAG(ENABLE_WEB_DISCOVERY_NATIVE) +#include "brave/components/web_discovery/common/features.h" +#endif + using ntp_background_images::ViewCounterServiceFactory; BraveSettingsUI::BraveSettingsUI(content::WebUI* web_ui) : SettingsUI(web_ui) { @@ -195,9 +198,10 @@ void BraveSettingsUI::AddResources(content::WebUIDataSource* html_source, html_source->AddBoolean("enable_extensions", BUILDFLAG(ENABLE_EXTENSIONS)); #if BUILDFLAG(ENABLE_WEB_DISCOVERY_NATIVE) - html_source->AddBoolean("isWebDiscoveryNativeEnabled", - base::FeatureList::IsEnabled( - web_discovery::features::kWebDiscoveryNative)); + html_source->AddBoolean( + "isWebDiscoveryNativeEnabled", + base::FeatureList::IsEnabled( + web_discovery::features::kBraveWebDiscoveryNative)); #endif html_source->AddBoolean("extensionsManifestV2Feature", diff --git a/browser/web_discovery/BUILD.gn b/browser/web_discovery/BUILD.gn index aa506028751e..1ec447b9c281 100644 --- a/browser/web_discovery/BUILD.gn +++ b/browser/web_discovery/BUILD.gn @@ -3,11 +3,11 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at http://mozilla.org/MPL/2.0/. -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//extensions/buildflags/buildflags.gni") if (enable_web_discovery_native) { - source_set("web_discovery") { + static_library("web_discovery") { sources = [ "web_discovery_service_factory.cc", "web_discovery_service_factory.h", diff --git a/browser/web_discovery/web_discovery_cta_util.cc b/browser/web_discovery/web_discovery_cta_util.cc index 0a053023faeb..4bb473a53294 100644 --- a/browser/web_discovery/web_discovery_cta_util.cc +++ b/browser/web_discovery/web_discovery_cta_util.cc @@ -14,7 +14,7 @@ #include "brave/components/constants/pref_names.h" #include "brave/components/constants/url_constants.h" #include "brave/components/search_engines/brave_prepopulated_engines.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "components/prefs/pref_service.h" #include "components/prefs/scoped_user_pref_update.h" #include "components/search_engines/template_url.h" @@ -96,7 +96,7 @@ bool ShouldShowWebDiscoveryInfoBar(TemplateURLService* service, const char* enabled_pref_name = kWebDiscoveryExtensionEnabled; #if BUILDFLAG(ENABLE_WEB_DISCOVERY_NATIVE) if (base::FeatureList::IsEnabled( - web_discovery::features::kWebDiscoveryNative)) { + web_discovery::features::kBraveWebDiscoveryNative)) { enabled_pref_name = web_discovery::kWebDiscoveryNativeEnabled; } #endif diff --git a/browser/web_discovery/web_discovery_infobar_delegate.cc b/browser/web_discovery/web_discovery_infobar_delegate.cc index b604262e05b1..6b8f934e6d31 100644 --- a/browser/web_discovery/web_discovery_infobar_delegate.cc +++ b/browser/web_discovery/web_discovery_infobar_delegate.cc @@ -7,7 +7,7 @@ #include "brave/browser/web_discovery/web_discovery_cta_util.h" #include "brave/components/constants/pref_names.h" -#include "brave/components/web_discovery/common/buildflags/buildflags.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "components/infobars/core/infobar.h" #include "components/prefs/pref_service.h" @@ -51,7 +51,7 @@ void WebDiscoveryInfoBarDelegate::EnableWebDiscovery() { const char* pref_name = kWebDiscoveryExtensionEnabled; #if BUILDFLAG(ENABLE_WEB_DISCOVERY_NATIVE) if (base::FeatureList::IsEnabled( - web_discovery::features::kWebDiscoveryNative)) { + web_discovery::features::kBraveWebDiscoveryNative)) { pref_name = web_discovery::kWebDiscoveryNativeEnabled; } #endif diff --git a/browser/web_discovery/web_discovery_service_factory.cc b/browser/web_discovery/web_discovery_service_factory.cc index 093226b23b23..87522d736a7f 100644 --- a/browser/web_discovery/web_discovery_service_factory.cc +++ b/browser/web_discovery/web_discovery_service_factory.cc @@ -49,7 +49,7 @@ KeyedService* WebDiscoveryServiceFactory::BuildServiceInstanceFor( content::BrowserContext* WebDiscoveryServiceFactory::GetBrowserContextToUse( content::BrowserContext* context) const { - if (!base::FeatureList::IsEnabled(features::kWebDiscoveryNative)) { + if (!base::FeatureList::IsEnabled(features::kBraveWebDiscoveryNative)) { return nullptr; } // Prevents creation of service instance for incognito/OTR profiles diff --git a/browser/web_discovery/web_discovery_service_factory_unittest.cc b/browser/web_discovery/web_discovery_service_factory_unittest.cc index b4717b1749e7..cbda07dfa63b 100644 --- a/browser/web_discovery/web_discovery_service_factory_unittest.cc +++ b/browser/web_discovery/web_discovery_service_factory_unittest.cc @@ -15,8 +15,9 @@ namespace web_discovery { TEST(WebDiscoveryServiceFactoryTest, PrivateNotCreated) { + base::test::ScopedFeatureList scoped_features( + features::kBraveWebDiscoveryNative); content::BrowserTaskEnvironment task_environment; - base::test::ScopedFeatureList scoped_features(features::kWebDiscoveryNative); auto* browser_process = TestingBrowserProcess::GetGlobal(); TestingProfileManager profile_manager(browser_process); ASSERT_TRUE(profile_manager.SetUp()); diff --git a/chromium_src/chrome/browser/DEPS b/chromium_src/chrome/browser/DEPS index a76c71b81b32..a4f2b0061486 100644 --- a/chromium_src/chrome/browser/DEPS +++ b/chromium_src/chrome/browser/DEPS @@ -48,7 +48,7 @@ include_rules = [ "+brave/components/url_sanitizer", "+brave/components/vector_icons", "+brave/components/version_info", - "+brave/components/web_discovery/common", + "+brave/components/web_discovery", "+brave/components/webcompat", "+brave/net", "+brave/services/network/public", diff --git a/chromium_src/chrome/browser/flags/android/chrome_feature_list.cc b/chromium_src/chrome/browser/flags/android/chrome_feature_list.cc index c836c47732d8..8c79f36c76ef 100644 --- a/chromium_src/chrome/browser/flags/android/chrome_feature_list.cc +++ b/chromium_src/chrome/browser/flags/android/chrome_feature_list.cc @@ -18,7 +18,7 @@ #include "brave/components/playlist/common/features.h" #include "brave/components/request_otr/common/features.h" #include "brave/components/speedreader/common/features.h" -#include "brave/components/web_discovery/common/features.h" +#include "brave/components/web_discovery/buildflags/buildflags.h" #include "brave/components/webcompat/core/common/features.h" #include "net/base/features.h" #include "third_party/blink/public/common/features.h" @@ -30,9 +30,18 @@ #define BRAVE_AI_CHAT_FLAG #endif +#if BUILDFLAG(ENABLE_WEB_DISCOVERY_NATIVE) +#include "brave/components/web_discovery/common/features.h" +#define BRAVE_WEB_DISCOVERY_FLAG \ + &web_discovery::features::kBraveWebDiscoveryNative, +#else +#define BRAVE_WEB_DISCOVERY_FLAG +#endif + // clang-format off #define kForceWebContentsDarkMode kForceWebContentsDarkMode, \ BRAVE_AI_CHAT_FLAG \ + BRAVE_WEB_DISCOVERY_FLAG \ &brave_rewards::features::kBraveRewards, \ &brave_search_conversion::features::kOmniboxBanner, \ &brave_vpn::features::kBraveVPNLinkSubscriptionAndroidUI, \ @@ -50,14 +59,14 @@ &google_sign_in_permission::features::kBraveGoogleSignInPermission, \ &net::features::kBraveForgetFirstPartyStorage, \ &brave_shields::features::kBraveShowStrictFingerprintingMode, \ - &brave_shields::features::kBraveLocalhostAccessPermission, \ - &web_discovery::features::kWebDiscoveryNative + &brave_shields::features::kBraveLocalhostAccessPermission // clang-format on #include "src/chrome/browser/flags/android/chrome_feature_list.cc" #undef kForceWebContentsDarkMode #undef BRAVE_AI_CHAT_FLAG +#undef BRAVE_WEB_DISCOVERY_FLAG namespace chrome { namespace android { diff --git a/components/web_discovery/browser/BUILD.gn b/components/web_discovery/browser/BUILD.gn index aa3c8dc31878..dc19dce58ff2 100644 --- a/components/web_discovery/browser/BUILD.gn +++ b/components/web_discovery/browser/BUILD.gn @@ -3,7 +3,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at https://mozilla.org/MPL/2.0/. -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") assert(enable_web_discovery_native) @@ -11,7 +11,6 @@ static_library("browser") { sources = [ "credential_manager.cc", "credential_manager.h", - "credential_signer.cc", "credential_signer.h", "patterns.cc", "patterns.h", @@ -33,7 +32,6 @@ static_library("browser") { "//brave/components/web_discovery/common", "//components/keyed_service/core", "//components/prefs", - "//content/public/browser", "//crypto", "//extensions/buildflags", "//net", diff --git a/components/web_discovery/browser/DEPS b/components/web_discovery/browser/DEPS index 8a3c47daae89..290a8dab5740 100644 --- a/components/web_discovery/browser/DEPS +++ b/components/web_discovery/browser/DEPS @@ -1,4 +1,5 @@ include_rules = [ + "-content", "+services/network/public", "+extensions/buildflags/buildflags.h", "+services/service_manager/public/cpp", diff --git a/components/web_discovery/browser/credential_manager.cc b/components/web_discovery/browser/credential_manager.cc index 562d3d18a90f..0c5ccca25dc1 100644 --- a/components/web_discovery/browser/credential_manager.cc +++ b/components/web_discovery/browser/credential_manager.cc @@ -8,6 +8,8 @@ #include #include "base/base64.h" +#include "base/containers/span.h" +#include "base/containers/span_rust.h" #include "base/functional/bind.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" @@ -17,6 +19,7 @@ #include "base/threading/thread_restrictions.h" #include "brave/components/web_discovery/browser/anonymous_credentials/rs/cxx/src/lib.rs.h" #include "brave/components/web_discovery/browser/pref_names.h" +#include "brave/components/web_discovery/browser/rsa.h" #include "brave/components/web_discovery/browser/util.h" #include "components/prefs/pref_service.h" #include "components/prefs/scoped_user_pref_update.h" @@ -61,19 +64,43 @@ constexpr net::NetworkTrafficAnnotationTag kJoinNetworkTrafficAnnotation = "Users can opt-in or out via brave://settings/search" })"); -std::optional GenerateJoinRequest( - anonymous_credentials::CredentialManager* anonymous_credential_manager, - crypto::RSAPrivateKey* rsa_private_key, - std::string pre_challenge) { +} // namespace + +BackgroundCredentialHelper::BackgroundCredentialHelper() + : anonymous_credential_manager_( + anonymous_credentials::new_credential_manager()) {} + +BackgroundCredentialHelper::~BackgroundCredentialHelper() = default; + +void BackgroundCredentialHelper::UseFixedSeedForTesting() { + anonymous_credential_manager_ = + anonymous_credentials::new_credential_manager_with_fixed_seed(); +} + +std::unique_ptr BackgroundCredentialHelper::GenerateRSAKey() { + auto key_pair = GenerateRSAKeyPair(); + if (!key_pair) { + return nullptr; + } + rsa_private_key_ = std::move(key_pair->key_pair); + return key_pair; +} + +void BackgroundCredentialHelper::SetRSAKey( + std::unique_ptr rsa_private_key) { + rsa_private_key_ = std::move(rsa_private_key); +} + +std::optional +BackgroundCredentialHelper::GenerateJoinRequest(std::string pre_challenge) { base::AssertLongCPUWorkAllowed(); - base::span pre_challenge_span( - reinterpret_cast(pre_challenge.data()), pre_challenge.size()); - auto challenge = crypto::SHA256Hash(pre_challenge_span); + CHECK(rsa_private_key_); + auto challenge = crypto::SHA256Hash(base::as_byte_span(pre_challenge)); - auto join_request = anonymous_credential_manager->start_join( - rust::Slice(challenge.data(), challenge.size())); + auto join_request = anonymous_credential_manager_->start_join( + base::SpanToRustSlice(challenge)); - auto signature = RSASign(rsa_private_key, join_request.join_request); + auto signature = RSASign(rsa_private_key_.get(), join_request.join_request); if (!signature) { VLOG(1) << "RSA signature failed"; @@ -84,19 +111,18 @@ std::optional GenerateJoinRequest( .signature = *signature}; } -std::optional FinishJoin( - anonymous_credentials::CredentialManager* anonymous_credential_manager, +std::optional BackgroundCredentialHelper::FinishJoin( std::string date, std::vector group_pub_key, std::vector gsk, std::vector join_resp_bytes) { base::AssertLongCPUWorkAllowed(); auto pub_key_result = anonymous_credentials::load_group_public_key( - rust::Slice(group_pub_key.data(), group_pub_key.size())); - auto gsk_result = anonymous_credentials::load_credential_big( - rust::Slice(gsk.data(), gsk.size())); + base::SpanToRustSlice(group_pub_key)); + auto gsk_result = + anonymous_credentials::load_credential_big(base::SpanToRustSlice(gsk)); auto join_resp_result = anonymous_credentials::load_join_response( - rust::Slice(join_resp_bytes.data(), join_resp_bytes.size())); + base::SpanToRustSlice(join_resp_bytes)); if (!pub_key_result.error_message.empty() || !gsk_result.error_message.empty() || !join_resp_result.error_message.empty()) { @@ -107,7 +133,7 @@ std::optional FinishJoin( << join_resp_result.error_message.c_str(); return std::nullopt; } - auto finish_res = anonymous_credential_manager->finish_join( + auto finish_res = anonymous_credential_manager_->finish_join( *pub_key_result.value, *gsk_result.value, std::move(join_resp_result.value)); if (!finish_res.error_message.empty()) { @@ -118,8 +144,8 @@ std::optional FinishJoin( return base::Base64Encode(finish_res.data); } -std::optional> PerformSign( - anonymous_credentials::CredentialManager* anonymous_credential_manager, +std::optional> +BackgroundCredentialHelper::PerformSign( std::vector msg, std::vector basename, std::optional> gsk_bytes, @@ -127,11 +153,9 @@ std::optional> PerformSign( base::AssertLongCPUWorkAllowed(); if (gsk_bytes && credential_bytes) { auto gsk_result = anonymous_credentials::load_credential_big( - rust::Slice(reinterpret_cast(gsk_bytes->data()), - gsk_bytes->size())); + base::SpanToRustSlice(*gsk_bytes)); auto credential_result = anonymous_credentials::load_user_credentials( - rust::Slice(reinterpret_cast(credential_bytes->data()), - credential_bytes->size())); + base::SpanToRustSlice(*credential_bytes)); if (!gsk_result.error_message.empty() || !credential_result.error_message.empty()) { VLOG(1) << "Failed to sign due to deserialization error with gsk, or " @@ -140,12 +164,11 @@ std::optional> PerformSign( << credential_result.error_message.c_str(); return std::nullopt; } - anonymous_credential_manager->set_gsk_and_credentials( + anonymous_credential_manager_->set_gsk_and_credentials( std::move(gsk_result.value), std::move(credential_result.value)); } - auto sig_res = anonymous_credential_manager->sign( - rust::Slice(msg.data(), msg.size()), - rust::Slice(basename.data(), basename.size())); + auto sig_res = anonymous_credential_manager_->sign( + base::SpanToRustSlice(msg), base::SpanToRustSlice(basename)); if (!sig_res.error_message.empty()) { VLOG(1) << "Failed to sign: " << sig_res.error_message.c_str(); return std::nullopt; @@ -153,8 +176,6 @@ std::optional> PerformSign( return std::vector(sig_res.data.begin(), sig_res.data.end()); } -} // namespace - CredentialManager::CredentialManager( PrefService* profile_prefs, network::SharedURLLoaderFactory* shared_url_loader_factory, @@ -164,12 +185,8 @@ CredentialManager::CredentialManager( server_config_loader_(server_config_loader), join_url_(GetDirectHPNHost() + kJoinPath), backoff_entry_(&kBackoffPolicy), - sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})), - anonymous_credential_manager_( - new rust::Box(anonymous_credentials::new_credential_manager()), - base::OnTaskRunnerDeleter(sequenced_task_runner_)), - rsa_private_key_(nullptr, - base::OnTaskRunnerDeleter(sequenced_task_runner_)) {} + background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})), + background_credential_helper_(background_task_runner_) {} CredentialManager::~CredentialManager() = default; @@ -183,11 +200,15 @@ bool CredentialManager::LoadRSAKey() { return true; } - rsa_private_key_.reset(ImportRSAKeyPair(private_key_b64).release()); - if (!rsa_private_key_) { + auto key_pair = ImportRSAKeyPair(private_key_b64); + if (!key_pair) { VLOG(1) << "Failed to import stored RSA key"; + rsa_public_key_b64_ = std::nullopt; return false; } + background_credential_helper_ + .AsyncCall(&BackgroundCredentialHelper::SetRSAKey) + .WithArgs(std::move(key_pair)); return true; } @@ -198,7 +219,6 @@ void CredentialManager::OnNewRSAKey(std::unique_ptr key_info) { return; } - rsa_private_key_.reset(key_info->key_pair.release()); rsa_public_key_b64_ = key_info->public_key_b64; profile_prefs_->SetString(kCredentialRSAPrivateKey, @@ -213,53 +233,48 @@ void CredentialManager::JoinGroups() { auto today_date = FormatServerDate(base::Time::Now().UTCMidnight()); const auto& anon_creds_dict = profile_prefs_->GetDict(kAnonymousCredentialsDict); - for (const auto& [date, group_pub_key_b64] : server_config.group_pub_keys) { + for (const auto& [date, group_pub_key] : server_config.group_pub_keys) { if (date < today_date || join_url_loaders_.contains(date) || anon_creds_dict.contains(date)) { continue; } - if (rsa_private_key_ == nullptr) { + if (!rsa_public_key_b64_) { if (!LoadRSAKey()) { return; } - if (rsa_private_key_ == nullptr) { - sequenced_task_runner_->PostTaskAndReplyWithResult( - FROM_HERE, base::BindOnce(&GenerateRSAKeyPair), - base::BindOnce(&CredentialManager::OnNewRSAKey, - weak_ptr_factory_.GetWeakPtr())); + if (!rsa_public_key_b64_) { + background_credential_helper_ + .AsyncCall(&BackgroundCredentialHelper::GenerateRSAKey) + .Then(base::BindOnce(&CredentialManager::OnNewRSAKey, + weak_ptr_factory_.GetWeakPtr())); return; } } - StartJoinGroup(date, group_pub_key_b64); + StartJoinGroup(date, group_pub_key); } } -void CredentialManager::StartJoinGroup(const std::string& date, - const std::string& group_pub_key_b64) { - auto group_pub_key = base::Base64Decode(group_pub_key_b64); - if (!group_pub_key) { - VLOG(1) << "Failed to decode group public key for " << date; - return; - } - std::vector group_pub_key_const(group_pub_key->begin(), - group_pub_key->end()); +void CredentialManager::StartJoinGroup( + const std::string& date, + const std::vector& group_pub_key) { + std::vector group_pub_key_const(group_pub_key.begin(), + group_pub_key.end()); auto challenge_elements = base::Value::List::with_capacity(2); challenge_elements.Append(*rsa_public_key_b64_); - challenge_elements.Append(group_pub_key_b64); + challenge_elements.Append(base::Base64Encode(group_pub_key)); std::string pre_challenge; base::JSONWriter::Write(challenge_elements, &pre_challenge); - sequenced_task_runner_->PostTaskAndReplyWithResult( - FROM_HERE, - base::BindOnce(&GenerateJoinRequest, &**anonymous_credential_manager_, - rsa_private_key_.get(), pre_challenge), - base::BindOnce(&CredentialManager::OnJoinRequestReady, - weak_ptr_factory_.GetWeakPtr(), date, - group_pub_key_const)); + background_credential_helper_ + .AsyncCall(&BackgroundCredentialHelper::GenerateJoinRequest) + .WithArgs(pre_challenge) + .Then(base::BindOnce(&CredentialManager::OnJoinRequestReady, + weak_ptr_factory_.GetWeakPtr(), date, + group_pub_key_const)); } void CredentialManager::OnJoinRequestReady( @@ -376,12 +391,11 @@ bool CredentialManager::ProcessJoinResponse( std::vector join_resp_bytes_const(join_resp_bytes->begin(), join_resp_bytes->end()); - sequenced_task_runner_->PostTaskAndReplyWithResult( - FROM_HERE, - base::BindOnce(&FinishJoin, &**anonymous_credential_manager_, date, - group_pub_key, gsk, join_resp_bytes_const), - base::BindOnce(&CredentialManager::OnCredentialsReady, - weak_ptr_factory_.GetWeakPtr(), date, gsk)); + background_credential_helper_ + .AsyncCall(&BackgroundCredentialHelper::FinishJoin) + .WithArgs(date, group_pub_key, gsk, join_resp_bytes_const) + .Then(base::BindOnce(&CredentialManager::OnCredentialsReady, + weak_ptr_factory_.GetWeakPtr(), date, gsk)); return true; } @@ -405,7 +419,7 @@ bool CredentialManager::CredentialExistsForToday() { .contains(FormatServerDate(base::Time::Now())); } -bool CredentialManager::Sign(std::vector msg, +void CredentialManager::Sign(std::vector msg, std::vector basename, SignCallback callback) { auto today_date = FormatServerDate(base::Time::Now().UTCMidnight()); @@ -417,30 +431,31 @@ bool CredentialManager::Sign(std::vector msg, auto* today_cred_dict = anon_creds_dict.FindDict(today_date); if (!today_cred_dict) { VLOG(1) << "Failed to sign due to unavailability of credentials"; - return false; + std::move(callback).Run(std::nullopt); + return; } auto* gsk_b64 = today_cred_dict->FindString(kGSKDictKey); auto* credential_b64 = today_cred_dict->FindString(kCredentialDictKey); if (!gsk_b64 || !credential_b64) { VLOG(1) << "Failed to sign due to unavailability of gsk/credential"; - return false; + std::move(callback).Run(std::nullopt); + return; } gsk_bytes = base::Base64Decode(*gsk_b64); credential_bytes = base::Base64Decode(*credential_b64); if (!gsk_bytes || !credential_bytes) { VLOG(1) << "Failed to sign due to bad gsk/credential base64"; - return false; + std::move(callback).Run(std::nullopt); + return; } } - sequenced_task_runner_->PostTaskAndReplyWithResult( - FROM_HERE, - base::BindOnce(&PerformSign, &**anonymous_credential_manager_, msg, - basename, gsk_bytes, credential_bytes), - base::BindOnce(&CredentialManager::OnSignResult, - weak_ptr_factory_.GetWeakPtr(), today_date, - std::move(callback))); - return true; + background_credential_helper_ + .AsyncCall(&BackgroundCredentialHelper::PerformSign) + .WithArgs(msg, basename, gsk_bytes, credential_bytes) + .Then(base::BindOnce(&CredentialManager::OnSignResult, + weak_ptr_factory_.GetWeakPtr(), today_date, + std::move(callback))); } void CredentialManager::OnSignResult( @@ -452,12 +467,8 @@ void CredentialManager::OnSignResult( } void CredentialManager::UseFixedSeedForTesting() { - anonymous_credential_manager_ = - std::unique_ptr, - base::OnTaskRunnerDeleter>( - new rust::Box( - anonymous_credentials::new_credential_manager_with_fixed_seed()), - base::OnTaskRunnerDeleter(sequenced_task_runner_)); + background_credential_helper_.AsyncCall( + &BackgroundCredentialHelper::UseFixedSeedForTesting); } } // namespace web_discovery diff --git a/components/web_discovery/browser/credential_manager.h b/components/web_discovery/browser/credential_manager.h index e5550a37ca0f..18841571de24 100644 --- a/components/web_discovery/browser/credential_manager.h +++ b/components/web_discovery/browser/credential_manager.h @@ -14,6 +14,7 @@ #include "base/functional/callback.h" #include "base/memory/raw_ptr.h" #include "base/task/sequenced_task_runner.h" +#include "base/threading/sequence_bound.h" #include "base/timer/wall_clock_timer.h" #include "brave/components/web_discovery/browser/anonymous_credentials/rs/cxx/src/lib.rs.h" #include "brave/components/web_discovery/browser/credential_signer.h" @@ -36,6 +37,38 @@ struct GenerateJoinRequestResult { std::string signature; }; +class BackgroundCredentialHelper { + public: + BackgroundCredentialHelper(); + ~BackgroundCredentialHelper(); + + BackgroundCredentialHelper(const BackgroundCredentialHelper&) = delete; + BackgroundCredentialHelper& operator=(const BackgroundCredentialHelper&) = + delete; + + void UseFixedSeedForTesting(); + + std::unique_ptr GenerateRSAKey(); + void SetRSAKey(std::unique_ptr rsa_private_key); + std::optional GenerateJoinRequest( + std::string pre_challenge); + std::optional FinishJoin( + std::string date, + std::vector group_pub_key, + std::vector gsk, + std::vector join_resp_bytes); + std::optional> PerformSign( + std::vector msg, + std::vector basename, + std::optional> gsk_bytes, + std::optional> credential_bytes); + + private: + rust::Box + anonymous_credential_manager_; + std::unique_ptr rsa_private_key_; +}; + // Manages and utilizes anonymous credentials used for communicating // with Web Discovery servers. These Direct Anonymous Attestation credentials // are used to prevent Sybil attacks on the servers. @@ -63,7 +96,7 @@ class CredentialManager : public CredentialSigner { // CredentialSigner: bool CredentialExistsForToday() override; - bool Sign(std::vector msg, + void Sign(std::vector msg, std::vector basename, SignCallback callback) override; @@ -78,7 +111,7 @@ class CredentialManager : public CredentialSigner { void OnNewRSAKey(std::unique_ptr key_info); void StartJoinGroup(const std::string& date, - const std::string& group_pub_key_b64); + const std::vector& group_pub_key); void OnJoinRequestReady( std::string date, @@ -102,9 +135,9 @@ class CredentialManager : public CredentialSigner { SignCallback callback, std::optional> signed_message); - raw_ptr profile_prefs_; - raw_ptr shared_url_loader_factory_; - raw_ptr server_config_loader_; + const raw_ptr profile_prefs_; + const raw_ptr shared_url_loader_factory_; + const raw_ptr server_config_loader_; GURL join_url_; base::flat_map> @@ -112,14 +145,9 @@ class CredentialManager : public CredentialSigner { net::BackoffEntry backoff_entry_; base::WallClockTimer retry_timer_; - scoped_refptr sequenced_task_runner_; - - std::unique_ptr, - base::OnTaskRunnerDeleter> - anonymous_credential_manager_; + scoped_refptr background_task_runner_; - std::unique_ptr - rsa_private_key_; + base::SequenceBound background_credential_helper_; std::optional rsa_public_key_b64_; std::optional loaded_credential_date_; diff --git a/components/web_discovery/browser/credential_manager_unittest.cc b/components/web_discovery/browser/credential_manager_unittest.cc index 98e12761a8b1..b7237ced323a 100644 --- a/components/web_discovery/browser/credential_manager_unittest.cc +++ b/components/web_discovery/browser/credential_manager_unittest.cc @@ -9,6 +9,7 @@ #include #include +#include "base/base64.h" #include "base/files/file_util.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" @@ -69,7 +70,9 @@ class WebDiscoveryCredentialManagerTest : public testing::Test { auto server_config = std::make_unique(); for (const auto [date, join_response] : *join_responses) { - server_config->group_pub_keys[date] = *group_pub_key; + auto decoded_group_pub_key = base::Base64Decode(*group_pub_key); + ASSERT_TRUE(decoded_group_pub_key); + server_config->group_pub_keys[date] = *decoded_group_pub_key; join_responses_[date] = join_response.GetString(); } server_config_loader_->SetLastServerConfigForTesting( @@ -90,12 +93,6 @@ class WebDiscoveryCredentialManagerTest : public testing::Test { credential_manager_->UseFixedSeedForTesting(); } - base::test::TaskEnvironment task_environment_; - std::unique_ptr credential_manager_; - TestingPrefServiceSimple profile_prefs_; - size_t join_requests_made_ = 0; - - private: void HandleRequest(const network::ResourceRequest& request) { url_loader_factory_.ClearResponses(); std::string response; @@ -120,10 +117,19 @@ class WebDiscoveryCredentialManagerTest : public testing::Test { join_requests_made_++; } - base::flat_map join_responses_; - std::unique_ptr server_config_loader_; + base::test::TaskEnvironment task_environment_; + network::TestURLLoaderFactory url_loader_factory_; scoped_refptr shared_url_loader_factory_; + + TestingPrefServiceSimple profile_prefs_; + std::unique_ptr server_config_loader_; + + base::flat_map join_responses_; + + std::unique_ptr credential_manager_; + + size_t join_requests_made_ = 0; }; TEST_F(WebDiscoveryCredentialManagerTest, JoinGroups) { diff --git a/components/web_discovery/browser/credential_signer.cc b/components/web_discovery/browser/credential_signer.cc deleted file mode 100644 index ca8fce0c3b51..000000000000 --- a/components/web_discovery/browser/credential_signer.cc +++ /dev/null @@ -1,12 +0,0 @@ -/* Copyright (c) 2024 The Brave Authors. All rights reserved. - * This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at https://mozilla.org/MPL/2.0/. */ - -#include "brave/components/web_discovery/browser/credential_signer.h" - -namespace web_discovery { - -CredentialSigner::~CredentialSigner() = default; - -} // namespace web_discovery diff --git a/components/web_discovery/browser/credential_signer.h b/components/web_discovery/browser/credential_signer.h index b7e167cf1e49..fd79c67a894c 100644 --- a/components/web_discovery/browser/credential_signer.h +++ b/components/web_discovery/browser/credential_signer.h @@ -17,7 +17,7 @@ class CredentialSigner { public: using SignCallback = base::OnceCallback>)>; - virtual ~CredentialSigner(); + virtual ~CredentialSigner() = default; // Returns true is a credential is available for the current date. // The caller can expect future calls to `Sign` to succeed, if made today. @@ -29,7 +29,7 @@ class CredentialSigner { // preventing Sybil attacks. // See signature_basename.h/cc for more information on how the basename // should be generated. - virtual bool Sign(std::vector msg, + virtual void Sign(std::vector msg, std::vector basename, SignCallback callback) = 0; }; diff --git a/components/web_discovery/browser/patterns.cc b/components/web_discovery/browser/patterns.cc index 34cce886417e..d0e7457dd5cd 100644 --- a/components/web_discovery/browser/patterns.cc +++ b/components/web_discovery/browser/patterns.cc @@ -7,10 +7,12 @@ #include +#include "base/containers/contains.h" #include "base/containers/fixed_flat_map.h" #include "base/json/json_reader.h" #include "base/logging.h" #include "base/strings/string_number_conversions.h" +#include "base/threading/thread_restrictions.h" #include "third_party/re2/src/re2/re2.h" namespace web_discovery { @@ -88,7 +90,6 @@ std::optional> ParsePayloadRules( return std::nullopt; } auto* action = rule_group_dict->FindString(kActionKey); - auto* fields = rule_group_dict->FindList(kFieldsKey); auto* rule_type_str = rule_group_dict->FindString(kRuleTypeKey); auto* result_type_str = rule_group_dict->FindString(kResultTypeKey); if (!action || !rule_type_str || !result_type_str) { @@ -106,7 +107,7 @@ std::optional> ParsePayloadRules( rule_group_it->key = key; rule_group_it->result_type = result_type_it->second; rule_group_it->rule_type = rule_type_it->second; - if (fields) { + if (auto* fields = rule_group_dict->FindList(kFieldsKey)) { rule_group_it->rules = std::vector(fields->size()); auto rule_it = rule_group_it->rules.begin(); @@ -124,15 +125,16 @@ std::optional> ParsePayloadRules( } RefineFunctionList ParseFunctionsApplied(const base::Value::List* list) { + CHECK(list); RefineFunctionList result; for (const auto& function_val : *list) { const auto* function_list = function_val.GetIfList(); if (!function_list || function_list->size() <= 1) { continue; } - std::vector function_vec; + base::Value::List function_vec; for (const auto& element : *function_list) { - function_vec.push_back(element.Clone()); + function_vec.Append(element.Clone()); } result.push_back(std::move(function_vec)); } @@ -212,6 +214,11 @@ std::optional> ParsePatternsURLDetails( auto& details = result[i]; details.url_regex = std::make_unique(*url_regex); + if (!details.url_regex->ok()) { + VLOG(1) << "URL pattern is not valid regex: " + << details.url_regex->error(); + return std::nullopt; + } std::string i_str = base::NumberToString(i); @@ -224,10 +231,7 @@ std::optional> ParsePatternsURLDetails( } details.id = *id; - details.is_search_engine = - base::ranges::find(search_engines_list->begin(), - search_engines_list->end(), - i_str) != search_engines_list->end(); + details.is_search_engine = base::Contains(*search_engines_list, i_str); auto scrape_rule_groups = ParseScrapeRules(scrape_url_dict); if (!scrape_rule_groups) { @@ -285,14 +289,22 @@ const PatternsURLDetails* PatternsGroup::GetMatchingURLPattern( return nullptr; } -std::unique_ptr ParsePatterns(const std::string& patterns_json) { - auto result = std::make_unique(); - auto patterns_value = base::JSONReader::Read(patterns_json); - if (!patterns_value || !patterns_value->is_dict()) { +std::unique_ptr ParsePatterns(std::string_view patterns_json) { + base::AssertLongCPUWorkAllowed(); + const auto patterns_parse_result = + base::JSONReader::ReadAndReturnValueWithError(patterns_json); + if (!patterns_parse_result.has_value()) { + VLOG(1) << "Failed to parse patterns JSON: " + << patterns_parse_result.error().ToString(); + return nullptr; + } + const auto& patterns_value = patterns_parse_result.value(); + if (!patterns_value.is_dict()) { VLOG(1) << "Patterns is not JSON or is not dict"; return nullptr; } - const auto& patterns_dict = patterns_value->GetDict(); + auto result = std::make_unique(); + const auto& patterns_dict = patterns_value.GetDict(); auto* normal_dict = patterns_dict.FindDict(kNormalPatternsKey); auto* strict_dict = patterns_dict.FindDict(kStrictPatternsKey); diff --git a/components/web_discovery/browser/patterns.h b/components/web_discovery/browser/patterns.h index b9e3b749e104..a87de2ceaa64 100644 --- a/components/web_discovery/browser/patterns.h +++ b/components/web_discovery/browser/patterns.h @@ -52,7 +52,7 @@ enum class PayloadResultType { // Contains functions for refining the scraped value. The inner vector // contains the function name and arguments for the function. -using RefineFunctionList = std::vector>; +using RefineFunctionList = std::vector; // Defines rule for scraping an attribute from a given selected element. struct ScrapeRule { @@ -167,7 +167,7 @@ struct PatternsGroup { }; // Returns nullptr if parsing fails. -std::unique_ptr ParsePatterns(const std::string& patterns_json); +std::unique_ptr ParsePatterns(std::string_view patterns_json); } // namespace web_discovery diff --git a/components/web_discovery/browser/server_config_loader.cc b/components/web_discovery/browser/server_config_loader.cc index 1f4d0c3a7fca..b9ecd4bbeac2 100644 --- a/components/web_discovery/browser/server_config_loader.cc +++ b/components/web_discovery/browser/server_config_loader.cc @@ -84,12 +84,11 @@ constexpr auto kAllowedReportLocations = KeyMap ParseKeys(const base::Value::Dict& encoded_keys) { KeyMap map; for (const auto [date, key_b64] : encoded_keys) { - std::vector decoded_data; - // Decode to check for valid base64 - if (!base::Base64Decode(key_b64.GetString())) { + auto decoded_data = base::Base64Decode(key_b64.GetString()); + if (!decoded_data) { continue; } - map[date] = key_b64.GetString(); + map[date] = *decoded_data; } return map; } @@ -145,6 +144,70 @@ std::optional ReadPatternsFile(base::FilePath patterns_path) { return contents; } +std::unique_ptr ProcessConfigResponses( + const std::string collector_response_body, + const std::string quorum_response_body) { + base::AssertLongCPUWorkAllowed(); + auto collector_parsed_json = base::JSONReader::ReadAndReturnValueWithError( + collector_response_body, base::JSON_PARSE_RFC); + auto quorum_parsed_json = base::JSONReader::ReadAndReturnValueWithError( + quorum_response_body, base::JSON_PARSE_RFC); + + if (!collector_parsed_json.has_value() || !quorum_parsed_json.has_value()) { + const auto& error = !collector_parsed_json.has_value() + ? collector_parsed_json.error() + : quorum_parsed_json.error(); + VLOG(1) << "Failed to parse server config json: " << error.ToString(); + return nullptr; + } + + const auto* collector_root = collector_parsed_json.value().GetIfDict(); + const auto* quorum_root = quorum_parsed_json.value().GetIfDict(); + if (!collector_root || !quorum_root) { + VLOG(1) << "Failed to parse server config: not a dict"; + return nullptr; + } + + const auto min_version = collector_root->FindInt(kMinVersionFieldName); + if (min_version && *min_version > kCurrentVersion) { + VLOG(1) << "Server minimum version is higher than current version, failing"; + return nullptr; + } + + auto config = std::make_unique(); + + const auto* group_pub_keys = collector_root->FindDict(kGroupPubKeysFieldName); + if (!group_pub_keys) { + VLOG(1) << "Failed to retrieve groupPubKeys from server config"; + return nullptr; + } + const auto* pub_keys = collector_root->FindDict(kPubKeysFieldName); + if (!pub_keys) { + VLOG(1) << "Failed to retrieve pubKeys from server config"; + return nullptr; + } + const auto* source_map = collector_root->FindDict(kSourceMapFieldName); + const auto* source_map_actions = + source_map ? source_map->FindDict(kSourceMapActionsFieldName) : nullptr; + if (!source_map_actions) { + VLOG(1) << "Failed to retrieve sourceMap from server config"; + return nullptr; + } + + const auto* location = quorum_root->FindString(kLocationFieldName); + if (location && kAllowedReportLocations.contains(*location)) { + config->location = *location; + } else { + config->location = kOmittedLocationValue; + } + + config->group_pub_keys = ParseKeys(*group_pub_keys); + config->pub_keys = ParseKeys(*pub_keys); + config->source_map_actions = ParseSourceMapActionConfigs(*source_map_actions); + + return config; +} + } // namespace SourceMapActionConfig::SourceMapActionConfig() = default; @@ -153,6 +216,14 @@ SourceMapActionConfig::~SourceMapActionConfig() = default; ServerConfig::ServerConfig() = default; ServerConfig::~ServerConfig() = default; +ServerConfigDownloadResult::ServerConfigDownloadResult( + bool is_collector_config, + std::optional response_body) + : is_collector_config(is_collector_config), response_body(response_body) {} +ServerConfigDownloadResult::~ServerConfigDownloadResult() = default; +ServerConfigDownloadResult::ServerConfigDownloadResult( + const ServerConfigDownloadResult&) = default; + ServerConfigLoader::ServerConfigLoader( PrefService* local_state, base::FilePath user_data_dir, @@ -160,7 +231,7 @@ ServerConfigLoader::ServerConfigLoader( base::RepeatingClosure config_callback, base::RepeatingClosure patterns_callback) : local_state_(local_state), - sequenced_task_runner_( + background_task_runner_( base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()})), shared_url_loader_factory_(shared_url_loader_factory), config_callback_(config_callback), @@ -211,27 +282,75 @@ void ServerConfigLoader::LoadConfigs() { quorum_config_url_loader_ = network::SimpleURLLoader::Create( std::move(quorum_resource_request), kNetworkTrafficAnnotation); - auto callback = base::BarrierCallback>( - 2, base::BindOnce(&ServerConfigLoader::OnConfigResponses, + auto callback = base::BarrierCallback( + 2, base::BindOnce(&ServerConfigLoader::OnConfigResponsesDownloaded, base::Unretained(this))); + auto make_download_result = [](bool is_collector_config, + std::optional response_body) { + return ServerConfigDownloadResult(is_collector_config, response_body); + }; + + auto collector_callback = + base::BindOnce(make_download_result, true).Then(callback); + auto quorum_callback = + base::BindOnce(make_download_result, false).Then(callback); + collector_config_url_loader_->DownloadToString( - shared_url_loader_factory_.get(), callback, kMaxResponseSize); + shared_url_loader_factory_.get(), std::move(collector_callback), + kMaxResponseSize); quorum_config_url_loader_->DownloadToString(shared_url_loader_factory_.get(), - callback, kMaxResponseSize); + std::move(quorum_callback), + kMaxResponseSize); } -void ServerConfigLoader::OnConfigResponses( - std::vector> response_bodies) { - CHECK_EQ(response_bodies.size(), 2u); - base::Time update_time = base::Time::Now(); - bool result = ProcessConfigResponses(response_bodies[0], response_bodies[1]); +void ServerConfigLoader::OnConfigResponsesDownloaded( + std::vector results) { + CHECK_EQ(results.size(), 2u); + const std::optional* collector_response_body = nullptr; + const std::optional* quorum_response_body = nullptr; + for (const auto& result : results) { + if (result.is_collector_config) { + collector_response_body = &result.response_body; + } else { + quorum_response_body = &result.response_body; + } + } + CHECK(collector_response_body && quorum_response_body); + + auto* collector_response_info = collector_config_url_loader_->ResponseInfo(); + auto* quorum_response_info = quorum_config_url_loader_->ResponseInfo(); + if (!*collector_response_body || !*quorum_response_body || + !collector_response_info || !quorum_response_info || + collector_response_info->headers->response_code() != 200 || + quorum_response_info->headers->response_code() != 200) { + VLOG(1) << "Failed to download one or more server configs"; + OnConfigResponsesProcessed(nullptr); + return; + } + + background_task_runner_->PostTaskAndReplyWithResult( + FROM_HERE, + base::BindOnce(&ProcessConfigResponses, **collector_response_body, + **quorum_response_body), + base::BindOnce(&ServerConfigLoader::OnConfigResponsesProcessed, + weak_ptr_factory_.GetWeakPtr())); +} + +void ServerConfigLoader::OnConfigResponsesProcessed( + std::unique_ptr config) { + bool result = config != nullptr; + if (result) { + last_loaded_server_config_ = std::move(config); + config_callback_.Run(); + } config_backoff_entry_.InformOfRequest(result); collector_config_url_loader_ = nullptr; quorum_config_url_loader_ = nullptr; + auto update_time = base::Time::Now(); if (!result) { update_time += config_backoff_entry_.GetTimeUntilRelease(); } else { @@ -245,80 +364,8 @@ void ServerConfigLoader::OnConfigResponses( base::BindOnce(&ServerConfigLoader::LoadConfigs, base::Unretained(this))); } -bool ServerConfigLoader::ProcessConfigResponses( - const std::optional& collector_response_body, - const std::optional& quorum_response_body) { - auto* collector_response_info = collector_config_url_loader_->ResponseInfo(); - auto* quorum_response_info = quorum_config_url_loader_->ResponseInfo(); - if (!collector_response_body || !collector_response_info || - !quorum_response_body || !collector_response_info || - collector_response_info->headers->response_code() != 200 || - quorum_response_info->headers->response_code() != 200) { - VLOG(1) << "Failed to fetch server config"; - return false; - } - - auto collector_parsed_json = base::JSONReader::ReadAndReturnValueWithError( - *collector_response_body, base::JSON_PARSE_RFC); - auto quorum_parsed_json = base::JSONReader::ReadAndReturnValueWithError( - *quorum_response_body, base::JSON_PARSE_RFC); - - if (!collector_parsed_json.has_value() || !quorum_parsed_json.has_value()) { - VLOG(1) << "Failed to parse server config json"; - return false; - } - - const auto* collector_root = collector_parsed_json.value().GetIfDict(); - const auto* quorum_root = quorum_parsed_json.value().GetIfDict(); - if (!collector_root || !quorum_root) { - VLOG(1) << "Failed to parse server config: not a dict"; - return false; - } - - const auto min_version = collector_root->FindInt(kMinVersionFieldName); - if (min_version && *min_version > kCurrentVersion) { - VLOG(1) << "Server minimum version is higher than current version, failing"; - return false; - } - - auto config = std::make_unique(); - - const auto* group_pub_keys = collector_root->FindDict(kGroupPubKeysFieldName); - if (!group_pub_keys) { - VLOG(1) << "Failed to retrieve groupPubKeys from server config"; - return false; - } - const auto* pub_keys = collector_root->FindDict(kPubKeysFieldName); - if (!pub_keys) { - VLOG(1) << "Failed to retrieve pubKeys from server config"; - return false; - } - const auto* source_map = collector_root->FindDict(kSourceMapFieldName); - const auto* source_map_actions = - source_map ? source_map->FindDict(kSourceMapActionsFieldName) : nullptr; - if (!source_map_actions) { - VLOG(1) << "Failed to retrieve sourceMap from server config"; - return false; - } - - const auto* location = quorum_root->FindString(kLocationFieldName); - if (location && kAllowedReportLocations.contains(*location)) { - config->location = *location; - } else { - config->location = kOmittedLocationValue; - } - - config->group_pub_keys = ParseKeys(*group_pub_keys); - config->pub_keys = ParseKeys(*pub_keys); - config->source_map_actions = ParseSourceMapActionConfigs(*source_map_actions); - - last_loaded_server_config_ = std::move(config); - config_callback_.Run(); - return true; -} - void ServerConfigLoader::LoadStoredPatterns() { - sequenced_task_runner_->PostTaskAndReplyWithResult( + background_task_runner_->PostTaskAndReplyWithResult( FROM_HERE, base::BindOnce(&ReadPatternsFile, patterns_path_), base::BindOnce(&ServerConfigLoader::OnPatternsFileLoaded, weak_ptr_factory_.GetWeakPtr())); @@ -332,14 +379,10 @@ void ServerConfigLoader::OnPatternsFileLoaded( SchedulePatternsRequest(); return; } - auto parsed_patterns = ParsePatterns(*patterns_json); - if (!parsed_patterns) { - local_state_->ClearPref(kPatternsRetrievalTime); - SchedulePatternsRequest(); - return; - } - last_loaded_patterns_ = std::move(parsed_patterns); - patterns_callback_.Run(); + background_task_runner_->PostTaskAndReplyWithResult( + FROM_HERE, base::BindOnce(&ParsePatterns, *patterns_json), + base::BindOnce(&ServerConfigLoader::OnStoredPatternsParsed, + weak_ptr_factory_.GetWeakPtr())); } void ServerConfigLoader::SchedulePatternsRequest() { @@ -403,7 +446,7 @@ void ServerConfigLoader::OnPatternsResponse( HandlePatternsStatus(false); return; } - sequenced_task_runner_->PostTaskAndReplyWithResult( + background_task_runner_->PostTaskAndReplyWithResult( FROM_HERE, base::BindOnce(&GunzipContents, *response_body), base::BindOnce(&ServerConfigLoader::OnPatternsGunzip, weak_ptr_factory_.GetWeakPtr())); @@ -416,14 +459,33 @@ void ServerConfigLoader::OnPatternsGunzip( HandlePatternsStatus(false); return; } - auto parsed_patterns = ParsePatterns(*patterns_json); + background_task_runner_->PostTaskAndReplyWithResult( + FROM_HERE, base::BindOnce(&ParsePatterns, *patterns_json), + base::BindOnce(&ServerConfigLoader::OnNewPatternsParsed, + weak_ptr_factory_.GetWeakPtr(), *patterns_json)); +} + +void ServerConfigLoader::OnStoredPatternsParsed( + std::unique_ptr parsed_patterns) { + if (!parsed_patterns) { + local_state_->ClearPref(kPatternsRetrievalTime); + SchedulePatternsRequest(); + return; + } + last_loaded_patterns_ = std::move(parsed_patterns); + patterns_callback_.Run(); +} + +void ServerConfigLoader::OnNewPatternsParsed( + std::string new_patterns_json, + std::unique_ptr parsed_patterns) { if (!parsed_patterns) { HandlePatternsStatus(false); return; } - sequenced_task_runner_->PostTaskAndReplyWithResult( + background_task_runner_->PostTaskAndReplyWithResult( FROM_HERE, - base::BindOnce(&WritePatternsFile, patterns_path_, *patterns_json), + base::BindOnce(&WritePatternsFile, patterns_path_, new_patterns_json), base::BindOnce(&ServerConfigLoader::OnPatternsWritten, weak_ptr_factory_.GetWeakPtr(), std::move(parsed_patterns))); diff --git a/components/web_discovery/browser/server_config_loader.h b/components/web_discovery/browser/server_config_loader.h index 10ad1f55fa4d..552735976493 100644 --- a/components/web_discovery/browser/server_config_loader.h +++ b/components/web_discovery/browser/server_config_loader.h @@ -28,7 +28,7 @@ class SimpleURLLoader; namespace web_discovery { -using KeyMap = base::flat_map; +using KeyMap = base::flat_map>; struct SourceMapActionConfig { SourceMapActionConfig(); @@ -58,6 +58,17 @@ struct ServerConfig { std::string location; }; +struct ServerConfigDownloadResult { + ServerConfigDownloadResult(bool is_collector_config, + std::optional response_body); + ~ServerConfigDownloadResult(); + + ServerConfigDownloadResult(const ServerConfigDownloadResult&); + + bool is_collector_config; + std::optional response_body; +}; + // Handles retrieval, updating and caching of the following server // configurations: // - HPN server config: contains public keys, and "source maps" used @@ -93,11 +104,9 @@ class ServerConfigLoader { void SetLastPatternsForTesting(std::unique_ptr patterns); private: - void OnConfigResponses( - std::vector> response_bodies); - bool ProcessConfigResponses( - const std::optional& collector_response_body, - const std::optional& quorum_response_body); + void OnConfigResponsesDownloaded( + std::vector results); + void OnConfigResponsesProcessed(std::unique_ptr config); void LoadStoredPatterns(); void OnPatternsFileLoaded(std::optional patterns_json); @@ -105,13 +114,16 @@ class ServerConfigLoader { void RequestPatterns(); void OnPatternsResponse(std::optional response_body); void OnPatternsGunzip(std::optional patterns_json); + void OnStoredPatternsParsed(std::unique_ptr parsed_patterns); + void OnNewPatternsParsed(std::string new_patterns_json, + std::unique_ptr parsed_patterns); void OnPatternsWritten(std::unique_ptr parsed_group, bool result); void HandlePatternsStatus(bool result); - raw_ptr local_state_; + const raw_ptr local_state_; - scoped_refptr sequenced_task_runner_; + scoped_refptr background_task_runner_; GURL collector_config_url_; GURL quorum_config_url_; diff --git a/components/web_discovery/browser/server_config_loader_unittest.cc b/components/web_discovery/browser/server_config_loader_unittest.cc index fe69cbf7520e..2c7d1689bf6d 100644 --- a/components/web_discovery/browser/server_config_loader_unittest.cc +++ b/components/web_discovery/browser/server_config_loader_unittest.cc @@ -78,22 +78,6 @@ class WebDiscoveryServerConfigLoaderTest : public testing::Test { install_dir_.GetPath().AppendASCII("wdp_patterns.json")); } - base::test::TaskEnvironment task_environment_; - std::unique_ptr server_config_loader_; - base::ScopedTempDir install_dir_; - - size_t hpn_config_requests_made_ = 0; - size_t quorum_config_requests_made_ = 0; - size_t patterns_requests_made_ = 0; - - size_t config_ready_calls_made_ = 0; - size_t patterns_ready_calls_made_ = 0; - - net::HttpStatusCode hpn_config_status_code_ = net::HTTP_OK; - net::HttpStatusCode quorum_config_status_code_ = net::HTTP_OK; - net::HttpStatusCode patterns_status_code_ = net::HTTP_OK; - - private: void HandleRequest(const network::ResourceRequest& request) { url_loader_factory_.ClearResponses(); @@ -120,13 +104,30 @@ class WebDiscoveryServerConfigLoaderTest : public testing::Test { void HandlePatternsReady() { patterns_ready_calls_made_++; } + base::test::TaskEnvironment task_environment_; + TestingPrefServiceSimple local_state_; + + network::TestURLLoaderFactory url_loader_factory_; + scoped_refptr shared_url_loader_factory_; + + base::ScopedTempDir install_dir_; + + size_t hpn_config_requests_made_ = 0; + size_t quorum_config_requests_made_ = 0; + size_t patterns_requests_made_ = 0; + + size_t config_ready_calls_made_ = 0; + size_t patterns_ready_calls_made_ = 0; + + net::HttpStatusCode hpn_config_status_code_ = net::HTTP_OK; + net::HttpStatusCode quorum_config_status_code_ = net::HTTP_OK; + net::HttpStatusCode patterns_status_code_ = net::HTTP_OK; + std::string patterns_gz_contents_; std::string quorum_config_contents_; std::string hpn_config_contents_; - TestingPrefServiceSimple local_state_; - network::TestURLLoaderFactory url_loader_factory_; - scoped_refptr shared_url_loader_factory_; + std::unique_ptr server_config_loader_; }; TEST_F(WebDiscoveryServerConfigLoaderTest, LoadConfigs) { diff --git a/components/web_discovery/browser/util.cc b/components/web_discovery/browser/util.cc index acb38bdf63e2..bd124961518e 100644 --- a/components/web_discovery/browser/util.cc +++ b/components/web_discovery/browser/util.cc @@ -5,6 +5,8 @@ #include "brave/components/web_discovery/browser/util.h" +#include + #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" #include "brave/brave_domains/service_domains.h" @@ -46,7 +48,7 @@ GURL GetPatternsEndpoint() { std::unique_ptr CreateResourceRequest(GURL url) { auto resource_request = std::make_unique(); - resource_request->url = url; + resource_request->url = std::move(url); resource_request->credentials_mode = network::mojom::CredentialsMode::kOmit; return resource_request; } diff --git a/components/web_discovery/browser/web_discovery_service.cc b/components/web_discovery/browser/web_discovery_service.cc index 6acb7c733565..848d89ec289b 100644 --- a/components/web_discovery/browser/web_discovery_service.cc +++ b/components/web_discovery/browser/web_discovery_service.cc @@ -18,7 +18,6 @@ #include "components/prefs/scoped_user_pref_update.h" #include "extensions/buildflags/buildflags.h" #include "services/network/public/cpp/shared_url_loader_factory.h" -#include "services/service_manager/public/cpp/interface_provider.h" namespace web_discovery { @@ -66,13 +65,18 @@ void WebDiscoveryService::RegisterProfilePrefs(PrefRegistrySimple* registry) { void WebDiscoveryService::SetExtensionPrefIfNativeDisabled( PrefService* profile_prefs) { #if BUILDFLAG(ENABLE_EXTENSIONS) - if (!base::FeatureList::IsEnabled(features::kWebDiscoveryNative) && + if (!base::FeatureList::IsEnabled(features::kBraveWebDiscoveryNative) && profile_prefs->GetBoolean(kWebDiscoveryNativeEnabled)) { profile_prefs->SetBoolean(kWebDiscoveryExtensionEnabled, true); } #endif } +void WebDiscoveryService::Shutdown() { + Stop(); + pref_change_registrar_.RemoveAll(); +} + void WebDiscoveryService::Start() { if (!server_config_loader_) { server_config_loader_ = std::make_unique( @@ -93,8 +97,9 @@ void WebDiscoveryService::Start() { void WebDiscoveryService::Stop() { server_config_loader_ = nullptr; credential_manager_ = nullptr; +} - profile_prefs_->ClearPref(kWebDiscoveryNativeEnabled); +void WebDiscoveryService::ClearPrefs() { profile_prefs_->ClearPref(kAnonymousCredentialsDict); profile_prefs_->ClearPref(kCredentialRSAPrivateKey); profile_prefs_->ClearPref(kCredentialRSAPublicKey); @@ -105,6 +110,7 @@ void WebDiscoveryService::OnEnabledChange() { Start(); } else { Stop(); + ClearPrefs(); } } diff --git a/components/web_discovery/browser/web_discovery_service.h b/components/web_discovery/browser/web_discovery_service.h index 52cc9b74193c..c30f03917ed0 100644 --- a/components/web_discovery/browser/web_discovery_service.h +++ b/components/web_discovery/browser/web_discovery_service.h @@ -52,9 +52,13 @@ class WebDiscoveryService : public KeyedService { // Relevant for a Griffin/variations rollback. static void SetExtensionPrefIfNativeDisabled(PrefService* profile_prefs); + // KeyedService: + void Shutdown() override; + private: void Start(); void Stop(); + void ClearPrefs(); void OnEnabledChange(); diff --git a/components/web_discovery/common/buildflags/BUILD.gn b/components/web_discovery/buildflags/BUILD.gn similarity index 85% rename from components/web_discovery/common/buildflags/BUILD.gn rename to components/web_discovery/buildflags/BUILD.gn index 5d0fc823c167..2c23f6ce7b16 100644 --- a/components/web_discovery/common/buildflags/BUILD.gn +++ b/components/web_discovery/buildflags/BUILD.gn @@ -3,7 +3,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at https://mozilla.org/MPL/2.0/. -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//build/buildflag_header.gni") buildflag_header("buildflags") { diff --git a/components/web_discovery/common/buildflags/buildflags.gni b/components/web_discovery/buildflags/buildflags.gni similarity index 100% rename from components/web_discovery/common/buildflags/buildflags.gni rename to components/web_discovery/buildflags/buildflags.gni diff --git a/components/web_discovery/common/BUILD.gn b/components/web_discovery/common/BUILD.gn index a50ead3e6ff5..6ce00a413af9 100644 --- a/components/web_discovery/common/BUILD.gn +++ b/components/web_discovery/common/BUILD.gn @@ -3,7 +3,9 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at https://mozilla.org/MPL/2.0/. -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") + +assert(enable_web_discovery_native) static_library("common") { sources = [ @@ -13,6 +15,6 @@ static_library("common") { deps = [ "//base", - "//brave/components/web_discovery/common/buildflags", + "//brave/components/web_discovery/buildflags", ] } diff --git a/components/web_discovery/common/features.cc b/components/web_discovery/common/features.cc index 6eec658cfc8b..b22c1cf7a31c 100644 --- a/components/web_discovery/common/features.cc +++ b/components/web_discovery/common/features.cc @@ -7,7 +7,7 @@ namespace web_discovery::features { -BASE_FEATURE(kWebDiscoveryNative, +BASE_FEATURE(kBraveWebDiscoveryNative, "BraveWebDiscoveryNative", base::FEATURE_DISABLED_BY_DEFAULT); diff --git a/components/web_discovery/common/features.h b/components/web_discovery/common/features.h index 45db0afd2555..98349380cf5e 100644 --- a/components/web_discovery/common/features.h +++ b/components/web_discovery/common/features.h @@ -12,7 +12,7 @@ namespace web_discovery::features { // Enables the native re-implementation of the Web Discovery Project. // If enabled, the Web Discovery component of the extension should be disabled. -BASE_DECLARE_FEATURE(kWebDiscoveryNative); +BASE_DECLARE_FEATURE(kBraveWebDiscoveryNative); } // namespace web_discovery::features diff --git a/test/BUILD.gn b/test/BUILD.gn index 5c3f1fad0810..842b873e57e9 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -17,7 +17,7 @@ import("//brave/components/playlist/common/buildflags/buildflags.gni") import("//brave/components/request_otr/common/buildflags/buildflags.gni") import("//brave/components/speedreader/common/buildflags/buildflags.gni") import("//brave/components/tor/buildflags/buildflags.gni") -import("//brave/components/web_discovery/common/buildflags/buildflags.gni") +import("//brave/components/web_discovery/buildflags/buildflags.gni") import("//brave/test/testing.gni") import("//brave/updater/config.gni") import("//chrome/common/features.gni")