Skip to content

Commit

Permalink
Address Web Discovery feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
DJAndries committed Oct 25, 2024
1 parent a7d9315 commit 4aaff0b
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 67 deletions.
20 changes: 8 additions & 12 deletions browser/web_discovery/web_discovery_service_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,6 @@

namespace web_discovery {

namespace {

ProfileSelections GetProfileSelections() {
if (!base::FeatureList::IsEnabled(features::kBraveWebDiscoveryNative)) {
return ProfileSelections::BuildNoProfilesSelected();
}
return ProfileSelections::BuildForRegularProfile();
}

} // namespace

WebDiscoveryService* WebDiscoveryServiceFactory::GetForBrowserContext(
content::BrowserContext* context) {
return static_cast<WebDiscoveryService*>(
Expand All @@ -41,10 +30,17 @@ WebDiscoveryServiceFactory* WebDiscoveryServiceFactory::GetInstance() {

WebDiscoveryServiceFactory::WebDiscoveryServiceFactory()
: ProfileKeyedServiceFactory("WebDiscoveryService",
GetProfileSelections()) {}
CreateProfileSelections()) {}

WebDiscoveryServiceFactory::~WebDiscoveryServiceFactory() = default;

ProfileSelections WebDiscoveryServiceFactory::CreateProfileSelections() {
if (!base::FeatureList::IsEnabled(features::kBraveWebDiscoveryNative)) {
return ProfileSelections::BuildNoProfilesSelected();
}
return ProfileSelections::BuildForRegularProfile();
}

KeyedService* WebDiscoveryServiceFactory::BuildServiceInstanceFor(
content::BrowserContext* context) const {
auto* default_storage_partition = context->GetDefaultStoragePartition();
Expand Down
2 changes: 2 additions & 0 deletions browser/web_discovery/web_discovery_service_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class WebDiscoveryServiceFactory : public ProfileKeyedServiceFactory {
WebDiscoveryServiceFactory& operator=(const WebDiscoveryServiceFactory&) =
delete;

static ProfileSelections CreateProfileSelections();

KeyedService* BuildServiceInstanceFor(
content::BrowserContext* context) const override;
bool ServiceIsCreatedWithBrowserContext() const override;
Expand Down
2 changes: 1 addition & 1 deletion components/web_discovery/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ component("browser") {
"web_discovery_service.h",
]
deps = [
"anonymous_credentials/rs/cxx:rust_lib",
"anonymous_credentials/rust:rust_lib",
"//base",
"//brave/brave_domains",
"//brave/components/constants",
Expand Down
15 changes: 7 additions & 8 deletions components/web_discovery/browser/background_credential_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ BackgroundCredentialHelper::GenerateJoinRequest(std::string pre_challenge) {

std::optional<std::string> BackgroundCredentialHelper::FinishJoin(
std::string date,
std::vector<const uint8_t> group_pub_key,
std::vector<const uint8_t> gsk,
std::vector<const uint8_t> join_resp_bytes) {
std::vector<uint8_t> group_pub_key,
std::vector<uint8_t> gsk,
std::vector<uint8_t> join_resp_bytes) {
base::AssertLongCPUWorkAllowed();
auto pub_key_result = anonymous_credentials::load_group_public_key(
base::SpanToRustSlice(group_pub_key));
Expand Down Expand Up @@ -93,10 +93,9 @@ std::optional<std::string> BackgroundCredentialHelper::FinishJoin(
return base::Base64Encode(finish_res.data);
}

std::optional<std::vector<const uint8_t>>
BackgroundCredentialHelper::PerformSign(
std::vector<const uint8_t> msg,
std::vector<const uint8_t> basename,
std::optional<std::vector<uint8_t>> BackgroundCredentialHelper::PerformSign(
std::vector<uint8_t> msg,
std::vector<uint8_t> basename,
std::optional<std::vector<uint8_t>> gsk_bytes,
std::optional<std::vector<uint8_t>> credential_bytes) {
base::AssertLongCPUWorkAllowed();
Expand All @@ -122,7 +121,7 @@ BackgroundCredentialHelper::PerformSign(
VLOG(1) << "Failed to sign: " << sig_res.error_message.c_str();
return std::nullopt;
}
return std::vector<const uint8_t>(sig_res.data.begin(), sig_res.data.end());
return std::vector<uint8_t>(sig_res.data.begin(), sig_res.data.end());
}

} // namespace web_discovery
17 changes: 8 additions & 9 deletions components/web_discovery/browser/background_credential_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <string>
#include <vector>

#include "brave/components/web_discovery/browser/anonymous_credentials/rs/cxx/src/lib.rs.h"
#include "brave/components/web_discovery/browser/anonymous_credentials/rust/src/lib.rs.h"
#include "brave/components/web_discovery/browser/rsa.h"
#include "crypto/rsa_private_key.h"

Expand All @@ -37,14 +37,13 @@ class BackgroundCredentialHelper {
void SetRSAKey(std::unique_ptr<crypto::RSAPrivateKey> rsa_private_key);
std::optional<GenerateJoinRequestResult> GenerateJoinRequest(
std::string pre_challenge);
std::optional<std::string> FinishJoin(
std::string date,
std::vector<const uint8_t> group_pub_key,
std::vector<const uint8_t> gsk,
std::vector<const uint8_t> join_resp_bytes);
std::optional<std::vector<const uint8_t>> PerformSign(
std::vector<const uint8_t> msg,
std::vector<const uint8_t> basename,
std::optional<std::string> FinishJoin(std::string date,
std::vector<uint8_t> group_pub_key,
std::vector<uint8_t> gsk,
std::vector<uint8_t> join_resp_bytes);
std::optional<std::vector<uint8_t>> PerformSign(
std::vector<uint8_t> msg,
std::vector<uint8_t> basename,
std::optional<std::vector<uint8_t>> gsk_bytes,
std::optional<std::vector<uint8_t>> credential_bytes);

Expand Down
33 changes: 14 additions & 19 deletions components/web_discovery/browser/credential_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ void CredentialManager::JoinGroups() {
void CredentialManager::StartJoinGroup(
const std::string& date,
const std::vector<uint8_t>& group_pub_key) {
std::vector<const uint8_t> group_pub_key_const(group_pub_key.begin(),
group_pub_key.end());

auto challenge_elements = base::Value::List::with_capacity(2);
challenge_elements.Append(*rsa_public_key_b64_);
challenge_elements.Append(base::Base64Encode(group_pub_key));
Expand All @@ -158,12 +155,12 @@ void CredentialManager::StartJoinGroup(
.WithArgs(pre_challenge)
.Then(base::BindOnce(&CredentialManager::OnJoinRequestReady,
weak_ptr_factory_.GetWeakPtr(), date,
group_pub_key_const));
group_pub_key));
}

void CredentialManager::OnJoinRequestReady(
std::string date,
std::vector<const uint8_t> group_pub_key,
std::vector<uint8_t> group_pub_key,
std::optional<GenerateJoinRequestResult> generate_join_result) {
if (!generate_join_result) {
return;
Expand All @@ -183,9 +180,9 @@ void CredentialManager::OnJoinRequestReady(
return;
}

auto gsk = std::vector<const uint8_t>(
generate_join_result->start_join_result.gsk.begin(),
generate_join_result->start_join_result.gsk.end());
auto gsk =
std::vector<uint8_t>(generate_join_result->start_join_result.gsk.begin(),
generate_join_result->start_join_result.gsk.end());

auto resource_request = CreateResourceRequest(join_url_);
resource_request->headers.SetHeader(kVersionHeader,
Expand All @@ -207,8 +204,8 @@ void CredentialManager::OnJoinRequestReady(

void CredentialManager::OnJoinResponse(
std::string date,
std::vector<const uint8_t> group_pub_key,
std::vector<const uint8_t> gsk,
std::vector<uint8_t> group_pub_key,
std::vector<uint8_t> gsk,
std::optional<std::string> response_body) {
bool result = ProcessJoinResponse(date, group_pub_key, gsk, response_body);
if (!result) {
Expand All @@ -235,8 +232,8 @@ void CredentialManager::HandleJoinResponseStatus(const std::string& date,

bool CredentialManager::ProcessJoinResponse(
const std::string& date,
const std::vector<const uint8_t>& group_pub_key,
const std::vector<const uint8_t>& gsk,
const std::vector<uint8_t>& group_pub_key,
const std::vector<uint8_t>& gsk,
const std::optional<std::string>& response_body) {
CHECK(join_url_loaders_[date]);
auto& url_loader = join_url_loaders_[date];
Expand Down Expand Up @@ -272,20 +269,18 @@ bool CredentialManager::ProcessJoinResponse(
VLOG(1) << "Failed to decode join response base64";
return false;
}
std::vector<const uint8_t> join_resp_bytes_const(join_resp_bytes->begin(),
join_resp_bytes->end());

background_credential_helper_
.AsyncCall(&BackgroundCredentialHelper::FinishJoin)
.WithArgs(date, group_pub_key, gsk, join_resp_bytes_const)
.WithArgs(date, group_pub_key, gsk, *join_resp_bytes)
.Then(base::BindOnce(&CredentialManager::OnCredentialsReady,
weak_ptr_factory_.GetWeakPtr(), date, gsk));
return true;
}

void CredentialManager::OnCredentialsReady(
std::string date,
std::vector<const uint8_t> gsk,
std::vector<uint8_t> gsk,
std::optional<std::string> credentials) {
if (!credentials) {
HandleJoinResponseStatus(date, false);
Expand All @@ -303,8 +298,8 @@ bool CredentialManager::CredentialExistsForToday() {
.contains(FormatServerDate(base::Time::Now()));
}

void CredentialManager::Sign(std::vector<const uint8_t> msg,
std::vector<const uint8_t> basename,
void CredentialManager::Sign(std::vector<uint8_t> msg,
std::vector<uint8_t> basename,
SignCallback callback) {
auto today_date = FormatServerDate(base::Time::Now().UTCMidnight());
const auto& anon_creds_dict =
Expand Down Expand Up @@ -345,7 +340,7 @@ void CredentialManager::Sign(std::vector<const uint8_t> msg,
void CredentialManager::OnSignResult(
std::string credential_date,
SignCallback callback,
std::optional<std::vector<const uint8_t>> signed_message) {
std::optional<std::vector<uint8_t>> signed_message) {
loaded_credential_date_ = credential_date;
std::move(callback).Run(signed_message);
}
Expand Down
18 changes: 9 additions & 9 deletions components/web_discovery/browser/credential_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class CredentialManager : public CredentialSigner {
// CredentialSigner:
bool CredentialExistsForToday() override;

void Sign(std::vector<const uint8_t> msg,
std::vector<const uint8_t> basename,
void Sign(std::vector<uint8_t> msg,
std::vector<uint8_t> basename,
SignCallback callback) override;

// Uses a fixed seed in the anonymous credential manager
Expand All @@ -75,25 +75,25 @@ class CredentialManager : public CredentialSigner {

void OnJoinRequestReady(
std::string date,
std::vector<const uint8_t> group_pub_key,
std::vector<uint8_t> group_pub_key,
std::optional<GenerateJoinRequestResult> generate_join_result);

void OnJoinResponse(std::string date,
std::vector<const uint8_t> group_pub_key,
std::vector<const uint8_t> gsk,
std::vector<uint8_t> group_pub_key,
std::vector<uint8_t> gsk,
std::optional<std::string> response_body);
void HandleJoinResponseStatus(const std::string& date, bool result);
bool ProcessJoinResponse(const std::string& date,
const std::vector<const uint8_t>& group_pub_key,
const std::vector<const uint8_t>& gsk,
const std::vector<uint8_t>& group_pub_key,
const std::vector<uint8_t>& gsk,
const std::optional<std::string>& response_body);
void OnCredentialsReady(std::string date,
std::vector<const uint8_t> gsk,
std::vector<uint8_t> gsk,
std::optional<std::string> credentials);

void OnSignResult(std::string credential_date,
SignCallback callback,
std::optional<std::vector<const uint8_t>> signed_message);
std::optional<std::vector<uint8_t>> signed_message);

const raw_ptr<PrefService> profile_prefs_;
const raw_ptr<network::SharedURLLoaderFactory> shared_url_loader_factory_;
Expand Down
12 changes: 6 additions & 6 deletions components/web_discovery/browser/credential_manager_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,25 @@ TEST_F(WebDiscoveryCredentialManagerTest, LoadKeysFromStorage) {
}

TEST_F(WebDiscoveryCredentialManagerTest, Sign) {
std::vector<const uint8_t> message({0, 1, 2, 3, 4});
std::vector<const uint8_t> basename({5, 6, 7, 8, 9});
std::vector<uint8_t> message({0, 1, 2, 3, 4});
std::vector<uint8_t> basename({5, 6, 7, 8, 9});
credential_manager_->Sign(
message, basename,
base::BindLambdaForTesting(
[&](const std::optional<std::vector<const uint8_t>> signature) {
[&](const std::optional<std::vector<uint8_t>> signature) {
EXPECT_FALSE(signature);
}));
task_environment_.RunUntilIdle();

credential_manager_->JoinGroups();
task_environment_.RunUntilIdle();

base::flat_set<std::vector<const uint8_t>> signatures;
base::flat_set<std::vector<uint8_t>> signatures;
for (size_t i = 0; i < 3; i++) {
credential_manager_->Sign(
message, basename,
base::BindLambdaForTesting(
[&](const std::optional<std::vector<const uint8_t>> signature) {
[&](const std::optional<std::vector<uint8_t>> signature) {
ASSERT_TRUE(signature);
EXPECT_FALSE(signature->empty());
EXPECT_FALSE(signatures.contains(*signature));
Expand All @@ -191,7 +191,7 @@ TEST_F(WebDiscoveryCredentialManagerTest, Sign) {
credential_manager_->Sign(
message, basename,
base::BindLambdaForTesting(
[&](const std::optional<std::vector<const uint8_t>> signature) {
[&](const std::optional<std::vector<uint8_t>> signature) {
EXPECT_FALSE(signature);
}));
task_environment_.RunUntilIdle();
Expand Down
6 changes: 3 additions & 3 deletions components/web_discovery/browser/credential_signer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace web_discovery {
class CredentialSigner {
public:
using SignCallback =
base::OnceCallback<void(std::optional<std::vector<const uint8_t>>)>;
base::OnceCallback<void(std::optional<std::vector<uint8_t>>)>;
virtual ~CredentialSigner() = default;

// Returns true is a credential is available for the current date.
Expand All @@ -29,8 +29,8 @@ class CredentialSigner {
// preventing Sybil attacks.
// See signature_basename.h/cc for more information on how the basename
// should be generated.
virtual void Sign(std::vector<const uint8_t> msg,
std::vector<const uint8_t> basename,
virtual void Sign(std::vector<uint8_t> msg,
std::vector<uint8_t> basename,
SignCallback callback) = 0;
};

Expand Down

0 comments on commit 4aaff0b

Please sign in to comment.