Skip to content

Commit

Permalink
feat: add week/month settings for tokens rate-limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
alekseyvdovenko committed Dec 16, 2024
1 parent caeba5c commit bd3cfe7
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 7 deletions.
6 changes: 4 additions & 2 deletions config/src/main/java/com/epam/aidial/core/config/Limit.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
public class Limit {
private long minute = Long.MAX_VALUE;
private long day = Long.MAX_VALUE;
private long week = Long.MAX_VALUE;
private long month = Long.MAX_VALUE;
private long requestHour = Long.MAX_VALUE;
private long requestDay = Long.MAX_VALUE;

public boolean isPositive() {
return minute > 0 && day > 0 && requestDay > 0 && requestHour > 0;
return minute > 0 && day > 0 && week > 0 && month > 0 && requestDay > 0 && requestHour > 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
public class LimitStats {
private ItemLimitStats minuteTokenStats;
private ItemLimitStats dayTokenStats;
private ItemLimitStats weekTokenStats;
private ItemLimitStats monthTokenStats;
private ItemLimitStats hourRequestStats;
private ItemLimitStats dayRequestStats;
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ private LimitStats create(Limit limit) {
dayRequestStats.setTotal(limit.getRequestDay());
limitStats.setDayRequestStats(dayRequestStats);

ItemLimitStats weekTokenStats = new ItemLimitStats();
weekTokenStats.setTotal(limit.getWeek());
limitStats.setWeekTokenStats(weekTokenStats);

ItemLimitStats monthTokenStats = new ItemLimitStats();
monthTokenStats.setTotal(limit.getMonth());
limitStats.setMonthTokenStats(monthTokenStats);

return limitStats;
}

Expand Down Expand Up @@ -230,11 +238,15 @@ private Limit getLimitByUser(ProxyContext context, RoleBasedEntity roleBasedEnti
limit.setRequestHour(candidate.getRequestHour());
limit.setRequestDay(candidate.getRequestDay());
limit.setDay(candidate.getDay());
limit.setWeek(candidate.getWeek());
limit.setMonth(candidate.getMonth());
} else {
limit.setMinute(Math.max(candidate.getMinute(), limit.getMinute()));
limit.setDay(Math.max(candidate.getDay(), limit.getDay()));
limit.setRequestDay(Math.max(candidate.getRequestDay(), limit.getRequestDay()));
limit.setRequestHour(Math.max(candidate.getRequestHour(), limit.getRequestHour()));
limit.setWeek(Math.max(candidate.getWeek(), limit.getWeek()));
limit.setMonth(Math.max(candidate.getMonth(), limit.getMonth()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
public enum RateWindow {
MINUTE(60L * 1000, 60),
HOUR(60 * 60 * 1000, 60),
DAY(24L * 60 * 60 * 1000, 24);
DAY(24L * 60 * 60 * 1000, 24),
WEEK(7L * 24 * 60 * 60 * 1000, 7),
MONTH(30L * 24 * 60 * 60 * 1000, 30);

private final long window;
private final long interval;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,29 @@ public class TokenRateLimit {

private final RateBucket minute = new RateBucket(RateWindow.MINUTE);
private final RateBucket day = new RateBucket(RateWindow.DAY);
private final RateBucket week = new RateBucket(RateWindow.WEEK);
private final RateBucket month = new RateBucket(RateWindow.MONTH);

public void add(long timestamp, long count) {
minute.add(timestamp, count);
day.add(timestamp, count);
week.add(timestamp, count);
month.add(timestamp, count);
}

public RateLimitResult update(long timestamp, Limit limit) {
long minuteTotal = minute.update(timestamp);
long dayTotal = day.update(timestamp);
long weekTotal = week.update(timestamp);
long monthTotal = month.update(timestamp);

boolean result = minuteTotal >= limit.getMinute() || dayTotal >= limit.getDay();
boolean result = minuteTotal >= limit.getMinute() || dayTotal >= limit.getDay()
|| weekTotal >= limit.getWeek() || monthTotal >= limit.getMonth();
if (result) {
String errorMsg = String.format("Hit token rate limit. Minute limit: %d / %d tokens. Day limit: %d / %d tokens.",
minuteTotal, limit.getMinute(), dayTotal, limit.getDay());

String errorMsg = String.format(
"Hit token rate limit. Minute limit: %d / %d tokens. Day limit: %d / %d tokens. Week limit: %d / %d tokens. Month limit: %d / %d tokens.",
minuteTotal, limit.getMinute(), dayTotal, limit.getDay(), weekTotal, limit.getWeek(), monthTotal, limit.getMonth());
long minuteRetryAfter = minute.retryAfter(limit.getMinute());
long dayRetryAfter = day.retryAfter(limit.getDay());
long retryAfter = Math.max(minuteRetryAfter, dayRetryAfter);
Expand All @@ -36,7 +45,11 @@ public RateLimitResult update(long timestamp, Limit limit) {
public void update(long timestamp, LimitStats limitStats) {
long minuteTotal = minute.update(timestamp);
long dayTotal = day.update(timestamp);
long weekTotal = week.update(timestamp);
long monthTotal = month.update(timestamp);
limitStats.getDayTokenStats().setUsed(dayTotal);
limitStats.getMinuteTokenStats().setUsed(minuteTotal);
limitStats.getWeekTokenStats().setUsed(weekTotal);
limitStats.getMonthTokenStats().setUsed(monthTotal);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ public void testGetLimitStats_Success() {
"total": %d,
"used": %d
},
"weekTokenStats": {
"total": %d,
"used": %d
},
"monthTokenStats": {
"total": %d,
"used": %d
},
"hourRequestStats": {
"total": %d,
"used": %d
Expand All @@ -27,7 +35,7 @@ public void testGetLimitStats_Success() {
"used": %d
}
}
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void testDayBucket() {
}

@Test

public void testRetryAfterMinute() {
bucket = new RateBucket(RateWindow.MINUTE);

Expand All @@ -77,6 +78,49 @@ public void testRetryAfterMinute() {
assertEquals(15, bucket.retryAfter(30));
}

void testWeekBucket() {
bucket = new RateBucket(RateWindow.WEEK);

update(0, 0);
add(0, 10, 10);
add(0, 20, 30);
update(0, 30);

add(1, 30, 60);
add(6, 40, 100);
update(6, 100);

add(7, 10, 80);
update(7, 80);

add(8, 5, 55);
update(8, 55);

update(15, 0);
}

@Test
void testMonthBucket() {
bucket = new RateBucket(RateWindow.MONTH);

update(0, 0);
add(0, 10, 10);
add(0, 20, 30);
update(0, 30);

add(1, 30, 60);
add(29, 40, 100);
update(29, 100);

add(30, 10, 80);
update(30, 80);

add(31, 5, 55);
update(31, 55);

update(61, 0);
}

@Test
public void testRetryAfterDay() {
bucket = new RateBucket(RateWindow.DAY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ public void testGetLimitStats_ApiKey() {
limit.setMinute(100);
limit.setRequestDay(10);
limit.setRequestHour(2);
limit.setWeek(1000000);
limit.setMonth(10000000);
role.setLimits(Map.of("model", limit));
config.setRoles(Map.of("role", role));
ApiKeyData apiKeyData = new ApiKeyData();
Expand Down Expand Up @@ -271,6 +273,10 @@ public void testGetLimitStats_ApiKey() {
assertEquals(1, limitStats.getDayRequestStats().getUsed());
assertEquals(2, limitStats.getHourRequestStats().getTotal());
assertEquals(1, limitStats.getHourRequestStats().getUsed());
assertEquals(1000000, limitStats.getWeekTokenStats().getTotal());
assertEquals(90, limitStats.getWeekTokenStats().getUsed());
assertEquals(10000000, limitStats.getMonthTokenStats().getTotal());
assertEquals(90, limitStats.getMonthTokenStats().getUsed());

increaseLimitFuture = rateLimiter.increase(proxyContext, model);
assertNotNull(increaseLimitFuture);
Expand All @@ -285,6 +291,10 @@ public void testGetLimitStats_ApiKey() {
assertEquals(180, limitStats.getDayTokenStats().getUsed());
assertEquals(100, limitStats.getMinuteTokenStats().getTotal());
assertEquals(180, limitStats.getMinuteTokenStats().getUsed());
assertEquals(1000000, limitStats.getWeekTokenStats().getTotal());
assertEquals(180, limitStats.getWeekTokenStats().getUsed());
assertEquals(10000000, limitStats.getMonthTokenStats().getTotal());
assertEquals(180, limitStats.getMonthTokenStats().getUsed());

}

Expand Down

0 comments on commit bd3cfe7

Please sign in to comment.