Skip to content

Commit

Permalink
Switch to DiscordRestClient and refactor caching logic
Browse files Browse the repository at this point in the history
Migrated from DiscordSocketClient to DiscordRestClient in all relevant controllers. Updated methods to improve caching strategy for guilds and channels. Removed Discord.Addons.Hosting package and revised constructors and dependency injections accordingly.
  • Loading branch information
EpicOfficer committed Oct 7, 2024
1 parent 2c406d2 commit 121c10f
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 59 deletions.
1 change: 0 additions & 1 deletion Blink3.API/Blink3.API.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

<ItemGroup>
<PackageReference Include="AspNet.Security.OAuth.Discord" Version="8.2.0" />
<PackageReference Include="Discord.Addons.Hosting" Version="6.1.0" />
<PackageReference Include="Discord.Net" Version="3.16.0" />
<PackageReference Include="Discord.Net.Rest" Version="3.16.0" />
<PackageReference Include="Microsoft.AspNetCore.JsonPatch" Version="8.0.8" />
Expand Down
33 changes: 15 additions & 18 deletions Blink3.API/Controllers/ApiControllerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
using Blink3.Core.DiscordAuth.Extensions;
using Blink3.Core.Models;
using Discord;
using Discord.Addons.Hosting.Util;
using Discord.Rest;
using Discord.WebSocket;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Swashbuckle.AspNetCore.Annotations;
Expand All @@ -25,10 +23,12 @@ namespace Blink3.API.Controllers;
[Route("api/[controller]")]
[ApiController]
[Authorize]
public abstract class ApiControllerBase(DiscordSocketClient discordSocketClient, ICachingService cachingService, IEncryptionService encryptionService) : ControllerBase
public abstract class ApiControllerBase(DiscordRestClient botClient,
Func<DiscordRestClient> userClientFactory,
ICachingService cachingService,
IEncryptionService encryptionService) : ControllerBase
{
private DiscordRestClient? _client;
protected readonly DiscordSocketClient DiscordBotClient = discordSocketClient;
private readonly DiscordRestClient _userClient = userClientFactory();

/// <summary>
/// Represents an Unauthorized Access message.
Expand Down Expand Up @@ -90,32 +90,26 @@ private ObjectResult ProblemForUnauthorizedAccess()
return userId != UserId ? ProblemForUnauthorizedAccess() : null;
}

private async Task InitDiscordClientAsync()
private async Task AuthenticateUserClientAsync()
{
await DiscordBotClient.WaitForReadyAsync(CancellationToken.None);
if (_client is not null) return;

string? encryptedToken = await cachingService.GetAsync<string>($"token:{UserId}");
string? iv = await cachingService.GetAsync<string>($"token:{UserId}:iv");
if (encryptedToken is null || iv is null) return;

string accessToken = encryptionService.Decrypt(encryptedToken, iv);

_client = new DiscordRestClient();
await _client.LoginAsync(TokenType.Bearer, accessToken);
await _userClient.LoginAsync(TokenType.Bearer, accessToken);
}

