From dd300190911efd3eda910d15528abcc37e0d2601 Mon Sep 17 00:00:00 2001 From: Vladimir Katardjiev Date: Mon, 14 Aug 2023 12:04:18 -0700 Subject: [PATCH] SNOW-862760: Add additionalHeader support (#1490) * SNOW-862760: Add additionalHeader support for Snowsight --- .../net/snowflake/client/core/HttpUtil.java | 17 +++ .../snowflake/client/core/SFLoginInput.java | 38 +++++ .../snowflake/client/core/SessionUtil.java | 14 ++ .../net/snowflake/client/core/StmtUtil.java | 32 +++++ .../client/core/SessionUtilLatestIT.java | 136 ++++++++++++++++++ .../snowflake/client/core/StmtUtilTest.java | 109 ++++++++++++++ 6 files changed, 346 insertions(+) create mode 100644 src/test/java/net/snowflake/client/core/StmtUtilTest.java diff --git a/src/main/java/net/snowflake/client/core/HttpUtil.java b/src/main/java/net/snowflake/client/core/HttpUtil.java index ac11b08f7..628d0e28d 100644 --- a/src/main/java/net/snowflake/client/core/HttpUtil.java +++ b/src/main/java/net/snowflake/client/core/HttpUtil.java @@ -899,4 +899,21 @@ static int convertSystemPropertyToIntValue(String systemProperty, int defaultVal } return returnVal; } + + /** + * Helper function to attach additional headers to a request if present. This takes a (nullable) + * map of headers in format and adds them to the incoming request using addHeader. + * + *

Snowsight uses this to attach headers with additional telemetry information, see + * https://snowflakecomputing.atlassian.net/wiki/spaces/EN/pages/2960557006/GS+Communication + * + * @param request The request to add headers to. Must not be null. + * @param additionalHeaders The headers to add. May be null. + */ + static void applyAdditionalHeadersForSnowsight( + HttpRequestBase request, Map additionalHeaders) { + if (additionalHeaders != null && !additionalHeaders.isEmpty()) { + additionalHeaders.forEach(request::addHeader); + } + } } diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index f448d086a..35df50a7c 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -47,6 +47,10 @@ public class SFLoginInput { private HttpClientSettingsKey httpClientKey; private String privateKeyFile; private String privateKeyFilePwd; + private String inFlightCtx; // Opaque string sent for Snowsight account activation + + // Additional headers to add for Snowsight. + Map additionalHttpHeadersForSnowsight; SFLoginInput() {} @@ -338,6 +342,40 @@ SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) { return this; } + // Opaque string sent for Snowsight account activation + String getInFlightCtx() { + return inFlightCtx; + } + + // Opaque string sent for Snowsight account activation + SFLoginInput setInFlightCtx(String inFlightCtx) { + this.inFlightCtx = inFlightCtx; + return this; + } + + Map getAdditionalHttpHeadersForSnowsight() { + return additionalHttpHeadersForSnowsight; + } + + /** + * Set additional http headers to apply to the outgoing request. The additional headers cannot be + * used to replace or overwrite a header in use by the driver. These will be applied to the + * outgoing request. Primarily used by Snowsight, as described in {@link + * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase, + * Map)} + * + * @param additionalHttpHeaders The new headers to add + * @return The input object, for chaining + * @see + * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase, + * Map) + */ + public SFLoginInput setAdditionalHttpHeadersForSnowsight( + Map additionalHttpHeaders) { + this.additionalHttpHeadersForSnowsight = additionalHttpHeaders; + return this; + } + static boolean getBooleanValue(Object v) { if (v instanceof Boolean) { return (Boolean) v; diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index b37bb1659..c4a3ead7a 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -418,6 +418,8 @@ private static SFLoginOutput newSession( try { ClientAuthnDTO authnData = new ClientAuthnDTO(); + authnData.setInFlightCtx(loginInput.getInFlightCtx()); + Map data = new HashMap<>(); data.put(ClientAuthnParameter.CLIENT_APP_ID.name(), loginInput.getAppId()); @@ -585,6 +587,10 @@ private static SFLoginOutput newSession( postRequest = new HttpPost(loginURI); + // Add custom headers before adding common headers + HttpUtil.applyAdditionalHeadersForSnowsight( + postRequest, loginInput.getAdditionalHttpHeadersForSnowsight()); + // attach the login info json body to the post request StringEntity input = new StringEntity(json, StandardCharsets.UTF_8); input.setContentType("application/json"); @@ -894,6 +900,10 @@ private static SFLoginOutput tokenRequest(SFLoginInput loginInput, TokenRequestT uriBuilder.addParameter(SFSession.SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString()); postRequest = new HttpPost(uriBuilder.build()); + + // Add custom headers before adding common headers + HttpUtil.applyAdditionalHeadersForSnowsight( + postRequest, loginInput.getAdditionalHttpHeadersForSnowsight()); } catch (URISyntaxException ex) { logger.error("Exception when creating http request", ex); @@ -1006,6 +1016,10 @@ static void closeSession(SFLoginInput loginInput) throws SFException, SnowflakeS postRequest = new HttpPost(uriBuilder.build()); + // Add custom headers before adding common headers + HttpUtil.applyAdditionalHeadersForSnowsight( + postRequest, loginInput.getAdditionalHttpHeadersForSnowsight()); + postRequest.setHeader( SF_HEADER_AUTHORIZATION, SF_HEADER_SNOWFLAKE_AUTHTYPE diff --git a/src/main/java/net/snowflake/client/core/StmtUtil.java b/src/main/java/net/snowflake/client/core/StmtUtil.java index dc4961180..a02fb4d7b 100644 --- a/src/main/java/net/snowflake/client/core/StmtUtil.java +++ b/src/main/java/net/snowflake/client/core/StmtUtil.java @@ -105,6 +105,8 @@ static class StmtInput { QueryContextDTO queryContextDTO; + Map additionalHttpHeadersForSnowsight; + StmtInput() {} public StmtInput setSql(String sql) { @@ -236,6 +238,26 @@ public StmtInput setMaxRetries(int maxRetries) { this.maxRetries = maxRetries; return this; } + + /** + * Set additional http headers to apply to the outgoing request. The additional headers cannot + * be used to replace or overwrite a header in use by the driver. These will be applied to the + * outgoing request. Primarily used by Snowsight, as described in {@link + * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase, + * Map)} + * + * @param additionalHttpHeaders The new headers to add + * @return The input object, for chaining + * @see + * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase, + * Map) + */ + @SuppressWarnings("unchecked") + public StmtInput setAdditionalHttpHeadersForSnowsight( + Map additionalHttpHeaders) { + this.additionalHttpHeadersForSnowsight = additionalHttpHeaders; + return this; + } } /** Output for running a statement on server */ @@ -301,6 +323,10 @@ public static StmtOutput execute(StmtInput stmtInput, ExecTimeTelemetryData exec httpRequest = new HttpPost(uriBuilder.build()); + // Add custom headers before adding common headers + HttpUtil.applyAdditionalHeadersForSnowsight( + httpRequest, stmtInput.additionalHttpHeadersForSnowsight); + /* * sequence id is only needed for old query API, when old query API * is deprecated, we can remove sequence id. @@ -590,6 +616,9 @@ protected static String getQueryResult(String getResultPath, StmtInput stmtInput uriBuilder.addParameter(SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString()); httpRequest = new HttpGet(uriBuilder.build()); + // Add custom headers before adding common headers + HttpUtil.applyAdditionalHeadersForSnowsight( + httpRequest, stmtInput.additionalHttpHeadersForSnowsight); httpRequest.addHeader("accept", stmtInput.mediaType); @@ -691,6 +720,9 @@ public static void cancel(StmtInput stmtInput) throws SFException, SnowflakeSQLE uriBuilder.addParameter(SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString()); httpRequest = new HttpPost(uriBuilder.build()); + // Add custom headers before adding common headers + HttpUtil.applyAdditionalHeadersForSnowsight( + httpRequest, stmtInput.additionalHttpHeadersForSnowsight); /* * The JSON input has two fields: sqlText and requestId diff --git a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java index 6849eb6c5..17e458fb3 100644 --- a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java +++ b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java @@ -6,20 +6,31 @@ import static net.snowflake.client.TestUtil.systemGetEnv; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; +import java.util.UUID; import net.snowflake.client.category.TestCategoryCore; import net.snowflake.client.jdbc.BaseJDBCTest; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.common.core.ClientAuthnDTO; +import org.apache.commons.io.IOUtils; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpRequestBase; import org.junit.Ignore; import org.junit.Test; import org.junit.experimental.categories.Category; import org.mockito.MockedStatic; +import org.mockito.MockedStatic.Verification; import org.mockito.Mockito; @Category(TestCategoryCore.class) @@ -104,4 +115,129 @@ public void testConvertSystemPropertyToIntValue() { assertEquals( -1, HttpUtil.convertSystemPropertyToIntValue(HttpUtil.JDBC_TTL, HttpUtil.DEFAULT_TTL)); } + + /** + * SNOW-862760 Tests that when additional headers are set on a login request, they are forwarded + * to the recipient. + */ + @Test + public void testForwardedHeaders() throws Throwable { + SFLoginInput input = createLoginInput(); + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put("Extra-Snowflake-Header", "present"); + + input.setAdditionalHttpHeadersForSnowsight(additionalHeaders); + + Map connectionPropertiesMap = initConnectionPropertiesMap(); + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class)) { + // Both mocks the call _and_ verifies that the headers are forwarded. + Verification httpCalledWithHeaders = + () -> + HttpUtil.executeGeneralRequest( + Mockito.argThat( + arg -> { + for (Entry definedHeader : additionalHeaders.entrySet()) { + Header actualHeader = arg.getLastHeader(definedHeader.getKey()); + if (actualHeader == null) { + return false; + } + + if (!definedHeader.getValue().equals(actualHeader.getValue())) { + return false; + } + } + + return true; + }), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class)); + mockedHttpUtil + .when(httpCalledWithHeaders) + .thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}"); + + mockedHttpUtil + .when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any())) + .thenCallRealMethod(); + + SessionUtil.openSession(input, connectionPropertiesMap, "ALL"); + + // After login, the only invocation to http should have been with the new + // headers. + // No calls should have happened without additional headers. + mockedHttpUtil.verify(times(1), httpCalledWithHeaders); + } + } + + /** + * SNOW-862760 Verifies that, if inFlightCtx is provided to the login input, it's forwarded as + * part of the message body. + */ + @Test + public void testForwardInflightCtx() throws Throwable { + SFLoginInput input = createLoginInput(); + String inflightCtx = UUID.randomUUID().toString(); + input.setInFlightCtx(inflightCtx); + + Map connectionPropertiesMap = initConnectionPropertiesMap(); + + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class)) { + // Both mocks the call _and_ verifies that the headers are forwarded. + Verification httpCalledWithHeaders = + () -> + HttpUtil.executeGeneralRequest( + Mockito.argThat( + arg -> { + try { + // This gets tricky because the entity is a string. + // To not fail on JSON parsing changes, we'll verify that the key + // inFlightCtx is present and the random UUID body + HttpEntity entity = ((HttpPost) arg).getEntity(); + InputStream is = entity.getContent(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + IOUtils.copy(is, out); + String body = new String(out.toByteArray()); + return body.contains("inFlightCtx") && body.contains(inflightCtx); + } catch (UnsupportedOperationException | IOException e) { + } + return false; + }), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class)); + mockedHttpUtil + .when(httpCalledWithHeaders) + .thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}"); + + mockedHttpUtil + .when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any())) + .thenCallRealMethod(); + + SessionUtil.openSession(input, connectionPropertiesMap, "ALL"); + + // After login, the only invocation to http should have been with the new + // headers. + // No calls should have happened without additional headers. + mockedHttpUtil.verify(times(1), httpCalledWithHeaders); + } + } + + private SFLoginInput createLoginInput() { + SFLoginInput input = new SFLoginInput(); + input.setServerUrl("MOCK_TEST_HOST"); + input.setUserName("MOCK_USERNAME"); + input.setPassword("MOCK_PASSWORD"); + input.setAccountName("MOCK_ACCOUNT_NAME"); + input.setAppId("MOCK_APP_ID"); + input.setOCSPMode(OCSPMode.FAIL_OPEN); + input.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); + input.setLoginTimeout(1000); + input.setSessionParameters(new HashMap<>()); + + return input; + } } diff --git a/src/test/java/net/snowflake/client/core/StmtUtilTest.java b/src/test/java/net/snowflake/client/core/StmtUtilTest.java new file mode 100644 index 000000000..e59920206 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/StmtUtilTest.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core; + +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import net.snowflake.client.category.TestCategoryCore; +import net.snowflake.client.core.StmtUtil.StmtInput; +import net.snowflake.client.jdbc.BaseJDBCTest; +import org.apache.http.Header; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.mockito.MockedStatic; +import org.mockito.MockedStatic.Verification; +import org.mockito.Mockito; + +@Category(TestCategoryCore.class) +public class StmtUtilTest extends BaseJDBCTest { + + /** SNOW-862760 Verify that additional headers are added to request */ + @Test + public void testForwardedHeaders() throws Throwable { + SFLoginInput input = createLoginInput(); + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put("Extra-Snowflake-Header", "present"); + input.setAdditionalHttpHeadersForSnowsight(additionalHeaders); + + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class)) { + // Both mocks the call _and_ verifies that the headers are forwarded. + Verification httpCalledWithHeaders = + () -> + HttpUtil.executeRequest( + Mockito.argThat( + arg -> { + for (Entry definedHeader : additionalHeaders.entrySet()) { + Header actualHeader = arg.getLastHeader(definedHeader.getKey()); + if (actualHeader == null) { + return false; + } + + if (!definedHeader.getValue().equals(actualHeader.getValue())) { + return false; + } + } + + return true; + }), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(AtomicBoolean.class), + Mockito.anyBoolean(), + Mockito.anyBoolean(), + Mockito.nullable(HttpClientSettingsKey.class), + Mockito.nullable(ExecTimeTelemetryData.class)); + mockedHttpUtil + .when(httpCalledWithHeaders) + .thenReturn("{\"data\":null,\"code\":333334,\"message\":null,\"success\":true}"); + + mockedHttpUtil + .when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any())) + .thenCallRealMethod(); + + StmtInput stmtInput = new StmtInput(); + stmtInput.setAdditionalHttpHeadersForSnowsight(additionalHeaders); + // Async mode skips result post-processing so we don't need to mock an advanced + // response + stmtInput.setAsync(true); + stmtInput.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); + stmtInput.setRequestId(UUID.randomUUID().toString()); + stmtInput.setServiceName("MOCK_SERVICE_NAME"); + stmtInput.setServerUrl("MOCK_SERVER_URL"); + stmtInput.setSessionToken("MOCK_SESSION_TOKEN"); + stmtInput.setSequenceId(1); + stmtInput.setSql("SELECT * FROM MOCK_TABLE"); + + StmtUtil.execute(stmtInput, new ExecTimeTelemetryData()); + + // After login, the only invocation to http should have been with the new + // headers. + // No calls should have happened without additional headers. + mockedHttpUtil.verify(times(1), httpCalledWithHeaders); + } + } + + private SFLoginInput createLoginInput() { + SFLoginInput input = new SFLoginInput(); + input.setServerUrl("MOCK_TEST_HOST"); + input.setUserName("MOCK_USERNAME"); + input.setPassword("MOCK_PASSWORD"); + input.setAccountName("MOCK_ACCOUNT_NAME"); + input.setAppId("MOCK_APP_ID"); + input.setOCSPMode(OCSPMode.FAIL_OPEN); + input.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); + input.setLoginTimeout(1000); + input.setSessionParameters(new HashMap<>()); + + return input; + } +}