Skip to content

Commit

Permalink
feat: Implement per request API keys through stateful layer #183 (#312)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <aliaksandr_stsiapanay@epam.com>
  • Loading branch information
astsiapanay and astsiapanay authored Apr 23, 2024
1 parent 3a8f503 commit 93ddf0f
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 87 deletions.
5 changes: 3 additions & 2 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ void start() throws Exception {
vertx = Vertx.vertx(vertxOptions);
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));

ApiKeyStore apiKeyStore = new ApiKeyStore();
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
LogStore logStore = new GfLogStore(vertx);
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();

Expand All @@ -121,6 +119,9 @@ void start() throws Exception {
AccessService accessService = new AccessService(encryptionService, shareService, publicationService, settings("access"));
RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

ApiKeyStore apiKeyStore = new ApiKeyStore(resourceService, lockService, vertx);
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);

proxy = new Proxy(vertx, client, configStore, logStore,
rateLimiter, upstreamBalancer, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService, invitationService,
Expand Down
46 changes: 28 additions & 18 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,43 +147,53 @@ private void handleRequest(HttpServerRequest request) {
String traceId = spanContext.getTraceId();
String spanId = spanContext.getSpanId();
log.debug("Authorization header: {}", authorization);
ApiKeyData apiKeyData;
Future<AuthorizationResult> authorizationResultFuture;

request.pause();
if (apiKey == null && authorization == null) {
respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided");
return;
} else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) {
respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both");
return;
} else if (apiKey != null) {
apiKeyData = apiKeyStore.getApiKeyData(apiKey);
// Special case handling. OpenAI client sends both API key and Auth headers even if a caller sets just API Key only
// Auth header is set to the same value as API Key header
// ignore auth header in this case
authorization = null;
if (apiKeyData == null) {
respond(request, HttpStatus.UNAUTHORIZED, "Unknown api key");
return;
}
authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey)
.onFailure(error -> onGettingApiKeyDataFailure(error, request))
.compose(apiKeyData -> {
if (apiKeyData == null) {
String errorMessage = "Unknown api key";
respond(request, HttpStatus.UNAUTHORIZED, errorMessage);
return Future.failedFuture(errorMessage);
}
return Future.succeededFuture(new AuthorizationResult(apiKeyData, null));
});
} else {
apiKeyData = new ApiKeyData();
authorizationResultFuture = tokenValidator.extractClaims(authorization)
.onFailure(error -> onExtractClaimsFailure(error, request))
.map(extractedClaims -> new AuthorizationResult(new ApiKeyData(), extractedClaims));
}

request.pause();
Future<ExtractedClaims> extractedClaims = tokenValidator.extractClaims(authorization);

extractedClaims.onFailure(error -> onExtractClaimsFailure(error, request))
.compose(claims -> onExtractClaimsSuccess(claims, config, request, apiKeyData, traceId, spanId))
authorizationResultFuture.compose(result -> processAuthorizationResult(result.extractedClaims, config, request, result.apiKeyData, traceId, spanId))
.onComplete(ignore -> request.resume());
}

private record AuthorizationResult(ApiKeyData apiKeyData, ExtractedClaims extractedClaims) {

}

