Skip to content

Commit

Permalink
refactor into RunTask.handle_microbatch_model
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Nov 15, 2024
1 parent 69c088d commit 109fbd4
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
from copy import deepcopy
from dataclasses import asdict, field
from datetime import datetime
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type
from multiprocessing.pool import ThreadPool
from typing import (
AbstractSet,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
)

from dbt import tracking, utils
from dbt.adapters.base import BaseAdapter, BaseRelation
Expand Down Expand Up @@ -674,40 +686,46 @@ def handle_job_queue(self, pool, callback):
runner.do_skip(cause=cause)

if isinstance(runner, MicrobatchModelRunner):
# Initial run computes batch metadata
result = self.call_runner(runner)
batch_results: List[RunResult] = []

# execute batches serially until a relation exists
relation_exists = runner.relation_exists
batch_idx = 0
while batch_idx < len(runner.batches):
batch_runner = MicrobatchModelRunner(
self.config, runner.adapter, deepcopy(node), self.run_count, self.num_nodes
)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(runner.batches)
self.handle_microbatch_model(runner, pool, callback)
else:
args = [runner]
self._submit(pool, args, callback)

if relation_exists:
self._submit(pool, [batch_runner], batch_results.append)
else:
batch_results.append(self.call_runner(batch_runner))
relation_exists = batch_runner.relation_exists
def handle_microbatch_model(
self, runner: MicrobatchModelRunner, pool: ThreadPool, callback: Callable
):
# Initial run computes batch metadata
result = self.call_runner(runner)
batch_results: List[RunResult] = []

# Execute batches serially until a relation exists, at which point future batches are run in parallel
relation_exists = runner.relation_exists
batch_idx = 0
while batch_idx < len(runner.batches):
batch_runner = MicrobatchModelRunner(
self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes
)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(runner.batches)

batch_idx += 1
if relation_exists:
self._submit(pool, [batch_runner], batch_results.append)
else:
batch_results.append(self.call_runner(batch_runner))
relation_exists = batch_runner.relation_exists

# wait until all batches have completed
while len(batch_results) != len(runner.batches):
pass
batch_idx += 1

runner.merge_batch_results(result, batch_results)
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter)
runner.print_result_line(result)
callback(result)
else:
args = [runner]
self._submit(pool, args, callback)
# Wait until all batches have completed
while len(batch_results) != len(runner.batches):
pass

runner.merge_batch_results(result, batch_results)
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter)
runner.print_result_line(result)

callback(result)

def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]:
package_name = hook.package_name
Expand Down

0 comments on commit 109fbd4

Please sign in to comment.