Skip to content
This repository has been archived by the owner on Aug 30, 2023. It is now read-only.

Commit

Permalink
Enable WebSocket communication
Browse files Browse the repository at this point in the history
  • Loading branch information
hhvrc committed Jun 19, 2023
1 parent d506369 commit 4ba940a
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 77 deletions.
3 changes: 2 additions & 1 deletion Common/Common.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
<PackageReference Include="OneOf" Version="3.0.243" />
<PackageReference Include="Quartz.AspNetCore" Version="3.6.2" />
<PackageReference Include="SixLabors.ImageSharp" Version="3.0.1" />
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="6.31.0" />
<PackageReference Include="UAParser" Version="3.1.47" />
</ItemGroup>

<ItemGroup>
<FlatSharpSchema Include="Schemas\**\*.fbs"/>
<FlatSharpSchema Include="Schemas\**\*.fbs" />
</ItemGroup>

<ItemGroup>
Expand Down
8 changes: 4 additions & 4 deletions Common/Helpers/HttpErrors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ public static ErrorDetails UnsupportedSSOProvider(string providerName) =>
public static ErrorDetails DeviceNotFound => Generic(StatusCodes.Status404NotFound, "device_not_found", "Device not found", null, null, new UserNotification(NotificationSeverityLevel.Error, "Device not found"));
public static IActionResult DeviceNotFoundActionResult => DeviceNotFound.ToActionResult();

public static ErrorDetails UnverifiedEmail => HttpErrors.Generic(StatusCodes.Status400BadRequest, "Unverified Email", "Email has not been verified", NotificationSeverityLevel.Warning, "Please verify your email address, and then try again");
public static ErrorDetails UnverifiedEmail => Generic(StatusCodes.Status401Unauthorized, "Unverified Email", "Email has not been verified", NotificationSeverityLevel.Warning, "Please verify your email address, and then try again");
public static IActionResult UnverifiedEmailActionResult => UnverifiedEmail.ToActionResult();

public static ErrorDetails ReviewPrivacyPolicy => HttpErrors.Generic(StatusCodes.Status400BadRequest, "review_privpol", "User needs to accept new Privacy Policy", NotificationSeverityLevel.Error, "Please read and accept the new Privacy Policy, and then try again");
public static ErrorDetails ReviewPrivacyPolicy => Generic(StatusCodes.Status400BadRequest, "review_privpol", "User needs to accept new Privacy Policy", NotificationSeverityLevel.Error, "Please read and accept the new Privacy Policy, and then try again");
public static IActionResult ReviewPrivacyPolicyActionResult => ReviewPrivacyPolicy.ToActionResult();

public static ErrorDetails ReviewTermsOfService => HttpErrors.Generic(StatusCodes.Status400BadRequest, "review_tos", "User needs to accept new Terms of Service", NotificationSeverityLevel.Error, "Please read and accept the new Terms of Service, and then try again");
public static ErrorDetails ReviewTermsOfService => Generic(StatusCodes.Status400BadRequest, "review_tos", "User needs to accept new Terms of Service", NotificationSeverityLevel.Error, "Please read and accept the new Terms of Service, and then try again");
public static IActionResult ReviewTermsOfServiceActionResult => ReviewTermsOfService.ToActionResult();

public static ErrorDetails UserAgentTooLong => HttpErrors.Generic(StatusCodes.Status413RequestEntityTooLarge, "User-Agent too long", $"User-Agent header has a hard limit on {UserAgentLimits.MaxUploadLength} characters", NotificationSeverityLevel.Error, "Unexpected behaviour, please contact developers");
public static ErrorDetails UserAgentTooLong => Generic(StatusCodes.Status413RequestEntityTooLarge, "User-Agent too long", $"User-Agent header has a hard limit on {UserAgentLimits.MaxUploadLength} characters", NotificationSeverityLevel.Error, "Unexpected behaviour, please contact developers");
public static IActionResult UserAgentTooLongActionResult => UserAgentTooLong.ToActionResult();
}
14 changes: 14 additions & 0 deletions Common/Services/Interfaces/IJwtAuthenticationManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Microsoft.IdentityModel.Tokens;
using OneOf;
using System.Security.Claims;
using ZapMe.Database.Models;
using ZapMe.DTOs;

