Skip to content

Commit

Permalink
Refactor oidc logic into UserHandlers
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeidnx committed Oct 26, 2023
1 parent e7f2187 commit c1fde37
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 171 deletions.
159 changes: 6 additions & 153 deletions src/main/java/me/kavin/piped/server/ServerLauncher.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.oauth2.sdk.*;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.openid.connect.sdk.*;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import com.rometools.rome.feed.synd.SyndFeed;
import com.rometools.rome.io.SyndFeedInput;
import io.activej.config.Config;
Expand Down Expand Up @@ -44,11 +37,8 @@
import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress;
import static io.activej.http.HttpHeaders.*;
Expand All @@ -61,7 +51,6 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher {

private static final HttpHeader FILE_NAME = HttpHeaders.of("x-file-name");
private static final HttpHeader LAST_ETAG = HttpHeaders.of("x-last-etag");
private static final Map<String, OidcData> PENDING_OIDC = new HashMap<>();

@Provides
Executor executor() {
Expand Down Expand Up @@ -291,137 +280,12 @@ AsyncServlet mainServlet(Executor executor) {
if (provider == null)
return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server");

URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");

switch (function) {
case "login" -> {
String redirectUri = request.getQueryParameter("redirect");

if (StringUtils.isBlank(redirectUri)) {
return HttpResponse.ofCode(400).withHtml("redirect is a required parameter");
}

OidcData data = new OidcData(redirectUri);
String state = data.getState();

PENDING_OIDC.put(state, data);

AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"),
provider.clientID, callback).endpointURI(provider.authUri)
.state(new State(state)).nonce(data.nonce).build();

if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) {
return HttpResponse.redirect302(oidcRequest.toURI().toString());
}
return HttpResponse.ok200().withHtml(
"<!DOCTYPE html><html style=\"color-scheme: dark light;\"><body>" +
"<h3>Warning:</h3> You are trying to give <pre style=\"font-size: 1.2rem;\">" +
redirectUri +
"</pre> access to your Piped account. If you wish to continue click " +
"<a style=\"text-decoration: underline;color: inherit;\"href=\"" +
oidcRequest.toURI().toString() +
"\">here</a></body></html>");
}
case "callback" -> {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);

AuthenticationSuccessResponse sr = parseOidcUri(URI.create(request.getFullUrl()));

OidcData data = PENDING_OIDC.get(sr.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback);


TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());

if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}

OIDCTokenResponse successResponse = tokenResponse.toSuccessResponse();

if (data.isInvalidNonce((String) successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Try again or contact your oidc admin"
);
}

UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());

if (!userInfoResponse.indicatesSuccess()) {
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode());
System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription());
return HttpResponse.ofCode(500).withHtml("Failed to query userInfo:\n\n" + userInfoResponse.toErrorResponse().getErrorObject().getDescription());
}

UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();

String sessionId = UserHandlers.oidcCallbackResponse(provider.name, userInfo.getSubject().toString());
return HttpResponse.redirect302(data.data + "?session=" + sessionId);
}
case "delete" -> {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);

AuthenticationSuccessResponse sr = parseOidcUri(URI.create(request.getFullUrl()));

OidcData data = UserHandlers.PENDING_OIDC.get(sr.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}

long start = Long.parseLong(data.data.split("\\|")[1]);
String session = data.data.split("\\|")[0];

AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, new URI(Constants.PUBLIC_URL + request.getPath()));


TokenRequest tr = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tr.toHTTPRequest().send());

if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription());
}

OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse();

JWTClaimsSet claims = successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet();

if (data.isInvalidNonce((String) claims.getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Please try again or contact your oidc admin."
);
}

long authTime = (long) claims.getClaim("auth_time");

if (authTime < start) {
return HttpResponse.ofCode(500).withHtml(
"Your oidc provider didn't verify your identity. Please try again or contact your oidc admin."
);
}

return HttpResponse.redirect302(Constants.FRONTEND_URL + "/preferences?deleted=" + UserHandlers.deleteOidcUserResponse(session));
}
default -> {
return HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`");
}
}


return switch (function) {
case "login" -> UserHandlers.oidcLoginResponse(provider, request.getQueryParameter("redirect"));
case "callback" -> UserHandlers.oidcCallbackResponse(provider, URI.create(request.getFullUrl()));
case "delete" -> UserHandlers.oidcDeleteResponse(provider, URI.create(request.getFullUrl()));
default -> HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`");
};
} catch (Exception e) {
return getErrorResponse(e, request.getPath());
}
Expand Down Expand Up @@ -680,17 +544,6 @@ private static OidcProvider getOidcProvider(String provider) {
return null;
}

private static AuthenticationSuccessResponse parseOidcUri(URI uri) throws Exception {
AuthenticationResponse response = AuthenticationResponseParser.parse(uri);

if (response instanceof AuthenticationErrorResponse) {
// The OpenID provider returned an error
System.err.println(response.toErrorResponse().getErrorObject());
throw new Exception(response.toErrorResponse().getErrorObject().toString());
}
return response.toSuccessResponse();
}

private static String[] getArray(String s) {

if (s == null) {
Expand Down
Loading

0 comments on commit c1fde37

Please sign in to comment.