Skip to content

Commit

Permalink
feat: Implement cooldowns for application commands (#431)
Browse files Browse the repository at this point in the history
* feat: make cooldowns work

feat: move cooldown to discatsharp base package
chore: some experimental stuff
fix: commandsnext cooldowns

* [ci skip] chore: resharper disable for sealed on CooldownBucket

* fix: try fixing translations (while i'm on it)

* fix: really fix translation export

* fix: fixed translations (while i'm on it)

* feat: add by discord added allowed locales

* [ci skip] chore: update release notes

* fix: fix registration of applicaiton commands when commands cleared or non existent in advance

* feat: custom cooldown responder

* docs: fix space -> tab

* Update RELEASENOTES.md

* fix: fix registration of translations for subcommands

* fix: nre

* chore: remove command grouping type

---------

Co-authored-by: Mira <56395159+TheXorog@users.noreply.github.com>
  • Loading branch information
Lulalaby and TheXorog authored Jan 17, 2024
1 parent 237d1f0 commit 8ab32b6
Show file tree
Hide file tree
Showing 42 changed files with 746 additions and 1,098 deletions.
74 changes: 48 additions & 26 deletions DisCatSharp.ApplicationCommands/ApplicationCommandsExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
using DisCatSharp.ApplicationCommands.EventArgs;
using DisCatSharp.ApplicationCommands.Exceptions;
using DisCatSharp.ApplicationCommands.Workers;
using DisCatSharp.Attributes;
using DisCatSharp.Common;
using DisCatSharp.Common.Utilities;
using DisCatSharp.Entities;
using DisCatSharp.Enums;
using DisCatSharp.Enums.Core;
using DisCatSharp.EventArgs;
using DisCatSharp.Exceptions;

Expand Down Expand Up @@ -661,39 +661,37 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
if (Configuration.GenerateTranslationFilesOnly)
{
var cgwsgs = new List<CommandGroupWithSubGroups>();
var cgs2 = new List<CommandGroup>();
foreach (var cmd in slashGroupsTuple.applicationCommands)
if (cmd.Type is ApplicationCommandType.ChatInput)
{
var cgs = new List<CommandGroup>();
var cs2 = new List<Command>();
if (cmd.Options is not null)
{
foreach (var scg in cmd.Options.Where(x => x.Type is ApplicationCommandOptionType.SubCommandGroup))
{
var cs = new List<Command>();
if (scg.Options is not null)
foreach (var sc in scg.Options)
if (sc.Options is null || sc.Options.Count is 0)
cs.Add(new(sc.Name, sc.Description, null, null));
cs.Add(new(sc.Name, sc.Description, null, null, sc.RawNameLocalizations, sc.RawDescriptionLocalizations));
else
cs.Add(new(sc.Name, sc.Description, [.. sc.Options], null));
cgs.Add(new(scg.Name, scg.Description, cs, null));
cs.Add(new(sc.Name, sc.Description, [.. sc.Options], null, sc.RawNameLocalizations, sc.RawDescriptionLocalizations));
cgs.Add(new(scg.Name, scg.Description, cs, null, scg.RawNameLocalizations, scg.RawDescriptionLocalizations));
}

cgwsgs.Add(new(cmd.Name, cmd.Description, cgs, cmd.Type));
foreach (var sc2 in cmd.Options.Where(x => x.Type is ApplicationCommandOptionType.SubCommand))
if (sc2.Options == null || sc2.Options.Count == 0)
cs2.Add(new(sc2.Name, sc2.Description, null, null, sc2.RawNameLocalizations, sc2.RawDescriptionLocalizations));
else
cs2.Add(new(sc2.Name, sc2.Description, [.. sc2.Options], null, sc2.RawNameLocalizations, sc2.RawDescriptionLocalizations));
}

var cs2 = new List<Command>();
foreach (var sc2 in cmd.Options.Where(x => x.Type is ApplicationCommandOptionType.SubCommand))
if (sc2.Options == null || sc2.Options.Count == 0)
cs2.Add(new(sc2.Name, sc2.Description, null, null));
else
cs2.Add(new(sc2.Name, sc2.Description, [.. sc2.Options], null));
cgs2.Add(new(cmd.Name, cmd.Description, cs2, cmd.Type));
cgwsgs.Add(new(cmd.Name, cmd.Description, cgs, cs2, cmd.Type, cmd.RawNameLocalizations, cmd.RawDescriptionLocalizations));
}

if (cgwsgs.Count is not 0)
groupTranslation.AddRange(cgwsgs.Select(cgwsg => JsonConvert.DeserializeObject<GroupTranslator>(JsonConvert.SerializeObject(cgwsg))!));
if (cgs2.Count is not 0)
groupTranslation.AddRange(cgs2.Select(cg2 => JsonConvert.DeserializeObject<GroupTranslator>(JsonConvert.SerializeObject(cg2))!));
}
}

Expand Down Expand Up @@ -733,12 +731,20 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
var cs = new List<Command>();
foreach (var cmd in slashCommands.applicationCommands.Where(cmd => cmd.Type is ApplicationCommandType.ChatInput && (cmd.Options is null || !cmd.Options.Any(x => x.Type is ApplicationCommandOptionType.SubCommand or ApplicationCommandOptionType.SubCommandGroup))))
if (cmd.Options == null || cmd.Options.Count == 0)
cs.Add(new(cmd.Name, cmd.Description, null, ApplicationCommandType.ChatInput));
cs.Add(new(cmd.Name, cmd.Description, null, ApplicationCommandType.ChatInput, cmd.RawNameLocalizations, cmd.RawDescriptionLocalizations));
else
cs.Add(new(cmd.Name, cmd.Description, [.. cmd.Options], ApplicationCommandType.ChatInput));
cs.Add(new(cmd.Name, cmd.Description, [.. cmd.Options], ApplicationCommandType.ChatInput, cmd.RawNameLocalizations, cmd.RawDescriptionLocalizations));

if (cs.Count is not 0)
translation.AddRange(cs.Select(c => JsonConvert.DeserializeObject<CommandTranslator>(JsonConvert.SerializeObject(c))!));
//translation.AddRange(cs.Select(c => JsonConvert.DeserializeObject<CommandTranslator>(JsonConvert.SerializeObject(c))!));
{
foreach (var c in cs)
{
var json = JsonConvert.SerializeObject(c);
var obj = JsonConvert.DeserializeObject<CommandTranslator>(json);
translation.Add(obj!);
}
}
}
}

