Skip to content

Commit

Permalink
[asset selection] Introduce AssetSelection.to_selection_str (#26047)
Browse files Browse the repository at this point in the history
## Summary

On the backend, `AssetSelection` is a nice serializable representation
of an asset selection, which we might generate through a few methods:

- From an Antlr asset selection string
- From a set of frontend filters, saved in a catalog view
- Potentially from snapshotted user code, though we don't do this right
now (for example, examining an asset job or sensor/schedule target and
being able to copy an Antlr-ready string)

We don't have a good way to convert this selection back into a
user-readable Antlr string. There are a few possible solutions here:

- Save the user-specified Antlr string alongside a selection. This seems
like a promising approach when the initial source of truth is an Antlr
string.
- Generate an Antlr string. This is the approach we'll need if the
source of truth is user code or frontend filters where no Antlr string
previously existed.

This PR introduces a `to_selection_str()` method on `AssetSelection`
subclasses, which output a valid Antlr string representing that
selection:

- The string must be parsed into an identical AssetSelection
- The string may not identically match the original Antlr string, if one
exists

## Test Plan

New tests which
1. Ensure that Antlr strings can be converted to selections, back into
equivalent strings (e.g. produce the same selection, but may not be
equal)
2. Ensure that we have all Antlr literals under test, a rough proxy for
ensuring we have full coverage of the grammar

Updates the existing `__str__` tests, which were introduced in #19059
for similar purposes but never fully fleshed out. For now, drops check
selection stuff since the Antlr syntax doesn't include it.
  • Loading branch information
benpankow authored Nov 22, 2024
1 parent c67d4ea commit 41f3fab
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ def test_jobless_asset_selection(graphql_context):

assert result.data
assert result.data["scheduleOrError"]["__typename"] == "Schedule"
assert result.data["scheduleOrError"]["assetSelection"]["assetSelectionString"] == "asset_one"
assert (
result.data["scheduleOrError"]["assetSelection"]["assetSelectionString"]
== 'key:"asset_one"'
)
assert result.data["scheduleOrError"]["assetSelection"]["assetKeys"] == [
{"path": ["asset_one"]}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,7 @@ def test_asset_selection(graphql_context):

assert (
result.data["sensorOrError"]["assetSelection"]["assetSelectionString"]
== "fresh_diamond_bottom or asset_with_automation_condition or asset_with_custom_automation_condition"
== 'key:"fresh_diamond_bottom" or key:"asset_with_automation_condition" or key:"asset_with_custom_automation_condition"'
)
assert result.data["sensorOrError"]["assetSelection"]["assetKeys"] == [
{"path": ["asset_with_automation_condition"]},
Expand Down Expand Up @@ -1699,7 +1699,9 @@ def test_jobless_asset_selection(graphql_context):

assert result.data
assert result.data["sensorOrError"]["__typename"] == "Sensor"
assert result.data["sensorOrError"]["assetSelection"]["assetSelectionString"] == "asset_one"
assert (
result.data["sensorOrError"]["assetSelection"]["assetSelectionString"] == 'key:"asset_one"'
)
assert result.data["sensorOrError"]["assetSelection"]["assetKeys"] == [{"path": ["asset_one"]}]
assert result.data["sensorOrError"]["assetSelection"]["assets"] == [
{
Expand Down Expand Up @@ -1727,7 +1729,8 @@ def test_invalid_sensor_asset_selection(graphql_context):
assert result.data
assert result.data["sensorOrError"]["__typename"] == "Sensor"
assert (
result.data["sensorOrError"]["assetSelection"]["assetSelectionString"] == "does_not_exist"
result.data["sensorOrError"]["assetSelection"]["assetSelectionString"]
== 'key:"does_not_exist"'
)
assert (
result.data["sensorOrError"]["assetSelection"]["assetsOrError"]["__typename"]
Expand Down
126 changes: 75 additions & 51 deletions python_modules/dagster/dagster/_core/definitions/asset_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,31 @@ def needs_parentheses_when_operand(self) -> bool:
"""
return False

def operand__str__(self) -> str:
return f"({self})" if self.needs_parentheses_when_operand() else str(self)
def operand_to_selection_str(self) -> str:
"""Returns a string representation of the selection when it is a child of a boolean expression,
for example, in an `AndAssetSelection` or `OrAssetSelection`. The main difference from `to_selection_str`
is that this method may include additional parentheses around the selection to ensure that the
expression is parsed correctly.
"""
return (
f"({self.to_selection_str()})"
if self.needs_parentheses_when_operand()
else self.to_selection_str()
)

def to_selection_str(self) -> str:
"""Returns an Antlr string representation of the selection that can be parsed by `from_string`."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support conversion to a string."
)

def __str__(self) -> str:
# Attempt to use the to-Antlr-selection-string method if it's implemented,
# otherwise fall back to the default Python string representation
try:
return self.to_selection_str()
except NotImplementedError:
return super().__str__()


@whitelist_for_serdes
Expand All @@ -590,8 +613,8 @@ def resolve_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
return "all materializable assets" + (" and source assets" if self.include_sources else "")
def to_selection_str(self) -> str:
return "*"


@whitelist_for_serdes
Expand All @@ -610,9 +633,6 @@ def resolve_checks_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
return "all asset checks"


@whitelist_for_serdes
@record
Expand All @@ -636,11 +656,6 @@ def resolve_checks_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
if len(self.selected_asset_keys) == 1:
return f"asset_check:{self.selected_asset_keys[0].to_user_string()}"
return f"asset_check:({' or '.join(k.to_user_string() for k in self.selected_asset_keys)})"


@whitelist_for_serdes
@record
Expand Down Expand Up @@ -670,11 +685,6 @@ def resolve_checks_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
if len(self.selected_asset_check_keys) == 1:
return f"asset_check:{self.selected_asset_check_keys[0].to_user_string()}"
return f"asset_check:({' or '.join(k.to_user_string() for k in self.selected_asset_check_keys)})"


@record
class OperandListAssetSelection(AssetSelection):
Expand Down Expand Up @@ -729,8 +739,8 @@ def resolve_checks_inner(
),
)

def __str__(self) -> str:
return " and ".join(operand.operand__str__() for operand in self.operands)
def to_selection_str(self) -> str:
return " and ".join(f"{operand.operand_to_selection_str()}" for operand in self.operands)


@whitelist_for_serdes
Expand All @@ -757,8 +767,8 @@ def resolve_checks_inner(
),
)

def __str__(self) -> str:
return " or ".join(operand.operand__str__() for operand in self.operands)
def to_selection_str(self) -> str:
return " or ".join(f"{operand.operand_to_selection_str()}" for operand in self.operands)


@whitelist_for_serdes
Expand Down Expand Up @@ -791,8 +801,10 @@ def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "Asset
def needs_parentheses_when_operand(self) -> bool:
return True

def __str__(self) -> str:
return f"{self.left.operand__str__()} - {self.right.operand__str__()}"
def to_selection_str(self) -> str:
if isinstance(self.left, AllSelection):
return f"not {self.right.to_selection_str()}"
return f"{self.left.operand_to_selection_str()} and not {self.right.operand_to_selection_str()}"


@record
Expand All @@ -815,6 +827,9 @@ def resolve_inner(
selection = self.child.resolve_inner(asset_graph, allow_missing=allow_missing)
return fetch_sinks(asset_graph.asset_dep_graph, selection)

def to_selection_str(self) -> str:
return f"sinks({self.child.to_selection_str()})"


@whitelist_for_serdes
class RequiredNeighborsAssetSelection(ChainedAssetSelection):
Expand All @@ -836,6 +851,9 @@ def resolve_inner(
selection = self.child.resolve_inner(asset_graph, allow_missing=allow_missing)
return fetch_sources(asset_graph, selection)

def to_selection_str(self) -> str:
return f"roots({self.child.to_selection_str()})"


@whitelist_for_serdes
class MaterializableAssetSelection(ChainedAssetSelection):
Expand Down Expand Up @@ -876,6 +894,19 @@ def resolve_inner(
selection if not self.include_self else set(),
)

def to_selection_str(self) -> str:
if self.depth is None:
base = f"{self.child.operand_to_selection_str()}*"
elif self.depth == 0:
base = self.child.operand_to_selection_str()
else:
base = f"{self.child.operand_to_selection_str()}{'+' * self.depth}"

if self.include_self:
return base
else:
return f"{base} and not {self.child.operand_to_selection_str()}"


@whitelist_for_serdes
@record
Expand All @@ -901,11 +932,14 @@ def resolve_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
def needs_parentheses_when_operand(self) -> bool:
return len(self.selected_groups) > 1

def to_selection_str(self) -> str:
if len(self.selected_groups) == 1:
return f"group:{self.selected_groups[0]}"
return f'group:"{self.selected_groups[0]}"'
else:
return f"group:({' or '.join(self.selected_groups)})"
return " or ".join(f'group:"{group}"' for group in self.selected_groups)


@whitelist_for_serdes
Expand All @@ -926,8 +960,8 @@ def resolve_inner(

return {key for key in base_set if asset_graph.get(key).tags.get(self.key) == self.value}

def __str__(self) -> str:
return f"tag:{self.key}={self.value}"
def to_selection_str(self) -> str:
return f'tag:"{self.key}"="{self.value}"'


@whitelist_for_serdes
Expand All @@ -944,8 +978,8 @@ def resolve_inner(
if self.selected_owner in asset_graph.get(key).owners
}

def __str__(self) -> str:
return f"owner:{self.selected_owner}"
def to_selection_str(self) -> str:
return f'owner:"{self.selected_owner}"'


@whitelist_for_serdes
Expand All @@ -963,8 +997,8 @@ def resolve_inner(
"""This should not be invoked in user code."""
raise NotImplementedError

def __str__(self) -> str:
return f"code_location:{self.selected_code_location}"
def to_selection_str(self) -> str:
return f'code_location:"{self.selected_code_location}"'


@whitelist_for_serdes
Expand Down Expand Up @@ -1013,11 +1047,8 @@ def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "Asset
def needs_parentheses_when_operand(self) -> bool:
return len(self.selected_keys) > 1

def __str__(self) -> str:
if len(self.selected_keys) <= 3:
return f"{' or '.join(k.to_user_string() for k in self.selected_keys)}"
else:
return f"{len(self.selected_keys)} assets"
def to_selection_str(self) -> str:
return " or ".join(f'key:"{x.to_user_string()}"' for x in self.selected_keys)


@whitelist_for_serdes
Expand All @@ -1043,13 +1074,6 @@ def resolve_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
key_prefix_strs = ["/".join(key_prefix) for key_prefix in self.selected_key_prefixes]
if len(self.selected_key_prefixes) == 1:
return f"key_prefix:{key_prefix_strs[0]}"
else:
return f"key_prefix:({' or '.join(key_prefix_strs)})"


@whitelist_for_serdes
@record
Expand All @@ -1070,8 +1094,8 @@ def resolve_inner(
def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
return self

def __str__(self) -> str:
return f"key_substring:{self.selected_key_substring}"
def to_selection_str(self) -> str:
return f'key_substring:"{self.selected_key_substring}"'


def _fetch_all_upstream(
Expand Down Expand Up @@ -1114,18 +1138,18 @@ def resolve_inner(
all_upstream = _fetch_all_upstream(selection, asset_graph, self.depth, self.include_self)
return {key for key in all_upstream if key in asset_graph.materializable_asset_keys}

def __str__(self) -> str:
def to_selection_str(self) -> str:
if self.depth is None:
base = f"*({self.child})"
base = f"*{self.child.operand_to_selection_str()}"
elif self.depth == 0:
base = str(self.child)
base = self.child.operand_to_selection_str()
else:
base = f"{'+' * self.depth}({self.child})"
base = f"{'+' * self.depth}{self.child.operand_to_selection_str()}"

if self.include_self:
return base
else:
return f"{base} - ({self.child})"
return f"{base} and not {self.child.operand_to_selection_str()}"


@whitelist_for_serdes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from dagster._core.definitions.antlr_asset_selection.antlr_asset_selection import (
AntlrAssetSelectionParser,
)
from dagster._core.definitions.antlr_asset_selection.generated.AssetSelectionParser import (
AssetSelectionParser,
)
from dagster._core.definitions.asset_selection import AssetSelection, CodeLocationAssetSelection
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.storage.tags import KIND_PREFIX
Expand Down Expand Up @@ -70,10 +73,19 @@
),
],
)
def test_antlr_tree(selection_str, expected_tree_str):
def test_antlr_tree(selection_str, expected_tree_str) -> None:
asset_selection = AntlrAssetSelectionParser(selection_str, include_sources=True)
assert asset_selection.tree_str == expected_tree_str

generated_selection = asset_selection.asset_selection

# Ensure the generated selection can be converted back to a selection string, and then back to the same selection
regenerated_selection = AntlrAssetSelectionParser(
generated_selection.to_selection_str(),
include_sources=True,
).asset_selection
assert regenerated_selection == generated_selection


@pytest.mark.parametrize(
"selection_str",
Expand Down Expand Up @@ -143,7 +155,7 @@ def test_antlr_tree_invalid(selection_str):
),
],
)
def test_antlr_visit_basic(selection_str, expected_assets):
def test_antlr_visit_basic(selection_str, expected_assets) -> None:
# a -> b -> c
@asset(tags={"foo": "bar"}, owners=["team:billing"])
def a(): ...
Expand All @@ -157,7 +169,34 @@ def b(): ...
)
def c(): ...

assert (
AntlrAssetSelectionParser(selection_str, include_sources=True).asset_selection
== expected_assets
)
generated_selection = AntlrAssetSelectionParser(
selection_str, include_sources=True
).asset_selection
assert generated_selection == expected_assets

# Ensure the generated selection can be converted back to a selection string, and then back to the same selection
regenerated_selection = AntlrAssetSelectionParser(
generated_selection.to_selection_str(),
include_sources=True,
).asset_selection
assert regenerated_selection == expected_assets


def test_full_test_coverage() -> None:
# Ensures that every Antlr literal is tested in test_antlr_visit_basic
# by extension, also ensures that the to_selection_str method is tested
# for all Antlr literals
names = AssetSelectionParser.literalNames

all_selection_strings_we_are_testing = [
selection_str for selection_str, _ in test_antlr_visit_basic.pytestmark[0].args[1]
]

for name in names:
if name in ("<INVALID>", "','"):
continue

name_substr = name.strip("'")
assert any(
name_substr in selection_str for selection_str in all_selection_strings_we_are_testing
), f"Antlr literal {name_substr} is not under test in test_antlr_asset_selection.py:test_antlr_visit_basic"
Loading

0 comments on commit 41f3fab

Please sign in to comment.