Skip to content

Commit

Permalink
SNOW-862760: Add additionalHeader support (#1490)
Browse files Browse the repository at this point in the history
* SNOW-862760: Add additionalHeader support for Snowsight
  • Loading branch information
sfc-gh-vkatardjiev authored Aug 14, 2023
1 parent 709b4ab commit dd30019
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/main/java/net/snowflake/client/core/HttpUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -899,4 +899,21 @@ static int convertSystemPropertyToIntValue(String systemProperty, int defaultVal
}
return returnVal;
}

/**
* Helper function to attach additional headers to a request if present. This takes a (nullable)
* map of headers in <name,value> format and adds them to the incoming request using addHeader.
*
* <p>Snowsight uses this to attach headers with additional telemetry information, see
* https://snowflakecomputing.atlassian.net/wiki/spaces/EN/pages/2960557006/GS+Communication
*
* @param request The request to add headers to. Must not be null.
* @param additionalHeaders The headers to add. May be null.
*/
static void applyAdditionalHeadersForSnowsight(
HttpRequestBase request, Map<String, String> additionalHeaders) {
if (additionalHeaders != null && !additionalHeaders.isEmpty()) {
additionalHeaders.forEach(request::addHeader);
}
}
}
38 changes: 38 additions & 0 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public class SFLoginInput {
private HttpClientSettingsKey httpClientKey;
private String privateKeyFile;
private String privateKeyFilePwd;
private String inFlightCtx; // Opaque string sent for Snowsight account activation

// Additional headers to add for Snowsight.
Map<String, String> additionalHttpHeadersForSnowsight;

SFLoginInput() {}

Expand Down Expand Up @@ -338,6 +342,40 @@ SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) {
return this;
}

// Opaque string sent for Snowsight account activation
String getInFlightCtx() {
return inFlightCtx;
}

// Opaque string sent for Snowsight account activation
SFLoginInput setInFlightCtx(String inFlightCtx) {
this.inFlightCtx = inFlightCtx;
return this;
}

Map<String, String> getAdditionalHttpHeadersForSnowsight() {
return additionalHttpHeadersForSnowsight;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
*/
public SFLoginInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
this.additionalHttpHeadersForSnowsight = additionalHttpHeaders;
return this;
}