Expand Down Expand Up @@ -804,7 +810,7 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
{
updateList = updateList.DistinctBy(x => x.Name).ToList();
if (Configuration.GenerateTranslationFilesOnly)
await this.CheckRegistrationStartup(translation, groupTranslation);
await this.CheckRegistrationStartup(translation, groupTranslation, guildId);
else
try
{
Expand Down Expand Up @@ -911,7 +917,7 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
RegisteredCommands = GlobalCommandsInternal
}).ConfigureAwait(false);

await this.CheckRegistrationStartup(translation, groupTranslation);
await this.CheckRegistrationStartup(translation, groupTranslation, guildId);
}
catch (NullReferenceException ex)
{
Expand Down Expand Up @@ -965,15 +971,16 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
/// </summary>
/// <param name="translation">The optional translations.</param>
/// <param name="groupTranslation">The optional group translations.</param>
private async Task CheckRegistrationStartup(List<CommandTranslator>? translation = null, List<GroupTranslator>? groupTranslation = null)
/// <param name="guildId">The optional guild id.</param>
private async Task CheckRegistrationStartup(List<CommandTranslator>? translation = null, List<GroupTranslator>? groupTranslation = null, ulong? guildId = null)
{
if (Configuration.GenerateTranslationFilesOnly)
{
try
{
if (translation is not null && translation.Count is not 0)
{
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-SINGLE.json";
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-SINGLE-{(guildId.HasValue ? guildId.Value : "global")}.json";
var fs = File.Create(fileName);
var ms = new MemoryStream();
var writer = new StreamWriter(ms);
Expand All @@ -991,7 +998,7 @@ private async Task CheckRegistrationStartup(List<CommandTranslator>? translation

if (groupTranslation is not null && groupTranslation.Count is not 0)
{
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-GROUP.json";
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-GROUP-{(guildId.HasValue ? guildId.Value : "global")}.json";
var fs = File.Create(fileName);
var ms = new MemoryStream();
var writer = new StreamWriter(ms);
Expand Down Expand Up @@ -1030,6 +1037,8 @@ private async Task CheckStartupFinishAsync(ApplicationCommandsExtension sender,
GuildsWithoutScope = s_missingScopeGuildIdsGlobal
}).ConfigureAwait(false);
FinishFired = true;
if (Configuration.GenerateTranslationFilesOnly)
Environment.Exit(0);
}

args.Handled = false;
Expand Down Expand Up @@ -1081,7 +1090,11 @@ private Task InteractionHandler(DiscordClient client, InteractionCreateEventArgs
GuildLocale = e.Interaction.GuildLocale,
AppPermissions = e.Interaction.AppPermissions,
Entitlements = e.Interaction.Entitlements,
EntitlementSkuIds = e.Interaction.EntitlementSkuIds
EntitlementSkuIds = e.Interaction.EntitlementSkuIds,
UserId = e.Interaction.User.Id,
GuildId = e.Interaction.GuildId,
MemberId = e.Interaction.GuildId is not null ? e.Interaction.User.Id : null,
ChannelId = e.Interaction.ChannelId
};

try
Expand Down Expand Up @@ -1340,7 +1353,12 @@ private Task ContextMenuHandler(DiscordClient client, ContextMenuInteractionCrea
_ = Task.Run(async () =>
{
//Creates the context
var context = new ContextMenuContext
var context = new ContextMenuContext(e.Type switch
{
ApplicationCommandType.User => DisCatSharpCommandType.UserCommand,
ApplicationCommandType.Message => DisCatSharpCommandType.MessageCommand,
_ => throw new ArgumentOutOfRangeException(nameof(e.Type), "Unknown context menu type")
})
{
Interaction = e.Interaction,
Channel = e.Interaction.Channel,
Expand All @@ -1359,7 +1377,11 @@ private Task ContextMenuHandler(DiscordClient client, ContextMenuInteractionCrea
GuildLocale = e.Interaction.GuildLocale,
AppPermissions = e.Interaction.AppPermissions,
Entitlements = e.Interaction.Entitlements,
EntitlementSkuIds = e.Interaction.EntitlementSkuIds
EntitlementSkuIds = e.Interaction.EntitlementSkuIds,
UserId = e.Interaction.User.Id,
GuildId = e.Interaction.GuildId,
MemberId = e.Interaction.GuildId is not null ? e.Interaction.User.Id : null,
ChannelId = e.Interaction.ChannelId
};

try
Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,62 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Globalization;
using System.Threading.Tasks;

using DisCatSharp.ApplicationCommands.Context;
using DisCatSharp.ApplicationCommands.Entities;
using DisCatSharp.ApplicationCommands.Enums;
using DisCatSharp.Entities;
using DisCatSharp.Entities.Core;
using DisCatSharp.Enums;
using DisCatSharp.Enums.Core;

using Sentry;

namespace DisCatSharp.ApplicationCommands.Attributes;

/// <summary>
/// Defines a cooldown for this command. This allows you to define how many times can users execute a specific command
/// </summary>
/// <remarks>
/// Defines a cooldown for this command. This means that users will be able to use the command a specific number of times before they have to wait to use it again.
/// </remarks>
/// <param name="maxUses">Number of times the command can be used before triggering a cooldown.</param>
/// <param name="resetAfter">Number of seconds after which the cooldown is reset.</param>
/// <param name="bucketType">Type of cooldown bucket. This allows controlling whether the bucket will be cooled down per user, guild, member, channel, and/or globally.</param>
/// <param name="cooldownResponderType">The responder type used to respond to cooldown ratelimit hits.</param>
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public sealed class ContextMenuCooldownAttribute : ApplicationCommandCheckBaseAttribute, ICooldown<BaseContext, ContextMenuCooldownBucket>
public sealed class ContextMenuCooldownAttribute(int maxUses, double resetAfter, CooldownBucketType bucketType, Type? cooldownResponderType = null) : ApplicationCommandCheckBaseAttribute, ICooldown<BaseContext, CooldownBucket>
{
/// <summary>
/// Gets the maximum number of uses before this command triggers a cooldown for its bucket.
/// </summary>
public int MaxUses { get; }
public int MaxUses { get; } = maxUses;

/// <summary>
/// Gets the time after which the cooldown is reset.
/// </summary>
public TimeSpan Reset { get; }
public TimeSpan Reset { get; } = TimeSpan.FromSeconds(resetAfter);

/// <summary>
/// Gets the type of the cooldown bucket. This determines how cooldowns are applied.
/// </summary>
public CooldownBucketType BucketType { get; }

/// <summary>
/// Gets the cooldown buckets for this command.
/// </summary>
internal readonly ConcurrentDictionary<string, ContextMenuCooldownBucket> Buckets;
public CooldownBucketType BucketType { get; } = bucketType;

/// <summary>
/// Defines a cooldown for this command. This means that users will be able to use the command a specific number of times before they have to wait to use it again.
/// Gets the responder type.
/// </summary>
/// <param name="maxUses">Number of times the command can be used before triggering a cooldown.</param>
/// <param name="resetAfter">Number of seconds after which the cooldown is reset.</param>
/// <param name="bucketType">Type of cooldown bucket. This allows controlling whether the bucket will be cooled down per user, guild, channel, or globally.</param>
public ContextMenuCooldownAttribute(int maxUses, double resetAfter, CooldownBucketType bucketType)
{
this.MaxUses = maxUses;
this.Reset = TimeSpan.FromSeconds(resetAfter);
this.BucketType = bucketType;
this.Buckets = new();
}
public Type? ResponderType { get; } = cooldownResponderType;

/// <summary>
/// Gets a cooldown bucket for given command context.
/// </summary>
/// <param name="ctx">Command context to get cooldown bucket for.</param>
/// <returns>Requested cooldown bucket, or null if one wasn't present.</returns>
public ContextMenuCooldownBucket GetBucket(BaseContext ctx)
public CooldownBucket GetBucket(BaseContext ctx)
{
var bid = this.GetBucketId(ctx, out _, out _, out _);
this.Buckets.TryGetValue(bid, out var bucket);
return bucket;
var bid = this.GetBucketId(ctx, out _, out _, out _, out _);
ctx.Client.CommandCooldownBuckets.TryGetValue(bid, out var bucket);
return bucket!;
}

/// <summary>
Expand All @@ -68,7 +67,7 @@ public ContextMenuCooldownBucket GetBucket(BaseContext ctx)
public TimeSpan GetRemainingCooldown(BaseContext ctx)
{
var bucket = this.GetBucket(ctx);
return bucket == null
return bucket == null!
? TimeSpan.Zero
: bucket.RemainingUses > 0
? TimeSpan.Zero
Expand All @@ -82,8 +81,9 @@ public TimeSpan GetRemainingCooldown(BaseContext ctx)
/// <param name="userId">ID of the user with which this bucket is associated.</param>
/// <param name="channelId">ID of the channel with which this bucket is associated.</param>
/// <param name="guildId">ID of the guild with which this bucket is associated.</param>
/// <param name="memberId">ID of the member with which this bucket is associated.</param>
/// <returns>Calculated bucket ID.</returns>
private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelId, out ulong guildId)
private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelId, out ulong guildId, out ulong memberId)
{
userId = 0ul;
if ((this.BucketType & CooldownBucketType.User) != 0)
Expand All @@ -92,14 +92,16 @@ private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelI
channelId = 0ul;
if ((this.BucketType & CooldownBucketType.Channel) != 0)
channelId = ctx.Channel.Id;
if ((this.BucketType & CooldownBucketType.Guild) != 0 && ctx.Guild == null)
channelId = ctx.Channel.Id;

guildId = 0ul;
if (ctx.Guild != null && (this.BucketType & CooldownBucketType.Guild) != 0)
if (ctx.Guild is not null && (this.BucketType & CooldownBucketType.Guild) != 0)
guildId = ctx.Guild.Id;

var bid = CooldownBucket.MakeId(userId, channelId, guildId);
memberId = 0ul;
if (ctx.Guild is not null && ctx.Member is not null && (this.BucketType & CooldownBucketType.Member) != 0)
memberId = ctx.Member.Id;

var bid = CooldownBucket.MakeId(ctx.FullCommandName, ctx.Interaction.Data.Id.ToString(CultureInfo.InvariantCulture), userId, channelId, guildId, memberId);
return bid;
}

Expand All @@ -109,29 +111,36 @@ private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelI
/// <param name="ctx">The command context.</param>
public override async Task<bool> ExecuteChecksAsync(BaseContext ctx)
{
var bid = this.GetBucketId(ctx, out var usr, out var chn, out var gld);
if (!this.Buckets.TryGetValue(bid, out var bucket))
{
bucket = new(this.MaxUses, this.Reset, usr, chn, gld);
this.Buckets.AddOrUpdate(bid, bucket, (k, v) => bucket);
}
var bid = this.GetBucketId(ctx, out var usr, out var chn, out var gld, out var mem);
if (ctx.Client.CommandCooldownBuckets.TryGetValue(bid, out var bucket))
return await this.RespondRatelimitHitAsync(ctx, await bucket.DecrementUseAsync(ctx), bucket);

bucket = new(this.MaxUses, this.Reset, ctx.FullCommandName, ctx.Interaction.Data.Id.ToString(CultureInfo.InvariantCulture), usr, chn, gld, mem);
ctx.Client.CommandCooldownBuckets.AddOrUpdate(bid, bucket, (k, v) => bucket);

return await bucket.DecrementUseAsync().ConfigureAwait(false);
return await this.RespondRatelimitHitAsync(ctx, await bucket.DecrementUseAsync(ctx), bucket);
}
}

/// <summary>
/// Represents a cooldown bucket for commands.
/// </summary>
public sealed class ContextMenuCooldownBucket : CooldownBucket
{
internal ContextMenuCooldownBucket(int maxUses, TimeSpan resetAfter, ulong userId = 0, ulong channelId = 0, ulong guildId = 0)
: base(maxUses, resetAfter, userId, channelId, guildId)
{ }
/// <inheritdoc/>
public async Task<bool> RespondRatelimitHitAsync(BaseContext ctx, bool noHit, CooldownBucket bucket)
{
if (noHit)
return true;

/// <summary>
/// Returns a string representation of this command cooldown bucket.
/// </summary>
/// <returns>String representation of this command cooldown bucket.</returns>
public override string ToString() => $"Context Menu Command bucket {this.BucketId}";
if (this.ResponderType is null)
{
if (ApplicationCommandsExtension.Configuration.AutoDefer)
await ctx.EditResponseAsync(new DiscordWebhookBuilder().WithContent($"Error: Ratelimit hit\nTry again {bucket.ResetsAt.Timestamp()}"));
else
await ctx.CreateResponseAsync(InteractionResponseType.ChannelMessageWithSource, new DiscordInteractionResponseBuilder().WithContent($"Error: Ratelimit hit\nTry again {bucket.ResetsAt.Timestamp()}").AsEphemeral());

return false;
}

var providerMethod = this.ResponderType.GetMethod(nameof(ICooldownResponder.Responder));
var providerInstance = Activator.CreateInstance(this.ResponderType);
await ((Task)providerMethod.Invoke(providerInstance, [ctx])).ConfigureAwait(false);

return false;
}
}
Loading

0 comments on commit 8ab32b6

Please sign in to comment.