diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 80ed2464..a229930f 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -41,7 +41,16 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 - # Initializes the CodeQL tools for scanning. + - uses: actions/setup-java@v3 + if: matrix.language == 'java' + with: + java-version: 17 + distribution: temurin + + - uses: gradle/gradle-build-action@v2 + if: matrix.language == 'java' + + # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL uses: github/codeql-action/init@v2 with: diff --git a/.gitignore b/.gitignore index 1ce04ed0..2b4d22c2 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,5 @@ bin/ docker/prod -*__pycache__ \ No newline at end of file +*__pycache__ +google-credentials.json diff --git a/src/integrationTest/resources/application.properties b/src/integrationTest/resources/application.properties index 639f9b9e..bcd3bef2 100644 --- a/src/integrationTest/resources/application.properties +++ b/src/integrationTest/resources/application.properties @@ -96,3 +96,9 @@ security.radar.managementportal.url=http://localhost:8081 # Github Authentication security.github.client.token= +security.github.client.timeout=10 +# max content size 1 MB +security.github.client.maxContentLength=1000000 +security.github.cache.size=10000 +security.github.cache.duration=3600 +security.github.cache.retryDuration=60 diff --git a/src/integrationTest/resources/docker/docker-compose.yml b/src/integrationTest/resources/docker/docker-compose.yml index 916e89b4..dfb5159e 100644 --- a/src/integrationTest/resources/docker/docker-compose.yml +++ b/src/integrationTest/resources/docker/docker-compose.yml @@ -64,7 +64,7 @@ services: # Management Portal # #---------------------------------------------------------------------------# managementportal: - image: radarbase/management-portal:0.5.6 + image: radarbase/management-portal:2.0.0 ports: - "8081:8081" environment: diff --git a/src/main/java/org/radarbase/appserver/controller/GithubEndpoint.java b/src/main/java/org/radarbase/appserver/controller/GithubEndpoint.java index fc164158..33ff4b9e 100644 --- a/src/main/java/org/radarbase/appserver/controller/GithubEndpoint.java +++ b/src/main/java/org/radarbase/appserver/controller/GithubEndpoint.java @@ -22,7 +22,7 @@ package org.radarbase.appserver.controller; import org.radarbase.appserver.config.AuthConfig; -import org.radarbase.appserver.service.GithubClient; +import org.radarbase.appserver.service.GithubService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; @@ -31,17 +31,16 @@ import org.springframework.web.bind.annotation.RestController; import radar.spring.auth.common.Authorized; -import java.io.IOException; import java.net.MalformedURLException; @RestController public class GithubEndpoint { - private transient GithubClient githubClient; + private final transient GithubService githubService; @Autowired - public GithubEndpoint(GithubClient githubClient) { - this.githubClient = githubClient; + public GithubEndpoint(GithubService githubService) { + this.githubService = githubService; } @Authorized( @@ -51,13 +50,13 @@ public GithubEndpoint(GithubClient githubClient) { PathsUtil.GITHUB_PATH + "/" + PathsUtil.GITHUB_CONTENT_PATH) - public ResponseEntity getGithubContent(@RequestParam() String url + public ResponseEntity getGithubContent(@RequestParam() String url ) { try { - return ResponseEntity.ok().body(this.githubClient.getGithubContent(url)); + return ResponseEntity.ok().body(this.githubService.getGithubContent(url)); } catch (MalformedURLException e) { return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(e.getMessage()); - } catch (IOException | InterruptedException e) { + } catch (Exception e) { return ResponseEntity.status(HttpStatus.BAD_GATEWAY).body("Error getting content from Github."); } } diff --git a/src/main/java/org/radarbase/appserver/entity/User.java b/src/main/java/org/radarbase/appserver/entity/User.java index 226b3dea..f59f95a5 100644 --- a/src/main/java/org/radarbase/appserver/entity/User.java +++ b/src/main/java/org/radarbase/appserver/entity/User.java @@ -22,16 +22,23 @@ package org.radarbase.appserver.entity; import com.fasterxml.jackson.annotation.JsonIgnore; - -import java.io.Serializable; -import java.time.Instant; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import jakarta.persistence.*; +import jakarta.persistence.CascadeType; +import jakarta.persistence.CollectionTable; +import jakarta.persistence.Column; +import jakarta.persistence.ElementCollection; +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.JoinColumn; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.MapKeyColumn; +import jakarta.persistence.OneToOne; +import jakarta.persistence.Table; +import jakarta.persistence.UniqueConstraint; import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotNull; - import lombok.Getter; import lombok.ToString; import org.hibernate.annotations.OnDelete; @@ -39,6 +46,12 @@ import org.radarbase.appserver.dto.fcm.FcmUserDto; import org.springframework.lang.Nullable; +import java.io.Serializable; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + /** * {@link Entity} for persisting users. The corresponding DTO is {@link FcmUserDto}. A {@link * Project} can have multiple {@link User} (Many-to-One). @@ -97,7 +110,7 @@ public class User extends AuditModel implements Serializable { @CollectionTable(name = "attributes_map") @MapKeyColumn(name = "key", nullable = true) @Column(name = "value") - private Map attributes = new HashMap(); + private Map attributes = new HashMap<>(); public User setSubjectId(String subjectId) { this.subjectId = subjectId; diff --git a/src/main/java/org/radarbase/appserver/service/GithubClient.java b/src/main/java/org/radarbase/appserver/service/GithubClient.java index 2c49a3a1..925151fa 100644 --- a/src/main/java/org/radarbase/appserver/service/GithubClient.java +++ b/src/main/java/org/radarbase/appserver/service/GithubClient.java @@ -21,7 +21,7 @@ package org.radarbase.appserver.service; -import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.annotation.Nonnull; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -33,6 +33,7 @@ import org.springframework.web.server.ResponseStatusException; import java.io.IOException; +import java.io.InputStream; import java.net.MalformedURLException; import java.net.URI; import java.net.http.HttpClient; @@ -44,25 +45,29 @@ @Component @Scope(value = ConfigurableBeanFactory.SCOPE_SINGLETON) public class GithubClient { - private static final String GITHUB_API_URI = "api.github.com"; private static final String GITHUB_API_ACCEPT_HEADER = "application/vnd.github.v3+json"; - private static final String LOCATION_HEADER = "location"; - private final transient ObjectMapper objectMapper; private final transient HttpClient client; - @Value("${security.github.client.token}") - private transient String githubToken; + @Nonnull + private final transient String authorizationHeader; + + private transient final Duration httpTimeout; + + @Value("${security.github.client.maxContentLength:1000000}") + private transient int maxContentLength; @SneakyThrows @Autowired - public GithubClient(ObjectMapper objectMapper) { - this.objectMapper = objectMapper; - client = HttpClient.newBuilder().followRedirects(HttpClient.Redirect.NORMAL).connectTimeout(Duration.ofSeconds(10)).build(); - } - - private static boolean isSuccessfulResponse(HttpResponse response) { - return response.statusCode() >= 200 && response.statusCode() < 300; + public GithubClient( + @Value("${security.github.client.timeout:10}") int httpTimeout, + @Value("${security.github.client.token:}") String githubToken) { + this.authorizationHeader = githubToken != null ? "Bearer " + githubToken.trim() : ""; + this.httpTimeout = Duration.ofSeconds(httpTimeout); + this.client = HttpClient.newBuilder() + .followRedirects(HttpClient.Redirect.NORMAL) + .connectTimeout(this.httpTimeout) + .build(); } public String getGithubContent(String url) throws IOException, InterruptedException { @@ -70,9 +75,16 @@ public String getGithubContent(String url) throws IOException, InterruptedExcept if (!this.isValidGithubUri(uri)) { throw new MalformedURLException("Invalid Github url."); } - HttpResponse response = client.send(getRequest(uri), HttpResponse.BodyHandlers.ofString()); - if (isSuccessfulResponse(response)) { - return response.body().toString(); + HttpResponse response = makeRequest(uri); + + if (response.statusCode() >= 200 && response.statusCode() < 300) { + checkContentLengthHeader(response); + + try (InputStream inputStream = response.body()) { + byte[] bytes = inputStream.readNBytes(maxContentLength + 1); + checkContentLength(bytes.length); + return new String(bytes); + } } else { log.error("Error getting Github content from URL {} : {}", url, response); throw new ResponseStatusException( @@ -80,17 +92,47 @@ public String getGithubContent(String url) throws IOException, InterruptedExcept } } + private HttpResponse makeRequest(URI uri) throws InterruptedException { + try { + return client.send(getRequest(uri), HttpResponse.BodyHandlers.ofInputStream()); + } catch (IOException ex) { + log.error("Failed to retrieve data from github: {}", ex.toString()); + throw new ResponseStatusException(HttpStatus.BAD_GATEWAY, "Github responded with an error."); + } + } + + private void checkContentLengthHeader(HttpResponse response) { + response.headers().firstValue("Content-Length") + .map((l) -> { + try { + return Integer.valueOf(l); + } catch (NumberFormatException ex) { + return null; + } + }) + .ifPresent(this::checkContentLength); + } + + private void checkContentLength(int contentLength) { + if (contentLength > maxContentLength) { + throw new ResponseStatusException( + HttpStatus.BAD_REQUEST, "Github content is too large"); + } + } + public boolean isValidGithubUri(URI uri) { - return uri.getHost().contains(GITHUB_API_URI); + return uri.getHost().equalsIgnoreCase(GITHUB_API_URI) + && uri.getScheme().equalsIgnoreCase("https") + && (uri.getPort() == -1 || uri.getPort() == 443); } private HttpRequest getRequest(URI uri) { HttpRequest.Builder request = HttpRequest.newBuilder(uri) .header("Accept", GITHUB_API_ACCEPT_HEADER) .GET() - .timeout(Duration.ofSeconds(10)); - if (githubToken != null && !githubToken.isEmpty()) { - request.header("Authorization", "Bearer " + githubToken); + .timeout(httpTimeout); + if (!authorizationHeader.isEmpty()) { + request.header("Authorization", authorizationHeader); } return request.build(); } diff --git a/src/main/java/org/radarbase/appserver/service/GithubService.java b/src/main/java/org/radarbase/appserver/service/GithubService.java new file mode 100644 index 00000000..644e9496 --- /dev/null +++ b/src/main/java/org/radarbase/appserver/service/GithubService.java @@ -0,0 +1,36 @@ +package org.radarbase.appserver.service; + +import org.radarbase.appserver.util.CachedFunction; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.context.annotation.Scope; +import org.springframework.stereotype.Component; + +import java.time.Duration; + +@Component +@Scope(value = ConfigurableBeanFactory.SCOPE_SINGLETON) +public class GithubService { + + private final transient CachedFunction cachedGetContent; + + @Autowired + public GithubService( + GithubClient githubClient, + @Value("${security.github.cache.duration:3600}") + int cacheTime, + @Value("${security.github.cache.retryDuration:60}") + int retryTime, + @Value("${security.github.cache.size:10000}") + int maxSize) { + this.cachedGetContent = new CachedFunction<>(githubClient::getGithubContent, + Duration.ofSeconds(cacheTime), + Duration.ofSeconds(retryTime), + maxSize); + } + + public String getGithubContent(String url) throws Exception { + return this.cachedGetContent.applyWithException(url); + } +} diff --git a/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/DefaultProtocolGenerator.java b/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/DefaultProtocolGenerator.java index 4aefaae8..57584cb0 100644 --- a/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/DefaultProtocolGenerator.java +++ b/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/DefaultProtocolGenerator.java @@ -21,22 +21,20 @@ package org.radarbase.appserver.service.questionnaire.protocol; -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; - import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.radarbase.appserver.dto.protocol.Protocol; -import org.radarbase.appserver.entity.User; import org.radarbase.appserver.util.CachedMap; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.context.annotation.Scope; import org.springframework.stereotype.Service; +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.NoSuchElementException; + /** * @author yatharthranjan * @see aRMT Protocols @@ -90,7 +88,7 @@ public Protocol getProtocol(String projectId) throws IOException { return cachedProjectProtocolMap.get(projectId); } catch (IOException ex) { log.warn( - "Cannot retrieve Protocols for project {} : {}, Using cached values.", projectId, ex); + "Cannot retrieve Protocols for project {} : {}, Using cached values.", projectId, ex.toString()); return cachedProjectProtocolMap.get(true).get(projectId); } } @@ -115,7 +113,7 @@ public Protocol getProtocolForSubject(String subjectId) { return protocol; } catch (IOException ex) { log.warn( - "Cannot retrieve Protocols for subject {} : {}, Using cached values.", subjectId, ex); + "Cannot retrieve Protocols for subject {} : {}, Using cached values.", subjectId, ex.toString()); return cachedProtocolMap.getCache().get(subjectId); } catch(NoSuchElementException ex) { log.warn("Subject does not exist in map. Fetching.."); diff --git a/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/GithubProtocolFetcherStrategy.java b/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/GithubProtocolFetcherStrategy.java index 8779728e..42030b58 100644 --- a/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/GithubProtocolFetcherStrategy.java +++ b/src/main/java/org/radarbase/appserver/service/questionnaire/protocol/GithubProtocolFetcherStrategy.java @@ -25,16 +25,6 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.*; -import java.util.stream.Collectors; - import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.Maps; import lombok.SneakyThrows; @@ -43,7 +33,6 @@ import org.radarbase.appserver.dto.protocol.GithubContent; import org.radarbase.appserver.dto.protocol.Protocol; import org.radarbase.appserver.dto.protocol.ProtocolCacheEntry; -import org.radarbase.appserver.entity.Project; import org.radarbase.appserver.entity.User; import org.radarbase.appserver.repository.ProjectRepository; import org.radarbase.appserver.repository.UserRepository; @@ -53,10 +42,20 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.context.annotation.Scope; -import org.springframework.http.HttpStatus; import org.springframework.stereotype.Component; import org.springframework.web.server.ResponseStatusException; +import java.io.IOException; +import java.net.URI; +import java.time.Duration; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + @Slf4j @Component @Scope(value = ConfigurableBeanFactory.SCOPE_SINGLETON) @@ -66,7 +65,6 @@ public class GithubProtocolFetcherStrategy implements ProtocolFetcherStrategy { private final transient ProjectRepository projectRepository; private static final String GITHUB_API_URI = "https://api.github.com/repos/"; - private static final String GITHUB_API_ACCEPT_HEADER = "application/vnd.github.v3+json"; private final transient String protocolRepo; private final transient String protocolFileName; private final transient String protocolBranch; @@ -74,13 +72,8 @@ public class GithubProtocolFetcherStrategy implements ProtocolFetcherStrategy { private final transient ObjectMapper localMapper; // Keeps a cache of github URI's associated with protocol for each project private final transient CachedMap projectProtocolUriMap; - private final transient HttpClient client; - private final transient GithubClient githubClient; - @Value("${security.github.client.token}") - private transient String githubToken; - @SneakyThrows @Autowired public GithubProtocolFetcherStrategy( @@ -102,104 +95,100 @@ public GithubProtocolFetcherStrategy( this.protocolFileName = protocolFileName; this.protocolBranch = protocolBranch; projectProtocolUriMap = - new CachedMap<>(this::getProtocolDirectories, Duration.ofHours(3), Duration.ofHours(4)); + new CachedMap<>(this::getProtocolDirectories, Duration.ofHours(3), Duration.ofMinutes(4)); this.objectMapper = objectMapper; this.localMapper = this.objectMapper.copy(); this.localMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - client = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(10)).build(); this.userRepository = userRepository; this.projectRepository = projectRepository; this.githubClient = githubClient; } - private static boolean isSuccessfulResponse(HttpResponse response) { - return response.statusCode() >= 200 && response.statusCode() < 300; - } - @Override - public synchronized Map fetchProtocols() throws IOException { - Map subjectProtocolMap = new HashMap<>(); + public synchronized Map fetchProtocols() { List users = this.userRepository.findAll(); - Map protocolUriMap; - try { - protocolUriMap = projectProtocolUriMap.get(); - } catch (IOException e) { - // Failed to get the Uri Map. try using the cached value - protocolUriMap = projectProtocolUriMap.getCache(); - } - - if (protocolUriMap == null) { - return subjectProtocolMap; + Set protocolPaths = getProtocolPaths(); + if (protocolPaths == null) { + return Map.of(); } - Set protocolPaths = protocolUriMap.keySet(); - subjectProtocolMap = users.parallelStream() - .map(u -> { - ProtocolCacheEntry entry = this.fetchProtocolForSingleUser(u, u.getProject().getProjectId(), protocolPaths); - return entry; - }) + Map subjectProtocolMap = users.parallelStream() + .map(u -> this.fetchProtocolForSingleUser(u, u.getProject().getProjectId(), protocolPaths)) .filter(c -> c.getProtocol() != null) - .collect(Collectors.toMap(p -> p.getId(), p -> p.getProtocol())); + .collect(Collectors.toMap(ProtocolCacheEntry::getId, ProtocolCacheEntry::getProtocol)); log.info("Refreshed Protocols from Github"); return subjectProtocolMap; } private ProtocolCacheEntry fetchProtocolForSingleUser(User u, String projectId, Set protocolPaths) { - Map attributes = u.getAttributes(); - Map pathMap = protocolPaths.stream().filter(k -> k.contains(projectId)) + Map attributes = u.getAttributes() != null ? u.getAttributes() : Map.of(); + Map pathMap = protocolPaths.stream() + .filter(k -> k.contains(projectId)) .map(p -> { Map path = this.convertPathToAttributeMap(p, projectId); return Maps.difference(attributes, path).entriesInCommon(); - }).max(Comparator.comparingInt(Map::size)).orElse(Collections.emptyMap()); + }) + .max(Comparator.comparingInt(Map::size)) + .orElse(Collections.emptyMap()); + try { String attributePath = this.convertAttributeMapToPath(pathMap, projectId); if (projectProtocolUriMap.get().containsKey(attributePath)) { URI uri = projectProtocolUriMap.get(attributePath); return new ProtocolCacheEntry(u.getSubjectId(), getProtocolFromUrl(uri)); + } else { + return new ProtocolCacheEntry(u.getSubjectId(), null); } - return new ProtocolCacheEntry(u.getSubjectId(), null); } catch (IOException | InterruptedException | ResponseStatusException e) { return new ProtocolCacheEntry(u.getSubjectId(), null); } } @Override - public synchronized Map fetchProtocolsPerProject() throws IOException { - Map projectProtocolMap = new HashMap<>(); - List projects = this.projectRepository.findAll(); + public synchronized Map fetchProtocolsPerProject() { + Set protocolPaths = getProtocolPaths(); - Map protocolUriMap; - try { - protocolUriMap = projectProtocolUriMap.get(); - } catch (IOException e) { - // Failed to get the Uri Map. try using the cached value - protocolUriMap = projectProtocolUriMap.getCache(); + if (protocolPaths == null) { + return Map.of(); } - if (protocolUriMap == null) { - return projectProtocolMap; - } - - Set protocolPaths = protocolUriMap.keySet(); - projectProtocolMap = projects.parallelStream() + Map projectProtocolMap = projectRepository.findAll() + .parallelStream() .map(project -> { String projectId = project.getProjectId(); - String path = protocolPaths.stream().filter(k -> k.contains(projectId)).findFirst().get(); - try { - URI uri = projectProtocolUriMap.get(path); - Protocol protocol = getProtocolFromUrl(uri); - return new ProtocolCacheEntry(projectId, protocol); - } catch (IOException | InterruptedException | ResponseStatusException e) { - return new ProtocolCacheEntry(projectId, null); - } - }).collect(Collectors.toMap(p -> p.getId(), p -> p.getProtocol())); + Protocol protocol = protocolPaths.stream() + .filter(k -> k.contains(projectId)) + .findFirst() + .map(path -> { + try { + URI uri = projectProtocolUriMap.get(path); + return getProtocolFromUrl(uri); + } catch (IOException | InterruptedException + | ResponseStatusException e) { + return null; + } + }).orElse(null); + return new ProtocolCacheEntry(projectId, protocol); + }) + .collect(Collectors.toMap(ProtocolCacheEntry::getId, ProtocolCacheEntry::getProtocol)); log.info("Refreshed Protocols from Github"); return projectProtocolMap; } + private Set getProtocolPaths() { + Map uriMap; + try { + uriMap = projectProtocolUriMap.get(); + } catch (IOException e) { + // Failed to get the Uri Map. try using the cached value + uriMap = projectProtocolUriMap.getCache(); + } + return uriMap != null ? uriMap.keySet() : null; + } + public Map convertPathToAttributeMap(String path, String projectId) { String[] parts = path.split("/"); String key = ""; @@ -232,36 +221,20 @@ private Map getProtocolDirectories() throws IOException { Map protocolUriMap = new HashMap<>(); try { - HttpResponse response = - client.send( - getRequest( - URI.create(GITHUB_API_URI + protocolRepo + "/branches/" + protocolBranch)), - HttpResponse.BodyHandlers.ofString()); - if (isSuccessfulResponse(response)) { - ObjectNode result = getArrayNode(response.body().toString()); - String treeSha = result.findValue("tree").findValue("sha").asText(); - URI treeUri = URI.create(GITHUB_API_URI + protocolRepo + "/git/trees/" + treeSha + "?recursive=true"); - HttpResponse treeResponse = client.send(getRequest(treeUri), HttpResponse.BodyHandlers.ofString()); - - if (isSuccessfulResponse(treeResponse)) { - JsonNode tree = getArrayNode(treeResponse.body().toString()).get("tree"); - for (JsonNode jsonNode : tree) { - String path = jsonNode.get("path").asText(); - if (path.contains(this.protocolFileName)) { - protocolUriMap.put( - path, - URI.create(jsonNode.get("url").asText())); - } - } - } - } - else { - log.warn("Failed to retrieve protocols URIs from github: {}.", response); - throw new ResponseStatusException( - HttpStatus.valueOf(response.statusCode()), - "Failed to retrieve protocols URIs from github."); + String content = githubClient.getGithubContent(GITHUB_API_URI + protocolRepo + "/branches/" + protocolBranch); + ObjectNode result = getArrayNode(content); + String treeSha = result.findValue("tree").findValue("sha").asText(); + String treeContent = githubClient.getGithubContent(GITHUB_API_URI + protocolRepo + "/git/trees/" + treeSha + "?recursive=true"); + + JsonNode tree = getArrayNode(treeContent).get("tree"); + for (JsonNode jsonNode : tree) { + String path = jsonNode.get("path").asText(); + if (path.contains(this.protocolFileName)) { + protocolUriMap.put( + path, + URI.create(jsonNode.get("url").asText())); + } } - } catch (InterruptedException | ResponseStatusException e) { throw new IOException("Failed to retrieve protocols URIs from github", e); } @@ -269,25 +242,15 @@ private Map getProtocolDirectories() throws IOException { } private Protocol getProtocolFromUrl(URI uri) throws IOException, InterruptedException { - String contentString = this.githubClient.getGithubContent(uri.toString()); + String contentString = githubClient.getGithubContent(uri.toString()); GithubContent content = localMapper.readValue(contentString, GithubContent.class); return localMapper.readValue(content.getContent(), Protocol.class); } - private HttpRequest getRequest(URI uri) { - HttpRequest.Builder request = HttpRequest.newBuilder(uri) - .header("Accept", GITHUB_API_ACCEPT_HEADER) - .header("Authorization", "Bearer " + this.githubToken) - .GET() - .timeout(Duration.ofSeconds(10)); - - return request.build(); - } - @SneakyThrows private ObjectNode getArrayNode(String json) { try (JsonParser parserProtocol = objectMapper.getFactory().createParser(json)) { return objectMapper.readTree(parserProtocol); - } + } } } diff --git a/src/main/java/org/radarbase/appserver/util/CachedFunction.java b/src/main/java/org/radarbase/appserver/util/CachedFunction.java new file mode 100644 index 00000000..943b9cf1 --- /dev/null +++ b/src/main/java/org/radarbase/appserver/util/CachedFunction.java @@ -0,0 +1,106 @@ +package org.radarbase.appserver.util; + +import org.jetbrains.annotations.NotNull; +import org.springframework.util.function.ThrowingFunction; + +import java.lang.ref.SoftReference; +import java.time.Duration; +import java.time.Instant; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; + +public class CachedFunction implements ThrowingFunction { + private transient final Duration cacheTime; + + private transient final Duration retryTime; + + private transient final int maxSize; + + private transient final Map>> cachedMap; + private transient final ThrowingFunction function; + + public CachedFunction(ThrowingFunction function, + Duration cacheTime, + Duration retryTime, + int maxSize) { + this.cacheTime = cacheTime; + this.retryTime = retryTime; + this.maxSize = maxSize; + this.cachedMap = new LinkedHashMap<>(16, 0.75f, false); + this.function = function; + } + + @NotNull + public V applyWithException(@NotNull K input) throws Exception { + SoftReference> localRef; + synchronized (cachedMap) { + localRef = cachedMap.get(input); + } + Result result = localRef != null ? localRef.get() : null; + if (result != null && !result.isExpired()) { + return result.getOrThrow(); + } + + try { + V content = function.applyWithException(input); + putCache(input, new Result<>(cacheTime, content, null)); + return content; + } catch (Exception ex) { + synchronized (cachedMap) { + SoftReference> exRef = cachedMap.get(input); + Result exResult = exRef != null ? exRef.get() : null; + if (exResult == null || exResult.isBadResult()) { + putCache(input, new Result<>(retryTime, null, ex)); + throw ex; + } else { + return exResult.getOrThrow(); + } + } + } + } + + @SuppressWarnings("PMD.DataflowAnomalyAnalysis") + private void putCache(K input, Result result) { + synchronized (cachedMap) { + cachedMap.put(input, new SoftReference<>(result)); + int toRemove = cachedMap.size() - maxSize; + if (toRemove > 0) { + Iterator iter = cachedMap.entrySet().iterator(); + for (int i = 0; i < toRemove; i++) { + iter.next(); + iter.remove(); + } + } + } + } + + private static class Result { + private transient final Instant expiration; + private transient final T value; + + private transient final Exception exception; + + Result(Duration expiryDuration, T value, Exception exception) { + expiration = Instant.now().plus(expiryDuration); + this.value = value; + this.exception = exception; + } + + T getOrThrow() throws Exception { + if (exception != null) { + throw exception; + } else { + return value; + } + } + + boolean isBadResult() { + return exception != null || isExpired(); + } + + boolean isExpired() { + return Instant.now().isAfter(expiration); + } + } +} diff --git a/src/main/resources/application-dev.properties b/src/main/resources/application-dev.properties index 4264489c..694c1546 100644 --- a/src/main/resources/application-dev.properties +++ b/src/main/resources/application-dev.properties @@ -100,3 +100,9 @@ security.radar.managementportal.url=http://localhost:8081 #security.oauth2.client.userAuthorizationUri= # Github Authentication security.github.client.token= +security.github.client.timeout=PT10s +# max content size 1 MB +security.github.client.maxContentLength=1000000 +security.github.cache.size=10000 +security.github.cache.duration=3600 +security.github.cache.retryDuration=60 diff --git a/src/main/resources/application-prod.properties b/src/main/resources/application-prod.properties index ec57b82f..01d20eee 100644 --- a/src/main/resources/application-prod.properties +++ b/src/main/resources/application-prod.properties @@ -70,3 +70,9 @@ radar.admin.user=radar radar.admin.password=radar # Github Authentication security.github.client.token= +security.github.client.timeout=10 +# max content size 1 MB +security.github.client.maxContentLength=1000000 +security.github.cache.size=10000 +security.github.cache.duration=3600 +security.github.cache.retryDuration=60