From cba251ccfa02a1968c4e06d84db4f1353da843d8 Mon Sep 17 00:00:00 2001 From: Shane Krueger Date: Sun, 1 Dec 2024 16:08:57 -0500 Subject: [PATCH] Add JWT Bearer authentication package (#82) * Add JWT Bearer authentication package * update * update * update * fix bug * Update * Fix default scheme * Add tests * Update * Update * Update --- GraphQL.AspNetCore3.sln | 6 + README.md | 32 ++ migration.md | 16 + .../AspNetCore3JwtBearerExtensions.cs | 20 + .../GraphQL.AspNetCore3.JwtBearer.csproj | 27 ++ .../JwtWebSocketAuthenticationService.cs | 158 ++++++++ .../GraphQLHttpMiddleware.cs | 1 - .../GraphQLHttpMiddlewareOptions.cs | 2 + .../IAuthorizationOptions.cs | 6 + .../WebSockets/AuthenticationRequest.cs | 48 +++ .../WebSockets/BaseSubscriptionServer.cs | 2 +- .../GraphQLWs/SubscriptionServer.cs | 6 +- .../IWebSocketAuthenticationService.cs | 4 +- .../SubscriptionServer.cs | 6 +- .../GraphQL.AspNetCore3.approved.txt | 11 +- src/Tests/BuilderMethodTests.cs | 8 +- src/Tests/ChatTests.cs | 12 +- .../AspNetCore3JwtBearerExtensionsTests.cs | 44 +++ .../JwtWebSocketAuthenticationServiceTests.cs | 361 ++++++++++++++++++ src/Tests/Tests.csproj | 2 + 20 files changed, 756 insertions(+), 16 deletions(-) create mode 100644 src/GraphQL.AspNetCore3.JwtBearer/AspNetCore3JwtBearerExtensions.cs create mode 100644 src/GraphQL.AspNetCore3.JwtBearer/GraphQL.AspNetCore3.JwtBearer.csproj create mode 100644 src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs create mode 100644 src/GraphQL.AspNetCore3/WebSockets/AuthenticationRequest.cs create mode 100644 src/Tests/JwtBearer/AspNetCore3JwtBearerExtensionsTests.cs create mode 100644 src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs diff --git a/GraphQL.AspNetCore3.sln b/GraphQL.AspNetCore3.sln index bffdf2f..73e0bec 100644 --- a/GraphQL.AspNetCore3.sln +++ b/GraphQL.AspNetCore3.sln @@ -51,6 +51,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CorsSample", "src\Samples\C EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Net48Sample", "src\Samples\Net48Sample\Net48Sample.csproj", "{C325FFAC-F5D6-411A-B93F-2B04AC8356D4}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GraphQL.AspNetCore3.JwtBearer", "src\GraphQL.AspNetCore3.JwtBearer\GraphQL.AspNetCore3.JwtBearer.csproj", "{7FDCD730-A321-4147-998F-0F26549B0A39}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -105,6 +107,10 @@ Global {C325FFAC-F5D6-411A-B93F-2B04AC8356D4}.Debug|Any CPU.Build.0 = Debug|Any CPU {C325FFAC-F5D6-411A-B93F-2B04AC8356D4}.Release|Any CPU.ActiveCfg = Release|Any CPU {C325FFAC-F5D6-411A-B93F-2B04AC8356D4}.Release|Any CPU.Build.0 = Release|Any CPU + {7FDCD730-A321-4147-998F-0F26549B0A39}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7FDCD730-A321-4147-998F-0F26549B0A39}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7FDCD730-A321-4147-998F-0F26549B0A39}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7FDCD730-A321-4147-998F-0F26549B0A39}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/README.md b/README.md index 4ec3f36..3447e70 100644 --- a/README.md +++ b/README.md @@ -422,6 +422,38 @@ Note that `InvokeAsync` will execute even if the protocol is disabled in the opt disabling `HandleGet` or similar; `HandleAuthorizeAsync` and `HandleAuthorizeWebSocketConnectionAsync` will not. +JWT Bearer authentication is provided by the `GraphQL.AspNetCore3.JwtBearer` package. +Like the above sample, it will look for an "Authorization" entry that starts with "Bearer " +and validate the token using the configured ASP.Net Core JWT Bearer authentication handler. +Configure it using the `AddJwtBearerAuthentication` extension method as shown +in the example below: + +```csharp +builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) + .AddJwtBearer(); + +builder.Services.AddGraphQL(b => b + .AddAutoSchema() + .AddSystemTextJson() + .AddAuthorizationRule() + .AddJwtBearerAuthentication() +); + +app.UseGraphQL("/graphql", config => +{ + // require that the user be authenticated + config.AuthorizationRequired = true; +}); +``` + +Please note: + +- If JWT Bearer is not the default authentication scheme, you will need to specify + the authentication scheme to use for GraphQL requests. See 'Authentication schemes' + below for more information. + +- Events configured through `JwtBearerEvents` are not currently supported. + #### Authentication schemes By default the role and policy requirements are validated against the current user as defined by diff --git a/migration.md b/migration.md index 2c145ad..325fe0d 100644 --- a/migration.md +++ b/migration.md @@ -1,5 +1,21 @@ # Version history / migration notes +## 7.0.0 + +GraphQL.AspNetCore3 v7 requires GraphQL.NET v8 or newer. + +### New features + +- Supports JWT WebSocket Authentication using the separately-provided `GraphQL.AspNetCore3.JwtBearer` package. + - Inherits most options configured by the `Microsoft.AspNetCore.Authentication.JwtBearer` package. + - Supports multiple authentication schemes, configurable via the `GraphQLHttpMiddlewareOptions.AuthenticationSchemes` property. + - Defaults to attempting the `AuthenticationOptions.DefaultScheme` scheme if not specified. + +### Breaking changes + +- `AuthenticationSchemes` property added to `IAuthorizationOptions` interface. +- `IWebSocketAuthenticationService.AuthenticateAsync` parameters refactored into an `AuthenticationRequest` class. + ## 6.0.0 GraphQL.AspNetCore3 v6 requires GraphQL.NET v8 or newer. diff --git a/src/GraphQL.AspNetCore3.JwtBearer/AspNetCore3JwtBearerExtensions.cs b/src/GraphQL.AspNetCore3.JwtBearer/AspNetCore3JwtBearerExtensions.cs new file mode 100644 index 0000000..aa73d7a --- /dev/null +++ b/src/GraphQL.AspNetCore3.JwtBearer/AspNetCore3JwtBearerExtensions.cs @@ -0,0 +1,20 @@ +using GraphQL.AspNetCore3; +using GraphQL.AspNetCore3.JwtBearer; +using GraphQL.DI; + +namespace GraphQL; + +/// +/// Extension methods for adding JWT bearer authentication to a GraphQL server for WebSocket communications. +/// +public static class AspNetCore3JwtBearerExtensions +{ + /// + /// Adds JWT bearer authentication to a GraphQL server for WebSocket communications. + /// + public static IGraphQLBuilder AddJwtBearerAuthentication(this IGraphQLBuilder builder) + { + builder.AddWebSocketAuthentication(); + return builder; + } +} diff --git a/src/GraphQL.AspNetCore3.JwtBearer/GraphQL.AspNetCore3.JwtBearer.csproj b/src/GraphQL.AspNetCore3.JwtBearer/GraphQL.AspNetCore3.JwtBearer.csproj new file mode 100644 index 0000000..18c2f9f --- /dev/null +++ b/src/GraphQL.AspNetCore3.JwtBearer/GraphQL.AspNetCore3.JwtBearer.csproj @@ -0,0 +1,27 @@ + + + + netstandard2.0;netcoreapp2.1;netcoreapp3.1;net6.0;net8.0 + JWT Bearer authentication for GraphQL projects + true + + + + + + + + + + + + + + + + + + + + + diff --git a/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs b/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs new file mode 100644 index 0000000..6b30f71 --- /dev/null +++ b/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs @@ -0,0 +1,158 @@ +// Parts of this code file are based on the JwtBearerHandler class in the Microsoft.AspNetCore.Authentication.JwtBearer package found at: +// https://github.com/dotnet/aspnetcore/blob/5493b413d1df3aaf00651bdf1cbd8135fa63f517/src/Security/Authentication/JwtBearer/src/JwtBearerHandler.cs +// +// Those sections of code may be subject to the MIT license found at: +// https://github.com/dotnet/aspnetcore/blob/5493b413d1df3aaf00651bdf1cbd8135fa63f517/LICENSE.txt + +using System.Security.Claims; +using GraphQL.AspNetCore3.WebSockets; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Microsoft.IdentityModel.Tokens; + +namespace GraphQL.AspNetCore3.JwtBearer; + +/// +/// Authenticates WebSocket connections via the 'payload' of the initialization packet. +/// This is necessary because WebSocket connections initiated from the browser cannot +/// authenticate via HTTP headers. +///

+/// Notes: +/// +/// This class is not used when authenticating over GET/POST. +/// +/// This class pulls the instance registered by ASP.NET Core during the call to +/// AddJwtBearer +/// for the default or configured authentication scheme and authenticates the token +/// based on simplified logic used by . +/// +/// +/// The expected format of the payload is {"Authorization":"Bearer TOKEN"} where TOKEN is the JSON Web Token (JWT), +/// mirroring the format of the 'Authorization' HTTP header. +/// +/// +/// Events configured in are not raised by this implementation. +/// +/// +/// Implementation does not call to log authentication events. +/// +/// +///
+public class JwtWebSocketAuthenticationService : IWebSocketAuthenticationService +{ + private readonly IGraphQLSerializer _graphQLSerializer; + private readonly IOptionsMonitor _jwtBearerOptionsMonitor; + private readonly string[] _defaultAuthenticationSchemes; + + /// + /// Initializes a new instance of the class. + /// + public JwtWebSocketAuthenticationService(IGraphQLSerializer graphQLSerializer, IOptionsMonitor jwtBearerOptionsMonitor, IOptions authenticationOptions) + { + _graphQLSerializer = graphQLSerializer; + _jwtBearerOptionsMonitor = jwtBearerOptionsMonitor; + var defaultAuthenticationScheme = authenticationOptions.Value.DefaultScheme; + _defaultAuthenticationSchemes = defaultAuthenticationScheme != null ? [defaultAuthenticationScheme] : []; + } + + /// + public async Task AuthenticateAsync(AuthenticationRequest authenticationRequest) + { + var connection = authenticationRequest.Connection; + var operationMessage = authenticationRequest.OperationMessage; + var schemes = authenticationRequest.AuthenticationSchemes.Any() ? authenticationRequest.AuthenticationSchemes : _defaultAuthenticationSchemes; + try { + // for connections authenticated via HTTP headers, no need to reauthenticate + if (connection.HttpContext.User.Identity?.IsAuthenticated ?? false) + return; + + // attempt to read the 'Authorization' key from the payload object and verify it contains "Bearer XXXXXXXX" + var authPayload = _graphQLSerializer.ReadNode(operationMessage.Payload); + if (authPayload != null && authPayload.Authorization != null && authPayload.Authorization.StartsWith("Bearer ", StringComparison.Ordinal)) { + // pull the token from the value + var token = authPayload.Authorization.Substring(7); + + // try to authenticate with each of the configured authentication schemes + foreach (var scheme in schemes) { + var options = _jwtBearerOptionsMonitor.Get(scheme); + + // follow logic simplified from JwtBearerHandler.HandleAuthenticateAsync, as follows: + var tokenValidationParameters = await SetupTokenValidationParametersAsync(options, connection.HttpContext).ConfigureAwait(false); +#if NET8_0_OR_GREATER + if (!options.UseSecurityTokenValidators) { + foreach (var tokenHandler in options.TokenHandlers) { + try { + var tokenValidationResult = await tokenHandler.ValidateTokenAsync(token, tokenValidationParameters).ConfigureAwait(false); + if (tokenValidationResult.IsValid) { + var principal = new ClaimsPrincipal(tokenValidationResult.ClaimsIdentity); + // set the ClaimsPrincipal for the HttpContext; authentication will take place against this object + connection.HttpContext.User = principal; + return; + } + } catch { + // no errors during authentication should throw an exception + // specifically, attempting to validate an invalid JWT token may result in an exception + } + } + } else { +#else + { +#endif +#pragma warning disable CS0618 // Type or member is obsolete + foreach (var validator in options.SecurityTokenValidators) { + if (validator.CanReadToken(token)) { + try { + var principal = validator.ValidateToken(token, tokenValidationParameters, out _); + // set the ClaimsPrincipal for the HttpContext; authentication will take place against this object + connection.HttpContext.User = principal; + return; + } catch { + // no errors during authentication should throw an exception + // specifically, attempting to validate an invalid JWT token will result in an exception + } + } + } +#pragma warning restore CS0618 // Type or member is obsolete + } + } + } + } catch { + // no errors during authentication should throw an exception + // specifically, parsing invalid JSON will result in an exception + } + } + + private static async ValueTask SetupTokenValidationParametersAsync(JwtBearerOptions options, HttpContext httpContext) + { + // Clone to avoid cross request race conditions for updated configurations. + var tokenValidationParameters = options.TokenValidationParameters.Clone(); + +#if NET8_0_OR_GREATER + if (options.ConfigurationManager is BaseConfigurationManager baseConfigurationManager) { + tokenValidationParameters.ConfigurationManager = baseConfigurationManager; + } else { +#else + { +#endif + if (options.ConfigurationManager != null) { + // GetConfigurationAsync has a time interval that must pass before new http request will be issued. + var configuration = await options.ConfigurationManager.GetConfigurationAsync(httpContext.RequestAborted).ConfigureAwait(false); + var issuers = new[] { configuration.Issuer }; + tokenValidationParameters.ValidIssuers = (tokenValidationParameters.ValidIssuers == null ? issuers : tokenValidationParameters.ValidIssuers.Concat(issuers)); + tokenValidationParameters.IssuerSigningKeys = (tokenValidationParameters.IssuerSigningKeys == null ? configuration.SigningKeys : tokenValidationParameters.IssuerSigningKeys.Concat(configuration.SigningKeys)); + } + } + + return tokenValidationParameters; + } + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public sealed class AuthPayload + { + public string? Authorization { get; set; } + } +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member +} diff --git a/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs b/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs index 35b6747..d1558d4 100644 --- a/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs +++ b/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs @@ -5,7 +5,6 @@ using System.Security.Claims; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authorization; -using static System.Net.Mime.MediaTypeNames; namespace GraphQL.AspNetCore3; diff --git a/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs b/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs index 098f9d1..9b16e2c 100644 --- a/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs +++ b/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs @@ -99,6 +99,8 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions /// public List AuthenticationSchemes { get; set; } = new(); + IEnumerable IAuthorizationOptions.AuthenticationSchemes => AuthenticationSchemes; + /// /// /// HTTP requests return 401 Forbidden when the request is not authenticated. diff --git a/src/GraphQL.AspNetCore3/IAuthorizationOptions.cs b/src/GraphQL.AspNetCore3/IAuthorizationOptions.cs index 91d5b1b..f74d50d 100644 --- a/src/GraphQL.AspNetCore3/IAuthorizationOptions.cs +++ b/src/GraphQL.AspNetCore3/IAuthorizationOptions.cs @@ -9,6 +9,12 @@ namespace GraphQL.AspNetCore3; /// public interface IAuthorizationOptions { + /// + /// Gets a list of the authentication schemes the authentication requirements are evaluated against. + /// When no schemes are specified, the default authentication scheme is used. + /// + IEnumerable AuthenticationSchemes { get; } + /// /// If set, requires that return /// for the user within diff --git a/src/GraphQL.AspNetCore3/WebSockets/AuthenticationRequest.cs b/src/GraphQL.AspNetCore3/WebSockets/AuthenticationRequest.cs new file mode 100644 index 0000000..0bf7cce --- /dev/null +++ b/src/GraphQL.AspNetCore3/WebSockets/AuthenticationRequest.cs @@ -0,0 +1,48 @@ +namespace GraphQL.AspNetCore3.WebSockets; + +/// +/// Represents an authentication request within the GraphQL ASP.NET Core WebSocket context. +/// +public class AuthenticationRequest +{ + /// + /// Gets the WebSocket connection associated with the authentication request. + /// + /// + /// An instance of representing the active WebSocket connection. + /// + public IWebSocketConnection Connection { get; } + + /// + /// Gets the subprotocol used for the WebSocket communication. + /// + /// + /// A specifying the subprotocol negotiated for the WebSocket connection. + /// + public string SubProtocol { get; } + + /// + /// Gets the operation message containing details of the authentication operation. + /// + /// + /// An instance of that encapsulates the specifics of the authentication request. + /// + public OperationMessage OperationMessage { get; } + + /// + /// Gets a list of the authentication schemes the authentication requirements are evaluated against. + /// When no schemes are specified, the default authentication scheme is used. + /// + public IEnumerable AuthenticationSchemes { get; } + + /// + /// Initializes a new instance of the class. + /// + public AuthenticationRequest(IWebSocketConnection connection, string subProtocol, OperationMessage operationMessage, IEnumerable authenticationSchemes) + { + Connection = connection; + SubProtocol = subProtocol; + OperationMessage = operationMessage; + AuthenticationSchemes = authenticationSchemes; + } +} diff --git a/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs b/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs index f6940cd..f6ce2f5 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs @@ -195,7 +195,7 @@ protected virtual Task ErrorIdAlreadyExistsAsync(OperationMessage message) /// OnNotAuthorizedRoleAsync /// or OnNotAuthorizedPolicyAsync. ///

- /// Derived implementations should call the + /// Derived implementations should call the /// method to authenticate the request, and then call this base method. ///

/// This method will return if authorization is successful, or diff --git a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs index 47a6427..db1f1ee 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs @@ -6,6 +6,7 @@ namespace GraphQL.AspNetCore3.WebSockets.GraphQLWs; public class SubscriptionServer : BaseSubscriptionServer { private readonly IWebSocketAuthenticationService? _authenticationService; + private readonly IEnumerable _authenticationSchemes; private readonly IGraphQLSerializer _serializer; private readonly GraphQLWebSocketOptions _options; private DateTime _lastPongReceivedUtc; @@ -76,6 +77,7 @@ public SubscriptionServer( _authenticationService = authenticationService; _serializer = serializer; _options = options; + _authenticationSchemes = authorizationOptions.AuthenticationSchemes; } /// @@ -305,7 +307,7 @@ protected override async Task ExecuteRequestAsync(OperationMess /// /// Authorizes an incoming GraphQL over WebSockets request with the connection initialization message and initializes the . ///

- /// The default implementation calls the + /// The default implementation calls the /// method to authenticate the request (if was specified), /// checks the authorization rules set in , /// if any, against . If validation fails, control is passed @@ -323,7 +325,7 @@ protected override async Task ExecuteRequestAsync(OperationMess protected override async ValueTask AuthorizeAsync(OperationMessage message) { if (_authenticationService != null) - await _authenticationService.AuthenticateAsync(Connection, SubProtocol, message); + await _authenticationService.AuthenticateAsync(new(Connection, SubProtocol, message, _authenticationSchemes)); bool success = await base.AuthorizeAsync(message); diff --git a/src/GraphQL.AspNetCore3/WebSockets/IWebSocketAuthenticationService.cs b/src/GraphQL.AspNetCore3/WebSockets/IWebSocketAuthenticationService.cs index 9a651d1..5d22347 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/IWebSocketAuthenticationService.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/IWebSocketAuthenticationService.cs @@ -11,12 +11,12 @@ public interface IWebSocketAuthenticationService { /// /// Authenticates an incoming GraphQL over WebSockets request with the connection initialization message. The implementation should - /// set the .HttpContext.User + /// set the .Connection.HttpContext.User /// property after validating the provided credentials. ///

/// After calling this method to authenticate the request, the infrastructure will authorize the incoming request via the /// , and /// properties. ///
- Task AuthenticateAsync(IWebSocketConnection connection, string subProtocol, OperationMessage operationMessage); + Task AuthenticateAsync(AuthenticationRequest authenticationRequest); } diff --git a/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs b/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs index 0f39428..5357bb6 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs @@ -6,6 +6,7 @@ namespace GraphQL.AspNetCore3.WebSockets.SubscriptionsTransportWs; public class SubscriptionServer : BaseSubscriptionServer { private readonly IWebSocketAuthenticationService? _authenticationService; + private readonly IEnumerable _authenticationSchemes; /// /// The WebSocket sub-protocol used for this protocol. @@ -69,6 +70,7 @@ public SubscriptionServer( UserContextBuilder = userContextBuilder ?? throw new ArgumentNullException(nameof(userContextBuilder)); Serializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); _authenticationService = authenticationService; + _authenticationSchemes = authorizationOptions.AuthenticationSchemes; } /// @@ -222,7 +224,7 @@ await Connection.SendMessageAsync(new OperationMessage { /// /// Authorizes an incoming GraphQL over WebSockets request with the connection initialization message and initializes the . ///

- /// The default implementation calls the + /// The default implementation calls the /// method to authenticate the request (if was specified), /// checks the authorization rules set in , /// if any, against . If validation fails, control is passed @@ -240,7 +242,7 @@ await Connection.SendMessageAsync(new OperationMessage { protected override async ValueTask AuthorizeAsync(OperationMessage message) { if (_authenticationService != null) - await _authenticationService.AuthenticateAsync(Connection, SubProtocol, message); + await _authenticationService.AuthenticateAsync(new(Connection, SubProtocol, message, _authenticationSchemes)); bool success = await base.AuthorizeAsync(message); diff --git a/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt b/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt index fbd14bf..c3bf427 100644 --- a/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt +++ b/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt @@ -195,6 +195,7 @@ namespace GraphQL.AspNetCore3 } public interface IAuthorizationOptions { + System.Collections.Generic.IEnumerable AuthenticationSchemes { get; } bool AuthorizationRequired { get; } string? AuthorizedPolicy { get; } System.Collections.Generic.IEnumerable AuthorizedRoles { get; } @@ -273,6 +274,14 @@ namespace GraphQL.AspNetCore3.Errors } namespace GraphQL.AspNetCore3.WebSockets { + public class AuthenticationRequest + { + public AuthenticationRequest(GraphQL.AspNetCore3.WebSockets.IWebSocketConnection connection, string subProtocol, GraphQL.Transport.OperationMessage operationMessage, System.Collections.Generic.IEnumerable authenticationSchemes) { } + public System.Collections.Generic.IEnumerable AuthenticationSchemes { get; } + public GraphQL.AspNetCore3.WebSockets.IWebSocketConnection Connection { get; } + public GraphQL.Transport.OperationMessage OperationMessage { get; } + public string SubProtocol { get; } + } public abstract class BaseSubscriptionServer : GraphQL.AspNetCore3.WebSockets.IOperationMessageProcessor, System.IDisposable { protected BaseSubscriptionServer(GraphQL.AspNetCore3.WebSockets.IWebSocketConnection connection, GraphQL.AspNetCore3.WebSockets.GraphQLWebSocketOptions options, GraphQL.AspNetCore3.IAuthorizationOptions authorizationOptions) { } @@ -338,7 +347,7 @@ namespace GraphQL.AspNetCore3.WebSockets } public interface IWebSocketAuthenticationService { - System.Threading.Tasks.Task AuthenticateAsync(GraphQL.AspNetCore3.WebSockets.IWebSocketConnection connection, string subProtocol, GraphQL.Transport.OperationMessage operationMessage); + System.Threading.Tasks.Task AuthenticateAsync(GraphQL.AspNetCore3.WebSockets.AuthenticationRequest authenticationRequest); } public interface IWebSocketConnection : System.IDisposable { diff --git a/src/Tests/BuilderMethodTests.cs b/src/Tests/BuilderMethodTests.cs index c0cd92b..010f2c9 100644 --- a/src/Tests/BuilderMethodTests.cs +++ b/src/Tests/BuilderMethodTests.cs @@ -46,7 +46,7 @@ public void WebSocketAuthenticationService_Typed() services.AddGraphQL(b => b.AddWebSocketAuthentication()); using var provider = services.BuildServiceProvider(); var service = provider.GetRequiredService(); - Should.Throw(() => service.AuthenticateAsync(null!, null!, null!)); + Should.Throw(() => service.AuthenticateAsync(null!)); } [Fact] @@ -56,7 +56,7 @@ public void WebSocketAuthenticationService_Factory() services.AddGraphQL(b => b.AddWebSocketAuthentication(_ => new MyWebSocketAuthenticationService())); using var provider = services.BuildServiceProvider(); var service = provider.GetRequiredService(); - Should.Throw(() => service.AuthenticateAsync(null!, null!, null!)); + Should.Throw(() => service.AuthenticateAsync(null!)); } [Fact] @@ -66,12 +66,12 @@ public void WebSocketAuthenticationService_Instance() services.AddGraphQL(b => b.AddWebSocketAuthentication(new MyWebSocketAuthenticationService())); using var provider = services.BuildServiceProvider(); var service = provider.GetRequiredService(); - Should.Throw(() => service.AuthenticateAsync(null!, null!, null!)); + Should.Throw(() => service.AuthenticateAsync(null!)); } private class MyWebSocketAuthenticationService : IWebSocketAuthenticationService { - public Task AuthenticateAsync(IWebSocketConnection connection, string subProtocol, OperationMessage operationMessage) => throw new NotImplementedException(); + public Task AuthenticateAsync(AuthenticationRequest authenticationRequest) => throw new NotImplementedException(); } [Theory] diff --git a/src/Tests/ChatTests.cs b/src/Tests/ChatTests.cs index 3cfcfc6..3ca2710 100644 --- a/src/Tests/ChatTests.cs +++ b/src/Tests/ChatTests.cs @@ -313,7 +313,10 @@ public async Task Subscription_AuthorizationFailed(string subProtocol) { var builder = ConfigureBuilder(); var mockAuthorizationService = new Mock(MockBehavior.Strict); - mockAuthorizationService.Setup(x => x.AuthenticateAsync(It.IsAny(), subProtocol, It.IsAny())).Returns(Task.CompletedTask).Verifiable(); + mockAuthorizationService.Setup(x => x.AuthenticateAsync(It.IsAny())).Returns(request => { + request.SubProtocol.ShouldBe(subProtocol); + return Task.CompletedTask; + }).Verifiable(); builder.ConfigureServices(s => s.AddSingleton(mockAuthorizationService.Object)); using var app = new TestServer(builder); _options.AuthorizationRequired = true; @@ -353,8 +356,11 @@ public async Task Subscription_Authentication(string subProtocol, bool successfu { var builder = ConfigureBuilder(); var mockAuthorizationService = new Mock(MockBehavior.Strict); - mockAuthorizationService.Setup(x => x.AuthenticateAsync(It.IsAny(), subProtocol, It.IsAny())) - .Returns((connection, _, message) => { + mockAuthorizationService.Setup(x => x.AuthenticateAsync(It.IsAny())) + .Returns((authenticationRequest) => { + var connection = authenticationRequest.Connection; + var message = authenticationRequest.OperationMessage; + authenticationRequest.SubProtocol.ShouldBe(subProtocol); connection.HttpContext.User.Identity!.IsAuthenticated.ShouldBeFalse(); var serializer = connection.HttpContext.RequestServices.GetRequiredService(); var payload = serializer.ReadNode(message.Payload); diff --git a/src/Tests/JwtBearer/AspNetCore3JwtBearerExtensionsTests.cs b/src/Tests/JwtBearer/AspNetCore3JwtBearerExtensionsTests.cs new file mode 100644 index 0000000..2686cc0 --- /dev/null +++ b/src/Tests/JwtBearer/AspNetCore3JwtBearerExtensionsTests.cs @@ -0,0 +1,44 @@ +using GraphQL.AspNetCore3.JwtBearer; +using ServiceLifetime = GraphQL.DI.ServiceLifetime; + +namespace Tests.JwtBearer; + +public class AspNetCore3JwtBearerExtensionsTests +{ + [Fact] + public void AddJwtBearerAuthentication_ShouldAddJwtWebSocketAuthenticationService() + { + // Arrange + var serviceRegisterMock = new Mock(MockBehavior.Strict); + var graphQLBuilderMock = new Mock(MockBehavior.Strict); + + // Setup the Services property to return the mocked IServiceRegister + graphQLBuilderMock + .SetupGet(x => x.Services) + .Returns(serviceRegisterMock.Object); + + // Setup the Register method to accept specific parameters + serviceRegisterMock + .Setup(x => x.Register( + typeof(IWebSocketAuthenticationService), + typeof(JwtWebSocketAuthenticationService), + ServiceLifetime.Singleton, + false)) + .Returns(serviceRegisterMock.Object); + + // Act + var result = graphQLBuilderMock.Object.AddJwtBearerAuthentication(); + + // Assert + result.ShouldBe(graphQLBuilderMock.Object); + + // Verify that Register was called with the correct parameters + serviceRegisterMock.Verify( + x => x.Register( + typeof(IWebSocketAuthenticationService), + typeof(JwtWebSocketAuthenticationService), + ServiceLifetime.Singleton, + false), + Times.Once); + } +} diff --git a/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs b/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs new file mode 100644 index 0000000..8832450 --- /dev/null +++ b/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs @@ -0,0 +1,361 @@ +using System.IdentityModel.Tokens.Jwt; +using System.Net; +using System.Net.Http.Headers; +using System.Net.WebSockets; +using System.Security.Claims; +using System.Security.Cryptography; +using System.Text.Json; +using Microsoft.AspNetCore.Authentication.Cookies; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.IdentityModel.Tokens; +using RichardSzalay.MockHttp; + +namespace Tests.JwtBearer; + +public class JwtWebSocketAuthenticationServiceTests +{ + private string _issuer = "https://demo.identityserver.io"; + private string _audience = "testAudience"; + private readonly string _subject = "user123"; + private RSAParameters _rsaParameters; + private string? _jwtAccessToken; + private readonly MockHttpMessageHandler _oidcHttpMessageHandler = new(); + private readonly ISchema _schema; + + public JwtWebSocketAuthenticationServiceTests() + { + var query = new ObjectGraphType() { Name = "Query" }; + query.Field("test").Resolve(ctx => ctx.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value); + _schema = new Schema { Query = query }; + } + + [Fact] + public async Task SuccessfulAuthentication() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(); + await TestGetAsync(testServer, isAuthenticated: true); + await TestWebSocketAsync(testServer, isAuthenticated: true); + } + + [Fact] + public async Task WrongKeys() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(); + CreateSignedToken(); // create new token with different keys + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task WrongIssuer() + { + CreateSignedToken(); + _issuer = "https://wrong.issuer"; + SetupOidcDiscovery(); + using var testServer = CreateTestServer(); + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task WrongAudience() + { + CreateSignedToken(); + _audience = "wrongAudience"; + SetupOidcDiscovery(); + using var testServer = CreateTestServer(); + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task Expired() + { + CreateSignedToken(expired: true); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(); + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task NoDefaultScheme() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(defaultScheme: false); + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task NoDefaultSchemeSpecified() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(defaultScheme: false, specifyScheme: true); + await TestGetAsync(testServer, isAuthenticated: true); + await TestWebSocketAsync(testServer, isAuthenticated: true); + } + + [Fact] + public async Task CustomScheme() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(customScheme: true); + await TestGetAsync(testServer, isAuthenticated: true); + await TestWebSocketAsync(testServer, isAuthenticated: true); + } + + [Fact] + public async Task CustomNoDefaultScheme() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(customScheme: true, defaultScheme: false); + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task CustomNoDefaultSchemeSpecified() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(customScheme: true, defaultScheme: false, specifyScheme: true); + await TestGetAsync(testServer, isAuthenticated: true); + await TestWebSocketAsync(testServer, isAuthenticated: true); + } + + [Fact] + public async Task WrongScheme() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(specifyInvalidScheme: true); + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + [Fact] + public async Task MultipleSchemes() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(specifyInvalidScheme: true, specifyScheme: true, defaultScheme: false); + await TestGetAsync(testServer, isAuthenticated: true); + await TestWebSocketAsync(testServer, isAuthenticated: true); + } + + [Fact] + public async Task NoToken() + { + CreateSignedToken(); + SetupOidcDiscovery(); + using var testServer = CreateTestServer(); + _jwtAccessToken = null; + await TestGetAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false); + } + + private async Task TestGetAsync(TestServer testServer, bool isAuthenticated) + { + // test an authenticated request + using var client = testServer.CreateClient(); + var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={test}"); + if (_jwtAccessToken != null) + request.Headers.Authorization = new AuthenticationHeaderValue(JwtBearerDefaults.AuthenticationScheme, _jwtAccessToken); + using var response = await client.SendAsync(request); + if (isAuthenticated) { + response.StatusCode.ShouldBe(HttpStatusCode.OK); + var content = await response.Content.ReadAsStringAsync(); + content.ShouldBe($$$""" + {"data":{"test":"{{{_subject}}}"}} + """); + } else { + response.StatusCode.ShouldBe(HttpStatusCode.Unauthorized); + } + } + + private async Task TestWebSocketAsync(TestServer testServer, bool isAuthenticated) + { + // test an authenticated request + var webSocketClient = testServer.CreateWebSocketClient(); + webSocketClient.ConfigureRequest = request => { + request.Headers["Sec-WebSocket-Protocol"] = "graphql-ws"; + }; + webSocketClient.SubProtocols.Add("graphql-ws"); + using var webSocket = await webSocketClient.ConnectAsync(new Uri(testServer.BaseAddress, "/graphql"), default); + + // send CONNECTION_INIT + await webSocket.SendMessageAsync(new OperationMessage { + Type = "connection_init", + Payload = _jwtAccessToken != null ? new { + Authorization = "Bearer " + _jwtAccessToken, + } : null, + }); + + if (!isAuthenticated) { + // wait for CONNECTION_ERROR + var message1 = await webSocket.ReceiveMessageAsync(); + message1.Type.ShouldBe("connection_error"); + message1.Payload.ShouldBeOfType().ShouldBe("\"Access denied\""); // for the purposes of testing, this contains the raw JSON received for this JSON element. + + // wait for websocket closure + (await webSocket.ReceiveCloseAsync()).ShouldBe((WebSocketCloseStatus)4401); + return; + } + + // wait for CONNECTION_ACK + var message = await webSocket.ReceiveMessageAsync(); + message.Type.ShouldBe("connection_ack"); + + // send start + await webSocket.SendMessageAsync(new OperationMessage { + Id = "1", + Type = "start", + Payload = new GraphQLRequest { + Query = "{test}", + }, + }); + + // wait for data + message = await webSocket.ReceiveMessageAsync(); + message.Type.ShouldBe("data"); + message.Id.ShouldBe("1"); + message.Payload.ShouldBe($$$""" + {"data":{"test":"{{{_subject}}}"}} + """); + } + + /// + /// Creates a test server with JWT bearer authentication. + /// Uses the currently configured and . + /// + private TestServer CreateTestServer(bool defaultScheme = true, bool customScheme = false, bool specifyScheme = false, bool specifyInvalidScheme = false) + { + return new TestServer(new WebHostBuilder() + .ConfigureServices(services => { + var authBuilder = services.AddAuthentication(defaultScheme ? customScheme ? "Custom" : JwtBearerDefaults.AuthenticationScheme : ""); + if (specifyInvalidScheme) { + authBuilder.AddCookie(); + } + authBuilder.AddJwtBearer(customScheme ? "Custom" : JwtBearerDefaults.AuthenticationScheme, o => { + o.Authority = _issuer; + o.Audience = _audience; + o.BackchannelHttpHandler = _oidcHttpMessageHandler; + }); + services.AddGraphQL(b => b + .AddSchema(_schema) + .AddSystemTextJson() + .AddJwtBearerAuthentication() + ); +#if NET48 || NETCOREAPP2_1 + services.AddHostApplicationLifetime(); +#endif + }) + .Configure(app => { + app.UseWebSockets(); + app.UseAuthentication(); + app.UseGraphQL(configureMiddleware: o => { + o.AuthorizationRequired = true; + o.CsrfProtectionEnabled = false; + if (specifyInvalidScheme) { + o.AuthenticationSchemes.Add(CookieAuthenticationDefaults.AuthenticationScheme); + } + if (specifyScheme) { + o.AuthenticationSchemes.Add(customScheme ? "Custom" : JwtBearerDefaults.AuthenticationScheme); + } + }); + })); + } + + /// + /// Configures the mock HTTP message handler to respond to OIDC discovery requests. + /// Uses the currently configured and . + /// + private void SetupOidcDiscovery() + { + // Comprehensive OIDC discovery document + var discoveryDocument = new { + issuer = _issuer, + authorization_endpoint = $"{_issuer}/connect/authorize", + token_endpoint = $"{_issuer}/connect/token", + userinfo_endpoint = $"{_issuer}/connect/userinfo", + jwks_uri = $"{_issuer}/.well-known/jwks.json", + response_types_supported = new[] { "code", "token", "id_token" }, + subject_types_supported = new[] { "public" }, + id_token_signing_alg_values_supported = new[] { "RS256" } + }; + + // mock the discovery endpoint + _oidcHttpMessageHandler + .When($"{_issuer}/.well-known/openid-configuration") + .Respond("application/json", JsonSerializer.Serialize(discoveryDocument)); + + // Create JWKS based on the RSA public key + var jwk = new { + keys = new[] + { + new + { + kty = "RSA", + use = "sig", + alg = "RS256", + n = Base64UrlEncoder.Encode(_rsaParameters.Modulus), + e = Base64UrlEncoder.Encode(_rsaParameters.Exponent) + } + } + }; + + // mock the JWKS endpoint + _oidcHttpMessageHandler + .When($"{_issuer}/.well-known/jwks.json") + .Respond("application/json", JsonSerializer.Serialize(jwk)); + + // throw for all other requests + _oidcHttpMessageHandler + .When("*") + .Respond(request => { + throw new NotImplementedException($"No handler configured for {request.RequestUri}"); + }); + } + + /// + /// Creates a new RSA key pair and a signed JWT token. + /// Uses the currently configured , , and . + /// Overwrites the and fields. + /// + private void CreateSignedToken(bool expired = false) + { + using var rsa = RSA.Create(2048); + var rsaParameters = rsa.ExportParameters(true); + var key = new RsaSecurityKey(rsaParameters); + var signingCredentials = new SigningCredentials(key, SecurityAlgorithms.RsaSha256); + + var now = DateTime.UtcNow; + if (expired) { + now = now.AddMinutes(-10); + } + var tokenDescriptor = new SecurityTokenDescriptor { + Issuer = _issuer, + Audience = _audience, + Subject = new ClaimsIdentity([new Claim(JwtRegisteredClaimNames.Sub, _subject)]), // Subject (user ID) + Expires = now.Add(TimeSpan.FromMinutes(5)), + IssuedAt = now, + NotBefore = now, + SigningCredentials = signingCredentials + }; + var tokenHandler = new JwtSecurityTokenHandler(); + var token = tokenHandler.CreateToken(tokenDescriptor); + var tokenStr = tokenHandler.WriteToken(token); + _rsaParameters = rsaParameters; + _jwtAccessToken = tokenStr; + } +} + diff --git a/src/Tests/Tests.csproj b/src/Tests/Tests.csproj index cd5c035..6da495c 100644 --- a/src/Tests/Tests.csproj +++ b/src/Tests/Tests.csproj @@ -14,6 +14,7 @@ + @@ -27,6 +28,7 @@ +