From 95022730a4492251c5e94fedd10f1c5876ba192a Mon Sep 17 00:00:00 2001 From: Quigley Malcolm Date: Mon, 30 Sep 2024 15:06:18 -0500 Subject: [PATCH] When retrying microbatch models, propagate prior successful state --- core/dbt/contracts/graph/nodes.py | 4 ++-- core/dbt/task/retry.py | 12 +++++++++--- core/dbt/task/run.py | 18 +++++++++++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 2cdd25506a6..d5ef3d51174 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -60,7 +60,7 @@ from dbt.artifacts.resources import SqlOperation as SqlOperationResource from dbt.artifacts.resources import TimeSpine from dbt.artifacts.resources import UnitTestDefinition as UnitTestDefinitionResource -from dbt.artifacts.schemas.batch_results import BatchType +from dbt.artifacts.schemas.batch_results import BatchResults from dbt.contracts.graph.model_config import UnitTestNodeConfig from dbt.contracts.graph.node_args import ModelNodeArgs from dbt.contracts.graph.unparsed import ( @@ -454,7 +454,7 @@ def resource_class(cls) -> Type[HookNodeResource]: @dataclass class ModelNode(ModelResource, CompiledNode): - batches: Optional[List[BatchType]] = None + batch_info: Optional[BatchResults] = None @classmethod def resource_class(cls) -> Type[ModelResource]: diff --git a/core/dbt/task/retry.py b/core/dbt/task/retry.py index 4f08804d191..9b3c3874718 100644 --- a/core/dbt/task/retry.py +++ b/core/dbt/task/retry.py @@ -131,11 +131,17 @@ def run(self): ) } + # We need this so that re-running of a microbatch model will only rerun + # batches that previously failed. Note _explicitly_ do no pass the + # batch info if there were _no_ successful batches previously. This is + # because passing the batch info _forces_ the microbatch process into + # _incremental_ model, and it may be that we need to be in full refresh + # mode which is only handled if batch_info _isn't_ passed for a node batch_map = { - result.unique_id: result.batch_results.failed + result.unique_id: result.batch_results for result in self.previous_results.results - if result.status == NodeStatus.PartialSuccess - and result.batch_results is not None + if result.batch_results is not None + and len(result.batch_results.successful) != 0 and len(result.batch_results.failed) > 0 and not ( self.previous_command_name != "run-operation" diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 673a400ec02..aabed0e2fb9 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -1,6 +1,7 @@ import functools import os import threading +from copy import deepcopy from datetime import datetime from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type @@ -327,6 +328,13 @@ def _build_run_microbatch_model_result( status = RunStatus.PartialSuccess msg = f"PARTIAL SUCCESS ({num_successes}/{num_successes + num_failures})" + if model.batch_info is not None: + new_batch_results = deepcopy(model.batch_info) + new_batch_results.failed = [] + new_batch_results = new_batch_results + batch_results + else: + new_batch_results = batch_results + return RunResult( node=model, status=status, @@ -470,7 +478,7 @@ def _execute_microbatch_materialization( ) -> List[RunResult]: batch_results: List[RunResult] = [] - if model.batches is None: + if model.batch_info is None: microbatch_builder = MicrobatchBuilder( model=model, is_incremental=self._is_incremental(model), @@ -481,8 +489,8 @@ def _execute_microbatch_materialization( start = microbatch_builder.build_start_time(end) batches = microbatch_builder.build_batches(start, end) else: - batches = model.batches - # if there are batches, then don't run as full_refresh and do force is_incremental + batches = model.batch_info.failed + # if there is batch info, then don't run as full_refresh and do force is_incremental # not doing this risks blowing away the work that has already been done if self._has_relation(model=model): context["is_incremental"] = lambda: True @@ -567,7 +575,7 @@ def __init__( args: Flags, config: RuntimeConfig, manifest: Manifest, - batch_map: Optional[Dict[str, List[BatchType]]] = None, + batch_map: Optional[Dict[str, BatchResults]] = None, ) -> None: super().__init__(args, config, manifest) self.batch_map = batch_map @@ -709,7 +717,7 @@ def populate_microbatch_batches(self, selected_uids: AbstractSet[str]): if uid in self.batch_map: node = self.manifest.ref_lookup.perform_lookup(uid, self.manifest) if isinstance(node, ModelNode): - node.batches = self.batch_map[uid] + node.batch_info = self.batch_map[uid] def before_run(self, adapter: BaseAdapter, selected_uids: AbstractSet[str]) -> RunStatus: with adapter.connection_named("master"):