static boolean getBooleanValue(Object v) {
if (v instanceof Boolean) {
return (Boolean) v;
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ private static SFLoginOutput newSession(

try {
ClientAuthnDTO authnData = new ClientAuthnDTO();
authnData.setInFlightCtx(loginInput.getInFlightCtx());

Map<String, Object> data = new HashMap<>();
data.put(ClientAuthnParameter.CLIENT_APP_ID.name(), loginInput.getAppId());

Expand Down Expand Up @@ -585,6 +587,10 @@ private static SFLoginOutput newSession(

postRequest = new HttpPost(loginURI);

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

// attach the login info json body to the post request
StringEntity input = new StringEntity(json, StandardCharsets.UTF_8);
input.setContentType("application/json");
Expand Down Expand Up @@ -894,6 +900,10 @@ private static SFLoginOutput tokenRequest(SFLoginInput loginInput, TokenRequestT
uriBuilder.addParameter(SFSession.SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString());

postRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());
} catch (URISyntaxException ex) {
logger.error("Exception when creating http request", ex);

Expand Down Expand Up @@ -1006,6 +1016,10 @@ static void closeSession(SFLoginInput loginInput) throws SFException, SnowflakeS

postRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

postRequest.setHeader(
SF_HEADER_AUTHORIZATION,
SF_HEADER_SNOWFLAKE_AUTHTYPE
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/net/snowflake/client/core/StmtUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ static class StmtInput {

QueryContextDTO queryContextDTO;

Map<String, String> additionalHttpHeadersForSnowsight;

StmtInput() {}

public StmtInput setSql(String sql) {
Expand Down Expand Up @@ -236,6 +238,26 @@ public StmtInput setMaxRetries(int maxRetries) {
this.maxRetries = maxRetries;
return this;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot
* be used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
*/
@SuppressWarnings("unchecked")
public StmtInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
this.additionalHttpHeadersForSnowsight = additionalHttpHeaders;
return this;
}
}

/** Output for running a statement on server */
Expand Down Expand Up @@ -301,6 +323,10 @@ public static StmtOutput execute(StmtInput stmtInput, ExecTimeTelemetryData exec

httpRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

/*
* sequence id is only needed for old query API, when old query API
* is deprecated, we can remove sequence id.
Expand Down Expand Up @@ -590,6 +616,9 @@ protected static String getQueryResult(String getResultPath, StmtInput stmtInput
uriBuilder.addParameter(SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString());

httpRequest = new HttpGet(uriBuilder.build());
// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

httpRequest.addHeader("accept", stmtInput.mediaType);

Expand Down Expand Up @@ -691,6 +720,9 @@ public static void cancel(StmtInput stmtInput) throws SFException, SnowflakeSQLE
uriBuilder.addParameter(SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString());

httpRequest = new HttpPost(uriBuilder.build());
// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

/*
* The JSON input has two fields: sqlText and requestId
Expand Down
136 changes: 136 additions & 0 deletions src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,31 @@

import static net.snowflake.client.TestUtil.systemGetEnv;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID;
import net.snowflake.client.category.TestCategoryCore;
import net.snowflake.client.jdbc.BaseJDBCTest;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.common.core.ClientAuthnDTO;
import org.apache.commons.io.IOUtils;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.mockito.MockedStatic;
import org.mockito.MockedStatic.Verification;
import org.mockito.Mockito;

@Category(TestCategoryCore.class)
Expand Down Expand Up @@ -104,4 +115,129 @@ public void testConvertSystemPropertyToIntValue() {
assertEquals(
-1, HttpUtil.convertSystemPropertyToIntValue(HttpUtil.JDBC_TTL, HttpUtil.DEFAULT_TTL));
}

/**
* SNOW-862760 Tests that when additional headers are set on a login request, they are forwarded
* to the recipient.
*/
@Test
public void testForwardedHeaders() throws Throwable {
SFLoginInput input = createLoginInput();
Map<String, String> additionalHeaders = new HashMap<>();
additionalHeaders.put("Extra-Snowflake-Header", "present");

input.setAdditionalHttpHeadersForSnowsight(additionalHeaders);

Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
// Both mocks the call _and_ verifies that the headers are forwarded.
Verification httpCalledWithHeaders =
() ->
HttpUtil.executeGeneralRequest(
Mockito.argThat(
arg -> {
for (Entry<String, String> definedHeader : additionalHeaders.entrySet()) {
Header actualHeader = arg.getLastHeader(definedHeader.getKey());
if (actualHeader == null) {
return false;
}

if (!definedHeader.getValue().equals(actualHeader.getValue())) {
return false;
}
}

return true;
}),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class));
mockedHttpUtil
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

SessionUtil.openSession(input, connectionPropertiesMap, "ALL");

// After login, the only invocation to http should have been with the new
// headers.
// No calls should have happened without additional headers.
mockedHttpUtil.verify(times(1), httpCalledWithHeaders);
}
}

/**
* SNOW-862760 Verifies that, if inFlightCtx is provided to the login input, it's forwarded as
* part of the message body.
*/
@Test
public void testForwardInflightCtx() throws Throwable {
SFLoginInput input = createLoginInput();
String inflightCtx = UUID.randomUUID().toString();
input.setInFlightCtx(inflightCtx);

Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();

try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
// Both mocks the call _and_ verifies that the headers are forwarded.
Verification httpCalledWithHeaders =
() ->
HttpUtil.executeGeneralRequest(
Mockito.argThat(
arg -> {
try {
// This gets tricky because the entity is a string.
// To not fail on JSON parsing changes, we'll verify that the key
// inFlightCtx is present and the random UUID body
HttpEntity entity = ((HttpPost) arg).getEntity();
InputStream is = entity.getContent();
ByteArrayOutputStream out = new ByteArrayOutputStream();
IOUtils.copy(is, out);
String body = new String(out.toByteArray());
return body.contains("inFlightCtx") && body.contains(inflightCtx);
} catch (UnsupportedOperationException | IOException e) {
}
return false;
}),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class));
mockedHttpUtil
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

SessionUtil.openSession(input, connectionPropertiesMap, "ALL");

// After login, the only invocation to http should have been with the new
// headers.
// No calls should have happened without additional headers.
mockedHttpUtil.verify(times(1), httpCalledWithHeaders);
}
}

private SFLoginInput createLoginInput() {
SFLoginInput input = new SFLoginInput();
input.setServerUrl("MOCK_TEST_HOST");
input.setUserName("MOCK_USERNAME");
input.setPassword("MOCK_PASSWORD");
input.setAccountName("MOCK_ACCOUNT_NAME");
input.setAppId("MOCK_APP_ID");
input.setOCSPMode(OCSPMode.FAIL_OPEN);
input.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN));
input.setLoginTimeout(1000);
input.setSessionParameters(new HashMap<>());

return input;
}
}
Loading

0 comments on commit dd30019

Please sign in to comment.