Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Uplift 1.62] [Brave News]: Add hierarchical sampling for news cards feed composition (#21111) #21374

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 230 additions & 35 deletions components/brave_news/browser/feed_v2_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ struct ArticleWeight {
using ArticleInfo = std::tuple<mojom::FeedItemMetadataPtr, ArticleWeight>;
using ArticleInfos = std::vector<ArticleInfo>;

/* publisher_or_channel_id, is_channel */
using ContentGroup = std::pair<std::string, bool>;
constexpr char kAllContentGroup[] = "all";
constexpr float kSampleContentGroupAllRatio = 0.2f;

std::string GetFeedHash(const Channels& channels,
const Publishers& publishers,
const ETags& etags) {
Expand Down Expand Up @@ -140,9 +145,8 @@ double GetPopRecency(const mojom::FeedItemMetadataPtr& article) {

auto& publish_time = article->publish_time;

double popularity = article->pop_score == 0
? features::kBraveNewsPopScoreFallback.Get()
: article->pop_score;
double popularity = std::min(article->pop_score, 100.0) / 100.0 +
features::kBraveNewsPopScoreMin.Get();
double multiplier = publish_time > base::Time::Now() - base::Hours(5) ? 2 : 1;
auto dt = base::Time::Now() - publish_time;

Expand All @@ -168,17 +172,19 @@ ArticleWeight GetArticleWeight(const mojom::FeedItemMetadataPtr& article,
const double source_visits_projected =
source_visits_min + signals.at(0)->visit_weight * (1 - source_visits_min);
const auto pop_recency = GetPopRecency(article);

return {
.pop_recency = pop_recency,
.weighting = source_visits_projected * subscribed_weight * pop_recency,
.weighting = source_visits_projected + subscribed_weight + pop_recency,
// Note: GetArticleWeight returns the Signal for the Publisher first, and
// we use that to determine whether this Publisher has ever been visited.
.visited = signals.at(0)->visit_weight != 0,
.subscribed = subscribed_weight != 0,
};
}

std::string PickRandom(const std::vector<std::string>& items) {
template <typename T>
T PickRandom(const std::vector<T>& items) {
CHECK(!items.empty());
// Note: RandInt is inclusive, hence the minus 1
return items[base::RandInt(0, items.size() - 1)];
Expand All @@ -190,6 +196,7 @@ ArticleInfos GetArticleInfos(const std::string& locale,
const Signals& signals) {
ArticleInfos articles;
base::flat_set<GURL> seen_articles;

for (const auto& item : feed_items) {
if (item.is_null()) {
continue;
Expand Down Expand Up @@ -219,12 +226,29 @@ ArticleInfos GetArticleInfos(const std::string& locale,
ArticleInfo pair =
std::tuple(article->data->Clone(),
GetArticleWeight(article->data, article_signals));

articles.push_back(std::move(pair));
}
}

return articles;
}

std::vector<std::string> GetChannelsForPublisher(
const std::string& locale,
const mojom::PublisherPtr& publisher) {
std::vector<std::string> result;
for (const auto& locale_info : publisher->locales) {
if (locale_info->locale != locale) {
continue;
}
for (const auto& channel : locale_info->channels) {
result.push_back(channel);
}
}
return result;
}

// Randomly true/false with equal probability.
bool TossCoin() {
return base::RandDouble() < 0.5;
Expand Down Expand Up @@ -261,32 +285,76 @@ int GetNormal(int min, int max) {
return min + floor((max - min) * GetNormal());
}

using GetWeighting = double(const mojom::FeedItemMetadataPtr& article,
const ArticleWeight& weight);
using GetWeighting =
base::RepeatingCallback<double(const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight)>;

// Returns a probability distribution (sum to 1) of the weights. Temperature
// controls how "smooth" the distribution is. High temperature brings the
// distribution closer to a uniform distribution (more randomness).
// Low temperature brings the distribution closer to a delta function (less
// randomness).
void SoftmaxWithTemperature(
std::vector<double>& weights,
double temperature = features::kBraveNewsTemperature.Get()) {
if (temperature == 0) {
return;
}

double max = *base::ranges::max_element(weights.begin(), weights.end());
base::ranges::transform(weights.begin(), weights.end(), weights.begin(),
[temperature, max](double weight) {
return std::exp((weight - max) / temperature);
});
double sum = std::accumulate(weights.begin(), weights.end(), 0.0);
base::ranges::transform(weights.begin(), weights.end(), weights.begin(),
[sum](double weight) { return weight / sum; });
}

// Sample across subscribed channels (direct and native) and publishers.
ContentGroup SampleContentGroup(
const std::vector<ContentGroup>& eligible_content_groups) {
ContentGroup sampled_content_group;
if (eligible_content_groups.empty()) {
return sampled_content_group;
}

if (base::RandDouble() < kSampleContentGroupAllRatio) {
return std::make_pair(kAllContentGroup, true);
}
return PickRandom<ContentGroup>(eligible_content_groups);
}

// Picks an article with a probability article_weight/sum(article_weights).
mojom::FeedItemMetadataPtr PickRouletteAndRemove(
ArticleInfos& articles,
GetWeighting get_weighting = [](const auto& article, const auto& weight) {
return weight.weighting;
}) {
double total_weight = 0;
for (const auto& [article, weight] : articles) {
total_weight += get_weighting(article, weight);
}
GetWeighting get_weighting = base::BindRepeating(
[](const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) { return weight.weighting; }),
bool use_softmax = false) {
std::vector<double> weights;
base::ranges::transform(articles, std::back_inserter(weights),
[&get_weighting](const auto& article_info) {
return get_weighting.Run(std::get<0>(article_info),
std::get<1>(article_info));
});

// None of the items are eligible to be picked.
if (total_weight == 0) {
if (std::accumulate(weights.begin(), weights.end(), 0.0) == 0) {
return nullptr;
}

if (use_softmax) {
SoftmaxWithTemperature(weights);
}

double total_weight = std::accumulate(weights.begin(), weights.end(), 0.0);
double picked_value = base::RandDouble() * total_weight;
double current_weight = 0;

uint64_t i;
for (i = 0; i < articles.size(); ++i) {
auto& [article, weight] = articles[i];
current_weight += get_weighting(article, weight);
for (i = 0; i < weights.size(); ++i) {
current_weight += weights[i];
if (current_weight > picked_value) {
break;
}
Expand All @@ -304,13 +372,15 @@ mojom::FeedItemMetadataPtr PickRouletteAndRemove(
// 2. **AND** The user hasn't visited.
mojom::FeedItemMetadataPtr PickDiscoveryArticleAndRemove(
ArticleInfos& articles) {
return PickRouletteAndRemove(articles,
[](const auto& article, const auto& weight) {
if (weight.subscribed || weight.visited) {
return 0.;
}
return weight.pop_recency;
});
return PickRouletteAndRemove(
articles,
base::BindRepeating([](const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) {
if (weight.subscribed) {
return 0.0;
}
return weight.pop_recency;
}));
}

// Generates a standard block:
Expand All @@ -331,12 +401,14 @@ std::vector<mojom::FeedItemV2Ptr> GenerateBlock(
}

auto hero_article = PickRouletteAndRemove(
articles, [](const auto& article, const auto& weight) {
auto image_url = article->image->is_padded_image_url()
? article->image->get_padded_image_url()
: article->image->get_image_url();
articles,
base::BindRepeating([](const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) {
auto image_url = metadata->image->is_padded_image_url()
? metadata->image->get_padded_image_url()
: metadata->image->get_image_url();
return image_url.is_valid() ? weight.weighting : 0;
});
}));

// We might not be able to generate a hero card, if none of the articles in
// this feed have an image.
Expand All @@ -350,9 +422,115 @@ std::vector<mojom::FeedItemV2Ptr> GenerateBlock(
auto follow_count = GetNormal(block_min_inline, block_max_inline + 1);
for (auto i = 0; i < follow_count; ++i) {
bool is_discover = base::RandDouble() < inline_discovery_ratio;
auto generated = is_discover ? PickDiscoveryArticleAndRemove(articles)
: PickRouletteAndRemove(articles);
mojom::FeedItemMetadataPtr generated;

if (is_discover) {
generated = PickDiscoveryArticleAndRemove(articles);
} else {
generated = PickRouletteAndRemove(articles);
}

if (!generated) {
DVLOG(1) << "Failed to generate article (is_discover=" << is_discover
<< ")";
continue;
}
result.push_back(mojom::FeedItemV2::NewArticle(
mojom::Article::New(std::move(generated), is_discover)));
}

return result;
}

// Generates a block from sampled content groups:
// 1. Hero Article
// 2. 1 - 5 Inline Articles (a percentage of which might be discover cards).
std::vector<mojom::FeedItemV2Ptr> GenerateBlockFromContentGroups(
ArticleInfos& articles,
const std::string& locale,
const Publishers& publishers,
const std::vector<ContentGroup>& eligible_content_groups,
// Ratio of inline articles to discovery articles.
// discover ratio % of the time, we should do a discover card here instead
// of a roulette card.
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.4rkb0vecgekl
double inline_discovery_ratio =
features::kBraveNewsInlineDiscoveryRatio.Get()) {
DVLOG(1) << __FUNCTION__;
std::vector<mojom::FeedItemV2Ptr> result;
if (articles.empty() || eligible_content_groups.empty()) {
return result;
}

base::flat_map<std::string, std::vector<std::string>>
publisher_id_to_channels;
for (const auto& [publisher_id, publisher] : publishers) {
publisher_id_to_channels[publisher_id] =
GetChannelsForPublisher(locale, publisher);
}

// Generates a GetWeighting function tied to a specific content group. Each
// invocation of |get_weighting| will generate a new |GetWeighting| tied to a
// (freshly sampled) content_group.
auto get_weighting = [&eligible_content_groups, &publisher_id_to_channels,
&locale](bool is_hero = false) {
return base::BindRepeating(
[](const bool is_hero, const ContentGroup& content_group,
const base::flat_map<std::string, std::vector<std::string>>&
publisher_id_to_channels,
const std::string& locale,
const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) {
if (is_hero) {
auto image_url = metadata->image->is_padded_image_url()
? metadata->image->get_padded_image_url()
: metadata->image->get_image_url();
if (!image_url.is_valid()) {
return 0.0;
}
}

if (/*is_channel*/ content_group.second &&
content_group.first != kAllContentGroup) {
auto channels =
publisher_id_to_channels.find(metadata->publisher_id);
if (base::Contains(channels->second, content_group.first)) {
return weight.weighting;
}

return 0.0;
} else if (/*is_channel*/ !content_group.second) {
return metadata->publisher_id == content_group.first
? weight.weighting
: 0.0;
}

return weight.weighting;
},
is_hero, SampleContentGroup(eligible_content_groups),
publisher_id_to_channels, locale);
};

auto hero_article =
PickRouletteAndRemove(articles, get_weighting(/*is_hero*/ true));
if (!hero_article) {
DVLOG(1) << "Failed to generate hero";
return result;
}

result.push_back(mojom::FeedItemV2::NewHero(
mojom::HeroArticle::New(std::move(hero_article))));

const int block_min_inline = features::kBraveNewsMinBlockCards.Get();
const int block_max_inline = features::kBraveNewsMaxBlockCards.Get();
auto follow_count = GetNormal(block_min_inline, block_max_inline + 1);
for (auto i = 0; i < follow_count; ++i) {
bool is_discover = base::RandDouble() < inline_discovery_ratio;
auto generated = is_discover
? PickDiscoveryArticleAndRemove(articles)
: PickRouletteAndRemove(articles, get_weighting());
if (!generated) {
DVLOG(1) << "Failed to generate article";
continue;
}
result.push_back(mojom::FeedItemV2::NewArticle(
Expand Down Expand Up @@ -989,10 +1167,12 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() {
// what channel cards to show.
Channels channels =
channels_controller_->GetChannelsFromPublishers(publishers, &*prefs_);

std::vector<std::string> subscribed_channels;
for (const auto& [id, channel] : channels) {
if (base::Contains(channel->subscribed_locales, locale)) {
subscribed_channels.push_back(id);
DVLOG(1) << "Subscribed to channel: " << id;
}
}

Expand All @@ -1008,9 +1188,22 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() {
base::ranges::move(items, std::back_inserter(feed->items));
};

std::vector<ContentGroup> eligible_content_groups;
for (const auto& channel_id : subscribed_channels) {
eligible_content_groups.push_back(std::make_pair(channel_id, true));
}
for (const auto& [publisher_id, publisher] : publishers) {
if (publisher->user_enabled_status == mojom::UserEnabled::ENABLED) {
eligible_content_groups.push_back(std::make_pair(publisher_id, false));
DVLOG(1) << "Subscribed to publisher: " << publisher->publisher_name;
}
}

// Step 1: Generate a block
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.rkq699fwps0
auto initial_block = GenerateBlock(articles);
std::vector<mojom::FeedItemV2Ptr> initial_block =
GenerateBlockFromContentGroups(articles, locale, publishers,
eligible_content_groups);
DVLOG(1) << "Step 1: Standard Block (" << initial_block.size()
<< " articles)";
add_items(initial_block);
Expand Down Expand Up @@ -1040,14 +1233,16 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() {
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.os2ze8cesd8v
if (iteration_type == 0) {
DVLOG(1) << "Step 4: Standard Block";
items = GenerateBlock(articles);
items = GenerateBlockFromContentGroups(articles, locale, publishers,
eligible_content_groups);
} else if (iteration_type == 1) {
// Step 5: Block or Cluster Generation
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.tpvsjkq0lzmy
// Half the time, a normal block
if (TossCoin()) {
DVLOG(1) << "Step 5: Standard Block";
items = GenerateBlock(articles);
items = GenerateBlockFromContentGroups(articles, locale, publishers,
eligible_content_groups);
} else {
items = GenerateClusterBlock(locale, publishers, subscribed_channels,
topics, articles);
Expand Down
Loading