From b649d6331a3eced750026defdaaff63e18cf03dd Mon Sep 17 00:00:00 2001 From: lminto Date: Tue, 12 Dec 2023 22:23:05 +0000 Subject: [PATCH] Add hierarchical sampling for news cards feed composition (#21111) Co-authored-by: Jay Harris --- .../brave_news/browser/feed_v2_builder.cc | 265 +++++++++++++++--- components/brave_news/common/features.cc | 11 +- components/brave_news/common/features.h | 7 +- 3 files changed, 243 insertions(+), 40 deletions(-) diff --git a/components/brave_news/browser/feed_v2_builder.cc b/components/brave_news/browser/feed_v2_builder.cc index 0b29bcb5c917..912ce2027e44 100644 --- a/components/brave_news/browser/feed_v2_builder.cc +++ b/components/brave_news/browser/feed_v2_builder.cc @@ -69,6 +69,11 @@ struct ArticleWeight { using ArticleInfo = std::tuple; using ArticleInfos = std::vector; +/* publisher_or_channel_id, is_channel */ +using ContentGroup = std::pair; +constexpr char kAllContentGroup[] = "all"; +constexpr float kSampleContentGroupAllRatio = 0.2f; + std::string GetFeedHash(const Channels& channels, const Publishers& publishers, const ETags& etags) { @@ -140,9 +145,8 @@ double GetPopRecency(const mojom::FeedItemMetadataPtr& article) { auto& publish_time = article->publish_time; - double popularity = article->pop_score == 0 - ? features::kBraveNewsPopScoreFallback.Get() - : article->pop_score; + double popularity = std::min(article->pop_score, 100.0) / 100.0 + + features::kBraveNewsPopScoreMin.Get(); double multiplier = publish_time > base::Time::Now() - base::Hours(5) ? 2 : 1; auto dt = base::Time::Now() - publish_time; @@ -168,9 +172,10 @@ ArticleWeight GetArticleWeight(const mojom::FeedItemMetadataPtr& article, const double source_visits_projected = source_visits_min + signals.at(0)->visit_weight * (1 - source_visits_min); const auto pop_recency = GetPopRecency(article); + return { .pop_recency = pop_recency, - .weighting = source_visits_projected * subscribed_weight * pop_recency, + .weighting = source_visits_projected + subscribed_weight + pop_recency, // Note: GetArticleWeight returns the Signal for the Publisher first, and // we use that to determine whether this Publisher has ever been visited. .visited = signals.at(0)->visit_weight != 0, @@ -178,7 +183,8 @@ ArticleWeight GetArticleWeight(const mojom::FeedItemMetadataPtr& article, }; } -std::string PickRandom(const std::vector& items) { +template +T PickRandom(const std::vector& items) { CHECK(!items.empty()); // Note: RandInt is inclusive, hence the minus 1 return items[base::RandInt(0, items.size() - 1)]; @@ -190,6 +196,7 @@ ArticleInfos GetArticleInfos(const std::string& locale, const Signals& signals) { ArticleInfos articles; base::flat_set seen_articles; + for (const auto& item : feed_items) { if (item.is_null()) { continue; @@ -219,12 +226,29 @@ ArticleInfos GetArticleInfos(const std::string& locale, ArticleInfo pair = std::tuple(article->data->Clone(), GetArticleWeight(article->data, article_signals)); + articles.push_back(std::move(pair)); } } + return articles; } +std::vector GetChannelsForPublisher( + const std::string& locale, + const mojom::PublisherPtr& publisher) { + std::vector result; + for (const auto& locale_info : publisher->locales) { + if (locale_info->locale != locale) { + continue; + } + for (const auto& channel : locale_info->channels) { + result.push_back(channel); + } + } + return result; +} + // Randomly true/false with equal probability. bool TossCoin() { return base::RandDouble() < 0.5; @@ -261,32 +285,76 @@ int GetNormal(int min, int max) { return min + floor((max - min) * GetNormal()); } -using GetWeighting = double(const mojom::FeedItemMetadataPtr& article, - const ArticleWeight& weight); +using GetWeighting = + base::RepeatingCallback; + +// Returns a probability distribution (sum to 1) of the weights. Temperature +// controls how "smooth" the distribution is. High temperature brings the +// distribution closer to a uniform distribution (more randomness). +// Low temperature brings the distribution closer to a delta function (less +// randomness). +void SoftmaxWithTemperature( + std::vector& weights, + double temperature = features::kBraveNewsTemperature.Get()) { + if (temperature == 0) { + return; + } + + double max = *base::ranges::max_element(weights.begin(), weights.end()); + base::ranges::transform(weights.begin(), weights.end(), weights.begin(), + [temperature, max](double weight) { + return std::exp((weight - max) / temperature); + }); + double sum = std::accumulate(weights.begin(), weights.end(), 0.0); + base::ranges::transform(weights.begin(), weights.end(), weights.begin(), + [sum](double weight) { return weight / sum; }); +} + +// Sample across subscribed channels (direct and native) and publishers. +ContentGroup SampleContentGroup( + const std::vector& eligible_content_groups) { + ContentGroup sampled_content_group; + if (eligible_content_groups.empty()) { + return sampled_content_group; + } + + if (base::RandDouble() < kSampleContentGroupAllRatio) { + return std::make_pair(kAllContentGroup, true); + } + return PickRandom(eligible_content_groups); +} // Picks an article with a probability article_weight/sum(article_weights). mojom::FeedItemMetadataPtr PickRouletteAndRemove( ArticleInfos& articles, - GetWeighting get_weighting = [](const auto& article, const auto& weight) { - return weight.weighting; - }) { - double total_weight = 0; - for (const auto& [article, weight] : articles) { - total_weight += get_weighting(article, weight); - } + GetWeighting get_weighting = base::BindRepeating( + [](const mojom::FeedItemMetadataPtr& metadata, + const ArticleWeight& weight) { return weight.weighting; }), + bool use_softmax = false) { + std::vector weights; + base::ranges::transform(articles, std::back_inserter(weights), + [&get_weighting](const auto& article_info) { + return get_weighting.Run(std::get<0>(article_info), + std::get<1>(article_info)); + }); // None of the items are eligible to be picked. - if (total_weight == 0) { + if (std::accumulate(weights.begin(), weights.end(), 0.0) == 0) { return nullptr; } + if (use_softmax) { + SoftmaxWithTemperature(weights); + } + + double total_weight = std::accumulate(weights.begin(), weights.end(), 0.0); double picked_value = base::RandDouble() * total_weight; double current_weight = 0; uint64_t i; - for (i = 0; i < articles.size(); ++i) { - auto& [article, weight] = articles[i]; - current_weight += get_weighting(article, weight); + for (i = 0; i < weights.size(); ++i) { + current_weight += weights[i]; if (current_weight > picked_value) { break; } @@ -304,13 +372,15 @@ mojom::FeedItemMetadataPtr PickRouletteAndRemove( // 2. **AND** The user hasn't visited. mojom::FeedItemMetadataPtr PickDiscoveryArticleAndRemove( ArticleInfos& articles) { - return PickRouletteAndRemove(articles, - [](const auto& article, const auto& weight) { - if (weight.subscribed || weight.visited) { - return 0.; - } - return weight.pop_recency; - }); + return PickRouletteAndRemove( + articles, + base::BindRepeating([](const mojom::FeedItemMetadataPtr& metadata, + const ArticleWeight& weight) { + if (weight.subscribed) { + return 0.0; + } + return weight.pop_recency; + })); } // Generates a standard block: @@ -331,12 +401,14 @@ std::vector GenerateBlock( } auto hero_article = PickRouletteAndRemove( - articles, [](const auto& article, const auto& weight) { - auto image_url = article->image->is_padded_image_url() - ? article->image->get_padded_image_url() - : article->image->get_image_url(); + articles, + base::BindRepeating([](const mojom::FeedItemMetadataPtr& metadata, + const ArticleWeight& weight) { + auto image_url = metadata->image->is_padded_image_url() + ? metadata->image->get_padded_image_url() + : metadata->image->get_image_url(); return image_url.is_valid() ? weight.weighting : 0; - }); + })); // We might not be able to generate a hero card, if none of the articles in // this feed have an image. @@ -350,9 +422,115 @@ std::vector GenerateBlock( auto follow_count = GetNormal(block_min_inline, block_max_inline + 1); for (auto i = 0; i < follow_count; ++i) { bool is_discover = base::RandDouble() < inline_discovery_ratio; - auto generated = is_discover ? PickDiscoveryArticleAndRemove(articles) - : PickRouletteAndRemove(articles); + mojom::FeedItemMetadataPtr generated; + + if (is_discover) { + generated = PickDiscoveryArticleAndRemove(articles); + } else { + generated = PickRouletteAndRemove(articles); + } + + if (!generated) { + DVLOG(1) << "Failed to generate article (is_discover=" << is_discover + << ")"; + continue; + } + result.push_back(mojom::FeedItemV2::NewArticle( + mojom::Article::New(std::move(generated), is_discover))); + } + + return result; +} + +// Generates a block from sampled content groups: +// 1. Hero Article +// 2. 1 - 5 Inline Articles (a percentage of which might be discover cards). +std::vector GenerateBlockFromContentGroups( + ArticleInfos& articles, + const std::string& locale, + const Publishers& publishers, + const std::vector& eligible_content_groups, + // Ratio of inline articles to discovery articles. + // discover ratio % of the time, we should do a discover card here instead + // of a roulette card. + // https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.4rkb0vecgekl + double inline_discovery_ratio = + features::kBraveNewsInlineDiscoveryRatio.Get()) { + DVLOG(1) << __FUNCTION__; + std::vector result; + if (articles.empty() || eligible_content_groups.empty()) { + return result; + } + + base::flat_map> + publisher_id_to_channels; + for (const auto& [publisher_id, publisher] : publishers) { + publisher_id_to_channels[publisher_id] = + GetChannelsForPublisher(locale, publisher); + } + + // Generates a GetWeighting function tied to a specific content group. Each + // invocation of |get_weighting| will generate a new |GetWeighting| tied to a + // (freshly sampled) content_group. + auto get_weighting = [&eligible_content_groups, &publisher_id_to_channels, + &locale](bool is_hero = false) { + return base::BindRepeating( + [](const bool is_hero, const ContentGroup& content_group, + const base::flat_map>& + publisher_id_to_channels, + const std::string& locale, + const mojom::FeedItemMetadataPtr& metadata, + const ArticleWeight& weight) { + if (is_hero) { + auto image_url = metadata->image->is_padded_image_url() + ? metadata->image->get_padded_image_url() + : metadata->image->get_image_url(); + if (!image_url.is_valid()) { + return 0.0; + } + } + + if (/*is_channel*/ content_group.second && + content_group.first != kAllContentGroup) { + auto channels = + publisher_id_to_channels.find(metadata->publisher_id); + if (base::Contains(channels->second, content_group.first)) { + return weight.weighting; + } + + return 0.0; + } else if (/*is_channel*/ !content_group.second) { + return metadata->publisher_id == content_group.first + ? weight.weighting + : 0.0; + } + + return weight.weighting; + }, + is_hero, SampleContentGroup(eligible_content_groups), + publisher_id_to_channels, locale); + }; + + auto hero_article = + PickRouletteAndRemove(articles, get_weighting(/*is_hero*/ true)); + if (!hero_article) { + DVLOG(1) << "Failed to generate hero"; + return result; + } + + result.push_back(mojom::FeedItemV2::NewHero( + mojom::HeroArticle::New(std::move(hero_article)))); + + const int block_min_inline = features::kBraveNewsMinBlockCards.Get(); + const int block_max_inline = features::kBraveNewsMaxBlockCards.Get(); + auto follow_count = GetNormal(block_min_inline, block_max_inline + 1); + for (auto i = 0; i < follow_count; ++i) { + bool is_discover = base::RandDouble() < inline_discovery_ratio; + auto generated = is_discover + ? PickDiscoveryArticleAndRemove(articles) + : PickRouletteAndRemove(articles, get_weighting()); if (!generated) { + DVLOG(1) << "Failed to generate article"; continue; } result.push_back(mojom::FeedItemV2::NewArticle( @@ -989,10 +1167,12 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() { // what channel cards to show. Channels channels = channels_controller_->GetChannelsFromPublishers(publishers, &*prefs_); + std::vector subscribed_channels; for (const auto& [id, channel] : channels) { if (base::Contains(channel->subscribed_locales, locale)) { subscribed_channels.push_back(id); + DVLOG(1) << "Subscribed to channel: " << id; } } @@ -1008,9 +1188,22 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() { base::ranges::move(items, std::back_inserter(feed->items)); }; + std::vector eligible_content_groups; + for (const auto& channel_id : subscribed_channels) { + eligible_content_groups.push_back(std::make_pair(channel_id, true)); + } + for (const auto& [publisher_id, publisher] : publishers) { + if (publisher->user_enabled_status == mojom::UserEnabled::ENABLED) { + eligible_content_groups.push_back(std::make_pair(publisher_id, false)); + DVLOG(1) << "Subscribed to publisher: " << publisher->publisher_name; + } + } + // Step 1: Generate a block // https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.rkq699fwps0 - auto initial_block = GenerateBlock(articles); + std::vector initial_block = + GenerateBlockFromContentGroups(articles, locale, publishers, + eligible_content_groups); DVLOG(1) << "Step 1: Standard Block (" << initial_block.size() << " articles)"; add_items(initial_block); @@ -1040,14 +1233,16 @@ mojom::FeedV2Ptr FeedV2Builder::GenerateAllFeed() { // https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.os2ze8cesd8v if (iteration_type == 0) { DVLOG(1) << "Step 4: Standard Block"; - items = GenerateBlock(articles); + items = GenerateBlockFromContentGroups(articles, locale, publishers, + eligible_content_groups); } else if (iteration_type == 1) { // Step 5: Block or Cluster Generation // https://docs.google.com/document/d/1bSVHunwmcHwyQTpa3ab4KRbGbgNQ3ym_GHvONnrBypg/edit#heading=h.tpvsjkq0lzmy // Half the time, a normal block if (TossCoin()) { DVLOG(1) << "Step 5: Standard Block"; - items = GenerateBlock(articles); + items = GenerateBlockFromContentGroups(articles, locale, publishers, + eligible_content_groups); } else { items = GenerateClusterBlock(locale, publishers, subscribed_channels, topics, articles); diff --git a/components/brave_news/common/features.cc b/components/brave_news/common/features.cc index 66fa740dbd1b..ccb9836b6798 100644 --- a/components/brave_news/common/features.cc +++ b/components/brave_news/common/features.cc @@ -25,16 +25,16 @@ const base::FeatureParam kBraveNewsMaxBlockCards{&kBraveNewsFeedUpdate, const base::FeatureParam kBraveNewsPopScoreHalfLife{ &kBraveNewsFeedUpdate, "pop-score-half-life", 18}; -const base::FeatureParam kBraveNewsPopScoreFallback{ - &kBraveNewsFeedUpdate, "pop-score-fallback", 50}; +const base::FeatureParam kBraveNewsPopScoreMin{ + &kBraveNewsFeedUpdate, "pop-score-fallback", 0.5}; const base::FeatureParam kBraveNewsInlineDiscoveryRatio{ &kBraveNewsFeedUpdate, "inline-discovery-ratio", 0.25}; const base::FeatureParam kBraveNewsSourceSubscribedBoost{ - &kBraveNewsFeedUpdate, "source-subscribed-boost", 1}; + &kBraveNewsFeedUpdate, "source-subscribed-boost", 1.0}; const base::FeatureParam kBraveNewsChannelSubscribedBoost{ - &kBraveNewsFeedUpdate, "channel-subscribed-boost", 0.2}; + &kBraveNewsFeedUpdate, "channel-subscribed-boost", 1.0}; const base::FeatureParam kBraveNewsSourceVisitsMin{ &kBraveNewsFeedUpdate, "source-visits-min", 0.2}; @@ -42,4 +42,7 @@ const base::FeatureParam kBraveNewsSourceVisitsMin{ const base::FeatureParam kBraveNewsCategoryTopicRatio{ &kBraveNewsFeedUpdate, "category-topic-ratio", 0.5}; +const base::FeatureParam kBraveNewsTemperature{&kBraveNewsFeedUpdate, + "temperature", 1}; + } // namespace brave_news::features diff --git a/components/brave_news/common/features.h b/components/brave_news/common/features.h index 161a9b308e59..24670134a718 100644 --- a/components/brave_news/common/features.h +++ b/components/brave_news/common/features.h @@ -27,7 +27,7 @@ extern const base::FeatureParam kBraveNewsPopScoreHalfLife; // Used as the fallback |pop_score| value for articles we // don't have a |pop_score| for, such as articles from a direct feed, or just // articles that Brave Search doesn't have enough information about. -extern const base::FeatureParam kBraveNewsPopScoreFallback; +extern const base::FeatureParam kBraveNewsPopScoreMin; // The ratio at which inline cards present discovery options (i.e. a source the // user has not visited before). @@ -61,6 +61,11 @@ extern const base::FeatureParam kBraveNewsSourceVisitsMin; // 80% of the clusters should be categories and 20% topics. extern const base::FeatureParam kBraveNewsCategoryTopicRatio; +// The temperature of the softmax function used to compute the sampling +// probabilities of the articles in the feed. High temperature means the +// distribution is more uniform. +extern const base::FeatureParam kBraveNewsTemperature; + } // namespace brave_news::features #endif // BRAVE_COMPONENTS_BRAVE_NEWS_COMMON_FEATURES_H_