Skip to content

Commit

Permalink
Add hierarchical sampling for news cards feed composition (#21111)
Browse files Browse the repository at this point in the history
Co-authored-by: Jay Harris <jay.harris@outlook.co.nz>
  • Loading branch information
LorenzoMinto and fallaciousreasoning committed Dec 20, 2023
1 parent b93b21d commit b649d63
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 40 deletions.
265 changes: 230 additions & 35 deletions components/brave_news/browser/feed_v2_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ struct ArticleWeight {
using ArticleInfo = std::tuple<mojom::FeedItemMetadataPtr, ArticleWeight>;
using ArticleInfos = std::vector<ArticleInfo>;

/* publisher_or_channel_id, is_channel */
using ContentGroup = std::pair<std::string, bool>;
constexpr char kAllContentGroup[] = "all";
constexpr float kSampleContentGroupAllRatio = 0.2f;

std::string GetFeedHash(const Channels& channels,
const Publishers& publishers,
const ETags& etags) {
Expand Down Expand Up @@ -140,9 +145,8 @@ double GetPopRecency(const mojom::FeedItemMetadataPtr& article) {

auto& publish_time = article->publish_time;

double popularity = article->pop_score == 0
? features::kBraveNewsPopScoreFallback.Get()
: article->pop_score;
double popularity = std::min(article->pop_score, 100.0) / 100.0 +
features::kBraveNewsPopScoreMin.Get();
double multiplier = publish_time > base::Time::Now() - base::Hours(5) ? 2 : 1;
auto dt = base::Time::Now() - publish_time;

Expand All @@ -168,17 +172,19 @@ ArticleWeight GetArticleWeight(const mojom::FeedItemMetadataPtr& article,
const double source_visits_projected =
source_visits_min + signals.at(0)->visit_weight * (1 - source_visits_min);
const auto pop_recency = GetPopRecency(article);

return {
.pop_recency = pop_recency,
.weighting = source_visits_projected * subscribed_weight * pop_recency,
.weighting = source_visits_projected + subscribed_weight + pop_recency,
// Note: GetArticleWeight returns the Signal for the Publisher first, and
// we use that to determine whether this Publisher has ever been visited.
.visited = signals.at(0)->visit_weight != 0,
.subscribed = subscribed_weight != 0,
};
}

std::string PickRandom(const std::vector<std::string>& items) {
template <typename T>
T PickRandom(const std::vector<T>& items) {
CHECK(!items.empty());
// Note: RandInt is inclusive, hence the minus 1
return items[base::RandInt(0, items.size() - 1)];
Expand All @@ -190,6 +196,7 @@ ArticleInfos GetArticleInfos(const std::string& locale,
const Signals& signals) {
ArticleInfos articles;
base::flat_set<GURL> seen_articles;

for (const auto& item : feed_items) {
if (item.is_null()) {
continue;
Expand Down Expand Up @@ -219,12 +226,29 @@ ArticleInfos GetArticleInfos(const std::string& locale,
ArticleInfo pair =
std::tuple(article->data->Clone(),
GetArticleWeight(article->data, article_signals));

articles.push_back(std::move(pair));
}
}

return articles;
}

std::vector<std::string> GetChannelsForPublisher(
const std::string& locale,
const mojom::PublisherPtr& publisher) {
std::vector<std::string> result;
for (const auto& locale_info : publisher->locales) {
if (locale_info->locale != locale) {
continue;
}
for (const auto& channel : locale_info->channels) {
result.push_back(channel);
}
}
return result;
}

// Randomly true/false with equal probability.
bool TossCoin() {
return base::RandDouble() < 0.5;
Expand Down Expand Up @@ -261,32 +285,76 @@ int GetNormal(int min, int max) {
return min + floor((max - min) * GetNormal());
}

using GetWeighting = double(const mojom::FeedItemMetadataPtr& article,
const ArticleWeight& weight);
using GetWeighting =
base::RepeatingCallback<double(const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight)>;

// Returns a probability distribution (sum to 1) of the weights. Temperature
// controls how "smooth" the distribution is. High temperature brings the
// distribution closer to a uniform distribution (more randomness).
// Low temperature brings the distribution closer to a delta function (less
// randomness).
void SoftmaxWithTemperature(
std::vector<double>& weights,
double temperature = features::kBraveNewsTemperature.Get()) {
if (temperature == 0) {
return;
}

double max = *base::ranges::max_element(weights.begin(), weights.end());
base::ranges::transform(weights.begin(), weights.end(), weights.begin(),
[temperature, max](double weight) {
return std::exp((weight - max) / temperature);
});
double sum = std::accumulate(weights.begin(), weights.end(), 0.0);
base::ranges::transform(weights.begin(), weights.end(), weights.begin(),
[sum](double weight) { return weight / sum; });
}

// Sample across subscribed channels (direct and native) and publishers.
ContentGroup SampleContentGroup(
const std::vector<ContentGroup>& eligible_content_groups) {
ContentGroup sampled_content_group;
if (eligible_content_groups.empty()) {
return sampled_content_group;
}

if (base::RandDouble() < kSampleContentGroupAllRatio) {
return std::make_pair(kAllContentGroup, true);
}
return PickRandom<ContentGroup>(eligible_content_groups);
}

// Picks an article with a probability article_weight/sum(article_weights).
mojom::FeedItemMetadataPtr PickRouletteAndRemove(
ArticleInfos& articles,
GetWeighting get_weighting = [](const auto& article, const auto& weight) {
return weight.weighting;
}) {
double total_weight = 0;
for (const auto& [article, weight] : articles) {
total_weight += get_weighting(article, weight);
}
GetWeighting get_weighting = base::BindRepeating(
[](const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) { return weight.weighting; }),
bool use_softmax = false) {
std::vector<double> weights;
base::ranges::transform(articles, std::back_inserter(weights),
[&get_weighting](const auto& article_info) {
return get_weighting.Run(std::get<0>(article_info),
std::get<1>(article_info));
});

// None of the items are eligible to be picked.
if (total_weight == 0) {
if (std::accumulate(weights.begin(), weights.end(), 0.0) == 0) {
return nullptr;
}

if (use_softmax) {
SoftmaxWithTemperature(weights);
}

double total_weight = std::accumulate(weights.begin(), weights.end(), 0.0);
double picked_value = base::RandDouble() * total_weight;
double current_weight = 0;

uint64_t i;
for (i = 0; i < articles.size(); ++i) {
auto& [article, weight] = articles[i];
current_weight += get_weighting(article, weight);
for (i = 0; i < weights.size(); ++i) {
current_weight += weights[i];
if (current_weight > picked_value) {
break;
}
Expand All @@ -304,13 +372,15 @@ mojom::FeedItemMetadataPtr PickRouletteAndRemove(
// 2. **AND** The user hasn't visited.
mojom::FeedItemMetadataPtr PickDiscoveryArticleAndRemove(
ArticleInfos& articles) {
return PickRouletteAndRemove(articles,
[](const auto& article, const auto& weight) {
if (weight.subscribed || weight.visited) {
return 0.;
}
return weight.pop_recency;
});
return PickRouletteAndRemove(
articles,
base::BindRepeating([](const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) {
if (weight.subscribed) {
return 0.0;
}
return weight.pop_recency;
}));
}

// Generates a standard block:
Expand All @@ -331,12 +401,14 @@ std::vector<mojom::FeedItemV2Ptr> GenerateBlock(
}

auto hero_article = PickRouletteAndRemove(
articles, [](const auto& article, const auto& weight) {
auto image_url = article->image->is_padded_image_url()
? article->image->get_padded_image_url()
: article->image->get_image_url();
articles,
base::BindRepeating([](const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) {
auto image_url = metadata->image->is_padded_image_url()
? metadata->image->get_padded_image_url()
: metadata->image->get_image_url();
return image_url.is_valid() ? weight.weighting : 0;
});
}));

// We might not be able to generate a hero card, if none of the articles in
// this feed have an image.
Expand All @@ -350,9 +422,115 @@ std::vector<mojom::FeedItemV2Ptr> GenerateBlock(
auto follow_count = GetNormal(block_min_inline, block_max_inline + 1);
for (auto i = 0; i < follow_count; ++i) {
bool is_discover = base::RandDouble() < inline_discovery_ratio;
auto generated = is_discover ? PickDiscoveryArticleAndRemove(articles)
: PickRouletteAndRemove(articles);
mojom::FeedItemMetadataPtr generated;

if (is_discover) {
generated = PickDiscoveryArticleAndRemove(articles);
} else {
generated = PickRouletteAndRemove(articles);
}

if (!generated) {
DVLOG(1) << "Failed to generate article (is_discover=" << is_discover
<< ")";
continue;
}
result.push_back(mojom::FeedItemV2::NewArticle(
mojom::Article::New(std::move(generated), is_discover)));
}

return result;
}

// Generates a block from sampled content groups:
// 1. Hero Article
// 2. 1 - 5 Inline Articles (a percentage of which might be discover cards).
std::vector<mojom::FeedItemV2Ptr> GenerateBlockFromContentGroups(
ArticleInfos& articles,
const std::string& locale,
const Publishers& publishers,
const std::vector<ContentGroup>& eligible_content_groups,
// Ratio of inline articles to discovery articles.
// discover ratio % of the time, we should do a discover card here instead
// of a roulette card.
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.4rkb0vecgekl
double inline_discovery_ratio =
features::kBraveNewsInlineDiscoveryRatio.Get()) {
DVLOG(1) << __FUNCTION__;
std::vector<mojom::FeedItemV2Ptr> result;
if (articles.empty() || eligible_content_groups.empty()) {
return result;
}

base::flat_map<std::string, std::vector<std::string>>
publisher_id_to_channels;
for (const auto& [publisher_id, publisher] : publishers) {
publisher_id_to_channels[publisher_id] =
GetChannelsForPublisher(locale, publisher);
}

// Generates a GetWeighting function tied to a specific content group. Each
// invocation of |get_weighting| will generate a new |GetWeighting| tied to a
// (freshly sampled) content_group.
auto get_weighting = [&eligible_content_groups, &publisher_id_to_channels,
&locale](bool is_hero = false) {
return base::BindRepeating(
[](const bool is_hero, const ContentGroup& content_group,
const base::flat_map<std::string, std::vector<std::string>>&
publisher_id_to_channels,
const std::string& locale,
const mojom::FeedItemMetadataPtr& metadata,
const ArticleWeight& weight) {
if (is_hero) {
auto image_url = metadata->image->is_padded_image_url()
? metadata->image->get_padded_image_url()
: metadata->image->get_image_url();
if (!image_url.is_valid()) {
return 0.0;
}
}

if (/*is_channel*/ content_group.second &&
content_group.first != kAllContentGroup) {
auto channels =
publisher_id_to_channels.find(metadata->publisher_id);
if (base::Contains(channels->second, content_group.first)) {
return weight.weighting;
}

return 0.0;
} else if (/*is_channel*/ !content_group.second) {
return metadata->publisher_id == content_group.first
? weight.weighting
: 0.0;
}

return weight.weighting;
},
is_hero, SampleContentGroup(eligible_content_groups),
publisher_id_to_channels, locale);
};

auto hero_article =
PickRouletteAndRemove(articles, get_weighting(/*is_hero*/ true));
if (!hero_article) {
DVLOG(1) << "Failed to generate hero";
return result;
}

result.push_back(mojom::FeedItemV2::NewHero(
mojom::HeroArticle::New(std::move(hero_article))));

const int block_min_inline = features::kBraveNewsMinBlockCards.Get();
const int block_max_inline = features::kBraveNewsMaxBlockCards.Get();
auto follow_count = GetNormal(block_min_inline, block_max_inline + 1);
for (auto i = 0; i < follow_count; ++i) {
bool is_discover = base::RandDouble() < inline_discovery_ratio;
auto generated = is_discover
? PickDiscoveryArticleAndRemove(articles)
: PickRouletteAndRemove(articles, get_weighting());
if (!generated) {
DVLOG(1) << "Failed to generate article";
continue;
}
result.push_back(mojom::FeedItemV2::NewArticle(
Expand Down Expand Up @@ -989,10 +1167,12 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() {
// what channel cards to show.
Channels channels =
channels_controller_->GetChannelsFromPublishers(publishers, &*prefs_);

std::vector<std::string> subscribed_channels;
for (const auto& [id, channel] : channels) {
if (base::Contains(channel->subscribed_locales, locale)) {
subscribed_channels.push_back(id);
DVLOG(1) << "Subscribed to channel: " << id;
}
}

Expand All @@ -1008,9 +1188,22 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() {
base::ranges::move(items, std::back_inserter(feed->items));
};

std::vector<ContentGroup> eligible_content_groups;
for (const auto& channel_id : subscribed_channels) {
eligible_content_groups.push_back(std::make_pair(channel_id, true));
}
for (const auto& [publisher_id, publisher] : publishers) {
if (publisher->user_enabled_status == mojom::UserEnabled::ENABLED) {
eligible_content_groups.push_back(std::make_pair(publisher_id, false));
DVLOG(1) << "Subscribed to publisher: " << publisher->publisher_name;
}
}

// Step 1: Generate a block
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.rkq699fwps0
auto initial_block = GenerateBlock(articles);
std::vector<mojom::FeedItemV2Ptr> initial_block =
GenerateBlockFromContentGroups(articles, locale, publishers,
eligible_content_groups);
DVLOG(1) << "Step 1: Standard Block (" << initial_block.size()
<< " articles)";
add_items(initial_block);
Expand Down Expand Up @@ -1040,14 +1233,16 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() {
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.os2ze8cesd8v
if (iteration_type == 0) {
DVLOG(1) << "Step 4: Standard Block";
items = GenerateBlock(articles);
items = GenerateBlockFromContentGroups(articles, locale, publishers,
eligible_content_groups);
} else if (iteration_type == 1) {
// Step 5: Block or Cluster Generation
// https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.tpvsjkq0lzmy
// Half the time, a normal block
if (TossCoin()) {
DVLOG(1) << "Step 5: Standard Block";
items = GenerateBlock(articles);
items = GenerateBlockFromContentGroups(articles, locale, publishers,
eligible_content_groups);
} else {
items = GenerateClusterBlock(locale, publishers, subscribed_channels,
topics, articles);
Expand Down
Loading

0 comments on commit b649d63

Please sign in to comment.