namespace ZapMe.Services.Interfaces;

public interface IJwtAuthenticationManager
{
Task<OneOf<SessionEntity, ErrorDetails>> AuthenticateJwtTokenAsync(string jwtToken, CancellationToken cancellationToken = default);
bool ValidateJwtToken(string jwtToken, out ClaimsPrincipal claimsPrincipal, out SecurityToken validatedToken);
string GenerateJwtToken(ClaimsIdentity claimsIdentity, DateTime issuedAt, DateTime expiresAt);
}
Original file line number Diff line number Diff line change
@@ -1,50 +1,58 @@
using Microsoft.IdentityModel.Tokens;
using Microsoft.Extensions.Options;
using Microsoft.IdentityModel.Tokens;
using OneOf;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Text;
using ZapMe.Constants;
using ZapMe.Database.Models;
using ZapMe.DTOs;
using ZapMe.Helpers;
using ZapMe.Options;
using ZapMe.Services.Interfaces;

namespace ZapMe.Utils;
namespace ZapMe.Services;

public static class JwtTokenUtils
public sealed class JwtAuthenticationManager : IJwtAuthenticationManager
{
public static string GenerateJwtToken(ClaimsIdentity claimsIdentity, DateTime issuedAt, DateTime expiresAt, string jwtSecret)
private readonly ISessionStore _sessionStore;
private readonly JwtOptions _jwtOptions;
private readonly ILogger<JwtAuthenticationManager> _logger;

public JwtAuthenticationManager(ISessionStore sessionStore, IOptions<JwtOptions> jwtOptions, ILogger<JwtAuthenticationManager> logger)
{
ArgumentNullException.ThrowIfNull(claimsIdentity);
ArgumentException.ThrowIfNullOrEmpty(jwtSecret);
_sessionStore = sessionStore;
_jwtOptions = jwtOptions.Value;
_logger = logger;
}

var tokenHandler = new JwtSecurityTokenHandler();
var securityKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(jwtSecret));
var credentials = new SigningCredentials(securityKey, SecurityAlgorithms.HmacSha256);
public async Task<OneOf<SessionEntity, ErrorDetails>> AuthenticateJwtTokenAsync(string jwtToken, CancellationToken cancellationToken)
{
if (!ValidateJwtToken(jwtToken, out ClaimsPrincipal claimsPrincipal, out SecurityToken _))
{
return HttpErrors.Unauthorized;
}

var descriptor = new SecurityTokenDescriptor
if (!claimsPrincipal.GetUserEmailVerified())
{
Subject = claimsIdentity,
IssuedAt = issuedAt,
Expires = expiresAt,
SigningCredentials = credentials,
Issuer = AuthenticationConstants.JwtIssuer,
Audience = AuthenticationConstants.JwtAudience,
};
return HttpErrors.UnverifiedEmail;
}

var securityToken = tokenHandler.CreateToken(descriptor);
return tokenHandler.WriteToken(securityToken);
}
public static string GenerateJwtToken(SessionEntity session, string jwtSecret)
{
ArgumentNullException.ThrowIfNull(session);
return GenerateJwtToken(session.ToClaimsIdentity(), session.CreatedAt, session.ExpiresAt, jwtSecret);
SessionEntity? session = await _sessionStore.TryGetAsync(claimsPrincipal.GetSessionId(), cancellationToken);
if (session is null)
{
return HttpErrors.Unauthorized;
}

return session;
}

