Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow use of sources as unit testing inputs #9059

Merged
merged 5 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231111-191150.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support source inputs in unit tests
time: 2023-11-11T19:11:50.870494-05:00
custom:
Author: gshank
Issue: "8507"
18 changes: 14 additions & 4 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from dataclasses import dataclass, field
from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet

from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode
from dbt.contracts.graph.nodes import (
SourceDefinition,
ManifestNode,
ResultNode,
ParsedNode,
UnitTestSourceDefinition,
)
from dbt.contracts.relation import (
RelationType,
ComponentName,
Expand Down Expand Up @@ -201,7 +207,9 @@ def quoted(self, identifier):
)

@classmethod
def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self:
def create_from_source(
cls: Type[Self], source: Union[SourceDefinition, UnitTestSourceDefinition], **kwargs: Any
) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
Expand Down Expand Up @@ -263,8 +271,10 @@ def create_from(
node: ResultNode,
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
if not isinstance(node, SourceDefinition):
if node.resource_type == NodeType.Source or isinstance(node, UnitTestSourceDefinition):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify the logic here?

Copy link
Contributor Author

@gshank gshank Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? We can't set the resource_type to Source because that breaks execution.

if not (
isinstance(node, SourceDefinition) or isinstance(node, UnitTestSourceDefinition)
):
raise DbtInternalError(
"type mismatch, expected SourceDefinition but got {}".format(type(node))
)
Expand Down
25 changes: 24 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,29 @@
return self.Relation.create_from_source(target_source)


class RuntimeUnitTestSourceResolver(RuntimeSourceResolver):
gshank marked this conversation as resolved.
Show resolved Hide resolved
def resolve(self, source_name: str, table_name: str):
target_source = self.manifest.resolve_source(
source_name,
table_name,
self.current_project,
self.model.package_name,
)
if target_source is None or isinstance(target_source, Disabled):
raise TargetNotFoundError(

Check warning on line 618 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L618

Added line #L618 was not covered by tests
node=self.model,
target_name=f"{source_name}.{table_name}",
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
# For unit tests, this isn't a "real" source, it's a ModelNode taking
# the place of a source. We don't really need to return the relation here,
# we just need to set_cte, but skipping it confuses typing. We *do* need
# the relation in the "this" property.
self.model.set_cte(target_source.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_source)


# metric` implementations
class ParseMetricResolver(BaseMetricResolver):
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
Expand Down Expand Up @@ -746,7 +769,7 @@
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeSourceResolver # TODO: RuntimeUnitTestSourceResolver
source = RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver


Expand Down
12 changes: 11 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,9 +1062,19 @@ def test_node_type(self):
return "generic"


@dataclass
class UnitTestSourceDefinition(ModelNode):
source_name: str = "undefined"
quoting: Quoting = field(default_factory=Quoting)

@property
def search_name(self):
return f"{self.source_name}.{self.name}"


@dataclass
class UnitTestNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
resource_type: Literal[NodeType.Unit]
tested_node_unique_id: Optional[str] = None
this_input_node_unique_id: Optional[str] = None
overrides: Optional[UnitTestOverrides] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,31 @@
{% set default_row = {} %}

{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- endif -%}

{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("columns not available for" ~ model.name) }}
{{ exceptions.raise_compiler_error("columns not available for " ~ model.name) }}
{%- endif -%}

{%- for column_name, column_type in column_name_to_data_types.items() -%}
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}

{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}