protected async Task<List<DiscordPartialGuild>> GetUserGuilds()
{
await InitDiscordClientAsync();
await AuthenticateUserClientAsync();

List<DiscordPartialGuild> managedGuilds = await cachingService.GetOrAddAsync($"discord:guilds:{UserId}",
async () =>
{
List<DiscordPartialGuild> manageable = [];
if (_client is null) return manageable;

IAsyncEnumerable<IReadOnlyCollection<RestUserGuild>> guilds = _client.GetGuildSummariesAsync();
IAsyncEnumerable<IReadOnlyCollection<RestUserGuild>> guilds = _userClient.GetGuildSummariesAsync();
await foreach (IReadOnlyCollection<RestUserGuild> guildCollection in guilds)
{
manageable.AddRange(guildCollection.Where(g => g.Permissions.ManageGuild).Select(g =>
Expand All @@ -130,7 +124,10 @@ protected async Task<List<DiscordPartialGuild>> GetUserGuilds()
return manageable;
}, TimeSpan.FromMinutes(5));

List<ulong> discordGuildIds = DiscordBotClient.Guilds.Select(b => b.Id).ToList();
List<ulong> discordGuildIds = await botClient.GetGuildSummariesAsync()
.SelectMany(guildCollection => guildCollection.ToAsyncEnumerable())
.Select(guild => guild.Id)
.ToListAsync();
return managedGuilds.Where(g => discordGuildIds.Contains(g.Id)).ToList();
}

Expand All @@ -149,7 +146,7 @@ protected async Task<List<DiscordPartialGuild>> GetUserGuilds()

~ApiControllerBase()
{
_client?.Dispose();
_client = null;
_userClient.LogoutAsync().Wait();
_userClient.Dispose();
}
}
12 changes: 8 additions & 4 deletions Blink3.API/Controllers/BlinkGuildsController.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
using Blink3.API.Interfaces;
using Blink3.Core.Caching;
using Blink3.Core.DTOs;
using Blink3.Core.Entities;
using Blink3.Core.Models;
using Blink3.Core.Repositories.Interfaces;
using Discord.WebSocket;
using Discord.Rest;
using Microsoft.AspNetCore.JsonPatch;
using Microsoft.AspNetCore.Mvc;
using Swashbuckle.AspNetCore.Annotations;
Expand All @@ -15,8 +14,12 @@ namespace Blink3.API.Controllers;
/// Controller for performing CRUD operations on BlinkGuild items.
/// </summary>
[SwaggerTag("All CRUD operations for BlinkGuild items")]
public class BlinkGuildsController(DiscordSocketClient discordSocketClient, ICachingService cachingService, IEncryptionService encryptionService, IBlinkGuildRepository blinkGuildRepository)
: ApiControllerBase(discordSocketClient, cachingService, encryptionService)
public class BlinkGuildsController(DiscordRestClient botClient,
Func<DiscordRestClient> userClientFactory,
ICachingService cachingService,
IEncryptionService encryptionService,
IBlinkGuildRepository blinkGuildRepository)
: ApiControllerBase(botClient, userClientFactory, cachingService, encryptionService)
{
/// <summary>
/// Retrieves all BlinkGuild items that are manageable by the logged in user.
Expand Down Expand Up @@ -64,6 +67,7 @@ public async Task<ActionResult<UserTodo>> GetBlinkGuild(ulong id)
/// </summary>
/// <param name="id">The ID of the BlinkGuild item to update.</param>
/// <param name="blinkGuild">The updated BlinkGuild item data.</param>
/// <param name="cancellationToken"></param>
/// <returns>
/// No content.
/// </returns>
Expand Down
64 changes: 43 additions & 21 deletions Blink3.API/Controllers/GuildsController.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
using Blink3.API.Interfaces;
using Blink3.Core.Caching;
using Blink3.Core.Models;
using Discord.WebSocket;
using Discord.Rest;
using Microsoft.AspNetCore.Mvc;
using Swashbuckle.AspNetCore.Annotations;

namespace Blink3.API.Controllers;

[SwaggerTag("Endpoints for getting information on discord guilds")]
public class GuildsController(DiscordSocketClient discordSocketClient, ICachingService cachingService, IEncryptionService encryptionService)
: ApiControllerBase(discordSocketClient, cachingService, encryptionService)
public class GuildsController(DiscordRestClient botClient,
Func<DiscordRestClient> userClientFactory,
ICachingService cachingService,
IEncryptionService encryptionService)
: ApiControllerBase(botClient, userClientFactory, cachingService, encryptionService)
{
private readonly DiscordRestClient _botClient = botClient;
private readonly ICachingService _cachingService = cachingService;

[HttpGet]
[SwaggerOperation(
Summary = "Returns all Discord guilds",
Expand Down Expand Up @@ -39,15 +45,23 @@ public async Task<ActionResult<IReadOnlyCollection<DiscordPartialChannel>>> GetC
ObjectResult? accessCheckResult = await CheckGuildAccessAsync(id);
if (accessCheckResult is not null) return accessCheckResult;

return DiscordBotClient.GetGuild(id).CategoryChannels
.OrderBy(c => c.Position)
.Select(c =>
new DiscordPartialChannel
{
Id = c.Id,
Name = c.Name
})
.ToList();
string cacheKey = $"guild_{id}_categories";
IReadOnlyCollection<DiscordPartialChannel> categories = await _cachingService.GetOrAddAsync(cacheKey, async () =>
{
RestGuild? guild = await _botClient.GetGuildAsync(id);
IReadOnlyCollection<RestCategoryChannel>? categories = await guild.GetCategoryChannelsAsync();
return categories
.OrderBy(c => c.Position)
.Select(c =>
new DiscordPartialChannel
{
Id = c.Id,
Name = c.Name
})
.ToList();
}, TimeSpan.FromMinutes(5));

return Ok(categories);
}

[HttpGet("{id}/channels")]
Expand All @@ -63,14 +77,22 @@ public async Task<ActionResult<IReadOnlyCollection<DiscordPartialChannel>>> GetC
ObjectResult? accessCheckResult = await CheckGuildAccessAsync(id);
if (accessCheckResult is not null) return accessCheckResult;

return DiscordBotClient.GetGuild(id).TextChannels
.OrderBy(c => c.Position)
.Select(c =>
new DiscordPartialChannel
{
Id = c.Id,
Name = c.Name
})
.ToList();
string cacheKey = $"guild_{id}_channels";
IReadOnlyCollection<DiscordPartialChannel> channels = await _cachingService.GetOrAddAsync(cacheKey, async () =>
{
RestGuild? guild = await _botClient.GetGuildAsync(id);
IReadOnlyCollection<RestTextChannel>? channels = await guild.GetTextChannelsAsync();
return channels
.OrderBy(c => c.Position)
.Select(c =>
new DiscordPartialChannel
{
Id = c.Id,
Name = c.Name
})
.ToList();
}, TimeSpan.FromMinutes(5));

return Ok(channels);
}
}
10 changes: 7 additions & 3 deletions Blink3.API/Controllers/TodoController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using Blink3.Core.DTOs;
using Blink3.Core.Entities;
using Blink3.Core.Repositories.Interfaces;
using Discord.WebSocket;
using Discord.Rest;
using Microsoft.AspNetCore.Mvc;
using Swashbuckle.AspNetCore.Annotations;

Expand All @@ -13,8 +13,12 @@ namespace Blink3.API.Controllers;
/// Controller for performing CRUD operations on userTodo items.
/// </summary>
[SwaggerTag("All CRUD operations for todo items")]
public class TodoController(DiscordSocketClient discordSocketClient, ICachingService cachingService, IEncryptionService encryptionService, IUserTodoRepository todoRepository)
: ApiControllerBase(discordSocketClient, cachingService, encryptionService)
public class TodoController(DiscordRestClient botClient,
Func<DiscordRestClient> userClientFactory,
ICachingService cachingService,
IEncryptionService encryptionService,
IUserTodoRepository todoRepository)
: ApiControllerBase(botClient, userClientFactory, cachingService, encryptionService)
{
/// <summary>
/// Retrieves all userTodo items for the current user.
Expand Down
24 changes: 12 additions & 12 deletions Blink3.API/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
using Blink3.Core.Helpers;
using Blink3.DataAccess.Extensions;
using Discord;
using Discord.Addons.Hosting;
using Discord.WebSocket;
using Discord.Rest;
using Microsoft.AspNetCore.HttpOverrides;
using Serilog;
using Serilog.Events;
Expand Down Expand Up @@ -44,17 +43,18 @@
builder.Services.AddAppConfiguration(builder.Configuration);
BlinkConfiguration appConfig = builder.Services.GetAppConfiguration();

// Discord socket client
builder.Services.AddDiscordHost((config, _) =>
// Discord bot client
builder.Services.AddSingleton<DiscordRestClient>(_ =>
{
config.SocketConfig = new DiscordSocketConfig
{
LogLevel = LogSeverity.Verbose,
MessageCacheSize = 0,
GatewayIntents = GatewayIntents.Guilds
};

config.Token = appConfig.Discord.BotToken;
DiscordRestClient client = new();
client.LoginAsync(TokenType.Bot, appConfig.Discord.BotToken).Wait();
return client;
});

// Discord user client
builder.Services.AddScoped<Func<DiscordRestClient>>(_ =>
{
return () => new DiscordRestClient();
});

// Add forwarded headers
Expand Down

0 comments on commit 121c10f

Please sign in to comment.