public static bool ValidateJwtToken(string jwtToken, string jwtSecret, out ClaimsPrincipal claimsPrincipal, out SecurityToken validatedToken)
public bool ValidateJwtToken(string jwtToken, out ClaimsPrincipal claimsPrincipal, out SecurityToken validatedToken)
{
ArgumentNullException.ThrowIfNull(jwtToken);
ArgumentException.ThrowIfNullOrEmpty(jwtSecret);

var tokenHandler = new JwtSecurityTokenHandler();
var key = Encoding.ASCII.GetBytes(jwtSecret);
var key = Encoding.ASCII.GetBytes(_jwtOptions.SigningKey);

try
{
Expand All @@ -69,4 +77,26 @@ public static bool ValidateJwtToken(string jwtToken, string jwtSecret, out Claim

return true;
}

public string GenerateJwtToken(ClaimsIdentity claimsIdentity, DateTime issuedAt, DateTime expiresAt)
{
ArgumentNullException.ThrowIfNull(claimsIdentity);

var tokenHandler = new JwtSecurityTokenHandler();
var securityKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(_jwtOptions.SigningKey));
var credentials = new SigningCredentials(securityKey, SecurityAlgorithms.HmacSha256);

var descriptor = new SecurityTokenDescriptor
{
Subject = claimsIdentity,
IssuedAt = issuedAt,
Expires = expiresAt,
SigningCredentials = credentials,
Issuer = AuthenticationConstants.JwtIssuer,
Audience = AuthenticationConstants.JwtAudience,
};

var securityToken = tokenHandler.CreateToken(descriptor);
return tokenHandler.WriteToken(securityToken);
}
}
6 changes: 6 additions & 0 deletions Common/Services/SessionStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ public async Task<SessionEntity> CreateAsync(Guid userId, string ipAddress, stri
await _dbContext.Sessions.AddAsync(session, cancellationToken);
await _dbContext.SaveChangesAsync(cancellationToken);

session = await _dbContext
.Sessions
.Include(s => s.User)
.Include(s => s.UserAgent)
.FirstAsync(s => s.Id == session.Id, cancellationToken);

//string sessionKey = RedisCachePrefixes.Session + session.Id.ToString();
//await SetCacheAsync(sessionKey, session, cancellationToken);

Expand Down
6 changes: 3 additions & 3 deletions Common/Websocket/WebSocketInstance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
using FlatSharp;
using System.Buffers;
using System.Net.WebSockets;
using System.Security.Claims;
using ZapMe.Database.Models;

namespace ZapMe.Websocket;

public sealed class WebSocketInstance : IDisposable
{
public static async Task<WebSocketInstance?> CreateAsync(WebSocketManager wsManager, ClaimsPrincipal user, ILogger<WebSocketInstance> logger)
public static async Task<WebSocketInstance?> CreateAsync(WebSocketManager wsManager, SessionEntity session, ILogger<WebSocketInstance> logger)
{
WebSocket? ws = await wsManager.AcceptWebSocketAsync();
if (ws is null) return null;

return new WebSocketInstance(user.GetUserId(), user.GetSessionId(), ws, logger);
return new WebSocketInstance(session.UserId, session.Id, ws, logger);
}

public Guid UserId { get; init; }
Expand Down
28 changes: 7 additions & 21 deletions RestAPI/Authentication/ZapMeAuthenticationHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.IdentityModel.Tokens;
using System.Net.Http.Headers;
using System.Security.Claims;
using System.Text.Json;
Expand All @@ -13,9 +12,7 @@
using ZapMe.Database.Models;
using ZapMe.DTOs;
using ZapMe.Helpers;
using ZapMe.Options;
using ZapMe.Services.Interfaces;
using ZapMe.Utils;

namespace ZapMe.Authentication;

Expand All @@ -25,17 +22,15 @@ public sealed class ZapMeAuthenticationHandler : IAuthenticationSignInHandler
private HttpContext _context = default!;
private Task<AuthenticateResult>? _authenticateTask = null;
private readonly DatabaseContext _dbContext;
private readonly ISessionStore _sessionStore;
private readonly IJwtAuthenticationManager _authenticationManager;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly JwtOptions _jwtOptions;
private readonly ILogger<ZapMeAuthenticationHandler> _logger;

public ZapMeAuthenticationHandler(DatabaseContext dbContext, ISessionStore sessionStore, IOptions<JsonOptions> jsonOptions, IOptions<JwtOptions> jwtOptions, ILogger<ZapMeAuthenticationHandler> logger)
public ZapMeAuthenticationHandler(DatabaseContext dbContext, IJwtAuthenticationManager authenticationManager, IOptions<JsonOptions> jsonOptions, ILogger<ZapMeAuthenticationHandler> logger)
{
_dbContext = dbContext;
_sessionStore = sessionStore;
_authenticationManager = authenticationManager;
_jsonSerializerOptions = jsonOptions.Value.JsonSerializerOptions;
_jwtOptions = jwtOptions.Value;
_logger = logger;
}

Expand Down Expand Up @@ -129,7 +124,7 @@ public async Task SignInAsync(ClaimsPrincipal claimsPrincipal, AuthenticationPro
}

Response.StatusCode = StatusCodes.Status200OK;
await Response.WriteAsJsonAsync(new AuthenticationResponse(JwtToken: JwtTokenUtils.GenerateJwtToken(claimsIdentity, issuedAt, expiresAt, _jwtOptions.SigningKey)));
await Response.WriteAsJsonAsync(new AuthenticationResponse(JwtToken: _authenticationManager.GenerateJwtToken(claimsIdentity, issuedAt, expiresAt)));
}

public Task SignOutAsync(AuthenticationProperties? properties)
Expand All @@ -156,22 +151,13 @@ private async Task<AuthenticateResult> HandleAuthenticateAsync()
return AuthenticateResult.Fail("Invalid Authorization header.");
}