{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{% endif %}
{%- endfor -%}

{%- if (rows | length) == 0 -%}
Expand Down
98 changes: 56 additions & 42 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
UnitTestDefinition,
DependsOn,
UnitTestConfig,
UnitTestSourceDefinition,
)
from dbt.contracts.graph.unparsed import UnparsedUnitTest
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput
Expand Down Expand Up @@ -105,44 +106,52 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
- {id: 2, b: 2}
"""
# Add the model "input" nodes, consisting of all referenced models in the unit test.
# This creates a model for every input in every test, so there may be multiple
# input models substituting for the same input ref'd model.
# This creates an ephemeral model for every input in every test, so there may be multiple
# input models substituting for the same input ref'd model. Note that since these are
# always "ephemeral" they just wrap the tested_node SQL in additional CTEs. No actual table
# or view is created.
for given in test_case.given:
# extract the original_input_node from the ref in the "input" key of the given list
original_input_node = self._get_original_input_node(given.input, tested_node)

original_input_node_columns = None
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.config.contract.enforced
):
original_input_node_columns = {
column.name: column.data_type for column in original_input_node.columns
}
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved

# TODO: include package_name?
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_unique_id = f"model.{package_name}.{input_name}"
input_node = ModelNode(
raw_code=self._build_fixture_raw_code(
given.get_rows(
self.root_project.project_root, self.root_project.fixture_paths
),
original_input_node_columns,
project_root = self.root_project.project_root
common_fields = {
"resource_type": NodeType.Model,
"package_name": package_name,
"original_file_path": original_input_node.original_file_path,
"unique_id": f"model.{package_name}.{input_name}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may need to include source_name in input_name to avoid clobbering sources with the same table_name but different source_names when they are inserted to manifest.nodes and manifest.sources below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this to include the source_name. This does make for pretty long unique_ids. Do we have any concerns about that? It's not like we're using that name to construct tables or anything...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far I've noticed this issue creep up in #9015. I think we could shorten the node name for CTE generation (since it doesn't need to be unique) but keep the unique_id longer

"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
"alias": original_input_node.identifier,
"schema": original_input_node.schema,
"fqn": original_input_node.fqn,
"checksum": FileHash.empty(),
"raw_code": self._build_fixture_raw_code(
given.get_rows(project_root, self.root_project.fixture_paths), None
),
resource_type=NodeType.Model,
package_name=package_name,
path=original_input_node.path,
original_file_path=original_input_node.original_file_path,
unique_id=input_unique_id,
name=input_name,
config=ModelConfig(materialized="ephemeral"),
database=original_input_node.database,
schema=original_input_node.schema,
alias=original_input_node.alias,
fqn=input_unique_id.split("."),
checksum=FileHash.empty(),
)
}

if original_input_node.resource_type == NodeType.Model:
input_node = ModelNode(
**common_fields,
name=input_name,
path=original_input_node.path,
)
elif original_input_node.resource_type == NodeType.Source:
# We are reusing the database/schema/identifier from the original source,
# but that shouldn't matter since this acts as an ephemeral model which just
# wraps a CTE around the unit test node.
input_node = UnitTestSourceDefinition(
**common_fields,
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
)
# Sources need to go in the sources dictionary in order to create the right lookup
self.unit_test_manifest.sources[input_node.unique_id] = input_node # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we anticipate any issues by having the sources dictionary contain a unique_id key that is prefixed with model instead of source here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to care. We don't actually check the unique_id prefix that I can recall. If somebody starts parsing the unit_test_manifest, I suppose it might be confusing. But right now we're putting it in two places, so one of them will be wrong.

Copy link
Contributor

@MichelleArk MichelleArk Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably isn't worth spending tons of time on.. but I think it could be possible to get around having to add the node to manifest.sources and do the lookup from the .nodes collection in UnitTestRuntimeSourceResolver since the unique_id will include source_name. kind of like what's done here: https://github.com/dbt-labs/dbt-core/blob/unit_testing_feature_branch/core/dbt/context/providers.py#L578

Not entirely sure what's more readable or less complex in this case. I can imagine having to maintain UnitTestSourceDefinitions across both dictionaries could be error-prone though..

But right now we're putting it in two places, so one of them will be wrong.

Given that UnitTestSourceDefinition is a ModelNode, I think having it in nodes is 'more' correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the lookup behavior of sources and nodes is subtly different with regard to the meaning of package=None, so I don't think looking up sources as though they were nodes is worth it.


# Both ModelNode and UnitTestSourceDefinition need to go in nodes dictionary
Copy link
Contributor

@MichelleArk MichelleArk Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my own understanding - is this to enable cte injection?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. There's code in compilation.py that looks up the existence of the cte in the nodes dictionary: if cte.id not in manifest.nodes:.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory we could also check for a UnitTestSourceDefinition and in sources, but that didn't feel like an improvement.

self.unit_test_manifest.nodes[input_node.unique_id] = input_node

# Populate this_input_node_unique_id if input fixture represents node being tested
Expand All @@ -153,6 +162,8 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
unit_test_node.depends_on.nodes.append(input_node.unique_id)

def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str:
# We're not currently using column_name_to_data_types, but leaving here for
# possible future use.
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
Expand All @@ -178,18 +189,21 @@ def _get_original_input_node(self, input: str, tested_node: ModelNode):
raise InvalidUnitTestGivenInput(input=input)

if statically_parsed["refs"]:
for ref in statically_parsed["refs"]:
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
ref = list(statically_parsed["refs"])[0]
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
elif statically_parsed["sources"]:
input_package_name, input_source_name = statically_parsed["sources"][0]
source = list(statically_parsed["sources"])[0]
input_source_name, input_name = source
original_input_node = self.manifest.source_lookup.find(
input_source_name, input_package_name, self.manifest
f"{input_source_name}.{input_name}",
self.root_project.project_name,
gshank marked this conversation as resolved.
Show resolved Hide resolved
self.manifest,
)
else:
raise InvalidUnitTestGivenInput(input=input)
Expand Down
69 changes: 69 additions & 0 deletions tests/functional/unit_testing/test_ut_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from dbt.tests.util import run_dbt

raw_customers_csv = """id,first_name,last_name,email
1,Michael,Perez,mperez0@chronoengine.com
2,Shawn,Mccoy,smccoy1@reddit.com
3,Kathleen,Payne,kpayne2@cargocollective.com
4,Jimmy,Cooper,jcooper3@cargocollective.com
5,Katherine,Rice,krice4@typepad.com
6,Sarah,Ryan,sryan5@gnu.org
7,Martin,Mcdonald,mmcdonald6@opera.com
8,Frank,Robinson,frobinson7@wunderground.com
9,Jennifer,Franklin,jfranklin8@mail.ru
10,Henry,Welch,hwelch9@list-manage.com
"""

schema_sources_yml = """
sources:
- name: seed_sources
schema: "{{ target.schema }}"
tables:
- name: raw_customers
columns:
- name: id
tests:
- not_null:
severity: "{{ 'error' if target.name == 'prod' else 'warn' }}"
- unique
- name: first_name
- name: last_name
- name: email
unit_tests:
- name: test_customers
model: customers
given:
- input: source('seed_sources', 'raw_customers')
rows:
- {id: 1, first_name: Emily}
expect:
rows:
- {id: 1, first_name: Emily}
"""

customers_sql = """
select * from {{ source('seed_sources', 'raw_customers') }}
"""


class TestUnitTestSourceInput:
@pytest.fixture(scope="class")
def seeds(self):
return {
"raw_customers.csv": raw_customers_csv,
}

@pytest.fixture(scope="class")
def models(self):
return {
"customers.sql": customers_sql,
"sources.yml": schema_sources_yml,
}

def test_source_input(self, project):
results = run_dbt(["seed"])
results = run_dbt(["run"])
len(results) == 1

results = run_dbt(["unit-test"])
assert len(results) == 1