Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MF to support breaking changes in DSI for custom calendar #1522

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import Optional

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity


def error_if_not_standard_grain(input_granularity: str, context: Optional[str] = None) -> TimeGranularity:
"""Cast input grainularity string to TimeGranularity, otherwise error.

TODO: Not needed once, custom grain is supported for most things.
"""
try:
time_grain = TimeGranularity(input_granularity)
except ValueError:
error_msg = f"Received a non-standard time granularity, which is not supported at the moment, received: {input_granularity}."
if context:
error_msg += f"\nContext: {context}"
raise ValueError(error_msg)
return time_grain
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class DunderNamingScheme(QueryItemNamingScheme):
"""A naming scheme using the dundered name syntax.

TODO: Consolidate with StructuredLinkableSpecName / DunderedNameFormatter.
TODO: Consolidate with StructuredLinkableSpecName.
"""

_INPUT_REGEX = re.compile(r"\A[a-z]([a-z0-9_])*[a-z0-9]\Z")
Expand Down Expand Up @@ -52,7 +52,7 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:

@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> EntityLinkPattern:
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise ValueError(f"{repr(input_str)} does not follow this scheme.")

input_str = input_str.lower()
Expand Down Expand Up @@ -119,7 +119,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
)

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
# This naming scheme is case-insensitive.
input_str = input_str.lower()
if DunderNamingScheme._INPUT_REGEX.match(input_str) is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:
@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> MetricSpecPattern:
input_str = input_str.lower()
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise RuntimeError(f"{repr(input_str)} does not follow this scheme.")
return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str))

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
# TODO: Use regex.
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
pass

