Skip to content

Commit

Permalink
Add Language annotations to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vagaerg authored and dain committed Mar 6, 2024
1 parent f896e2b commit 3d3a0f0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static QueryRunnerHelper withOpaConfig(OpaConfig opaConfig)
builder -> builder.setSystemAccessControl(new OpaAccessControlFactory().create(opaConfigToDict(opaConfig)))));
}

public Set<String> querySetOfStrings(String user, String query)
public Set<String> querySetOfStrings(String user, @Language("SQL") String query)
{
return querySetOfStrings(userSession(user), query);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.VarcharType;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
Expand Down Expand Up @@ -115,8 +116,8 @@ public void testRowFilteringEnabled()
.setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME))
.setOpaRowFiltersUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ROW_LEVEL_FILTERING_POLICY_NAME)));
OPA_CONTAINER.submitPolicy(SAMPLE_ROW_LEVEL_FILTERING_POLICY);
String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
@Language("SQL") String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
assertResultsForUser("admin", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);
assertResultsForUser("admin", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);

Expand All @@ -132,8 +133,8 @@ public void testRowFilteringDisabledDoesNothing()
new OpaConfig()
.setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)));
OPA_CONTAINER.submitPolicy(SAMPLE_ROW_LEVEL_FILTERING_POLICY);
String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
@Language("SQL") String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
assertResultsForUser("admin", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);
assertResultsForUser("admin", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);

Expand All @@ -151,8 +152,8 @@ public void testColumnMasking()
.setOpaColumnMaskingUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_COLUMN_MASKING_POLICY_NAME)));
OPA_CONTAINER.submitPolicy(SAMPLE_COLUMN_MASKING_POLICY);

String userNamesInUnrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
String userNamesInRestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String userNamesInUnrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
@Language("SQL") String userNamesInRestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
// No masking is applied to the unrestricted table
assertResultsForUser("admin", userNamesInUnrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);
assertResultsForUser("bob", userNamesInUnrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);
Expand All @@ -164,8 +165,8 @@ public void testColumnMasking()
Set<String> expectedMaskedUserNames = ALL_DUMMY_USERS_IN_TABLE.stream().map(userName -> "****" + userName.substring(userName.length() - 3)).collect(toImmutableSet());
assertResultsForUser("bob", userNamesInRestrictedTableQuery, expectedMaskedUserNames);

String phoneNumbersInUnrestrictedTableQuery = "SELECT user_phone FROM sample_catalog.sample_schema.unrestricted_table";
String phoneNumbersInRestrictedTableQuery = "SELECT user_phone FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String phoneNumbersInUnrestrictedTableQuery = "SELECT user_phone FROM sample_catalog.sample_schema.unrestricted_table";
@Language("SQL") String phoneNumbersInRestrictedTableQuery = "SELECT user_phone FROM sample_catalog.sample_schema.restricted_table";

// Phone numbers are derived by hashing the name of the user
Set<String> allExpectedPhoneNumbers = ALL_DUMMY_USERS_IN_TABLE.stream().map(userName -> String.valueOf(userName.hashCode())).collect(toImmutableSet());
Expand All @@ -186,8 +187,8 @@ public void testColumnMaskingDisabledDoesNothing()
{
setupTrinoWithOpa(new OpaConfig().setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)));
OPA_CONTAINER.submitPolicy(SAMPLE_COLUMN_MASKING_POLICY);
String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
@Language("SQL") String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table";
assertResultsForUser("admin", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);
assertResultsForUser("admin", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE);

Expand Down Expand Up @@ -232,8 +233,8 @@ public void testColumnMaskingAndRowFiltering()
}""";
OPA_CONTAINER.submitPolicy(policy);

String selectUserNameData = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
String selectUserTypeData = "SELECT user_type FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String selectUserNameData = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table";
@Language("SQL") String selectUserTypeData = "SELECT user_type FROM sample_catalog.sample_schema.restricted_table";
Set<String> expectedUserTypes = ImmutableSet.of("internal_user", "customer");

assertResultsForUser("admin", selectUserNameData, ALL_DUMMY_USERS_IN_TABLE);
Expand All @@ -243,7 +244,7 @@ public void testColumnMaskingAndRowFiltering()
assertResultsForUser("bob", selectUserTypeData, ImmutableSet.of("internal_user"));
}

private void assertResultsForUser(String asUser, String query, Set<String> expectedResults)
private void assertResultsForUser(String asUser, @Language("SQL") String query, Set<String> expectedResults)
{
assertThat(runner.querySetOfStrings(asUser, query)).containsExactlyInAnyOrderElementsOf(expectedResults);
}
Expand Down

0 comments on commit 3d3a0f0

Please sign in to comment.