diff --git a/src/handler_util.rs b/src/discord_util.rs similarity index 100% rename from src/handler_util.rs rename to src/discord_util.rs diff --git a/src/fixed_deque.rs b/src/fixed_deque.rs new file mode 100644 index 0000000..0b8970c --- /dev/null +++ b/src/fixed_deque.rs @@ -0,0 +1,37 @@ +use std::{array::from_fn, collections::HashMap, hash::Hash}; + + +pub struct FixedDeque { + pub data: [T; N], + pos: usize +} + +impl FixedDeque { + pub fn new() -> Self { + FixedDeque { + data: from_fn(|_| T::default()), + pos: 0 + } + } +} + +impl FixedDeque { + pub fn push(&mut self, elem: T) { + self.data[self.pos] = elem; + self.pos += 1; + if self.pos >= N { + self.pos = 0; + } + } +} + +impl FixedDeque { + pub fn counts(&self) -> HashMap { + let mut res = HashMap::new(); + self.data.iter().cloned().for_each(|elem| { + let count = res.entry(elem).or_default(); + *count += 1; + }); + res + } +} \ No newline at end of file diff --git a/src/handler.rs b/src/handler.rs deleted file mode 100644 index ec06c4e..0000000 --- a/src/handler.rs +++ /dev/null @@ -1,207 +0,0 @@ -use itertools::Itertools; -use log::{warn, info, trace}; -use image::{write_buffer_with_format, ColorType, ImageOutputFormat}; -use regex::Regex; -use std::{io::{Cursor, Seek, SeekFrom}, sync::Arc}; -use palette::rgb::Rgb; -use dashmap::DashMap; -use serenity::{ - http::Http, - model:: { - application::interaction::{ - application_command::ApplicationCommandInteraction, - InteractionResponseType::ChannelMessageWithSource, - }, - id::GuildId, prelude::{UserId, ChannelId, Guild, Channel}, Timestamp - }, - prelude::*, utils::Color -}; -use futures::future::join_all; -use lazy_static::lazy_static; -use wordcloud_rs::{Token, WordCloud, Colors}; -use crate::{idiom::{Idioms, tokenize}, discord_emojis::DiscordEmojis, handler_util::read_past}; -const READ_PAST: u64 = 1000; -const DAYS: i64 = 100; - -lazy_static! { - static ref RE_EMO: Regex = Regex::new(r"").unwrap(); - static ref RE_TAG: Regex = Regex::new(r"<@(\d*)>").unwrap(); - static ref RE_CHAN: Regex = Regex::new(r"<#(\d*)>").unwrap(); - static ref RE_ROLE: Regex = Regex::new(r"<@&(\d*)>").unwrap(); -} - -fn convert_color(color: Color) -> Rgb { - Rgb::new( - color.r() as f32/255., - color.g() as f32/255., - color.b() as f32/255. - ) -} - -pub struct Handler { - idioms: Arc>>, - emojis: DiscordEmojis -} - -impl Handler { - pub fn new() -> Self { - Self { - idioms: Arc::new(DashMap::new()), - emojis: DiscordEmojis::new(1000) - } - } - - pub fn message(&self, guild_id: GuildId, channel_id: ChannelId, member_id: UserId, message: String) { - if let Some(mut idiom) = self.idioms.get_mut(&guild_id) { - idiom.update(channel_id, member_id, tokenize(message)); - } else { - warn!(target: "wordy", "Guild {} isn't registered yet.", guild_id); - } - } - - async fn to_wc_tokens( - &self, tokens: Vec<(String, f32)>, http: &Arc - ) -> Vec<(Token, f32)> { - join_all(tokens.into_iter().map(|(str, v)| async move { - if let Some(capts) = RE_EMO.captures(&str) { - let emo_id = capts.get(2).unwrap().as_str(); - if let Ok(img) = self.emojis.get(emo_id).await { - (Token::Img(img), v) - } else { - let name = capts.get(1).unwrap().as_str(); - (Token::Text(name.to_string()), v) - } - } else if let Some(capts) = RE_TAG.captures(&str) { - let user_id = capts.get(1).unwrap().as_str().parse().unwrap(); - if let Ok(member) = http.get_user(user_id).await { - (Token::Text(format!("@{}", member.name)), v) - } else { - (Token::Text("@deleted_user".to_string()), v) - } - } else if let Some(capts) = RE_CHAN.captures(&str) { - let chan_id = capts.get(1).unwrap().as_str().parse().unwrap(); - match http.get_channel(chan_id).await { - Ok(Channel::Guild(channel)) => (Token::Text(format!("#{}", channel.name)), v), - Ok(Channel::Category(channel)) => (Token::Text(format!("#{}", channel.name)), v), - _ => (Token::Text("#deleted_channel".to_string()), v) - } - } else { - (Token::Text(str), v) - } - }).collect_vec()).await - } - - pub async fn cloud(&self, ctx: Context, command: ApplicationCommandInteraction) { - if let Some(member) = &command.member { - let color = member.colour(&ctx.cache).unwrap_or(Color::from_rgb(255, 255, 255)); - if let Some(guild_id) = command.guild_id { - let member_id = member.user.id; - let tokens = self.idioms.get(&guild_id).unwrap().idiom(member_id); - trace!(target: "wordy", "/cloud: retrieved {} tokens for {}", tokens.len(), member.user.name); - let wc_tokens = self.to_wc_tokens(tokens, &ctx.http).await; - let image = WordCloud::new() - .colors(Colors::BiaisedRainbow { - anchor: convert_color(color), - variance: 50. - }).generate(wc_tokens); - let mut img_file = Cursor::new(Vec::new()); - write_buffer_with_format( - &mut img_file, - image.as_raw(), - image.width(), - image.height(), - ColorType::Rgba8, - ImageOutputFormat::Png, - ) - .unwrap(); - img_file.seek(SeekFrom::Start(0)).unwrap(); - let img_vec = img_file.into_inner(); - - if let Err(why) = command - .create_interaction_response(&ctx.http, |response| { - response - .kind(ChannelMessageWithSource) - .interaction_response_data( - |message| message.add_file(( - img_vec.as_slice(), - format!("WordCloud_{}.png", member.display_name()).as_str() - )) - ) - }) - .await - { - warn!(target: "wordy", "/cloud: Response failed with `{}`", why); - }; - } else { - warn!(target: "wordy", "/cloud: Couldn't get guild"); - } - } else { - warn!(target: "wordy", "/cloud: Couldn't get member"); - } - } - - pub async fn info(&self, ctx: Context, command: ApplicationCommandInteraction) { - if let Err(why) = command - .create_interaction_response(&ctx.http, |response| { - response - .kind(ChannelMessageWithSource) - .interaction_response_data( - |message| message.content( - "Made with ❤️ by Inspi#8989\n - Repository: " - ) - ) - }).await { - warn!(target: "wordy", "/info: Response failed with `{}`", why); - }; - } - - pub async fn register_guild(&self, http: Arc, guild: Guild) { - // only read messages that are less than 100 days old - let cutoff_date = Timestamp::from_unix_timestamp( - Timestamp::now().unix_timestamp() - 3600*24*DAYS - ).unwrap(); - if let Ok(channels) = guild.channels(&http).await { - if !self.idioms.contains_key(&guild.id) { - info!(target: "wordy", "Registering {} (id {})", guild.name, guild.id); - self.idioms.insert(guild.id, Idioms::new()); - let http = Arc::clone(&http); - let idioms = Arc::clone(&self.idioms); - tokio::spawn(async move { - for (channel_id, channel) in channels { - let messages = read_past(&http, &channel, READ_PAST, cutoff_date).await; - let len = messages.len(); - for message in messages { - idioms.get_mut(&guild.id).unwrap().update( - channel_id, message.author.id, tokenize(message.content) - ); - } - if len > 0 { - info!(target: "wordy", "Read {} past messages in {}/{}", len, guild.name, channel.name()) - } - } - }); - } - } - } - - pub async fn register_commands(&self, http: Arc, guild_id: GuildId) { - trace!("Registering slash commands for Guild {}", guild_id); - if let Err(why) = - GuildId::set_application_commands(&guild_id, http, |commands| { - commands - .create_application_command(|command| { - command.name("cloud").description( - "Discover the word cloud that defines you !", - ) - }) - .create_application_command(|command| { - command - .name("info") - .description("Information about this bot.") - }) - }).await { - println!("Couldn't register slash commmands: {}", why); - }; - } -} diff --git a/src/main.rs b/src/main.rs index a33c5e6..d3e297a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,11 @@ mod idiom; mod discord_emojis; -mod handler; -mod handler_util; -mod handle_events; -use handler::Handler; +mod discord_util; +mod wordy; +mod wordy_events; +mod wordy_commands; +mod fixed_deque; +use wordy::Wordy; use env_logger; use std::fs::read_to_string; use log::{warn, error, LevelFilter}; @@ -50,7 +52,7 @@ async fn main() { | GatewayIntents::GUILD_MEMBERS | GatewayIntents::GUILD_PRESENCES ) - .event_handler(Handler::new()) + .event_handler(Wordy::new()) .application_id(bot_id.into()) .await .expect("Error creating client"); diff --git a/src/wordy.rs b/src/wordy.rs new file mode 100644 index 0000000..1ce98a4 --- /dev/null +++ b/src/wordy.rs @@ -0,0 +1,193 @@ +use itertools::Itertools; +use log::{warn, info, trace}; +use image::RgbaImage; +use regex::Regex; +use std::{sync::Arc, collections::HashMap}; +use anyhow::{Result, bail}; +use palette::rgb::Rgb; +use dashmap::DashMap; +use serenity::{ + http::Http, + model:: { + id::GuildId, prelude::{UserId, ChannelId, Guild, Channel, Message, EmojiId, Emoji, Member} + }, + prelude::*, utils::Color +}; +use futures::future::join_all; +use lazy_static::lazy_static; +use wordcloud_rs::{Token, WordCloud, Colors}; +use crate::{idiom::{Idioms, tokenize}, discord_emojis::DiscordEmojis, fixed_deque::FixedDeque}; + +lazy_static! { + static ref RE_EMO: Regex = Regex::new(r"^$").unwrap(); + static ref RE_TAG: Regex = Regex::new(r"^<@(\d*)>$").unwrap(); + static ref RE_CHAN: Regex = Regex::new(r"^<#(\d*)>$").unwrap(); + static ref RE_ROLE: Regex = Regex::new(r"^<@&(\d*)>$").unwrap(); +} + +pub struct EmojiRankings { + pub png_ranking: Vec<(EmojiId, usize)>, + pub gif_ranking: Vec<(EmojiId, usize)> +} + +fn convert_color(color: Color) -> Rgb { + Rgb::new( + color.r() as f32/255., + color.g() as f32/255., + color.b() as f32/255. + ) +} + +pub fn register_guild( + guild: &Guild, + idioms: Arc>>, + recents_emos: Arc>>, + servers_emos: Arc>>, +) -> bool { + if !idioms.contains_key(&guild.id) { + info!(target: "wordy", "Registering {} (id {})", guild.name, guild.id); + idioms.insert(guild.id, Idioms::new()); + recents_emos.insert(guild.id, FixedDeque::new()); + servers_emos.insert(guild.id, guild.emojis.clone()); + true + } else { + info!(target: "wordy", "Guild {} (id {}) was already registered", guild.name, guild.id); + false + } +} + +pub fn read_message( + guild_id: GuildId, + message: Message, + idioms: Arc>>, + recents_emos: Arc>>, + servers_emos: Arc>>, +) { + if let ( + Some(mut idiom), + Some(mut recent_emos), + Some(server_emos) + ) = ( + idioms.get_mut(&guild_id), + recents_emos.get_mut(&guild_id), + servers_emos.get(&guild_id) + ) { + let tokens = tokenize(message.content); + tokens + .iter() + .filter_map(|token| { + if let Some(caps) = RE_EMO.captures(token) { + let emoji_id = EmojiId( + caps.get(2).unwrap().as_str().parse::().unwrap() + ); + if server_emos.contains_key(&emoji_id) { + return Some(emoji_id); + } + } + None + }).unique() + .for_each(|emoji_id| recent_emos.push(emoji_id.clone())); + idiom.update(message.channel_id, message.author.id, tokens); + } else { + warn!(target: "wordy", "Guild {} isn't registered yet.", guild_id); + } +} + +pub struct Wordy { + pub idioms: Arc>>, + pub discord_emos: DiscordEmojis, + pub recents_emos: Arc>>, + pub servers_emos: Arc>>, +} + +impl Wordy { + pub fn new() -> Self { + Self { + idioms: Arc::new(DashMap::new()), + discord_emos: DiscordEmojis::new(1000), + recents_emos: Arc::new(DashMap::new()), + servers_emos: Arc::new(DashMap::new()), + } + } + + pub fn message(&self, message: Message) { + read_message( + message.guild_id.unwrap(), + message, + self.idioms.clone(), + self.recents_emos.clone(), + self.servers_emos.clone() + ); + } + + async fn to_wc_tokens( + &self, tokens: Vec<(String, f32)>, http: &Arc + ) -> Vec<(Token, f32)> { + join_all(tokens.into_iter().map(|(token, v)| async move { + if let Some(capts) = RE_EMO.captures(&token) { + let emo_id = capts.get(2).unwrap().as_str(); + if let Ok(img) = self.discord_emos.get(emo_id).await { + (Token::Img(img), v) + } else { + let name = capts.get(1).unwrap().as_str(); + (Token::Text(name.to_string()), v) + } + } else if let Some(capts) = RE_TAG.captures(&token) { + let user_id = capts.get(1).unwrap().as_str().parse().unwrap(); + if let Ok(member) = http.get_user(user_id).await { + (Token::Text(format!("@{}", member.name)), v) + } else { + (Token::Text("@deleted_user".to_string()), v) + } + } else if let Some(capts) = RE_CHAN.captures(&token) { + let chan_id = capts.get(1).unwrap().as_str().parse().unwrap(); + match http.get_channel(chan_id).await { + Ok(Channel::Guild(channel)) => (Token::Text(format!("#{}", channel.name)), v), + Ok(Channel::Category(channel)) => (Token::Text(format!("#{}", channel.name)), v), + _ => (Token::Text("#deleted_channel".to_string()), v) + } + } else { + (Token::Text(token), v) + } + }).collect_vec()).await + } + + pub async fn cloud(&self, ctx: &Context, member: &Member) -> RgbaImage { + let color = member.colour(&ctx.cache).unwrap_or(Color::from_rgb(255, 255, 255)); + let member_id = member.user.id; + let tokens = self.idioms.get(&member.guild_id).unwrap().idiom(member_id); + trace!(target: "wordy", "/cloud: retrieved {} tokens for {}", tokens.len(), member.user.name); + let wc_tokens = self.to_wc_tokens(tokens, &ctx.http).await; + WordCloud::new() + .colors(Colors::BiaisedRainbow { + anchor: convert_color(color), + variance: 50. + }).generate(wc_tokens) + } + + pub fn emojis(&self, guild_id: GuildId) -> Result { + if let ( + Some(recent_emos), + Some(server_emos) + ) = ( + self.recents_emos.get(&guild_id), + self.servers_emos.get(&guild_id) + ) { + let counts = recent_emos.counts(); + let mut png_ranking = Vec::new(); + let mut gif_ranking = Vec::new(); + for (emoji_id, emoji) in server_emos.iter() { + if emoji.animated { + gif_ranking.push((*emoji_id, *(counts.get(emoji_id).unwrap_or(&0)))); + } else { + png_ranking.push((*emoji_id, *(counts.get(emoji_id).unwrap_or(&0)))); + } + } + png_ranking.sort_by_key(|(_, count)| *count); + gif_ranking.sort_by_key(|(_, count)| *count); + Ok(EmojiRankings { png_ranking, gif_ranking}) + } else { + bail!("Guild is not yet registered") + } + } +} diff --git a/src/wordy_commands.rs b/src/wordy_commands.rs new file mode 100644 index 0000000..9117cee --- /dev/null +++ b/src/wordy_commands.rs @@ -0,0 +1,129 @@ +use std::{io::{Cursor, Seek, SeekFrom}, sync::Arc}; +use log::{trace, info}; +use image::{write_buffer_with_format, ColorType, ImageOutputFormat}; +use anyhow::{Result, bail, Context as ErrContext, anyhow}; +use serenity::{http::Http, model::{ + prelude::{GuildId, Guild}, + application::interaction::{ + application_command::ApplicationCommandInteraction, + InteractionResponseType::ChannelMessageWithSource, + }, Timestamp}, prelude::Context}; +use crate::{wordy::{Wordy, register_guild, read_message}, discord_util::read_past}; +const READ_PAST: u64 = 1000; +const DAYS: i64 = 100; + +impl Wordy { + pub async fn cloud_command(&self, ctx: Context, command: ApplicationCommandInteraction) -> Result<()> { + if command.guild_id.is_none() { + bail!("Command wasn't invoked in a Guild."); + } + let member = command.member.as_ref().ok_or(anyhow!("Couldn't get member."))?; + let image = self.cloud(&ctx, &member).await; + let mut img_file = Cursor::new(Vec::new()); + write_buffer_with_format( + &mut img_file, + image.as_raw(), + image.width(), + image.height(), + ColorType::Rgba8, + ImageOutputFormat::Png, + ) + .unwrap(); + img_file.seek(SeekFrom::Start(0)).unwrap(); + let img_vec = img_file.into_inner(); + ( + command + .create_interaction_response(&ctx.http, |response| { + response + .kind(ChannelMessageWithSource) + .interaction_response_data( + |message| message.add_file(( + img_vec.as_slice(), + format!("WordCloud_{}.png", member.display_name()).as_str() + )) + ) + }) + .await + ).context("Command create response failed") + } + + pub async fn emojis_command(&self, ctx: Context, command: ApplicationCommandInteraction) -> Result<()> { + let guild_id = command.guild_id.as_ref().ok_or(anyhow!("Couldn't get member."))?; + let emoji_rankings = self.emojis(*guild_id); + todo!() + } + + pub async fn info_command(&self, ctx: Context, command: ApplicationCommandInteraction) -> Result<()> { + ( + command + .create_interaction_response(&ctx.http, |response| { + response + .kind(ChannelMessageWithSource) + .interaction_response_data( + |message| message.content( + "Made with ❤️ by Inspi#8989\n + Repository: " + ) + ) + }).await + ).context("Command create response failed") + } + + pub async fn register_commands(&self, http: Arc, guild_id: GuildId) { + trace!("Registering slash commands for Guild {}", guild_id); + if let Err(why) = + GuildId::set_application_commands(&guild_id, http, |commands| { + commands + .create_application_command(|command| { + command.name("cloud").description( + "Discover the word cloud that defines you !", + ) + }) + .create_application_command(|command| { + command.name("emojis").description( + "Recent emoji usage stats.", + ) + }) + .create_application_command(|command| { + command + .name("info") + .description("Information about this bot.") + }) + }).await { + println!("Couldn't register slash commmands: {}", why); + }; + } + + pub async fn register_guild(&self, http: Arc, guild: Guild) { + // only read messages that are less than 100 days old + let cutoff_date = Timestamp::from_unix_timestamp( + Timestamp::now().unix_timestamp() - 3600*24*DAYS + ).unwrap(); + if let Ok(channels) = guild.channels(&http).await { + if !register_guild( + &guild, + self.idioms.clone(), + self.recents_emos.clone(), + self.servers_emos.clone() + ) { + return; + } + let http = Arc::clone(&http); + let idioms = Arc::clone(&self.idioms); + let recents_emos = Arc::clone(&self.recents_emos); + let servers_emos = Arc::clone(&self.servers_emos); + tokio::spawn(async move { + for (_channel_id, channel) in channels { + let messages = read_past(&http, &channel, READ_PAST, cutoff_date).await; + let len = messages.len(); + for message in messages { + read_message(guild.id, message, idioms.clone(), recents_emos.clone(), servers_emos.clone()); + } + if len > 0 { + info!(target: "wordy", "Read {} past messages in {}/{}", len, guild.name, channel.name()) + } + } + }); + } + } +} \ No newline at end of file diff --git a/src/handle_events.rs b/src/wordy_events.rs similarity index 53% rename from src/handle_events.rs rename to src/wordy_events.rs index 97f4abe..69a49e3 100644 --- a/src/handle_events.rs +++ b/src/wordy_events.rs @@ -1,31 +1,37 @@ +use std::collections::HashMap; +use anyhow::anyhow; use serenity::{ model:: { application::interaction::{ Interaction, }, gateway::Ready, - guild::Guild, prelude::Message, + guild::Guild, prelude::{Message, GuildId, EmojiId, Emoji}, }, async_trait, prelude::* }; -use log::{info, trace}; -use crate::handler_util::{response, is_writable}; -use crate::handler::Handler; +use log::{info, trace, warn}; +use crate::discord_util::{response, is_writable}; +use crate::wordy::Wordy; #[async_trait] -impl EventHandler for Handler { +impl EventHandler for Wordy { async fn interaction_create(&self, ctx: Context, interaction: Interaction) { match interaction { Interaction::ApplicationCommand(command) => { // only answer if the bot has access to the channel if is_writable(&ctx, command.channel_id).await { - match command.data.name.as_str() { - "cloud" => self.cloud(ctx, command).await, - "info" => self.info(ctx, command).await, - _ => {} - }; + let command_name = command.data.name.to_string(); + if let Err(why) = match command_name.as_str() { + "cloud" => self.cloud_command(ctx, command).await, + "emojis" => self.emojis_command(ctx, command).await, + "info" => self.info_command(ctx, command).await, + _ => Err(anyhow!("Unknown command")) + } { + warn!(target: "wordy", "\\{}: {}", command_name, why); + } } else { response( &ctx.http, @@ -33,7 +39,7 @@ impl EventHandler for Handler { "Sorry, I only answer to commands in the channels that I can write to.", ).await; } - } + }, _ => {} } } @@ -48,9 +54,18 @@ impl EventHandler for Handler { } async fn message(&self, _ctx: Context, message: Message) { - if let Some(guild_id) = message.guild_id { + if message.guild_id.is_some() { trace!(target: "wordy", "Read a new message from {}", message.author.name); - self.message(guild_id, message.channel_id, message.author.id, message.content); + self.message(message); } } + + async fn guild_emojis_update( + &self, + _ctx: Context, + guild_id: GuildId, + current_state: HashMap, + ) { + self.servers_emos.insert(guild_id, current_state.clone()); + } }