if (!JwtTokenUtils.ValidateJwtToken(authHeaderValue.Parameter, _jwtOptions.SigningKey, out ClaimsPrincipal claimsPrincipal, out SecurityToken securityToken))
var authenticationResult = await _authenticationManager.AuthenticateJwtTokenAsync(authHeaderValue.Parameter, CancellationToken);
if (authenticationResult.TryPickT1(out ErrorDetails errorDetails, out SessionEntity session))
{
await errorDetails.Write(Response, _jsonSerializerOptions);
return AuthenticateResult.Fail("Invalid JWT token.");
}

if (!claimsPrincipal.GetUserEmailVerified())
{
return AuthenticateResult.Fail("Email is not verified.");
}

var session = await _sessionStore.TryGetAsync(claimsPrincipal.GetSessionId(), CancellationToken);
if (session is null)
{
return AuthenticateResult.Fail("Invalid or Expired Session");
}

return AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(session.ToClaimsIdentity()), AuthenticationConstants.ZapMeScheme));
}

Expand Down
5 changes: 2 additions & 3 deletions RestAPI/Controllers/Api/V1/Account/Create.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using ZapMe.DTOs.API.User;
using ZapMe.Enums;
using ZapMe.Helpers;
using ZapMe.Options;
using ZapMe.Services.Interfaces;
using ZapMe.Utils;

Expand Down Expand Up @@ -206,9 +205,9 @@ await _dbContext.Images
cancellationToken
);

var jwtOptions = HttpContext.RequestServices.GetRequiredService<JwtOptions>();
var authenticationManager = HttpContext.RequestServices.GetRequiredService<IJwtAuthenticationManager>();

jwtToken = JwtTokenUtils.GenerateJwtToken(session, jwtOptions.SigningKey);
jwtToken = authenticationManager.GenerateJwtToken(session.ToClaimsIdentity(), session.CreatedAt, session.ExpiresAt);
}

// Send email verification
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Microsoft.AspNetCore.Mvc;
using System.Security.Claims;
using ZapMe.Database.Models;
using ZapMe.DTOs;
using ZapMe.Helpers;
using ZapMe.Services.Interfaces;
using ZapMe.Websocket;

