diff --git a/src/main/java/dev/openfga/sdk/api/auth/AccessToken.java b/src/main/java/dev/openfga/sdk/api/auth/AccessToken.java index 0e99738..cddab41 100644 --- a/src/main/java/dev/openfga/sdk/api/auth/AccessToken.java +++ b/src/main/java/dev/openfga/sdk/api/auth/AccessToken.java @@ -15,6 +15,7 @@ import static dev.openfga.sdk.util.StringUtil.isNullOrWhitespace; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Random; class AccessToken { @@ -22,17 +23,29 @@ class AccessToken { private static final int TOKEN_EXPIRY_JITTER_IN_SEC = 300; // We add some jitter so that token refreshes are less likely to collide - private final Random random = new Random(); private Instant expiresAt; + private final Random random = new Random(); private String token; public boolean isValid() { - return !isNullOrWhitespace(token) - && (expiresAt == null - || expiresAt.isAfter(Instant.now() - .minusSeconds(TOKEN_EXPIRY_BUFFER_THRESHOLD_IN_SEC) - .minusSeconds(random.nextLong() % TOKEN_EXPIRY_JITTER_IN_SEC))); + if (isNullOrWhitespace(token)) { + return false; + } + + // Is expiry is null then the token will not expire so should be considered always valid + if (expiresAt == null) { + return true; + } + + // A token should be considered valid until 5 minutes before the expiry with some jitter + // to account for multiple calls to `isValid` at the same time and prevent multiple refresh calls + Instant expiresWithLeeway = expiresAt + .minusSeconds(TOKEN_EXPIRY_BUFFER_THRESHOLD_IN_SEC) + .minusSeconds(random.nextInt(TOKEN_EXPIRY_JITTER_IN_SEC)) + .truncatedTo(ChronoUnit.SECONDS); + + return Instant.now().truncatedTo(ChronoUnit.SECONDS).isBefore(expiresWithLeeway); } public String getToken() { @@ -40,7 +53,10 @@ public String getToken() { } public void setExpiresAt(Instant expiresAt) { - this.expiresAt = expiresAt; + if (expiresAt != null) { + // Truncate to seconds to zero out the milliseconds to keep comparison simpler + this.expiresAt = expiresAt.truncatedTo(ChronoUnit.SECONDS); + } } public void setToken(String token) { diff --git a/src/test/java/dev/openfga/sdk/api/auth/AccessTokenTest.java b/src/test/java/dev/openfga/sdk/api/auth/AccessTokenTest.java index 3d34bde..3a435f0 100644 --- a/src/test/java/dev/openfga/sdk/api/auth/AccessTokenTest.java +++ b/src/test/java/dev/openfga/sdk/api/auth/AccessTokenTest.java @@ -25,16 +25,32 @@ class AccessTokenTest { private static Stream expTimeAndResults() { return Stream.of( - Arguments.of(Instant.now().plus(1, ChronoUnit.HOURS), true), - Arguments.of(Instant.now().minus(1, ChronoUnit.HOURS), false), - Arguments.of(Instant.now().minus(10, ChronoUnit.MINUTES), false), - Arguments.of(Instant.now().plus(10, ChronoUnit.MINUTES), true), - Arguments.of(Instant.now(), true)); + Arguments.of("Expires in 1 hour should be valid", Instant.now().plus(1, ChronoUnit.HOURS), true), + Arguments.of( + "Expires in 15 minutes should be valid", Instant.now().plus(15, ChronoUnit.MINUTES), true), + Arguments.of("No expiry value should be valid", null, true), + Arguments.of( + "Expired 1 hour ago should not be valid", Instant.now().minus(1, ChronoUnit.HOURS), false), + Arguments.of( + "Expired 10 minutes ago should not be valid", + Instant.now().minus(10, ChronoUnit.MINUTES), + false), + Arguments.of( + "Expired 5 minutes ago should not be valid", + Instant.now().minus(5, ChronoUnit.MINUTES), + false), + Arguments.of( + "Expires in 5 minutes should not be valid", + Instant.now().plus(5, ChronoUnit.MINUTES), + false), + Arguments.of( + "Expires in 1 minute should not be valid", Instant.now().plus(1, ChronoUnit.MINUTES), false), + Arguments.of("Expires now should not be valid", Instant.now(), false)); } @MethodSource("expTimeAndResults") - @ParameterizedTest - public void testTokenValid(Instant exp, boolean valid) { + @ParameterizedTest(name = "{0}") + public void testTokenValid(String name, Instant exp, boolean valid) { AccessToken accessToken = new AccessToken(); accessToken.setToken("token"); accessToken.setExpiresAt(exp); diff --git a/src/test/java/dev/openfga/sdk/api/client/OpenFgaClientTest.java b/src/test/java/dev/openfga/sdk/api/client/OpenFgaClientTest.java index 2cfecf9..ca966df 100644 --- a/src/test/java/dev/openfga/sdk/api/client/OpenFgaClientTest.java +++ b/src/test/java/dev/openfga/sdk/api/client/OpenFgaClientTest.java @@ -15,6 +15,7 @@ import static org.hamcrest.Matchers.*; import static org.hamcrest.core.StringContains.containsString; import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; import com.fasterxml.jackson.databind.ObjectMapper; @@ -148,7 +149,7 @@ public void createStore_withClientCredentials() throws Exception { containsString(String.format("client_secret=%s", clientSecret)), containsString(String.format("audience=%s", apiAudience)), containsString(String.format("grant_type=%s", "client_credentials")))) - .doReturn(200, String.format("{\"access_token\":\"%s\"}", apiToken)); + .doReturn(200, String.format("{\"access_token\":\"%s\",\"expires_in\":\"%s\"}", apiToken, 3600)); mockHttpClient .onPost("https://localhost/stores") .withBody(is(expectedBody)) @@ -180,6 +181,62 @@ public void createStore_withClientCredentials() throws Exception { assertEquals(DEFAULT_STORE_NAME, response2.getName()); } + @Test + public void createStore_withClientCredentialsWithRefresh() throws Exception { + // Given + String apiTokenIssuer = "oauth2.server"; + String clientId = "some-client-id"; + String clientSecret = "some-client-secret"; + String apiToken = "some-generated-token"; + String apiAudience = "some-audience"; + clientConfiguration.credentials(new Credentials(new ClientCredentials() + .clientId(clientId) + .clientSecret(clientSecret) + .apiTokenIssuer(apiTokenIssuer) + .apiAudience(apiAudience))); + fga.setConfiguration(clientConfiguration); + + String expectedBody = String.format("{\"name\":\"%s\"}", DEFAULT_STORE_NAME); + String requestBody = String.format("{\"id\":\"%s\",\"name\":\"%s\"}", DEFAULT_STORE_ID, DEFAULT_STORE_NAME); + mockHttpClient + .onPost(String.format("https://%s/oauth/token", apiTokenIssuer)) + .withBody(allOf( + containsString(String.format("client_id=%s", clientId)), + containsString(String.format("client_secret=%s", clientSecret)), + containsString(String.format("audience=%s", apiAudience)), + containsString(String.format("grant_type=%s", "client_credentials")))) + .doReturn(200, String.format("{\"access_token\":\"%s\",\"expires_in\":\"%s\"}", apiToken, 1)); + mockHttpClient + .onPost("https://localhost/stores") + .withBody(is(expectedBody)) + .withHeader("Authorization", String.format("Bearer %s", apiToken)) + .doReturn(201, requestBody); + CreateStoreRequest request = new CreateStoreRequest().name(DEFAULT_STORE_NAME); + + // When + // We call two times to ensure the token is cached after the first request. + CreateStoreResponse response1 = fga.createStore(request).get(); + CreateStoreResponse response2 = fga.createStore(request).get(); + + // Then + // OAuth2 server should be called 1 time. + mockHttpClient + .verify() + .post(String.format("https://%s/oauth/token", apiTokenIssuer)) + .called(2); + // OpenFGA server should be called 2 times. + mockHttpClient + .verify() + .post("https://localhost/stores") + .withBody(is(expectedBody)) + .withHeader("Authorization", String.format("Bearer %s", apiToken)) + .called(2); + assertEquals(DEFAULT_STORE_ID, response1.getId()); + assertEquals(DEFAULT_STORE_NAME, response1.getName()); + assertEquals(DEFAULT_STORE_ID, response2.getId()); + assertEquals(DEFAULT_STORE_NAME, response2.getName()); + } + /** * List all stores. */