private void onExtractClaimsFailure(Throwable error, HttpServerRequest request) {
log.error("Can't extract claims from authorization header", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

private void onGettingApiKeyDataFailure(Throwable error, HttpServerRequest request) {
log.error("Can't find data associated with API key", error);
respond(request, HttpStatus.UNAUTHORIZED, "Bad Authorization header");
}

@SneakyThrows
private Future<?> onExtractClaimsSuccess(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
private Future<?> processAuthorizationResult(ExtractedClaims extractedClaims, Config config,
HttpServerRequest request, ApiKeyData apiKeyData, String traceId, String spanId) {
Future<?> future;
try {
ProxyContext context = new ProxyContext(config, request, apiKeyData, extractedClaims, traceId, spanId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,28 @@ private void handleRateLimitSuccess(String deploymentId) {
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignApiKey(proxyApiKeyData);

proxy.getTokenStatsTracker().startSpan(context);

context.getRequest().body()
.onSuccess(body -> proxy.getVertx().executeBlocking(() -> {
// run setting up api key data in the worker thread
setupProxyApiKeyData();
handleRequestBody(body);
return null;
}))
}).onFailure(this::handleError))
.onFailure(this::handleRequestBodyError);
}

/**
* The method uses blocking calls and should not be used in the event loop thread.
*/
private void setupProxyApiKeyData() {
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
}

private void handleRateLimitHit(RateLimitResult result) {
// Returning an error similar to the Azure format.
ErrorData rateLimitError = new ErrorData();
Expand Down Expand Up @@ -203,6 +210,9 @@ void handleRequestBody(Buffer requestBody) {

try {
ProxyUtil.collectAttachedFiles(tree, this::processAttachedFile);
// update api key data after processing attachments
ApiKeyData destApiKeyData = context.getProxyApiKeyData();
proxy.getApiKeyStore().updatePerRequestApiKeyData(destApiKeyData);
} catch (HttpException e) {
respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
Expand Down Expand Up @@ -619,8 +629,14 @@ private Future<Void> respond(HttpStatus status, Object result) {

private void finalizeRequest() {
proxy.getTokenStatsTracker().endSpan(context);
if (context.getProxyApiKeyData() != null) {
proxy.getApiKeyStore().invalidateApiKey(context.getProxyApiKeyData());
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}
}
2 changes: 1 addition & 1 deletion src/main/java/com/epam/aidial/core/data/ResourceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
public enum ResourceType {
FILE("files"), CONVERSATION("conversations"), PROMPT("prompts"), LIMIT("limits"),
SHARED_WITH_ME("shared_with_me"), SHARED_BY_ME("shared_by_me"), INVITATION("invitations"),
PUBLICATION("publications"), RULES("rules");
PUBLICATION("publications"), RULES("rules"), API_KEY_DATA("api_key_data");

private final String group;

Expand Down
144 changes: 121 additions & 23 deletions src/main/java/com/epam/aidial/core/security/ApiKeyStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,155 @@

import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.service.LockService;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.ResourceDescription;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import static com.epam.aidial.core.security.ApiKeyGenerator.generateKey;
import static com.epam.aidial.core.storage.BlobStorageUtil.PATH_SEPARATOR;

/**
* The store keeps per request and project API key data.
* <p>
* Per request key is assigned during the request and terminated in the end of the request.
* Project keys are hosted by external secure storage and might be periodically updated by {@link com.epam.aidial.core.config.FileConfigStore}.
* </p>
*/
@Slf4j
@AllArgsConstructor
public class ApiKeyStore {

public static final String API_KEY_DATA_BUCKET = "api_key_data";
public static final String API_KEY_DATA_LOCATION = API_KEY_DATA_BUCKET + PATH_SEPARATOR;

private final ResourceService resourceService;

private final LockService lockService;

private final Vertx vertx;

/**
* Project API keys are hosted in the secure storage.
*/
private final Map<String, ApiKeyData> keys = new HashMap<>();

public synchronized void assignApiKey(ApiKeyData data) {
String apiKey = generateApiKey();
keys.put(apiKey, data);
data.setPerRequestKey(apiKey);
/**
* Assigns a new generated per request key to the {@link ApiKeyData}.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*/
public synchronized void assignPerRequestApiKey(ApiKeyData data) {
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
ResourceDescription resource = generateApiKey();
String apiKey = resource.getName();
data.setPerRequestKey(apiKey);
String json = ProxyUtil.convertToString(data);
if (resourceService.putResource(resource, json, false, false) == null) {
throw new IllegalStateException(String.format("API key %s already exists in the storage", apiKey));
}
return apiKey;
});
}

public synchronized ApiKeyData getApiKeyData(String key) {
return keys.get(key);
/**
* Returns API key data for the given key.
*
* @param key API key could be either project or per request key.
* @return the future of data associated with the given key.
*/
public synchronized Future<ApiKeyData> getApiKeyData(String key) {
ApiKeyData apiKeyData = keys.get(key);
if (apiKeyData != null) {
return Future.succeededFuture(apiKeyData);
}
ResourceDescription resource = toResource(key);
return vertx.executeBlocking(() -> ProxyUtil.convertToObject(resourceService.getResource(resource), ApiKeyData.class));
}

public synchronized void invalidateApiKey(ApiKeyData apiKeyData) {
if (apiKeyData.getPerRequestKey() != null) {
keys.remove(apiKeyData.getPerRequestKey());
/**
* Invalidates per request API key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData associated with the key to be invalidated.
* @return the future of the invalidation result: <code>true</code> means the key is successfully invalidated.
*/
public Future<Boolean> invalidatePerRequestApiKey(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey != null) {
ResourceDescription resource = toResource(apiKey);
return vertx.executeBlocking(() -> resourceService.deleteResource(resource));
}
return Future.succeededFuture(true);
}

/**
* Adds new project keys from the secure storage and removes previous project keys if any.
* <p>
* Note. The method is blocking and shouldn't be run in the event loop thread.
* </p>
*
* @param projectKeys new projects to be added to the store.
*/
public synchronized void addProjectKeys(Map<String, Key> projectKeys) {
keys.values().removeIf(apiKeyData -> apiKeyData.getPerRequestKey() == null);
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String key = entry.getKey();
Key value = entry.getValue();
if (keys.containsKey(key)) {
key = generateApiKey();
keys.clear();
lockService.underBucketLock(API_KEY_DATA_LOCATION, () -> {
for (Map.Entry<String, Key> entry : projectKeys.entrySet()) {
String apiKey = entry.getKey();
Key value = entry.getValue();
ResourceDescription resource = toResource(apiKey);
if (resourceService.hasResource(resource)) {
resource = generateApiKey();
apiKey = resource.getName();
}
value.setKey(apiKey);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(apiKey, apiKeyData);
}
value.setKey(key);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(value);
keys.put(key, apiKeyData);
return null;
});
}

/**
* Updates data associated with per request key.
* If api key belongs to a project the operation will not have affect.
*
* @param apiKeyData per request key data.
*/
public void updatePerRequestApiKeyData(ApiKeyData apiKeyData) {
String apiKey = apiKeyData.getPerRequestKey();
if (apiKey == null) {
return;
}
String json = ProxyUtil.convertToString(apiKeyData);
ResourceDescription resource = toResource(apiKey);
resourceService.putResource(resource, json, true, false);
}

private String generateApiKey() {
private ResourceDescription generateApiKey() {
String apiKey = generateKey();
while (keys.containsKey(apiKey)) {
ResourceDescription resource = toResource(apiKey);
while (resourceService.hasResource(resource) || keys.containsKey(apiKey)) {
log.warn("duplicate API key is found. Trying to generate a new one");
apiKey = generateKey();
resource = toResource(apiKey);
}
return apiKey;
return resource;
}

private static ResourceDescription toResource(String apiKey) {
return ResourceDescription.fromDecoded(
ResourceType.API_KEY_DATA, API_KEY_DATA_BUCKET, API_KEY_DATA_LOCATION, apiKey);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ private ResourceFolderMetadata getFolderMetadata(ResourceDescription descriptor,
return new ResourceFolderMetadata(descriptor, resources).setNextToken(set.getNextMarker());
}

@Nullable
public ResourceItemMetadata getResourceMetadata(ResourceDescription descriptor) {
if (descriptor.isFolder()) {
throw new IllegalArgumentException("Resource folder: " + descriptor.getUrl());
Expand Down
Loading

0 comments on commit 93ddf0f

Please sign in to comment.