Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/development' into cleanup-config
Browse files Browse the repository at this point in the history
  • Loading branch information
artsiomkorzun committed Feb 1, 2024
2 parents dddaee9 + 79ed1b2 commit 8b8df83
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 88 deletions.
3 changes: 2 additions & 1 deletion src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ void start() throws Exception {
ApiKeyStore apiKeyStore = new ApiKeyStore();
ConfigStore configStore = new FileConfigStore(vertx, settings("config"), apiKeyStore);
LogStore logStore = new GfLogStore(vertx);
RateLimiter rateLimiter = new RateLimiter();
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();
AccessTokenValidator accessTokenValidator = new AccessTokenValidator(settings("identityProviders"), vertx);
if (storage == null) {
Expand All @@ -99,6 +98,8 @@ void start() throws Exception {
resourceService = new ResourceService(vertx, redis, storage, lockService, settings("resources"));
}

RateLimiter rateLimiter = new RateLimiter(vertx, resourceService);

proxy = new Proxy(vertx, client, configStore, logStore,
rateLimiter, upstreamBalancer, accessTokenValidator,
storage, encryptionService, apiKeyStore, tokenStatsTracker, resourceService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,29 @@ public Future<?> handle(String deploymentId, String deploymentApi) {

context.setDeployment(deployment);

RateLimitResult rateLimitResult;
if (deployment instanceof Model && (rateLimitResult = proxy.getRateLimiter().limit(context)).status() != HttpStatus.OK) {
// Returning an error similar to the Azure format.
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(rateLimitResult.status().getCode()));
rateLimitError.getError().setMessage(rateLimitResult.errorMessage());
log.error("Rate limit error {}. Key: {}. User sub: {}", rateLimitResult.errorMessage(), context.getProject(), context.getUserSub());
return respond(rateLimitResult.status(), rateLimitError);
Future<RateLimitResult> rateLimitResultFuture;
if (deployment instanceof Model) {
rateLimitResultFuture = proxy.getRateLimiter().limit(context);
} else {
rateLimitResultFuture = Future.succeededFuture(RateLimitResult.SUCCESS);
}

return rateLimitResultFuture.onSuccess(result -> {
if (result.status() == HttpStatus.OK) {
handleRateLimitSuccess(deploymentId);
} else {
handleRateLimitHit(result);
}
}).onFailure(this::handleRateLimitFailure);
}

private void handleRateLimitSuccess(String deploymentId) {
log.info("Received request from client. Trace: {}. Span: {}. Key: {}. Deployment: {}. Headers: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getRequest().headers().size());

UpstreamProvider endpointProvider = new DeploymentUpstreamProvider(deployment);
UpstreamProvider endpointProvider = new DeploymentUpstreamProvider(context.getDeployment());
UpstreamRoute endpointRoute = proxy.getUpstreamBalancer().balance(endpointProvider);
context.setUpstreamRoute(endpointRoute);

Expand All @@ -109,7 +116,8 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
context.getTraceId(), context.getSpanId(),
context.getProject(), deploymentId, context.getUserSub());

return respond(HttpStatus.BAD_GATEWAY, "No route");
respond(HttpStatus.BAD_GATEWAY, "No route");
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
Expand All @@ -119,11 +127,27 @@ public Future<?> handle(String deploymentId, String deploymentApi) {

proxy.getTokenStatsTracker().startSpan(context);

return context.getRequest().body()
context.getRequest().body()
.onSuccess(this::handleRequestBody)
.onFailure(this::handleRequestBodyError);
}

private void handleRateLimitHit(RateLimitResult result) {
// Returning an error similar to the Azure format.
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(result.status().getCode()));
rateLimitError.getError().setMessage(result.errorMessage());
log.error("Rate limit error {}. Key: {}. User sub: {}. Trace: {}. Span: {}", result.errorMessage(),
context.getProject(), context.getUserSub(), context.getTraceId(), context.getSpanId());
respond(result.status(), rateLimitError);
}

private void handleRateLimitFailure(Throwable error) {
log.warn("Failed to check limits. Key: {}. User sub: {}. Trace: {}. Span: {}. Error: {}",
context.getProject(), context.getUserSub(), context.getTraceId(), context.getSpanId(), error.getMessage());
respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to check limits");
}

@SneakyThrows
private Future<?> sendRequest() {
UpstreamRoute route = context.getUpstreamRoute();
Expand Down Expand Up @@ -312,15 +336,18 @@ void handleResponse() {
tokenUsage = new TokenUsage();
}
context.setTokenUsage(tokenUsage);
proxy.getRateLimiter().increase(context);
proxy.getRateLimiter().increase(context).onFailure(error -> {
log.warn("Failed to increase limit. Trace: {}. Span: {}",
context.getTraceId(), context.getSpanId(), error);
});
tokenUsageFuture = Future.succeededFuture(tokenUsage);
try {
BigDecimal cost = ModelCostCalculator.calculate(context);
tokenUsage.setCost(cost);
tokenUsage.setAggCost(cost);
} catch (Throwable e) {
log.warn("Failed to calculate cost for model={}. Trace: {}. Span: {}",
context.getDeployment().getName(), context.getTraceId(), context.getSpanId());
context.getDeployment().getName(), context.getTraceId(), context.getSpanId(), e);
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/epam/aidial/core/data/ResourceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@Getter
public enum ResourceType {
FILE("files"), CONVERSATION("conversations"), PROMPT("prompts");
FILE("files"), CONVERSATION("conversations"), PROMPT("prompts"), LIMIT("limits");

private final String group;

Expand Down
12 changes: 6 additions & 6 deletions src/main/java/com/epam/aidial/core/limiter/RateBucket.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package com.epam.aidial.core.limiter;

import lombok.Getter;
import lombok.experimental.Accessors;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@NoArgsConstructor
public class RateBucket {

@Getter
@Accessors(fluent = true)
private final RateWindow window;
private final long[] sums;
private RateWindow window;
private long[] sums;
private long sum;
private long start = Long.MIN_VALUE;
private long end = Long.MIN_VALUE;
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/com/epam/aidial/core/limiter/RateLimit.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.util.HttpStatus;
import lombok.Data;

@Data
public class RateLimit {

private final RateBucket minute = new RateBucket(RateWindow.MINUTE);
private final RateBucket day = new RateBucket(RateWindow.DAY);

public synchronized void add(long timestamp, long count) {
public void add(long timestamp, long count) {
minute.add(timestamp, count);
day.add(timestamp, count);
}

public synchronized RateLimitResult update(long timestamp, Limit limit) {
public RateLimitResult update(long timestamp, Limit limit) {
long minuteTotal = minute.update(timestamp);
long dayTotal = day.update(timestamp);

Expand Down
134 changes: 91 additions & 43 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,116 @@
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.BlobStorageUtil;
import com.epam.aidial.core.storage.ResourceDescription;
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@RequiredArgsConstructor
public class RateLimiter {
private final ConcurrentHashMap<Id, RateLimit> rates = new ConcurrentHashMap<>();

public void increase(ProxyContext context) {
Key key = context.getKey();
if (key == null) {
return;
}
Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();
private final Vertx vertx;

if (usage == null || usage.getTotalTokens() <= 0) {
return;
}
private final ResourceService resourceService;

Id id = new Id(key.getKey(), deployment.getName());
RateLimit rate = rates.computeIfAbsent(id, k -> new RateLimit());
public Future<Void> increase(ProxyContext context) {
try {
// skip checking limits if redis is not available
if (resourceService == null) {
return Future.succeededFuture();
}
Key key = context.getKey();
if (key == null) {
return Future.succeededFuture();
}
Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();

long timestamp = System.currentTimeMillis();
rate.add(timestamp, usage.getTotalTokens());
}
if (usage == null || usage.getTotalTokens() <= 0) {
return Future.succeededFuture();
}

public RateLimitResult limit(ProxyContext context) {
Key key = context.getKey();
Limit limit;
if (key == null) {
// don't support user limits yet
return RateLimitResult.SUCCESS;
} else {
limit = getLimitByApiKey(context);
String path = getPath(deployment.getName());
return vertx.executeBlocking(() -> updateLimit(path, context, usage.getTotalTokens()));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

Deployment deployment = context.getDeployment();

if (limit == null || !limit.isPositive()) {
if (limit == null) {
log.warn("Limit is not found for deployment: {}", deployment.getName());
public Future<RateLimitResult> limit(ProxyContext context) {
try {
// skip checking limits if redis is not available
if (resourceService == null) {
return Future.succeededFuture(RateLimitResult.SUCCESS);
}
Key key = context.getKey();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture(RateLimitResult.SUCCESS);
} else {
log.warn("Limit must be positive for deployment: {}", deployment.getName());
limit = getLimitByApiKey(context);
}
return new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied");
}

Id id = new Id(key.getKey(), deployment.getName());
RateLimit rate = rates.get(id);
Deployment deployment = context.getDeployment();

if (limit == null || !limit.isPositive()) {
if (limit == null) {
log.warn("Limit is not found for deployment: {}", deployment.getName());
} else {
log.warn("Limit must be positive for deployment: {}", deployment.getName());
}
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
}

if (rate == null) {
String path = getPath(deployment.getName());
return vertx.executeBlocking(() -> checkLimit(path, limit, context));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

private RateLimitResult checkLimit(String path, Limit limit, ProxyContext context) throws Exception {
RateLimit rateLimit;
String bucketLocation = BlobStorageUtil.buildUserBucket(context);
ResourceDescription resourceDescription = ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
String prevValue = resourceService.getResource(resourceDescription);
if (prevValue == null) {
return RateLimitResult.SUCCESS;
} else {
rateLimit = ProxyUtil.MAPPER.readValue(prevValue, RateLimit.class);
}
long timestamp = System.currentTimeMillis();
return rateLimit.update(timestamp, limit);
}

private Void updateLimit(String path, ProxyContext context, long totalUsedTokens) {
String bucketLocation = BlobStorageUtil.buildUserBucket(context);
ResourceDescription resourceDescription = ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
resourceService.computeResource(resourceDescription, json -> updateLimit(json, totalUsedTokens));
return null;
}

@SneakyThrows
private String updateLimit(String json, long totalUsedTokens) {
RateLimit rateLimit;
if (json == null) {
rateLimit = new RateLimit();
} else {
rateLimit = ProxyUtil.MAPPER.readValue(json, RateLimit.class);
}
long timestamp = System.currentTimeMillis();
return rate.update(timestamp, limit);
rateLimit.add(timestamp, totalUsedTokens);
return ProxyUtil.MAPPER.writeValueAsString(rateLimit);
}

private Limit getLimitByApiKey(ProxyContext context) {
Expand All @@ -79,11 +130,8 @@ private Limit getLimitByApiKey(ProxyContext context) {
return role.getLimits().get(deployment.getName());
}

private record Id(String key, String resource) {
@Override
public String toString() {
return String.format("key: %s, resource: %s", key, resource);
}
private static String getPath(String deploymentName) {
return String.format("%s/tokens", deploymentName);
}

}
Loading

0 comments on commit 8b8df83

Please sign in to comment.