Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634060537
  • Loading branch information
genehwung authored and ml metrics authors committed May 15, 2024
1 parent 5a07a0e commit 33e90cc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 28 deletions.
4 changes: 2 additions & 2 deletions ml_metrics/_src/chainables/courier_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def next_batch_from_generator():
' previously.'
)
result = [next(self._generator) for _ in range(self._generator.data_size)]
if not result and self._generator.exhausted:
result = lazy_fns.STOP_ITERATION
if self._generator.exhausted:
result.append(lazy_fns.STOP_ITERATION)
return pickler.dumps(result)

# TODO: b/318463291 - Considers deprecating in favor of
Expand Down
11 changes: 7 additions & 4 deletions ml_metrics/_src/chainables/courier_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def test_generator(n):

client.maybe_make(pickler.dumps(lazy_fns.trace(test_generator)(10)))
actual = []
while not lazy_fns.is_stop_iteration(
t := pickler.loads(client.next_batch_from_generator())
):
actual.extend(t)
while True:
states = pickler.loads(client.next_batch_from_generator())
if states and lazy_fns.is_stop_iteration(states[-1]):
actual.extend(states[:-1])
break
else:
actual.extend(states)
self.assertEqual(list(range(10)), actual)

def test_courier_server_shutdown(self):
Expand Down
40 changes: 27 additions & 13 deletions ml_metrics/_src/chainables/courier_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _check_heartbeat(self) -> bool:
if not self._heartbeat:
self._heartbeat = Worker(
self.server_name, call_timeout=self.call_timeout
).call('p')
).call(None)
try:
if self._heartbeat.done() and self._heartbeat.result():
self._heartbeat = None
Expand Down Expand Up @@ -441,8 +441,9 @@ def run_and_iterate(
running_tasks = []
total_tasks = len(tasks)
total_failures_cnt = 0
finished_tasks_cnt = 0
ticker = time.time()
while tasks or running_tasks or pending_tasks:
while tasks or pending_tasks or running_tasks:
if not self.workers:
raise ValueError(
'No workers are alive, remaining'
Expand Down Expand Up @@ -471,24 +472,29 @@ def run_and_iterate(
for task in running_tasks:
if task.done:
try:
if not lazy_fns.is_stop_iteration(result := task.result):
yield from result
still_running.append(task.iterate(self))
else:
states = task.result
if states and lazy_fns.is_stop_iteration(states[-1]):
logging.info(
'chainables: worker %s generator exhausted.', task.server_name
'chainables: worker %s generator exhausted.',
task.server_name,
)
yield from states[:-1]
finished_tasks_cnt += 1
else:
yield from states
still_running.append(task.iterate(self))
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning(
'chainables: exception when iterating, reappending task %s, \n'
logging.exception(
'chainables: exception when iterating, re-appending task %s, \n'
' exception: %s',
dataclasses.replace(task, parent_task=None),
e,
)
total_failures_cnt += 1
if total_failures_cnt > num_total_failures_threshold:
raise ValueError(
'chainables: too many failures, stopping the iteration.'
f'chainables: too many failures: {total_failures_cnt} >'
f' {num_total_failures_threshold}, stopping the iteration.'
) from e
tasks.append(task)
elif not self._workers[task.server_name].is_alive:
Expand All @@ -501,14 +507,22 @@ def run_and_iterate(
else:
still_running.append(task)
running_tasks = still_running
assert (
finished_tasks_cnt
+ len(running_tasks)
+ len(pending_tasks)
+ len(tasks)
) == total_tasks, 'Total tasks mismatch.'
if time.time() - ticker > _LOGGING_INTERVAL_SEC:
logging.info(
'chainables: iterate progress: %d/%d/%d/%d'
' (pending/running/remaining/total).',
len(pending_tasks),
'chainables: iterate progress: %d/%d/%d/%d/%d in %.2f secs.'
' (running/pending/remaining/finished/total).',
len(running_tasks),
len(pending_tasks),
len(tasks),
finished_tasks_cnt,
total_tasks,
time.time() - ticker,
)
ticker = time.time()
time.sleep(sleep_interval)
24 changes: 15 additions & 9 deletions ml_metrics/_src/chainables/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
TreeTransformT = TypeVar('TreeTransformT', bound='TreeTransform')
TreeFn = tree_fns.TreeFn

_LOGGING_INTERVAL_SECS = 30
_LOGGING_INTERVAL_SECS = 60


class PrefetchableIterator:
Expand Down Expand Up @@ -118,9 +118,6 @@ def prefetch(self, num_items: int = 0):
try:
self._data.append(next(self._generator))
self._error_cnt = 0
logging.info(
'chainables: Prefetching %d from %s', self._cnt, self._generator
)
self._cnt += 1
except StopIteration:
self._exhausted = True
Expand Down Expand Up @@ -209,11 +206,13 @@ def _extract_states(states: Iterable[Any | AggregateResult]) -> Iterator[Any]:
prev_batch_cnt = batch_cnt
prev_agg_cnt = agg_cnt
logging.info(
'chainables: finished iterating states, total: %d batches, %d agg_states,'
' took %.2f secs.',
'chainables: finished iterating, total: %d batches, %d agg_states,'
' took %.2f secs (%.2f batches/sec, %.2f agg_states/sec).',
batch_cnt,
agg_cnt,
time.time() - start_ticker,
batch_cnt / (time.time() - start_ticker),
agg_cnt / (time.time() - start_ticker),
)


Expand Down Expand Up @@ -273,6 +272,7 @@ def from_transform(
if mode == RunnerMode.AGGREGATE:
iterator_node = None
input_nodes = []
assert agg_node, 'No aggregation is required for "Aggregate" mode.'
elif mode == RunnerMode.SAMPLE:
agg_node = None
output_nodes = []
Expand Down Expand Up @@ -465,8 +465,12 @@ def iterate(
input_iterator = self._actual_inputs(None, input_iterator)
state = state or self.create_state()
with_agg_state = with_agg_state or with_agg_result
for i, batch in enumerate(input_iterator):
logging.info('chainables: calculating for batch %d.', i)
prev_ticker = time.time()
batch_index = -1
for batch_index, batch in enumerate(input_iterator):
if (ticker := time.time()) - prev_ticker > _LOGGING_INTERVAL_SECS:
logging.info('chainables: calculating for batch %d.', batch_index)
prev_ticker = ticker
batch_output = _call_fns(self.input_fns, batch)
yield batch_output if with_result else None
if with_agg_state:
Expand All @@ -476,7 +480,9 @@ def iterate(
# aggregation is enabled.
agg_result = self.get_result(state) if with_agg_result else None
if with_agg_state:
logging.info('chainables: yields aggregation.')
logging.info(
'chainables: yields aggregation after %d batches.', batch_index + 1
)
yield AggregateResult(agg_state=state, agg_result=agg_result)

def update_state(
Expand Down

0 comments on commit 33e90cc

Please sign in to comment.