Skip to content

Commit

Permalink
refactor: add ProjectProxy in batch searches to limit access to proje…
Browse files Browse the repository at this point in the history
…ct name only #1221

Co-authored-by: Bruno Thomas <bthomas@icij.org>
  • Loading branch information
caro3801 and bamthomas committed Oct 20, 2023
1 parent 3211437 commit ca1fdfe
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import org.icij.datashare.Entity;
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.batch.BatchSearch;

import org.icij.datashare.batch.SearchException;
import org.icij.datashare.function.TerFunction;
import org.icij.datashare.monitoring.Monitorable;
import org.icij.datashare.text.Document;
import org.icij.datashare.text.Project;
import org.icij.datashare.text.ProjectProxy;
import org.icij.datashare.text.indexing.Indexer;
import org.icij.datashare.time.DatashareTime;
import org.icij.datashare.user.User;
Expand All @@ -21,18 +20,17 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

import static java.lang.Integer.min;
import static java.lang.Integer.parseInt;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toList;
import static org.icij.datashare.cli.DatashareCliOptions.*;
import static org.icij.datashare.text.ProjectProxy.asCommaConcatNames;

public class BatchSearchRunner implements Callable<Integer>, Monitorable, UserTask {
private final Logger logger = LoggerFactory.getLogger(getClass());
Expand Down Expand Up @@ -80,14 +78,14 @@ public Integer call() throws SearchException {
callThread = Thread.currentThread();
callWaiterLatch.countDown(); // for tests
logger.info("running {} queries for batch search {} on projects {} with throttle {}ms and scroll size of {}",
batchSearch.queries.size(), batchSearch.uuid, batchSearch.projects.stream().map(Project::getId).collect(Collectors.joining(", "))
batchSearch.queries.size(), batchSearch.uuid, asCommaConcatNames(batchSearch.projects)
, throttleMs, scrollSize);

String query = null;
try {
for (String s : batchSearch.queries.keySet()) {
query = s;
Indexer.Searcher searcher = indexer.search(batchSearch.projects.stream().map(Project::getId).collect(toList()), Document.class).
Indexer.Searcher searcher = indexer.search(batchSearch.projects.stream().map(ProjectProxy::getId).collect(toList()), Document.class).
with(query, batchSearch.fuzziness, batchSearch.phraseMatches).
withFieldValues("contentType", batchSearch.fileTypes.toArray(new String[]{})).
withFieldValues("tags", batchSearch.tags.toArray(new String[]{})).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.icij.datashare.db.JooqBatchSearchRepository;
import org.icij.datashare.session.DatashareUser;
import org.icij.datashare.text.Project;
import org.icij.datashare.text.ProjectProxy;
import org.icij.datashare.user.User;
import org.icij.datashare.utils.PayloadFormatter;

import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
Expand Down Expand Up @@ -324,8 +324,8 @@ public Payload deleteSearches(Context context) {
return new Payload(204);
}

private String docUrl(String uri, List<Project> projects, String documentId, String rootId) {
return format("%s/#/d/%s/%s/%s", uri, projects.stream().map(Project::getId).collect(Collectors.joining(",")), documentId, rootId);
private String docUrl(String uri, List<ProjectProxy> projects, String documentId, String rootId) {
return format("%s/#/d/%s/%s/%s", uri, projects.stream().map(ProjectProxy::getId).collect(Collectors.joining(",")), documentId, rootId);
}

private String dirname(Path path) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import org.icij.datashare.tasks.DocumentCollectionFactory;
import org.icij.datashare.text.Project;
import org.icij.datashare.text.indexing.Indexer;
import org.icij.datashare.utils.IndexAccessVerifier;
import org.icij.datashare.utils.DataDirVerifier;
import org.icij.datashare.utils.IndexAccessVerifier;
import org.icij.datashare.utils.ModeVerifier;
import org.icij.datashare.utils.PayloadFormatter;
import org.icij.extract.queue.DocumentQueue;
Expand All @@ -30,7 +30,8 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.*;
import java.util.List;
import java.util.Objects;

import static net.codestory.http.errors.NotFoundException.notFoundIfNull;
import static net.codestory.http.payload.Payload.ok;
Expand Down Expand Up @@ -58,12 +59,12 @@ public ProjectResource(Repository repository, Indexer indexer, PropertiesProvide
this.documentCollectionFactory = documentCollectionFactory;
}

String[] getUserProjectIds(DatashareUser user) {
return user.getProjectNames().toArray(String[]::new);
List<String> getUserProjectIds(DatashareUser user) {
return user.getProjectNames();
}

List<Project> getUserProjects(DatashareUser user) {
String[] projectIds = this.getUserProjectIds(user);
List<String> projectIds = this.getUserProjectIds(user);
return repository.getProjects(projectIds);
}

Expand Down Expand Up @@ -193,12 +194,16 @@ public Project getProject(String id, Context context) {
@ApiResponse(responseCode = "403", description = "if project download is not allowed")
@Get("/isDownloadAllowed/:id")
public Payload isDownloadAllowed(String id, Context context) {
List<String> projects = ((DatashareUser) context.currentUser()).getProjectNames();
String projectId = projects.stream()
List<String> projectIds = ((DatashareUser) context.currentUser()).getProjectNames();
String retrievedProjectId = projectIds.stream()
.filter(i -> i.equals(id))
.findAny()
.orElse(null);
Project project = repository.getProject(projectId);

if (retrievedProjectId == null){
return ok(); // unknown is allowed
}
Project project = repository.getProject(retrievedProjectId);

if (project != null && !isAllowed(project, context.request().clientAddress())) {
return PayloadFormatter.error("Download not allowed", HttpStatus.FORBIDDEN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ public String getRoot(Context context) throws IOException {
} else {
content = new String(Files.readAllBytes(index), Charset.defaultCharset());
}
List<String> projects = context.currentUser() == null ? new LinkedList<>() : ((DatashareUser)context.currentUser()).getProjectNames();
List<String> projectNames = context.currentUser() == null ? new LinkedList<>() : ((DatashareUser)context.currentUser()).getProjectNames();
if (propertiesProvider.get(PLUGINS_DIR).isPresent()) {
ExtensionService extensionService = propertiesProvider.get(EXTENSIONS_DIR).isPresent() ? new ExtensionService(propertiesProvider): null;
return new PluginService(propertiesProvider, extensionService)
.addPlugins(content, projects);
.addPlugins(content, projectNames);
}
return content;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ public class UserResource {
private List<Project> getDatashareUserProjects (DatashareUser datashareUser) {
List<String> projectNames = datashareUser.getProjectNames();
List<Project> projects = datashareUser.getProjects();
List<Project> repositoryProjects = repository.getProjects(projectNames.toArray(new String[0]));
return projects.stream().map(project -> {
return repositoryProjects.stream()
.filter(p -> p.name.equals(project.name))
.findFirst()
.orElse(project);
}).collect(Collectors.toList());
List<Project> repositoryProjects = repository.getProjects(projectNames);
return projects.stream().map(project -> repositoryProjects.stream()
.filter(p -> p.name.equals(project.name))
.findFirst()
.orElse(project)).collect(Collectors.toList());
}

@Inject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.mockito.Mock;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.fest.assertions.Assertions.assertThat;
import static org.icij.datashare.CollectionUtils.asSet;
import static org.icij.datashare.text.Project.project;
import static org.icij.datashare.text.ProjectProxy.proxy;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -223,7 +224,7 @@ public void test_get_batch_searches_json() {

@Test
public void test_get_batch_searches_records_json_paginated() {
List<BatchSearchRecord> batchSearches = IntStream.range(0, 10).mapToObj(i -> new BatchSearchRecord(singletonList(project("local-datashare")), "name" + i, "description" + i, 2, new Date())).collect(toList());
List<BatchSearchRecord> batchSearches = IntStream.range(0, 10).mapToObj(i -> new BatchSearchRecord(singletonList(proxy("local-datashare")), "name" + i, "description" + i, 2, new Date())).collect(toList());
when(batchSearchRepository.getRecords(User.local(), singletonList("local-datashare"), WebQueryBuilder.createWebQuery().queryAll().withRange(0,2).build())).thenReturn(batchSearches.subList(0, 2));
when(batchSearchRepository.getTotal(User.local(), singletonList("local-datashare"), WebQueryBuilder.createWebQuery().queryAll().withRange(0,2).build())).thenReturn(batchSearches.size());
when(batchSearchRepository.getRecords(User.local(), singletonList("local-datashare"),WebQueryBuilder.createWebQuery().queryAll().withRange(4,3).build())).thenReturn(batchSearches.subList(5, 8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import org.icij.datashare.db.tables.records.BatchSearchQueryRecord;
import org.icij.datashare.db.tables.records.BatchSearchResultRecord;
import org.icij.datashare.text.Document;
import org.icij.datashare.text.Project;
import org.icij.datashare.text.ProjectProxy;
import org.icij.datashare.user.User;
import org.jooq.*;
import org.jooq.impl.DSL;
Expand All @@ -30,7 +30,8 @@
import static org.icij.datashare.db.tables.BatchSearch.BATCH_SEARCH;
import static org.icij.datashare.db.tables.BatchSearchQuery.BATCH_SEARCH_QUERY;
import static org.icij.datashare.db.tables.BatchSearchResult.BATCH_SEARCH_RESULT;
import static org.icij.datashare.text.Project.project;
import static org.icij.datashare.text.ProjectProxy.asNameList;
import static org.icij.datashare.text.ProjectProxy.proxy;
import static org.jooq.impl.DSL.*;

public class JooqBatchSearchRepository implements BatchSearchRepository {
Expand Down Expand Up @@ -185,12 +186,12 @@ public List<BatchSearchRecord> getRecords(User user, List<String> projectsIds) {

@Override
public List<BatchSearchRecord> getRecords(User user, List<String> projectsIds, WebQuery webQuery) {
try(DSLContext context = DSL.using(dataSource, dialect)) {
try(DSLContext context = using(dataSource, dialect)) {
cacheNbQueries(webQuery, context);

SelectConditionStep<Record12<String, String, String, String, Timestamp, String, Integer, Integer, String, String, String, Integer>> query = createBatchSearchRecordWithQueriesSelectStatement(context)
.where(BATCH_SEARCH.USER_ID.eq(user.id).or(BATCH_SEARCH.PUBLISHED.greaterThan(0)));
List<String> filteredProjects = webQuery.hasFilteredProjects() ? webQuery.project : projectsIds;
List<String> filteredProjects = webQuery.hasFilteredProjects() ? asNameList(webQuery.project) : projectsIds;
addFilterToSelectCondition(webQuery, query);
if (webQuery.isSorted()) {
query.orderBy(field(webQuery.sort + " " + webQuery.order));
Expand Down Expand Up @@ -435,7 +436,7 @@ private BatchSearch createBatchSearchFrom(final Record record) {
Integer nb_queries = query_results == null ? 0: query_results;
boolean phraseMatches= record.get(BATCH_SEARCH.PHRASE_MATCHES) != 0;
return new BatchSearch(record.get(BATCH_SEARCH.UUID).trim(),
singletonList(project(record.get(BATCH_SEARCH_PROJECT.PRJ_ID))),
singletonList(proxy(record.get(BATCH_SEARCH_PROJECT.PRJ_ID))),
record.getValue(BATCH_SEARCH.NAME),
record.getValue(BATCH_SEARCH.DESCRIPTION),
new LinkedHashMap<>() {{
Expand All @@ -456,14 +457,13 @@ private BatchSearch createBatchSearchFrom(final Record record) {
}

private BatchSearch createBatchSearchWithoutQueries(final Record record) {
Integer nb_queries = record.get("nb_queries", Integer.class);
String projects = (String) record.get("projects");
boolean phraseMatches= record.get(BATCH_SEARCH.PHRASE_MATCHES) != 0;
return new BatchSearch(record.get(BATCH_SEARCH.UUID).trim(),
getProjects(projects),
record.getValue(BATCH_SEARCH.NAME),
record.getValue(BATCH_SEARCH.DESCRIPTION),
nb_queries,
record.get(BATCH_SEARCH.NB_QUERIES),
Date.from(record.get(BATCH_SEARCH.BATCH_DATE).toInstant()),
State.valueOf(record.get(BATCH_SEARCH.STATE)),
new User(record.get(BATCH_SEARCH.USER_ID)),
Expand Down Expand Up @@ -500,8 +500,8 @@ private BatchSearchRecord createBatchSearchRecordFrom(final Record record) {
batchSearch.getErrorQuery());
}

private static List<Project> getProjects(String prj) {
return prj == null || prj.isEmpty() ? null : stream(prj.split(LIST_SEPARATOR)).sorted().map(Project::project).collect(toList());
private static List<ProjectProxy> getProjects(String prj) {
return prj == null || prj.isEmpty() ? null : stream(prj.split(LIST_SEPARATOR)).sorted().map(ProjectProxy::new).collect(toList());
}

private SearchResult createSearchResult(final User actualUser, final Record record) {
Expand All @@ -513,7 +513,7 @@ private SearchResult createSearchResult(final User actualUser, final Record reco
}
Timestamp creationDate = record.get(BATCH_SEARCH_RESULT.CREATION_DATE);
return new SearchResult(record.get(BATCH_SEARCH_RESULT.QUERY),
prj == null || prj.isEmpty() ? null : project(prj),
prj == null || prj.isEmpty() ? null : proxy(prj),
record.get(BATCH_SEARCH_RESULT.DOC_ID),
record.getValue(BATCH_SEARCH_RESULT.ROOT_ID),
Paths.get(record.getValue(BATCH_SEARCH_RESULT.DOC_PATH)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ public List<Project> getProjects() {
}

@Override
public List<Project> getProjects(String[] projectIds) {
public List<Project> getProjects(List<String> projectIds) {
return DSL.using(connectionProvider, dialect).selectFrom(PROJECT).
where(PROJECT.ID.in(projectIds)).
stream().map(this::createProjectFrom).collect(toList());
Expand Down Expand Up @@ -402,6 +402,8 @@ public boolean save(Note note) {

@Override
public boolean save(Project project) {
Timestamp projectCreationDate = project.creationDate == null ? null: new Timestamp(project.creationDate.getTime());
Timestamp projectUpdateDate = project.updateDate == null ? null: new Timestamp(project.updateDate.getTime());
return DSL.using(connectionProvider, dialect).
insertInto(
PROJECT, PROJECT.ID, PROJECT.LABEL, PROJECT.DESCRIPTION, PROJECT.PATH, PROJECT.SOURCE_URL,
Expand All @@ -412,7 +414,7 @@ public boolean save(Project project) {
project.name, project.label, project.description, project.sourcePath.toString(), project.sourceUrl,
project.maintainerName, project.publisherName, project.logoUrl,
project.allowFromMask,
new Timestamp(project.creationDate.getTime()), new Timestamp(project.updateDate.getTime())).
projectCreationDate, projectUpdateDate).
onConflict(PROJECT.ID).
doUpdate().
set(PROJECT.LABEL, project.label).
Expand All @@ -422,7 +424,7 @@ public boolean save(Project project) {
set(PROJECT.PUBLISHER_NAME, project.publisherName).
set(PROJECT.LOGO_URL, project.logoUrl).
set(PROJECT.ALLOW_FROM_MASK, project.allowFromMask).
set(PROJECT.UPDATE_DATE, new Timestamp(project.updateDate.getTime())).
set(PROJECT.UPDATE_DATE, projectUpdateDate).
execute() > 0;
}

Expand Down Expand Up @@ -499,6 +501,9 @@ private Project createProjectFrom(ProjectRecord record) {
if (record == null) {
return null;
}
Timestamp projectCreationDate = record.getCreationDate() == null ? null: new Timestamp(record.getCreationDate().getTime());
Timestamp projectUpdateDate = record.getUpdateDate() == null ? null: new Timestamp(record.getCreationDate().getTime());

return new Project(record.getId(),
record.getLabel(),
record.getDescription(),
Expand All @@ -508,8 +513,8 @@ private Project createProjectFrom(ProjectRecord record) {
record.getPublisherName(),
record.getLogoUrl(),
record.getAllowFromMask(),
new Timestamp(record.getCreationDate().getTime()),
new Timestamp(record.getUpdateDate().getTime()));
projectCreationDate,
projectUpdateDate);
}

private NamedEntity createFrom(NamedEntityRecord record) {
Expand Down
Loading

0 comments on commit ca1fdfe

Please sign in to comment.