@abstractmethod
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
"""Returns true if the given input string follows this naming scheme.

Consider adding a structured result that indicates why it does not match the scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:

@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> SpecPattern:
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise ValueError(
f"The specified input {repr(input_str)} does not match the input described by the object builder "
f"pattern."
)
try:
# TODO: Update when more appropriate parsing libraries are available.
call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets
call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets(
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names
)
except ParseWhereFilterException as e:
raise ValueError(f"A spec pattern can't be generated from the input string {repr(input_str)}") from e

Expand Down Expand Up @@ -121,11 +123,14 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
raise RuntimeError("There should have been a return associated with one of the CallParameterSets.")

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
if ObjectBuilderNamingScheme._NAME_REGEX.match(input_str) is None:
return False
try:
call_parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{ " + input_str + " }}")
call_parameter_sets = WhereFilterParser.parse_call_parameter_sets(
where_sql_template="{{ " + input_str + " }}",
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)
return_value = (
len(call_parameter_sets.dimension_call_parameter_sets)
+ len(call_parameter_sets.time_dimension_call_parameter_sets)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.errors.custom_grain_not_supported import error_if_not_standard_grain
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many
Expand Down Expand Up @@ -401,7 +402,13 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> PushDownRe

# If time granularity is not set for the metric, defaults to DAY if available, else the smallest available granularity.
# Note: ignores any granularity set on input metrics.
metric_default_time_granularity = metric_to_use_for_time_granularity_resolution.time_granularity or max(
metric_time_granularity: Optional[TimeGranularity] = None
if metric_to_use_for_time_granularity_resolution.time_granularity is not None:
metric_time_granularity = error_if_not_standard_grain(
context=f"Metric({metric_to_use_for_time_granularity_resolution}).time_granularity",
input_granularity=metric_to_use_for_time_granularity_resolution.time_granularity,
)
metric_default_time_granularity = metric_time_granularity or max(
TimeGranularity.DAY,
self._semantic_manifest_lookup.metric_lookup.get_min_queryable_time_granularity(
MetricReference(metric_to_use_for_time_granularity_resolution.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def _resolve_specs_for_where_filters(
for location, where_filters in where_filters_and_locations.items():
for where_filter in where_filters:
try:
filter_call_parameter_sets = where_filter.call_parameter_sets
filter_call_parameter_sets = where_filter.call_parameter_sets(
custom_granularity_names=self._manifest_lookup.semantic_model_lookup.custom_granularity_names
)
except Exception as e:
non_parsable_resolutions.append(
NonParsableFilterResolution(
Expand Down
16 changes: 12 additions & 4 deletions metricflow-semantics/metricflow_semantics/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _parse_order_by_names(
order_by_name_without_prefix = order_by_name

for group_by_item_naming_scheme in self._group_by_item_naming_schemes:
if group_by_item_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix):
if group_by_item_naming_scheme.input_str_follows_scheme(
order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup
):
possible_inputs.append(
ResolverInputForGroupByItem(
input_obj=order_by_name,
Expand All @@ -223,7 +225,9 @@ def _parse_order_by_names(
break

for metric_naming_scheme in self._metric_naming_schemes:
if metric_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix):
if metric_naming_scheme.input_str_follows_scheme(
order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup
):
possible_inputs.append(
ResolverInputForMetric(
input_obj=order_by_name,
Expand Down Expand Up @@ -373,7 +377,9 @@ def _parse_and_validate_query(
for metric_name in metric_names:
resolver_input_for_metric: Optional[MetricFlowQueryResolverInput] = None
for metric_naming_scheme in self._metric_naming_schemes:
if metric_naming_scheme.input_str_follows_scheme(metric_name):
if metric_naming_scheme.input_str_follows_scheme(
metric_name, semantic_manifest_lookup=self._manifest_lookup
):
resolver_input_for_metric = ResolverInputForMetric(
input_obj=metric_name,
naming_scheme=metric_naming_scheme,
Expand Down Expand Up @@ -405,7 +411,9 @@ def _parse_and_validate_query(
for group_by_name in group_by_names:
resolver_input_for_group_by_item: Optional[MetricFlowQueryResolverInput] = None
for group_by_item_naming_scheme in self._group_by_item_naming_schemes:
if group_by_item_naming_scheme.input_str_follows_scheme(group_by_name):
if group_by_item_naming_scheme.input_str_follows_scheme(
group_by_name, semantic_manifest_lookup=self._manifest_lookup
):
spec_pattern = group_by_item_naming_scheme.spec_pattern(
group_by_name, semantic_manifest_lookup=self._manifest_lookup
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
DimensionCallParameterSet,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceDimension,
Expand All @@ -18,6 +17,7 @@
from typing_extensions import override

from metricflow_semantics.errors.error_classes import InvalidQuerySyntax
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -134,15 +134,19 @@ def __init__( # noqa
spec_resolution_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
custom_granularity_names: Sequence[str],
):
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = spec_resolution_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker
self._custom_granularity_names = custom_granularity_names

def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension:
"""Create a WhereFilterDimension."""
structured_name = DunderedNameFormatter.parse_name(name.lower())
structured_name = StructuredLinkableSpecName.from_name(
name.lower(), custom_granularity_names=self._custom_granularity_names
)

return WhereFilterDimension(
column_association_resolver=self._column_association_resolver,
Expand All @@ -151,5 +155,7 @@ def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimen
rendered_spec_tracker=self._rendered_spec_tracker,
element_name=structured_name.element_name,
entity_links=tuple(EntityReference(entity_link_name.lower()) for entity_link_name in entity_path)
+ structured_name.entity_links,
+ tuple(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is just to support case insensitivity right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, we switched from using DunderedNameFormatter to StructuredLinkableSpecName and that didn't have entity_links: Tuple[EntityReference, ...], but only entity_link_names: Tuple[str, ...]. Now that I think about it, i'll just add a property to make entity_links work in StructuredLinkableSpecName so that we don't need this change

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooo ok sounds good!

EntityReference(entity_link_name.lower()) for entity_link_name in structured_name.entity_link_names
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dbt_semantic_interfaces.call_parameter_sets import (
EntityCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import QueryInterfaceEntity, QueryInterfaceEntityFactory
from dbt_semantic_interfaces.references import EntityReference
Expand All @@ -14,6 +13,7 @@
from typing_extensions import override

from metricflow_semantics.errors.error_classes import InvalidQuerySyntax
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__( # noqa

def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> WhereFilterEntity:
"""Create a WhereFilterEntity."""
structured_name = DunderedNameFormatter.parse_name(entity_name.lower())
structured_name = StructuredLinkableSpecName.from_name(entity_name.lower(), custom_granularity_names=())

return WhereFilterEntity(
column_association_resolver=self._column_association_resolver,
Expand All @@ -112,5 +112,7 @@ def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> WhereFilt
rendered_spec_tracker=self._rendered_spec_tracker,
element_name=structured_name.element_name,
entity_links=tuple(EntityReference(entity_link_name.lower()) for entity_link_name in entity_path)
+ structured_name.entity_links,
+ tuple(
EntityReference(entity_link_name.lower()) for entity_link_name in structured_name.entity_link_names
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_from_where_filter_intersection( # noqa: D102
spec_resolution_lookup=self._spec_resolution_lookup,
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
custom_granularity_names=self._semantic_model_lookup.custom_granularity_names,
)
time_dimension_factory = WhereFilterTimeDimensionFactory(
column_association_resolver=self._column_association_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,28 @@ def test_input_str(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D
)


def test_input_follows_scheme(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D103
assert dunder_naming_scheme.input_str_follows_scheme("listing__country")
assert dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__month")
assert dunder_naming_scheme.input_str_follows_scheme("booking__listing")
assert not dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month")
assert not dunder_naming_scheme.input_str_follows_scheme("123")
assert not dunder_naming_scheme.input_str_follows_scheme("TimeDimension('metric_time')")
def test_input_follows_scheme( # noqa: D103
dunder_naming_scheme: DunderNamingScheme,
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
assert dunder_naming_scheme.input_str_follows_scheme(
"listing__country", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert dunder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert dunder_naming_scheme.input_str_follows_scheme(
"booking__listing", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"123", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"TimeDimension('metric_time')", semantic_manifest_lookup=simple_semantic_manifest_lookup
)


def test_spec_pattern( # noqa: D103
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def test_input_str(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D
assert metric_naming_scheme.input_str(MetricSpec(element_name="bookings")) == "bookings"


def test_input_follows_scheme(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D103
assert metric_naming_scheme.input_str_follows_scheme("listings")
def test_input_follows_scheme( # noqa: D103
metric_naming_scheme: MetricNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup
) -> None:
assert metric_naming_scheme.input_str_follows_scheme(
"listings", semantic_manifest_lookup=simple_semantic_manifest_lookup
)


def test_spec_pattern( # noqa: D103
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,30 @@ def test_input_str(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> N
)


def test_input_follows_scheme(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> None: # noqa: D103
def test_input_follows_scheme( # noqa: D103
object_builder_naming_scheme: ObjectBuilderNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup
) -> None:
assert object_builder_naming_scheme.input_str_follows_scheme(
"Dimension('listing__country', entity_path=['booking'])"
"Dimension('listing__country', entity_path=['booking'])",
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
assert object_builder_naming_scheme.input_str_follows_scheme(
"TimeDimension('listing__creation_time', time_granularity_name='month', date_part_name='day', "
"entity_path=['booking'])"
"entity_path=['booking'])",
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
assert object_builder_naming_scheme.input_str_follows_scheme(
"Entity('user', entity_path=['booking', 'listing'])",
"Entity('user', entity_path=['booking', 'listing'])", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme(
"123", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme(
"NotADimension('listing__country')", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month")
assert not object_builder_naming_scheme.input_str_follows_scheme("123")
assert not object_builder_naming_scheme.input_str_follows_scheme("NotADimension('listing__country')")


def test_spec_pattern( # noqa: D103
Expand Down
Loading
Loading