diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index c12b95f1c1..87e69a24d2 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -492,8 +492,12 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met resolution_dag=resolution_dag, resolver_input_for_query=resolver_input_for_query, validation_rules=( - MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query), - DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query), + MetricTimeQueryValidationRule( + self._manifest_lookup, resolver_input_for_query, resolve_group_by_item_result + ), + DuplicateMetricValidationRule( + self._manifest_lookup, resolver_input_for_query, resolve_group_by_item_result + ), ), ) diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py index ec3587472f..58ee5f852e 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from abc import ABC, abstractmethod from typing import Sequence @@ -11,15 +12,22 @@ from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery +if typing.TYPE_CHECKING: + from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult + class PostResolutionQueryValidationRule(ABC): """A validation rule that runs after all query inputs have been resolved to specs.""" def __init__( # noqa: D107 - self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + self, + manifest_lookup: SemanticManifestLookup, + resolver_input_for_query: ResolverInputForQuery, + resolve_group_by_item_result: ResolveGroupByItemsResult, ) -> None: self._manifest_lookup = manifest_lookup self._resolver_input_for_query = resolver_input_for_query + self._resolve_group_by_item_result = resolve_group_by_item_result @abstractmethod def validate_metric_in_resolution_dag( diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index 95fbf4ecba..216b6e8bf1 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from dataclasses import dataclass from typing import List, Sequence, Tuple @@ -29,13 +30,14 @@ from metricflow_semantics.query.issues.parsing.scd_requires_metric_time import ( ScdRequiresMetricTimeIssue, ) -from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ( - ResolverInputForQuery, -) +from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule from metricflow_semantics.specs.instance_spec import InstanceSpec from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec +if typing.TYPE_CHECKING: + from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult + @dataclass(frozen=True) class QueryItemsAnalysis: @@ -57,9 +59,16 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): """ def __init__( # noqa: D107 - self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + self, + manifest_lookup: SemanticManifestLookup, + resolver_input_for_query: ResolverInputForQuery, + resolve_group_by_item_result: ResolveGroupByItemsResult, ) -> None: - super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query) + super().__init__( + manifest_lookup=manifest_lookup, + resolver_input_for_query=resolver_input_for_query, + resolve_group_by_item_result=resolve_group_by_item_result, + ) self._metric_time_specs = tuple( TimeDimensionSpec.generate_possible_specs_for_time_dimension(