Skip to content

Commit

Permalink
feat: Per-request Key authorization #124 (#125)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <aliaksandr_stsiapanay@epam.com>
  • Loading branch information
astsiapanay and astsiapanay authored Jan 12, 2024
1 parent 2e124ad commit 7075561
Show file tree
Hide file tree
Showing 18 changed files with 340 additions and 444 deletions.
6 changes: 4 additions & 2 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.epam.aidial.core.log.GfLogStore;
import com.epam.aidial.core.log.LogStore;
import com.epam.aidial.core.security.AccessTokenValidator;
import com.epam.aidial.core.security.ApiKeyStore;
import com.epam.aidial.core.security.EncryptionService;
import com.epam.aidial.core.storage.BlobStorage;
import com.epam.aidial.core.upstream.UpstreamBalancer;
Expand Down Expand Up @@ -65,7 +66,8 @@ void start() throws Exception {
vertx = Vertx.vertx(vertxOptions);
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));

ConfigStore configStore = new FileConfigStore(vertx, settings("config"));
ApiKeyStore apiKeyStore = new ApiKeyStore();
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
LogStore logStore = new GfLogStore(vertx);
RateLimiter rateLimiter = new RateLimiter();
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();
Expand All @@ -75,7 +77,7 @@ void start() throws Exception {
storage = new BlobStorage(storageConfig);
}
EncryptionService encryptionService = new EncryptionService(Json.decodeValue(settings("encryption").toBuffer(), Encryption.class));
Proxy proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage, encryptionService);
Proxy proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage, encryptionService, apiKeyStore);

server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy);
open(server, HttpServer::listen);
Expand Down
23 changes: 14 additions & 9 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package com.epam.aidial.core;

