From aa500cf538aeeb15919dc46c2ef225f0fa87df6e Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Tue, 20 Feb 2024 18:28:59 +0300 Subject: [PATCH] fix: verify caller has access to attached files --- .../java/com/epam/aidial/core/AiDial.java | 4 +- src/main/java/com/epam/aidial/core/Proxy.java | 2 + .../AccessControlBaseController.java | 2 +- .../controller/DeploymentPostController.java | 7 +- .../aidial/core/security/AccessService.java | 39 ++++++++ .../com/epam/aidial/core/util/ProxyUtil.java | 12 ++- .../epam/aidial/core/util/ProxyUtilTest.java | 95 ++++++++++++++++++- 7 files changed, 151 insertions(+), 10 deletions(-) create mode 100644 src/main/java/com/epam/aidial/core/security/AccessService.java diff --git a/src/main/java/com/epam/aidial/core/AiDial.java b/src/main/java/com/epam/aidial/core/AiDial.java index 807c34d0..cd1afe0d 100644 --- a/src/main/java/com/epam/aidial/core/AiDial.java +++ b/src/main/java/com/epam/aidial/core/AiDial.java @@ -7,6 +7,7 @@ import com.epam.aidial.core.limiter.RateLimiter; import com.epam.aidial.core.log.GfLogStore; import com.epam.aidial.core.log.LogStore; +import com.epam.aidial.core.security.AccessService; import com.epam.aidial.core.security.AccessTokenValidator; import com.epam.aidial.core.security.ApiKeyStore; import com.epam.aidial.core.security.EncryptionService; @@ -79,7 +80,8 @@ void start() throws Exception { storage = new BlobStorage(storageConfig); } EncryptionService encryptionService = new EncryptionService(Json.decodeValue(settings("encryption").toBuffer(), Encryption.class)); - proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage, encryptionService, apiKeyStore); + AccessService accessService = new AccessService(encryptionService); + proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage, encryptionService, apiKeyStore, accessService); server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy); open(server, HttpServer::listen); diff --git a/src/main/java/com/epam/aidial/core/Proxy.java b/src/main/java/com/epam/aidial/core/Proxy.java index b023686a..96fabcbb 100644 --- a/src/main/java/com/epam/aidial/core/Proxy.java +++ b/src/main/java/com/epam/aidial/core/Proxy.java @@ -7,6 +7,7 @@ import com.epam.aidial.core.controller.ControllerSelector; import com.epam.aidial.core.limiter.RateLimiter; import com.epam.aidial.core.log.LogStore; +import com.epam.aidial.core.security.AccessService; import com.epam.aidial.core.security.AccessTokenValidator; import com.epam.aidial.core.security.ApiKeyStore; import com.epam.aidial.core.security.EncryptionService; @@ -65,6 +66,7 @@ public class Proxy implements Handler { private final BlobStorage storage; private final EncryptionService encryptionService; private final ApiKeyStore apiKeyStore; + private final AccessService accessService; @Override public void handle(HttpServerRequest request) { diff --git a/src/main/java/com/epam/aidial/core/controller/AccessControlBaseController.java b/src/main/java/com/epam/aidial/core/controller/AccessControlBaseController.java index b4fb0811..b0d0546b 100644 --- a/src/main/java/com/epam/aidial/core/controller/AccessControlBaseController.java +++ b/src/main/java/com/epam/aidial/core/controller/AccessControlBaseController.java @@ -28,7 +28,7 @@ public Future handle(String bucket, String filePath) { String urlDecodedBucket = UrlUtil.decodePath(bucket); String decryptedBucket = proxy.getEncryptionService().decrypt(urlDecodedBucket); boolean hasReadAccess = isSharedWithMe(bucket, filePath); - boolean hasWriteAccess = hasWriteAccess(filePath, decryptedBucket); + boolean hasWriteAccess = proxy.getAccessService().hasWriteAccess(bucket, filePath, context); boolean hasAccess = checkFullAccess ? hasWriteAccess : hasReadAccess || hasWriteAccess; if (!hasAccess) { diff --git a/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java b/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java index 3c775268..99106769 100644 --- a/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java +++ b/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java @@ -155,7 +155,12 @@ void handleRequestBody(Buffer requestBody) { ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream); try { - ProxyUtil.collectAttachedFiles(tree, proxyApiKeyData); + ProxyUtil.collectAttachedFiles(tree, proxyApiKeyData, link -> proxy.getAccessService().hasWriteAccess(link, context)); + } catch (HttpException e) { + respond(e.getStatus(), e.getMessage()); + log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}", + context.getTraceId(), context.getSpanId(), e.getMessage()); + return; } catch (Throwable e) { context.respond(HttpStatus.BAD_REQUEST); log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}", diff --git a/src/main/java/com/epam/aidial/core/security/AccessService.java b/src/main/java/com/epam/aidial/core/security/AccessService.java new file mode 100644 index 00000000..d30e37c1 --- /dev/null +++ b/src/main/java/com/epam/aidial/core/security/AccessService.java @@ -0,0 +1,39 @@ +package com.epam.aidial.core.security; + +import com.epam.aidial.core.ProxyContext; +import com.epam.aidial.core.storage.BlobStorageUtil; +import com.epam.aidial.core.util.UrlUtil; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class AccessService { + + private final EncryptionService encryptionService; + + public boolean hasWriteAccess(String bucket, String filePath, ProxyContext context) { + String urlDecodedBucket = UrlUtil.decodePath(bucket); + String decryptedBucket = encryptionService.decrypt(urlDecodedBucket); + String expectedUserBucket = BlobStorageUtil.buildUserBucket(context); + if (expectedUserBucket.equals(decryptedBucket)) { + return true; + } + String expectedAppDataBucket = BlobStorageUtil.buildAppDataBucket(context); + if (expectedAppDataBucket != null && expectedAppDataBucket.equals(decryptedBucket)) { + return filePath.startsWith(BlobStorageUtil.APPDATA_PATTERN.formatted(UrlUtil.encodePath(context.getSourceDeployment()))); + } + return false; + } + + public boolean hasWriteAccess(String url, ProxyContext context) { + if (url == null) { + return false; + } + int index = url.indexOf(BlobStorageUtil.PATH_SEPARATOR); + if (index < 0) { + return false; + } + String bucket = url.substring(0, index); + String path = url.substring(index + 1); + return hasWriteAccess(bucket, path, context); + } +} diff --git a/src/main/java/com/epam/aidial/core/util/ProxyUtil.java b/src/main/java/com/epam/aidial/core/util/ProxyUtil.java index 03c758db..329f2b21 100644 --- a/src/main/java/com/epam/aidial/core/util/ProxyUtil.java +++ b/src/main/java/com/epam/aidial/core/util/ProxyUtil.java @@ -2,6 +2,7 @@ import com.epam.aidial.core.Proxy; import com.epam.aidial.core.config.ApiKeyData; +import com.epam.aidial.core.security.AccessService; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.MapperFeature; @@ -15,8 +16,11 @@ import lombok.experimental.UtilityClass; import java.util.Map; +import java.util.function.Function; import javax.annotation.Nullable; +import static com.epam.aidial.core.util.HttpStatus.FORBIDDEN; + @UtilityClass public class ProxyUtil { @@ -83,7 +87,7 @@ private static int contentLength(MultiMap header, int defaultValue) { return defaultValue; } - public static void collectAttachedFiles(ObjectNode tree, ApiKeyData apiKeyData) { + public static void collectAttachedFiles(ObjectNode tree, ApiKeyData apiKeyData, Function checkAccessFn) throws Exception { ArrayNode messages = (ArrayNode) tree.get("messages"); if (messages == null) { return; @@ -102,7 +106,11 @@ public static void collectAttachedFiles(ObjectNode tree, ApiKeyData apiKeyData) JsonNode attachment = attachments.get(j); JsonNode url = attachment.get("url"); if (url != null) { - apiKeyData.getAttachedFiles().add(url.textValue()); + if (checkAccessFn.apply(url.textValue())) { + apiKeyData.getAttachedFiles().add(url.textValue()); + } else { + throw new HttpException(FORBIDDEN, "Access denied to the file %s"); + } } } } diff --git a/src/test/java/com/epam/aidial/core/util/ProxyUtilTest.java b/src/test/java/com/epam/aidial/core/util/ProxyUtilTest.java index 18b5bad1..b9d400b3 100644 --- a/src/test/java/com/epam/aidial/core/util/ProxyUtilTest.java +++ b/src/test/java/com/epam/aidial/core/util/ProxyUtilTest.java @@ -4,16 +4,16 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import org.junit.jupiter.api.Test; -import java.io.IOException; import java.util.Set; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class ProxyUtilTest { @Test - public void testCollectAttachedFiles_ChatRequest() throws IOException { + public void testCollectAttachedFiles_ChatRequest() throws Exception { String content = """ { "modelId": "model", @@ -94,13 +94,13 @@ public void testCollectAttachedFiles_ChatRequest() throws IOException { """; ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content.getBytes()); ApiKeyData apiKeyData = new ApiKeyData(); - ProxyUtil.collectAttachedFiles(tree, apiKeyData); + ProxyUtil.collectAttachedFiles(tree, apiKeyData, link -> true); assertEquals(Set.of("b1/Dockerfile", "b1/LICENSE"), apiKeyData.getAttachedFiles()); } @Test - public void testCollectAttachedFiles_EmbeddingRequest() throws IOException { + public void testCollectAttachedFiles_EmbeddingRequest() throws Exception { String content = """ { "input": "some input", @@ -109,8 +109,93 @@ public void testCollectAttachedFiles_EmbeddingRequest() throws IOException { """; ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content.getBytes()); ApiKeyData apiKeyData = new ApiKeyData(); - ProxyUtil.collectAttachedFiles(tree, apiKeyData); + ProxyUtil.collectAttachedFiles(tree, apiKeyData, link -> true); assertTrue(apiKeyData.getAttachedFiles().isEmpty()); } + + @Test + public void testCollectAttachedFiles_ChatRequest_NoPermissions() throws Exception { + String content = """ + { + "modelId": "model", + "messages": [ + { + "content": "test", + "role": "user", + "custom_content": { + } + }, + { + "content": "I'm sorry, but your message is unclear. Could you please provide more details or context?", + "role": "assistant" + }, + { + "content": "what file is?", + "role": "user", + "custom_content": { + "attachments": [ + { + "type": "application/octet-stream", + "title": "Dockerfile", + "url": "b1/Dockerfile" + } + ] + } + }, + { + "content": "The file you provided is a Dockerfile.", + "role": "assistant", + "custom_content": { + "attachments": [ + { + "index": 0, + "type": "text/markdown", + "title": "[1] 'Dockerfile'", + "data": "FROM gradle:8.2.0", + "reference_url": "b1/Dockerfile" + }, + { + "index": 1, + "type": "text/markdown", + "title": "[2] 'Dockerfile'", + "data": "* /app/config/ RUN mkdir /app/log && chown -R appuser:appuser /app", + "reference_url": "b1/Dockerfile" + }, + { + "index": 2, + "type": "text/markdown", + "title": "[3] 'Dockerfile'", + "data": "USER appuser", + "reference_url": "b1/Dockerfile" + } + ] + } + }, + { + "content": "Compare these files?", + "role": "user", + "custom_content": { + "attachments": [ + { + "type": "application/octet-stream", + "title": "LICENSE", + "url": "b1/LICENSE" + }, + { + "type": "binary/octet-stream", + "title": "Dockerfile", + "url": "b1/Dockerfile" + } + ] + } + } + ], + "id": "id" + } + """; + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content.getBytes()); + ApiKeyData apiKeyData = new ApiKeyData(); + assertThrows(HttpException.class, () -> ProxyUtil.collectAttachedFiles(tree, apiKeyData, link -> false)); + } }