diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 4c431ec227..45c5a227f2 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -4,7 +4,19 @@ from copy import deepcopy from dataclasses import asdict, field from datetime import datetime -from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type +from multiprocessing.pool import ThreadPool +from typing import ( + AbstractSet, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, +) from dbt import tracking, utils from dbt.adapters.base import BaseAdapter, BaseRelation @@ -674,40 +686,46 @@ def handle_job_queue(self, pool, callback): runner.do_skip(cause=cause) if isinstance(runner, MicrobatchModelRunner): - # Initial run computes batch metadata - result = self.call_runner(runner) - batch_results: List[RunResult] = [] - - # execute batches serially until a relation exists - relation_exists = runner.relation_exists - batch_idx = 0 - while batch_idx < len(runner.batches): - batch_runner = MicrobatchModelRunner( - self.config, runner.adapter, deepcopy(node), self.run_count, self.num_nodes - ) - batch_runner.set_batch_idx(batch_idx) - batch_runner.set_relation_exists(relation_exists) - batch_runner.set_batches(runner.batches) + self.handle_microbatch_model(runner, pool, callback) + else: + args = [runner] + self._submit(pool, args, callback) - if relation_exists: - self._submit(pool, [batch_runner], batch_results.append) - else: - batch_results.append(self.call_runner(batch_runner)) - relation_exists = batch_runner.relation_exists + def handle_microbatch_model( + self, runner: MicrobatchModelRunner, pool: ThreadPool, callback: Callable + ): + # Initial run computes batch metadata + result = self.call_runner(runner) + batch_results: List[RunResult] = [] + + # Execute batches serially until a relation exists, at which point future batches are run in parallel + relation_exists = runner.relation_exists + batch_idx = 0 + while batch_idx < len(runner.batches): + batch_runner = MicrobatchModelRunner( + self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes + ) + batch_runner.set_batch_idx(batch_idx) + batch_runner.set_relation_exists(relation_exists) + batch_runner.set_batches(runner.batches) - batch_idx += 1 + if relation_exists: + self._submit(pool, [batch_runner], batch_results.append) + else: + batch_results.append(self.call_runner(batch_runner)) + relation_exists = batch_runner.relation_exists - # wait until all batches have completed - while len(batch_results) != len(runner.batches): - pass + batch_idx += 1 - runner.merge_batch_results(result, batch_results) - track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) - runner.print_result_line(result) - callback(result) - else: - args = [runner] - self._submit(pool, args, callback) + # Wait until all batches have completed + while len(batch_results) != len(runner.batches): + pass + + runner.merge_batch_results(result, batch_results) + track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) + runner.print_result_line(result) + + callback(result) def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]: package_name = hook.package_name