import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Config;
import com.epam.aidial.core.config.ConfigStore;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.controller.Controller;
import com.epam.aidial.core.controller.ControllerSelector;
import com.epam.aidial.core.limiter.RateLimiter;
import com.epam.aidial.core.log.LogStore;
import com.epam.aidial.core.security.AccessTokenValidator;
import com.epam.aidial.core.security.ApiKeyStore;
import com.epam.aidial.core.security.EncryptionService;
import com.epam.aidial.core.security.ExtractedClaims;
import com.epam.aidial.core.storage.BlobStorage;
import com.epam.aidial.core.upstream.UpstreamBalancer;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ProxyUtil;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanContext;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
Expand Down Expand Up @@ -62,6 +64,7 @@ public class Proxy implements Handler<HttpServerRequest> {
private final AccessTokenValidator tokenValidator;
private final BlobStorage storage;
private final EncryptionService encryptionService;
private final ApiKeyStore apiKeyStore;

@Override
public void handle(HttpServerRequest request) {
Expand Down Expand Up @@ -116,27 +119,29 @@ private void handleRequest(HttpServerRequest request) {
Config config = configStore.load();
String apiKey = request.headers().get(HEADER_API_KEY);
String authorization = request.getHeader(HttpHeaders.AUTHORIZATION);
String traceId = Span.current().getSpanContext().getTraceId();
SpanContext spanContext = Span.current().getSpanContext();
String traceId = spanContext.getTraceId();
String spanId = spanContext.getSpanId();
log.debug("Authorization header: {}", authorization);
Key key;
ApiKeyData apiKeyData;
if (apiKey == null && authorization == null) {
respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided");
return;
} else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) {
respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both");
return;
} else if (apiKey != null) {
key = config.getKeys().get(apiKey);
apiKeyData = apiKeyStore.getApiKeyData(apiKey);
// Special case handling. OpenAI client sends both API key and Auth headers even if a caller sets just API Key only
// Auth header is set to the same value as API Key header
// ignore auth header in this case
authorization = null;
if (key == null) {
if (apiKeyData == null) {
respond(request, HttpStatus.UNAUTHORIZED, "Unknown api key");
return;
}
} else {
key = null;
apiKeyData = new ApiKeyData();
}

request.pause();
Expand All @@ -145,7 +150,7 @@ private void handleRequest(HttpServerRequest request) {
extractedClaims.onComplete(result -> {
try {
if (result.succeeded()) {
onExtractClaimsSuccess(result.result(), config, request, key, traceId);
onExtractClaimsSuccess(result.result(), config, request, apiKeyData, traceId, spanId);
} else {
onExtractClaimsFailure(result.cause(), request);
}
Expand All @@ -163,8 +168,8 @@ private void onExtractClaimsFailure(Throwable error, HttpServerRequest request)
}

private void onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, Key key, String traceId) throws Exception {
ProxyContext context = new ProxyContext(config, request, key, extractedClaims, traceId);
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) throws Exception {
ProxyContext context = new ProxyContext(config, request, apiKeyData, extractedClaims, traceId, spanId);
Controller controller = ControllerSelector.select(this, context);
controller.handle();
}
Expand Down
29 changes: 21 additions & 8 deletions src/main/java/com/epam/aidial/core/ProxyContext.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.epam.aidial.core;

import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Config;
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Key;
Expand Down Expand Up @@ -32,6 +33,9 @@ public class ProxyContext {
private final HttpServerRequest request;
private final HttpServerResponse response;
private final String traceId;
private final ApiKeyData apiKeyData;
private final String spanId;
private final String parentSpanId;

private Deployment deployment;
private String userSub;
Expand All @@ -50,25 +54,34 @@ public class ProxyContext {
private long proxyConnectTimestamp;
private long proxyResponseTimestamp;
private long responseBodyTimestamp;
// the project belongs to API key which initiated request
private String originalProject;
private ExtractedClaims extractedClaims;

public ProxyContext(Config config, HttpServerRequest request, Key key, ExtractedClaims extractedClaims, String traceId) {
public ProxyContext(Config config, HttpServerRequest request, ApiKeyData apiKeyData, ExtractedClaims extractedClaims, String traceId, String spanId) {
this.config = config;
this.key = key;
if (key != null) {
originalProject = key.getProject();
}
this.apiKeyData = apiKeyData;
this.request = request;
this.response = request.response();
this.requestTimestamp = System.currentTimeMillis();
this.key = apiKeyData.getOriginalKey();
if (apiKeyData.getPerRequestKey() != null) {
initExtractedClaims(apiKeyData.getExtractedClaims());
this.traceId = apiKeyData.getTraceId();
this.parentSpanId = apiKeyData.getSpanId();
} else {
initExtractedClaims(extractedClaims);
this.traceId = traceId;
this.parentSpanId = null;
}
this.spanId = spanId;
}

private void initExtractedClaims(ExtractedClaims extractedClaims) {
this.extractedClaims = extractedClaims;
if (extractedClaims != null) {
this.userRoles = extractedClaims.userRoles();
this.userHash = extractedClaims.userHash();
this.userSub = extractedClaims.sub();
}
this.traceId = traceId;
}

public Future<Void> respond(HttpStatus status) {
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/com/epam/aidial/core/config/ApiKeyData.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.epam.aidial.core.config;

import com.epam.aidial.core.security.ExtractedClaims;
import lombok.Data;

import java.util.HashSet;
import java.util.Set;

@Data
public class ApiKeyData {
private String perRequestKey;
private Key originalKey;
private ExtractedClaims extractedClaims;
private String traceId;
private String spanId;
private Set<String> attachedFiles = new HashSet<>();

public ApiKeyData() {
}

public static ApiKeyData from(ApiKeyData another) {
ApiKeyData data = new ApiKeyData();
data.originalKey = another.originalKey;
data.extractedClaims = another.extractedClaims;
data.traceId = another.traceId;
return data;
}
}
6 changes: 0 additions & 6 deletions src/main/java/com/epam/aidial/core/config/Deployment.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@ public abstract class Deployment {
* Authorization token is NOT forwarded by default.
*/
private boolean forwardAuthToken = false;
/**
* Forward Http header with API key when request is sent to deployment.
* API key is forwarded by default.
*/
private boolean forwardApiKey = true;
private Features features;
private List<String> inputAttachmentTypes;
private Integer maxInputAttachments;
private String apiKey;
}
41 changes: 5 additions & 36 deletions src/main/java/com/epam/aidial/core/config/FileConfigStore.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.epam.aidial.core.config;

import com.epam.aidial.core.security.ApiKeyStore;
import com.epam.aidial.core.util.ProxyUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.annotations.VisibleForTesting;
import io.vertx.core.Vertx;
import io.vertx.core.json.JsonObject;
import lombok.SneakyThrows;
Expand All @@ -12,21 +12,20 @@
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

import static com.epam.aidial.core.config.Config.ASSISTANT;
import static com.epam.aidial.core.security.ApiKeyGenerator.generateKey;


@Slf4j
public final class FileConfigStore implements ConfigStore {

private final String[] paths;
private volatile Config config;
private final Map<String, String> deploymentKeys = new HashMap<>();
private final ApiKeyStore apiKeyStore;

public FileConfigStore(Vertx vertx, JsonObject settings) {
public FileConfigStore(Vertx vertx, JsonObject settings, ApiKeyStore apiKeyStore) {
this.apiKeyStore = apiKeyStore;
this.paths = settings.getJsonArray("files")
.stream().map(path -> (String) path).toArray(String[]::new);

Expand Down Expand Up @@ -55,7 +54,6 @@ private void load(boolean fail) {
String name = entry.getKey();
Model model = entry.getValue();
model.setName(name);
associateDeploymentWithApiKey(config, model);
}

for (Map.Entry<String, Addon> entry : config.getAddons().entrySet()) {
Expand All @@ -69,7 +67,6 @@ private void load(boolean fail) {
String name = entry.getKey();
Assistant assistant = entry.getValue();
assistant.setName(name);
associateDeploymentWithApiKey(config, assistant);

if (assistant.getEndpoint() == null) {
assistant.setEndpoint(assistants.getEndpoint());
Expand All @@ -83,22 +80,16 @@ private void load(boolean fail) {
baseAssistant.setName(ASSISTANT);
baseAssistant.setEndpoint(assistants.getEndpoint());
baseAssistant.setFeatures(assistants.getFeatures());
associateDeploymentWithApiKey(config, baseAssistant);
assistants.getAssistants().put(ASSISTANT, baseAssistant);
}

for (Map.Entry<String, Application> entry : config.getApplications().entrySet()) {
String name = entry.getKey();
Application application = entry.getValue();
application.setName(name);
associateDeploymentWithApiKey(config, application);
}

for (Map.Entry<String, Key> entry : config.getKeys().entrySet()) {
String key = entry.getKey();
Key value = entry.getValue();
value.setKey(key);
}
apiKeyStore.addProjectKeys(config.getKeys());

for (Map.Entry<String, Role> entry : config.getRoles().entrySet()) {
String name = entry.getKey();
Expand All @@ -116,28 +107,6 @@ private void load(boolean fail) {
}
}

@VisibleForTesting
void associateDeploymentWithApiKey(Config config, Deployment deployment) {
String apiKey = deployment.getApiKey();
String deploymentName = deployment.getName();
if (apiKey == null) {
apiKey = deploymentKeys.computeIfAbsent(deploymentName, k -> generateKey());
} else {
deploymentKeys.put(deploymentName, apiKey);
}
Map<String, Key> keys = config.getKeys();
while (keys.containsKey(apiKey) && !deploymentName.equals(keys.get(apiKey).getProject())) {
log.warn("duplicate API key is found for deployment {}. Trying to generate a new one", deployment.getName());
apiKey = generateKey();
}
deployment.setApiKey(apiKey);
Key key = new Key();
key.setKey(apiKey);
key.setProject(deployment.getName());
config.getKeys().put(apiKey, key);
deploymentKeys.put(deployment.getName(), apiKey);
}

private Config loadConfig() throws Exception {
JsonNode tree = null;

Expand Down
Loading

0 comments on commit 7075561

Please sign in to comment.