Skip to content

Commit

Permalink
fix: verify caller has access to attached files
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay committed Feb 20, 2024
1 parent dd9faa7 commit aa500cf
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.epam.aidial.core.limiter.RateLimiter;
import com.epam.aidial.core.log.GfLogStore;
import com.epam.aidial.core.log.LogStore;
import com.epam.aidial.core.security.AccessService;
import com.epam.aidial.core.security.AccessTokenValidator;
import com.epam.aidial.core.security.ApiKeyStore;
import com.epam.aidial.core.security.EncryptionService;
Expand Down Expand Up @@ -79,7 +80,8 @@ void start() throws Exception {
storage = new BlobStorage(storageConfig);
}
EncryptionService encryptionService = new EncryptionService(Json.decodeValue(settings("encryption").toBuffer(), Encryption.class));
proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage, encryptionService, apiKeyStore);
AccessService accessService = new AccessService(encryptionService);
proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage, encryptionService, apiKeyStore, accessService);

server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy);
open(server, HttpServer::listen);
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.epam.aidial.core.controller.ControllerSelector;
import com.epam.aidial.core.limiter.RateLimiter;
import com.epam.aidial.core.log.LogStore;
import com.epam.aidial.core.security.AccessService;
import com.epam.aidial.core.security.AccessTokenValidator;
import com.epam.aidial.core.security.ApiKeyStore;
import com.epam.aidial.core.security.EncryptionService;
Expand Down Expand Up @@ -65,6 +66,7 @@ public class Proxy implements Handler<HttpServerRequest> {
private final BlobStorage storage;
private final EncryptionService encryptionService;
private final ApiKeyStore apiKeyStore;
private final AccessService accessService;

@Override
public void handle(HttpServerRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public Future<?> handle(String bucket, String filePath) {
String urlDecodedBucket = UrlUtil.decodePath(bucket);
String decryptedBucket = proxy.getEncryptionService().decrypt(urlDecodedBucket);
boolean hasReadAccess = isSharedWithMe(bucket, filePath);
boolean hasWriteAccess = hasWriteAccess(filePath, decryptedBucket);
boolean hasWriteAccess = proxy.getAccessService().hasWriteAccess(bucket, filePath, context);
boolean hasAccess = checkFullAccess ? hasWriteAccess : hasReadAccess || hasWriteAccess;

if (!hasAccess) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,12 @@ void handleRequestBody(Buffer requestBody) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);

try {
ProxyUtil.collectAttachedFiles(tree, proxyApiKeyData);
ProxyUtil.collectAttachedFiles(tree, proxyApiKeyData, link -> proxy.getAccessService().hasWriteAccess(link, context));
} catch (HttpException e) {
respond(e.getStatus(), e.getMessage());
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
context.getTraceId(), context.getSpanId(), e.getMessage());
return;
} catch (Throwable e) {
context.respond(HttpStatus.BAD_REQUEST);
log.warn("Can't collect attached files. Trace: {}. Span: {}. Error: {}",
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/com/epam/aidial/core/security/AccessService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.epam.aidial.core.security;

import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.storage.BlobStorageUtil;
import com.epam.aidial.core.util.UrlUtil;
import lombok.AllArgsConstructor;

@AllArgsConstructor
public class AccessService {

private final EncryptionService encryptionService;

public boolean hasWriteAccess(String bucket, String filePath, ProxyContext context) {
String urlDecodedBucket = UrlUtil.decodePath(bucket);
String decryptedBucket = encryptionService.decrypt(urlDecodedBucket);
String expectedUserBucket = BlobStorageUtil.buildUserBucket(context);
if (expectedUserBucket.equals(decryptedBucket)) {
return true;
}
String expectedAppDataBucket = BlobStorageUtil.buildAppDataBucket(context);
if (expectedAppDataBucket != null && expectedAppDataBucket.equals(decryptedBucket)) {
return filePath.startsWith(BlobStorageUtil.APPDATA_PATTERN.formatted(UrlUtil.encodePath(context.getSourceDeployment())));
}
return false;
}

public boolean hasWriteAccess(String url, ProxyContext context) {
if (url == null) {
return false;
}
int index = url.indexOf(BlobStorageUtil.PATH_SEPARATOR);
if (index < 0) {
return false;
}
String bucket = url.substring(0, index);
String path = url.substring(index + 1);
return hasWriteAccess(bucket, path, context);
}
}
12 changes: 10 additions & 2 deletions src/main/java/com/epam/aidial/core/util/ProxyUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.epam.aidial.core.Proxy;
import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.security.AccessService;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.MapperFeature;
Expand All @@ -15,8 +16,11 @@
import lombok.experimental.UtilityClass;

import java.util.Map;
import java.util.function.Function;
import javax.annotation.Nullable;

import static com.epam.aidial.core.util.HttpStatus.FORBIDDEN;

@UtilityClass
public class ProxyUtil {

Expand Down Expand Up @@ -83,7 +87,7 @@ private static int contentLength(MultiMap header, int defaultValue) {
return defaultValue;
}

public static void collectAttachedFiles(ObjectNode tree, ApiKeyData apiKeyData) {
public static void collectAttachedFiles(ObjectNode tree, ApiKeyData apiKeyData, Function<String, Boolean> checkAccessFn) throws Exception {
ArrayNode messages = (ArrayNode) tree.get("messages");
if (messages == null) {
return;
Expand All @@ -102,7 +106,11 @@ public static void collectAttachedFiles(ObjectNode tree, ApiKeyData apiKeyData)
JsonNode attachment = attachments.get(j);
JsonNode url = attachment.get("url");
if (url != null) {
apiKeyData.getAttachedFiles().add(url.textValue());
if (checkAccessFn.apply(url.textValue())) {
apiKeyData.getAttachedFiles().add(url.textValue());
} else {
throw new HttpException(FORBIDDEN, "Access denied to the file %s");
}
}
}
}
Expand Down
95 changes: 90 additions & 5 deletions src/test/java/com/epam/aidial/core/util/ProxyUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.Set;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class ProxyUtilTest {

@Test
public void testCollectAttachedFiles_ChatRequest() throws IOException {
public void testCollectAttachedFiles_ChatRequest() throws Exception {
String content = """
{
"modelId": "model",
Expand Down Expand Up @@ -94,13 +94,13 @@ public void testCollectAttachedFiles_ChatRequest() throws IOException {
""";
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content.getBytes());
ApiKeyData apiKeyData = new ApiKeyData();
ProxyUtil.collectAttachedFiles(tree, apiKeyData);
ProxyUtil.collectAttachedFiles(tree, apiKeyData, link -> true);

assertEquals(Set.of("b1/Dockerfile", "b1/LICENSE"), apiKeyData.getAttachedFiles());
}

@Test
public void testCollectAttachedFiles_EmbeddingRequest() throws IOException {
public void testCollectAttachedFiles_EmbeddingRequest() throws Exception {
String content = """
{
"input": "some input",
Expand All @@ -109,8 +109,93 @@ public void testCollectAttachedFiles_EmbeddingRequest() throws IOException {
""";
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content.getBytes());
ApiKeyData apiKeyData = new ApiKeyData();
ProxyUtil.collectAttachedFiles(tree, apiKeyData);
ProxyUtil.collectAttachedFiles(tree, apiKeyData, link -> true);

assertTrue(apiKeyData.getAttachedFiles().isEmpty());
}

@Test
public void testCollectAttachedFiles_ChatRequest_NoPermissions() throws Exception {
String content = """
{
"modelId": "model",
"messages": [
{
"content": "test",
"role": "user",
"custom_content": {
}
},
{
"content": "I'm sorry, but your message is unclear. Could you please provide more details or context?",
"role": "assistant"
},
{
"content": "what file is?",
"role": "user",
"custom_content": {
"attachments": [
{
"type": "application/octet-stream",
"title": "Dockerfile",
"url": "b1/Dockerfile"
}
]
}
},
{
"content": "The file you provided is a Dockerfile.",
"role": "assistant",
"custom_content": {
"attachments": [
{
"index": 0,
"type": "text/markdown",
"title": "[1] 'Dockerfile'",
"data": "FROM gradle:8.2.0",
"reference_url": "b1/Dockerfile"
},
{
"index": 1,
"type": "text/markdown",
"title": "[2] 'Dockerfile'",
"data": "* /app/config/ RUN mkdir /app/log && chown -R appuser:appuser /app",
"reference_url": "b1/Dockerfile"
},
{
"index": 2,
"type": "text/markdown",
"title": "[3] 'Dockerfile'",
"data": "USER appuser",
"reference_url": "b1/Dockerfile"
}
]
}
},
{
"content": "Compare these files?",
"role": "user",
"custom_content": {
"attachments": [
{
"type": "application/octet-stream",
"title": "LICENSE",
"url": "b1/LICENSE"
},
{
"type": "binary/octet-stream",
"title": "Dockerfile",
"url": "b1/Dockerfile"
}
]
}
}
],
"id": "id"
}
""";
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(content.getBytes());
ApiKeyData apiKeyData = new ApiKeyData();
assertThrows(HttpException.class, () -> ProxyUtil.collectAttachedFiles(tree, apiKeyData, link -> false));
}
}

0 comments on commit aa500cf

Please sign in to comment.