From bd3cfe7f94d8c651ab495feaa6f2c610f4b2d2ac Mon Sep 17 00:00:00 2001 From: alekseyvdovenko Date: Thu, 12 Dec 2024 17:08:24 +0100 Subject: [PATCH] feat: add week/month settings for tokens rate-limiting --- .../com/epam/aidial/core/config/Limit.java | 6 ++- .../aidial/core/server/data/LimitStats.java | 2 + .../core/server/limiter/RateLimiter.java | 12 +++++ .../core/server/limiter/RateWindow.java | 4 +- .../core/server/limiter/TokenRateLimit.java | 19 ++++++-- .../epam/aidial/core/server/LimitApiTest.java | 10 ++++- .../core/server/limiter/RateBucketTest.java | 44 +++++++++++++++++++ .../core/server/limiter/RateLimiterTest.java | 10 +++++ 8 files changed, 100 insertions(+), 7 deletions(-) diff --git a/config/src/main/java/com/epam/aidial/core/config/Limit.java b/config/src/main/java/com/epam/aidial/core/config/Limit.java index d930d128a..22626c5b4 100644 --- a/config/src/main/java/com/epam/aidial/core/config/Limit.java +++ b/config/src/main/java/com/epam/aidial/core/config/Limit.java @@ -6,10 +6,12 @@ public class Limit { private long minute = Long.MAX_VALUE; private long day = Long.MAX_VALUE; + private long week = Long.MAX_VALUE; + private long month = Long.MAX_VALUE; private long requestHour = Long.MAX_VALUE; private long requestDay = Long.MAX_VALUE; public boolean isPositive() { - return minute > 0 && day > 0 && requestDay > 0 && requestHour > 0; + return minute > 0 && day > 0 && week > 0 && month > 0 && requestDay > 0 && requestHour > 0; } -} \ No newline at end of file +} diff --git a/server/src/main/java/com/epam/aidial/core/server/data/LimitStats.java b/server/src/main/java/com/epam/aidial/core/server/data/LimitStats.java index a23bae467..b545b043b 100644 --- a/server/src/main/java/com/epam/aidial/core/server/data/LimitStats.java +++ b/server/src/main/java/com/epam/aidial/core/server/data/LimitStats.java @@ -6,6 +6,8 @@ public class LimitStats { private ItemLimitStats minuteTokenStats; private ItemLimitStats dayTokenStats; + private ItemLimitStats weekTokenStats; + private ItemLimitStats monthTokenStats; private ItemLimitStats hourRequestStats; private ItemLimitStats dayRequestStats; } diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java b/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java index aa6145d00..80e0a49d1 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java @@ -141,6 +141,14 @@ private LimitStats create(Limit limit) { dayRequestStats.setTotal(limit.getRequestDay()); limitStats.setDayRequestStats(dayRequestStats); + ItemLimitStats weekTokenStats = new ItemLimitStats(); + weekTokenStats.setTotal(limit.getWeek()); + limitStats.setWeekTokenStats(weekTokenStats); + + ItemLimitStats monthTokenStats = new ItemLimitStats(); + monthTokenStats.setTotal(limit.getMonth()); + limitStats.setMonthTokenStats(monthTokenStats); + return limitStats; } @@ -230,11 +238,15 @@ private Limit getLimitByUser(ProxyContext context, RoleBasedEntity roleBasedEnti limit.setRequestHour(candidate.getRequestHour()); limit.setRequestDay(candidate.getRequestDay()); limit.setDay(candidate.getDay()); + limit.setWeek(candidate.getWeek()); + limit.setMonth(candidate.getMonth()); } else { limit.setMinute(Math.max(candidate.getMinute(), limit.getMinute())); limit.setDay(Math.max(candidate.getDay(), limit.getDay())); limit.setRequestDay(Math.max(candidate.getRequestDay(), limit.getRequestDay())); limit.setRequestHour(Math.max(candidate.getRequestHour(), limit.getRequestHour())); + limit.setWeek(Math.max(candidate.getWeek(), limit.getWeek())); + limit.setMonth(Math.max(candidate.getMonth(), limit.getMonth())); } } } diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/RateWindow.java b/server/src/main/java/com/epam/aidial/core/server/limiter/RateWindow.java index 95556f8a1..95927a05c 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/RateWindow.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/RateWindow.java @@ -8,7 +8,9 @@ public enum RateWindow { MINUTE(60L * 1000, 60), HOUR(60 * 60 * 1000, 60), - DAY(24L * 60 * 60 * 1000, 24); + DAY(24L * 60 * 60 * 1000, 24), + WEEK(7L * 24 * 60 * 60 * 1000, 7), + MONTH(30L * 24 * 60 * 60 * 1000, 30); private final long window; private final long interval; diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java b/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java index 688ba7664..4d3ae8d59 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java @@ -10,20 +10,29 @@ public class TokenRateLimit { private final RateBucket minute = new RateBucket(RateWindow.MINUTE); private final RateBucket day = new RateBucket(RateWindow.DAY); + private final RateBucket week = new RateBucket(RateWindow.WEEK); + private final RateBucket month = new RateBucket(RateWindow.MONTH); public void add(long timestamp, long count) { minute.add(timestamp, count); day.add(timestamp, count); + week.add(timestamp, count); + month.add(timestamp, count); } public RateLimitResult update(long timestamp, Limit limit) { long minuteTotal = minute.update(timestamp); long dayTotal = day.update(timestamp); + long weekTotal = week.update(timestamp); + long monthTotal = month.update(timestamp); - boolean result = minuteTotal >= limit.getMinute() || dayTotal >= limit.getDay(); + boolean result = minuteTotal >= limit.getMinute() || dayTotal >= limit.getDay() + || weekTotal >= limit.getWeek() || monthTotal >= limit.getMonth(); if (result) { - String errorMsg = String.format("Hit token rate limit. Minute limit: %d / %d tokens. Day limit: %d / %d tokens.", - minuteTotal, limit.getMinute(), dayTotal, limit.getDay()); + + String errorMsg = String.format( + "Hit token rate limit. Minute limit: %d / %d tokens. Day limit: %d / %d tokens. Week limit: %d / %d tokens. Month limit: %d / %d tokens.", + minuteTotal, limit.getMinute(), dayTotal, limit.getDay(), weekTotal, limit.getWeek(), monthTotal, limit.getMonth()); long minuteRetryAfter = minute.retryAfter(limit.getMinute()); long dayRetryAfter = day.retryAfter(limit.getDay()); long retryAfter = Math.max(minuteRetryAfter, dayRetryAfter); @@ -36,7 +45,11 @@ public RateLimitResult update(long timestamp, Limit limit) { public void update(long timestamp, LimitStats limitStats) { long minuteTotal = minute.update(timestamp); long dayTotal = day.update(timestamp); + long weekTotal = week.update(timestamp); + long monthTotal = month.update(timestamp); limitStats.getDayTokenStats().setUsed(dayTotal); limitStats.getMinuteTokenStats().setUsed(minuteTotal); + limitStats.getWeekTokenStats().setUsed(weekTotal); + limitStats.getMonthTokenStats().setUsed(monthTotal); } } diff --git a/server/src/test/java/com/epam/aidial/core/server/LimitApiTest.java b/server/src/test/java/com/epam/aidial/core/server/LimitApiTest.java index 8364ec076..5403fc661 100644 --- a/server/src/test/java/com/epam/aidial/core/server/LimitApiTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/LimitApiTest.java @@ -18,6 +18,14 @@ public void testGetLimitStats_Success() { "total": %d, "used": %d }, + "weekTokenStats": { + "total": %d, + "used": %d + }, + "monthTokenStats": { + "total": %d, + "used": %d + }, "hourRequestStats": { "total": %d, "used": %d @@ -27,7 +35,7 @@ public void testGetLimitStats_Success() { "used": %d } } - """.formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0)); + """.formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0)); } @Test diff --git a/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java b/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java index b4b21c4ff..c10304313 100644 --- a/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java @@ -55,6 +55,7 @@ void testDayBucket() { } @Test + public void testRetryAfterMinute() { bucket = new RateBucket(RateWindow.MINUTE); @@ -77,6 +78,49 @@ public void testRetryAfterMinute() { assertEquals(15, bucket.retryAfter(30)); } + void testWeekBucket() { + bucket = new RateBucket(RateWindow.WEEK); + + update(0, 0); + add(0, 10, 10); + add(0, 20, 30); + update(0, 30); + + add(1, 30, 60); + add(6, 40, 100); + update(6, 100); + + add(7, 10, 80); + update(7, 80); + + add(8, 5, 55); + update(8, 55); + + update(15, 0); + } + + @Test + void testMonthBucket() { + bucket = new RateBucket(RateWindow.MONTH); + + update(0, 0); + add(0, 10, 10); + add(0, 20, 30); + update(0, 30); + + add(1, 30, 60); + add(29, 40, 100); + update(29, 100); + + add(30, 10, 80); + update(30, 80); + + add(31, 5, 55); + update(31, 55); + + update(61, 0); + } + @Test public void testRetryAfterDay() { bucket = new RateBucket(RateWindow.DAY); diff --git a/server/src/test/java/com/epam/aidial/core/server/limiter/RateLimiterTest.java b/server/src/test/java/com/epam/aidial/core/server/limiter/RateLimiterTest.java index f6f4cce90..85ce887c1 100644 --- a/server/src/test/java/com/epam/aidial/core/server/limiter/RateLimiterTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/limiter/RateLimiterTest.java @@ -231,6 +231,8 @@ public void testGetLimitStats_ApiKey() { limit.setMinute(100); limit.setRequestDay(10); limit.setRequestHour(2); + limit.setWeek(1000000); + limit.setMonth(10000000); role.setLimits(Map.of("model", limit)); config.setRoles(Map.of("role", role)); ApiKeyData apiKeyData = new ApiKeyData(); @@ -271,6 +273,10 @@ public void testGetLimitStats_ApiKey() { assertEquals(1, limitStats.getDayRequestStats().getUsed()); assertEquals(2, limitStats.getHourRequestStats().getTotal()); assertEquals(1, limitStats.getHourRequestStats().getUsed()); + assertEquals(1000000, limitStats.getWeekTokenStats().getTotal()); + assertEquals(90, limitStats.getWeekTokenStats().getUsed()); + assertEquals(10000000, limitStats.getMonthTokenStats().getTotal()); + assertEquals(90, limitStats.getMonthTokenStats().getUsed()); increaseLimitFuture = rateLimiter.increase(proxyContext, model); assertNotNull(increaseLimitFuture); @@ -285,6 +291,10 @@ public void testGetLimitStats_ApiKey() { assertEquals(180, limitStats.getDayTokenStats().getUsed()); assertEquals(100, limitStats.getMinuteTokenStats().getTotal()); assertEquals(180, limitStats.getMinuteTokenStats().getUsed()); + assertEquals(1000000, limitStats.getWeekTokenStats().getTotal()); + assertEquals(180, limitStats.getWeekTokenStats().getUsed()); + assertEquals(10000000, limitStats.getMonthTokenStats().getTotal()); + assertEquals(180, limitStats.getMonthTokenStats().getUsed()); }