Skip to content

Commit

Permalink
Minor: Enable Azure DatabaseTokenProvider if there is azure=true in D…
Browse files Browse the repository at this point in the history
…B_QUERYPARAMS (#18451)
  • Loading branch information
harshach authored Oct 29, 2024
1 parent e197c3f commit 96c3364
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ public static List<String> fieldToExtensionStrings(String field) throws IOExcept
preprocessedField = preprocessedField.replace("\n", "\\n").replace("\"", "\\\"");

CSVFormat format =
CSVFormat.DEFAULT.builder()
CSVFormat.DEFAULT
.builder()
.setDelimiter(';')
.setQuote('"')
.setRecordSeparator(null)
.setIgnoreSurroundingSpaces(true)
.setIgnoreEmptyLines(true)
.setEscape('\\').build(); // Use backslash for escaping special characters
.setEscape('\\')
.build(); // Use backslash for escaping special characters

try (CSVParser parser = CSVParser.parse(new StringReader(preprocessedField), format)) {
return parser.getRecords().stream()
Expand Down Expand Up @@ -180,7 +182,7 @@ public static List<String> fieldToColumns(String field) throws IOException {
preprocessedField = preprocessedField.replace("\n", "\\n").replace("\"", "\\\"");

CSVFormat format =
CSVFormat.DEFAULT.builder().setDelimiter(',').setQuote('"').setEscape('\\').build();
CSVFormat.DEFAULT.builder().setDelimiter(',').setQuote('"').setEscape('\\').build();

List<String> columns;
try (CSVParser parser = CSVParser.parse(new StringReader(preprocessedField), format)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.openmetadata.service.util.jdbi;

import java.net.URI;
import java.net.URLDecoder;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;

/** Factory class for {@link DatabaseAuthenticationProvider}. */
Expand All @@ -14,15 +18,35 @@ private DatabaseAuthenticationProviderFactory() {}
* @return instance of {@link DatabaseAuthenticationProvider}.
*/
public static Optional<DatabaseAuthenticationProvider> get(String jdbcURL) {
// Check
if (jdbcURL.contains(AwsRdsDatabaseAuthenticationProvider.AWS_REGION)
Map<String, String> queryParams = parseQueryParams(jdbcURL);

if ("true".equals(queryParams.get("azure"))) {
return Optional.of(new AzureDatabaseAuthenticationProvider());
} else if (jdbcURL.contains(AwsRdsDatabaseAuthenticationProvider.AWS_REGION)
&& jdbcURL.contains(AwsRdsDatabaseAuthenticationProvider.ALLOW_PUBLIC_KEY_RETRIEVAL)) {
return Optional.of(new AwsRdsDatabaseAuthenticationProvider());
} else if (jdbcURL.contains(AzureDatabaseAuthenticationProvider.AZURE)) {
return Optional.of(new AzureDatabaseAuthenticationProvider());
}

// Return empty
return Optional.empty();
}

private static Map<String, String> parseQueryParams(String jdbcURL) {
try {
URI uri = new URI(jdbcURL.substring(jdbcURL.indexOf(":") + 1));
Map<String, String> queryPairs = new LinkedHashMap<>();
String query = uri.getQuery();
if (query != null) {
String[] pairs = query.split("&");
for (String pair : pairs) {
int idx = pair.indexOf("=");
queryPairs.put(
URLDecoder.decode(pair.substring(0, idx), "UTF-8"),
URLDecoder.decode(pair.substring(idx + 1), "UTF-8"));
}
}
return queryPairs;
} catch (Exception e) {
throw new IllegalArgumentException("Failed to parse query parameters from JDBC URL", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.openmetadata.service.util.jdbi;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.Optional;
import org.junit.jupiter.api.Test;

class DatabaseAuthenticationProviderFactoryTest {

@Test
void testGet_withAzureTrueInJdbcUrl_returnsAzureDatabaseAuthenticationProvider() {
String jdbcURL =
"jdbc:postgresql://your-database.postgres.database.azure.com:5432/testdb?azure=true&sslmode=require";

Optional<DatabaseAuthenticationProvider> provider =
DatabaseAuthenticationProviderFactory.get(jdbcURL);
assertTrue(provider.isPresent(), "Expected AzureDatabaseAuthenticationProvider to be present");
assertTrue(
provider.get() instanceof AzureDatabaseAuthenticationProvider,
"Expected instance of AzureDatabaseAuthenticationProvider");
}

@Test
void testGet_withoutAzureTrueInJdbcUrl_returnsEmptyOptional() {
String jdbcURL =
"jdbc:postgresql://your-database.postgres.database.azure.com:5432/testdb?sslmode=require";
Optional<DatabaseAuthenticationProvider> provider =
DatabaseAuthenticationProviderFactory.get(jdbcURL);
assertFalse(provider.isPresent(), "Expected no provider to be present");
}

@Test
void testGet_withAwsRdsParamsInJdbcUrl_returnsAwsRdsDatabaseAuthenticationProvider() {
String jdbcURL =
"jdbc:mysql://your-aws-db.rds.amazonaws.com:3306/testdb?awsRegion=us-west-2&allowPublicKeyRetrieval=true";
Optional<DatabaseAuthenticationProvider> provider =
DatabaseAuthenticationProviderFactory.get(jdbcURL);
assertTrue(provider.isPresent(), "Expected AwsRdsDatabaseAuthenticationProvider to be present");
assertTrue(
provider.get() instanceof AwsRdsDatabaseAuthenticationProvider,
"Expected instance of AwsRdsDatabaseAuthenticationProvider");
}

@Test
void testGet_withInvalidUrl_returnsEmptyOptional() {
String jdbcURL = "jdbc:invalidurl://your-db.test?someParam=true";
Optional<DatabaseAuthenticationProvider> provider =
DatabaseAuthenticationProviderFactory.get(jdbcURL);
assertFalse(provider.isPresent(), "Expected no provider to be present for invalid URL");
}
}

0 comments on commit 96c3364

Please sign in to comment.