diff --git a/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3AsyncClientFactory.java b/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3AsyncClientFactory.java index 2b2071fde08..69150aafa00 100644 --- a/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3AsyncClientFactory.java +++ b/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3AsyncClientFactory.java @@ -55,8 +55,8 @@ static S3AsyncClient getAsyncClient(@NotNull final S3Instructions instructions) // .addMetricPublisher(LoggingMetricPublisher.create(Level.INFO, Format.PRETTY)) .scheduledExecutorService(ensureScheduledExecutor()) .build()) - .region(Region.of(instructions.regionName())) .credentialsProvider(instructions.awsV2CredentialsProvider()); + instructions.regionName().map(Region::of).ifPresent(builder::region); instructions.endpointOverride().ifPresent(builder::endpointOverride); final S3AsyncClient ret = builder.build(); if (log.isDebugEnabled()) { diff --git a/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3Instructions.java b/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3Instructions.java index a74875032da..7560e9377df 100644 --- a/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3Instructions.java +++ b/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3Instructions.java @@ -32,14 +32,18 @@ public abstract class S3Instructions implements LogOutputAppendable { private final static Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(2); private final static Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(2); + static final S3Instructions DEFAULT = builder().build(); + public static Builder builder() { return ImmutableS3Instructions.builder(); } /** - * The region name to use when reading or writing to S3. + * The region name to use when reading or writing to S3. If not provided, the region name is picked by the AWS SDK + * from 'aws.region' system property, "AWS_REGION" environment variable, the {user.home}/.aws/credentials or + * {user.home}/.aws/config files, or from EC2 metadata service, if running in EC2. */ - public abstract String regionName(); + public abstract Optional regionName(); /** * The maximum number of concurrent requests to make to S3, defaults to {@value #DEFAULT_MAX_CONCURRENT_REQUESTS}. diff --git a/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3SeekableChannelProviderPlugin.java b/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3SeekableChannelProviderPlugin.java index 5728d79fe74..619e31b3824 100644 --- a/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3SeekableChannelProviderPlugin.java +++ b/extensions/s3/src/main/java/io/deephaven/extensions/s3/S3SeekableChannelProviderPlugin.java @@ -29,9 +29,10 @@ public SeekableChannelsProvider createProvider(@NotNull final URI uri, @Nullable if (!isCompatible(uri, config)) { throw new IllegalArgumentException("Arguments not compatible, provided uri " + uri); } - if (!(config instanceof S3Instructions)) { - throw new IllegalArgumentException("Must provide S3Instructions to read files from S3"); + if (config != null && !(config instanceof S3Instructions)) { + throw new IllegalArgumentException("Only S3Instructions are valid when reading files from S3, provided " + + "config instance of class " + config.getClass().getName()); } - return new S3SeekableChannelProvider((S3Instructions) config); + return new S3SeekableChannelProvider(config == null ? S3Instructions.DEFAULT : (S3Instructions) config); } } diff --git a/extensions/s3/src/test/java/io/deephaven/extensions/s3/S3InstructionsTest.java b/extensions/s3/src/test/java/io/deephaven/extensions/s3/S3InstructionsTest.java index 71699749bdb..ef9e70400a0 100644 --- a/extensions/s3/src/test/java/io/deephaven/extensions/s3/S3InstructionsTest.java +++ b/extensions/s3/src/test/java/io/deephaven/extensions/s3/S3InstructionsTest.java @@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test; import java.time.Duration; +import java.util.Optional; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -13,8 +14,8 @@ public class S3InstructionsTest { @Test void defaults() { - final S3Instructions instructions = S3Instructions.builder().regionName("some-region").build(); - assertThat(instructions.regionName()).isEqualTo("some-region"); + final S3Instructions instructions = S3Instructions.builder().build(); + assertThat(instructions.regionName().isEmpty()).isTrue(); assertThat(instructions.maxConcurrentRequests()).isEqualTo(256); assertThat(instructions.readAheadCount()).isEqualTo(32); assertThat(instructions.fragmentSize()).isEqualTo(65536); @@ -26,12 +27,13 @@ void defaults() { } @Test - void missingRegion() { - try { - S3Instructions.builder().build(); - } catch (IllegalStateException e) { - assertThat(e).hasMessageContaining("regionName"); - } + void testSetRegion() { + final Optional region = S3Instructions.builder() + .regionName("some-region") + .build() + .regionName(); + assertThat(region.isPresent()).isTrue(); + assertThat(region.get()).isEqualTo("some-region"); } @Test diff --git a/py/server/deephaven/experimental/s3.py b/py/server/deephaven/experimental/s3.py index cf02101ecf1..7065d519f5c 100644 --- a/py/server/deephaven/experimental/s3.py +++ b/py/server/deephaven/experimental/s3.py @@ -34,7 +34,7 @@ class S3Instructions(JObjectWrapper): j_object_type = _JS3Instructions or type(None) def __init__(self, - region_name: str, + region_name: Optional[str] = None, max_concurrent_requests: Optional[int] = None, read_ahead_count: Optional[int] = None, fragment_size: Optional[int] = None, @@ -52,7 +52,10 @@ def __init__(self, Initializes the instructions. Args: - region_name (str): the region name for reading parquet files, mandatory parameter. + region_name (str): the region name for reading parquet files. If not provided, the default region will be + picked by the AWS SDK from 'aws.region' system property, "AWS_REGION" environment variable, the + {user.home}/.aws/credentials or {user.home}/.aws/config files, or from EC2 metadata service, if running in + EC2. max_concurrent_requests (int): the maximum number of concurrent requests for reading files, default is 256. read_ahead_count (int): the number of fragments to send asynchronous read requests for while reading the current fragment. Defaults to 32, which means fetch the next 32 fragments in advance when reading the current fragment. @@ -87,7 +90,9 @@ def __init__(self, try: builder = self.j_object_type.builder() - builder.regionName(region_name) + + if region_name is not None: + builder.regionName(region_name) if max_concurrent_requests is not None: builder.maxConcurrentRequests(max_concurrent_requests) diff --git a/py/server/tests/test_parquet.py b/py/server/tests/test_parquet.py index 5275821061c..f9182ad6cda 100644 --- a/py/server/tests/test_parquet.py +++ b/py/server/tests/test_parquet.py @@ -570,9 +570,7 @@ def test_read_parquet_from_s3(self): # Fails since we have a negative read_ahead_count with self.assertRaises(DHError): - s3.S3Instructions(region_name="us-east-1", - read_ahead_count=-1, - ) + s3.S3Instructions(read_ahead_count=-1) # Fails since we provide the key without the secret key with self.assertRaises(DHError): @@ -580,9 +578,8 @@ def test_read_parquet_from_s3(self): access_key_id="Some key without secret", ) - s3_instructions = s3.S3Instructions(region_name="us-east-1", - read_ahead_count=1, - ) + s3_instructions = s3.S3Instructions() + # Fails because we don't have the right credentials with self.assertRaises(Exception): read("s3://dh-s3-parquet-test1/multiColFile.parquet", special_instructions=s3_instructions).select()