namespace ZapMe.Controllers.Ws;
namespace ZapMe.Controllers.Api.Ws;

public sealed partial class WebSocketController
{
Expand All @@ -13,6 +15,8 @@ public sealed partial class WebSocketController
/// Documentation:
/// Yes
/// </summary>
/// <param name="token"></param>
/// <param name="authenticationManager"></param>
/// <param name="logger"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Expand All @@ -21,24 +25,28 @@ public sealed partial class WebSocketController
[HttpGet(Name = "WebSocket")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
public async Task<IActionResult> EntryPointAsync([FromServices] ILogger<WebSocketInstance> logger, CancellationToken cancellationToken)
public async Task<IActionResult> EntryPointAsync(
[FromQuery] string token,
[FromServices] IJwtAuthenticationManager authenticationManager,
[FromServices] ILogger<WebSocketInstance> logger,
CancellationToken cancellationToken
)
{
var authenticationResult = await authenticationManager.AuthenticateJwtTokenAsync(token, cancellationToken);
if (authenticationResult.TryPickT1(out ErrorDetails errorDetails, out SessionEntity session))
{
return errorDetails.ToActionResult();
}

WebSocketManager wsManager = HttpContext.WebSockets;

if (wsManager.IsWebSocketRequest)
{
Guid? userId = User.GetUserId();
if (!userId.HasValue)
{
return HttpErrors.UnauthorizedActionResult;
}


// The trace identifier is used to identify the websocket instance, it will be unique for each websocket connection
string instanceId = HttpContext.TraceIdentifier;

// Create the connection instance
using WebSocketInstance? instance = await WebSocketInstance.CreateAsync(wsManager, User, logger);
using WebSocketInstance? instance = await WebSocketInstance.CreateAsync(wsManager, session, logger);
if (instance is null)
{
_logger.LogError("Failed to create websocket instance");
Expand All @@ -47,7 +55,7 @@ public async Task<IActionResult> EntryPointAsync([FromServices] ILogger<WebSocke
}

// Register instance globally, the manager will have the ability to kill this connection
if (!await _webSocketInstanceManager.RegisterInstanceAsync(userId.Value, instanceId, instance, cancellationToken))
if (!await _webSocketInstanceManager.RegisterInstanceAsync(session.UserId, instanceId, instance, cancellationToken))
{
return HttpErrors.InternalServerErrorActionResult;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc;
using ZapMe.DTOs;
using ZapMe.Services.Interfaces;
using static System.Net.Mime.MediaTypeNames;

namespace ZapMe.Controllers.Ws;
namespace ZapMe.Controllers.Api.Ws;

/// <summary>
///
/// </summary>
[Consumes(Application.Json)]
[Produces(Application.Json)]
[ProducesErrorResponseType(typeof(ErrorDetails))]
[ApiController, Authorize, Route("ws")]
[ApiController, Route("api/ws")]
public sealed partial class WebSocketController : ControllerBase
{
private readonly IWebSocketInstanceManager _webSocketInstanceManager;
Expand Down
2 changes: 2 additions & 0 deletions RestAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
//services.AddGoogleReCaptchaService(configuration);
services.AddMailGunService(configuration);

services.AddTransient<IJwtAuthenticationManager, JwtAuthenticationManager>();
services.AddTransient<IPasswordResetRequestStore, PasswordResetRequestStore>();
services.AddTransient<IPasswordResetManager, PasswordResetManager>();
services.AddTransient<IUserAgentStore, UserAgentStore>();
Expand Down Expand Up @@ -218,6 +219,7 @@
app.UseAuthorization();
app.UseRateLimiter();
app.UseMiddleware<ActivityTracker>();
app.UseWebSockets();
app.UseEndpoints(endpoints => endpoints.MapControllers());
});
app.Map("/swagger", true, app =>
Expand Down
Loading

0 comments on commit 4ba940a

Please sign in to comment.