Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633371752
  • Loading branch information
genehwung authored and ml metrics authors committed May 13, 2024
1 parent e4cbc18 commit 5a07a0e
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 98 deletions.
28 changes: 27 additions & 1 deletion ml_metrics/_src/chainables/courier_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,32 @@ def pickled_maybe_make(maybe_lazy):
return pickler.dumps(repr(result))
return pickler.dumps(result)

def next_batch_from_generator():
assert self._generator is not None, (
'Generator is not set, the worker might crashed unexpectedly'
' previously.'
)
result = [next(self._generator) for _ in range(self._generator.data_size)]
if not result and self._generator.exhausted:
result = lazy_fns.STOP_ITERATION
return pickler.dumps(result)

# TODO: b/318463291 - Considers deprecating in favor of
# `next_batch_from_generator`.
def next_from_generator():
try:
result = next(self._generator)
except StopIteration:
result = lazy_fns.STOP_ITERATION
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning('Chainables: Exception while iterating: %s', e)
logging.exception('Chainables: Exception while iterating: %s', e)
result = lazy_fns.STOP_ITERATION
return pickler.dumps(result)

server = courier.Server(self.server_name, port=self.port)
server.Bind('maybe_make', pickled_maybe_make)
server.Bind('next_from_generator', next_from_generator)
server.Bind('next_batch_from_generator', next_batch_from_generator)
server.Bind('shutdown', shutdown)
# TODO: b/318463291 - Add unit tests.
server.Bind('clear_cache', lazy_fns.clear_cache)
Expand All @@ -80,6 +93,7 @@ def next_from_generator():

def run_until_shutdown(self):
"""Run until shutdown requested."""
assert self._server is not None, 'Server is not built.'
if not self._server.has_started:
self._server.Start()
while not self._shutdown_requested[_DEFAULT]:
Expand All @@ -88,3 +102,15 @@ def run_until_shutdown(self):
time.sleep(0.01)
logging.info('Chainables: Shutdown requested, shutting down server.')
self._server.Stop()


def run_courier_server(name=None, port=None, prefetch_size: int = 128):
# TODO: b/318463291 - Preloaded task to start running prefetching even before
# master started.
server_wrapper = CourierServerWrapper(
server_name=name,
port=port,
prefetch_size=prefetch_size,
)
server_wrapper.build_server()
server_wrapper.run_until_shutdown()
55 changes: 39 additions & 16 deletions ml_metrics/_src/chainables/courier_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,44 @@

from absl.testing import absltest
import courier
from courier.python import testutil
from ml_metrics._src.chainables import courier_server
from ml_metrics._src.chainables import courier_worker
from ml_metrics._src.chainables import lazy_fns


pickler = lazy_fns.picklers.default


def setUpModule():
testutil.SetupMockBNS()


class CourierServerTest(absltest.TestCase):

def setUp(self):
super().setUp()
self.server_wrapper = courier_server.CourierServerWrapper()
self.server = self.server_wrapper.build_server()
self.t = threading.Thread(target=self.server_wrapper.run_until_shutdown)
self.t.start()
self.client = courier_worker.Worker(self.server.address)
self.client.wait_until_alive()

def tearDown(self):
self.client.shutdown()
self.t.join()
super().tearDown()

def test_courier_server_maybe_make(self):
server_wrapper = courier_server.CourierServerWrapper()
server = server_wrapper.build_server()
server.Start()
client = courier.Client(server.address, call_timeout=6)
client = courier.Client(self.server.address, call_timeout=1)
self.assertEqual('hello', pickler.loads(client.maybe_make('hello')))
self.assertEqual(
2, pickler.loads(client.maybe_make(lazy_fns.trace(len)([1, 2])))
)
server.Stop()

def test_courier_server_generator(self):
server_wrapper = courier_server.CourierServerWrapper()
server = server_wrapper.build_server()
server.Start()
client = courier.Client(server.address, call_timeout=6)
client = courier.Client(self.server.address, call_timeout=1)

def test_generator(n):
yield from range(n)
Expand All @@ -53,7 +67,20 @@ def test_generator(n):
):
actual.append(t)
self.assertEqual(list(range(10)), actual)
server.Stop()

def test_courier_server_batch_generator(self):
client = courier.Client(self.server.address, call_timeout=1)

def test_generator(n):
yield from range(n)

client.maybe_make(pickler.dumps(lazy_fns.trace(test_generator)(10)))
actual = []
while not lazy_fns.is_stop_iteration(
t := pickler.loads(client.next_batch_from_generator())
):
actual.extend(t)
self.assertEqual(list(range(10)), actual)

def test_courier_server_shutdown(self):
server_wrapper = courier_server.CourierServerWrapper()
Expand All @@ -66,13 +93,10 @@ def test_courier_server_shutdown(self):
client.shutdown()
time.sleep(7)
self.assertFalse(t.is_alive())
t.join()

def test_courier_exception_during_prefetch(self):
server_wrapper = courier_server.CourierServerWrapper()
server = server_wrapper.build_server()
t = threading.Thread(target=server_wrapper.run_until_shutdown)
t.start()
client = courier.Client(server.address, call_timeout=6)
client = courier.Client(self.server.address, call_timeout=1)

def test_generator(n):
for i in range(n):
Expand All @@ -87,7 +111,6 @@ def test_generator(n):
t := pickler.loads(client.next_from_generator())
):
actual.append(t)
client.shutdown()
self.assertEqual(list(range(6)), actual)
self.assertRegex(cm.output[0], '.*Traceback.*')

Expand Down
115 changes: 87 additions & 28 deletions ml_metrics/_src/chainables/courier_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import courier
from ml_metrics._src.chainables import lazy_fns


_LOGGING_INTERVAL_SEC = 60
_NUM_TOTAL_FAILURES_THRESHOLD = 60
picklers = lazy_fns.picklers


Expand Down Expand Up @@ -69,7 +72,7 @@ def new(

def iterate(self, worker_pool):
state = worker_pool.get_worker_by_name(self.server_name).call(
courier_method='next_from_generator'
courier_method='next_batch_from_generator'
)
return dataclasses.replace(self, state=state)

Expand All @@ -83,7 +86,7 @@ def from_list_of_tasks(cls, tasks: list['Task']) -> 'Task':
return task

@property
def done(self) -> bool:
def done(self) -> bool | None:
"""Checks whether the task is done."""
if self.state is not None:
return self.state.done()
Expand Down Expand Up @@ -186,7 +189,7 @@ def _normalize_args(args, kwargs):
return result_args, result_kwargs


# TODO(b/311207032): Adds unit test to cover logics for disconneted worker.
# TODO(b/311207032): Adds unit test to cover logic for disconneted worker.
@dataclasses.dataclass
class Worker:
"""Courier client wrapper that works as a chainable worker."""
Expand Down Expand Up @@ -216,20 +219,40 @@ def has_capacity(self) -> bool:
def _check_heartbeat(self) -> bool:
"""Ping the worker to check the heartbeat once."""
if not self._heartbeat:
self._heartbeat = Worker(self.server_name, call_timeout=60).call('echo')
self._heartbeat = Worker(
self.server_name, call_timeout=self.call_timeout
).call('p')
try:
if self._heartbeat.done():
if self._heartbeat.result():
self._heartbeat = None
self._last_heartbeat = time.time()
return True
if self._heartbeat.done() and self._heartbeat.result():
self._heartbeat = None
self._last_heartbeat = time.time()
return True
except Exception: # pylint: disable=broad-exception-caught
logging.warning(
'chainables: Worker %s missed a heartbeat.', self.server_name
)
self._heartbeat = None
return False

def wait_until_alive(
self,
num_attempts: int = 30,
sleep_interval: float = 6.0,
):
"""Waits for the workers to be alive with retries."""
for _ in range(num_attempts):
try:
if self.is_alive:
break
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning('chainables: exception when connecting: %s', e)
time.sleep(sleep_interval)
else:
raise ValueError(
f'Failed to connect to worker {self.server_name} after'
f' {num_attempts} tries.'
)

@property
def is_alive(self) -> bool:
"""Checks whether the worker is alive."""
Expand Down Expand Up @@ -290,13 +313,8 @@ def shutdown(self):
return self.state


def _raise_if_return_not_iterator(task: Task):
# The return of the state has to be a generator for this call.
if (
not (result := task.state)
or not (result := picklers.default.loads(task.state.result()))
or 'generator' not in str(result)
):
def _raise_if_return_error(task: Task):
if not (result := task.state) or (result := task.exception):
raise TypeError(
f'Expected iterator, got {result} from'
f' task: {dataclasses.replace(task, parent_task=None)}'
Expand Down Expand Up @@ -342,7 +360,7 @@ def wait_until_alive(
if len(workers) >= minimum_num_workers:
break
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning('Exception when connecting: %s', e)
logging.warning('chainables: exception when connecting: %s', e)
time.sleep(sleep_interval)
else:
raise ValueError(
Expand Down Expand Up @@ -412,36 +430,67 @@ def run_tasks(
time.sleep(sleep_interval)

def run_and_iterate(
self, tasks: list[Task], sleep_interval: float = 0.01
self,
tasks: list[Task],
sleep_interval: float = 0.0,
num_total_failures_threshold: int = _NUM_TOTAL_FAILURES_THRESHOLD,
) -> Iterator[Any]:
"""Iterates through the result of a generator if the iterator task."""
tasks = [task for task in tasks]
pending_tasks = []
running_tasks = []
while tasks or running_tasks:
total_tasks = len(tasks)
total_failures_cnt = 0
ticker = time.time()
while tasks or running_tasks or pending_tasks:
if not self.workers:
raise ValueError(
'No workers are alive, remaining'
f' {len(tasks)+len(running_tasks)} tasks.'
)
# Assign to the iterator tasks
for worker in self.idle_workers():
# Only assign non-running tasks to the workers that are not running.
if worker.server_name not in {
task.server_name for task in running_tasks
}:
if tasks:
task = worker.run_task(tasks.pop())
_raise_if_return_not_iterator(task)
running_tasks.append(task.iterate(self))
# Fetching finsihed outputs.
pending_tasks.append(task)
# Check the instantiated iterator, then assign iteratate if successful.
new_pending_tasks: list[Task] = []
for task in pending_tasks:
if task.done:
_raise_if_return_error(task)
running_tasks.append(task.iterate(self))
else:
new_pending_tasks.append(task)
pending_tasks = new_pending_tasks
# Fetching finsihed outputs. Re-collect task if failed during iteration.
still_running: list[Task] = []
for task in running_tasks:
if task.done:
if not lazy_fns.is_stop_iteration(result := task.result):
yield result
still_running.append(task.iterate(self))
else:
logging.info(
'chainables: worker %s generator exhausted.', task.server_name
try:
if not lazy_fns.is_stop_iteration(result := task.result):
yield from result
still_running.append(task.iterate(self))
else:
logging.info(
'chainables: worker %s generator exhausted.', task.server_name
)
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning(
'chainables: exception when iterating, reappending task %s, \n'
' exception: %s',
dataclasses.replace(task, parent_task=None),
e,
)
total_failures_cnt += 1
if total_failures_cnt > num_total_failures_threshold:
raise ValueError(
'chainables: too many failures, stopping the iteration.'
) from e
tasks.append(task)
elif not self._workers[task.server_name].is_alive:
logging.warning(
'chainables: Worker %s is not alive, re-appending task %s',
Expand All @@ -452,4 +501,14 @@ def run_and_iterate(
else:
still_running.append(task)
running_tasks = still_running
if time.time() - ticker > _LOGGING_INTERVAL_SEC:
logging.info(
'chainables: iterate progress: %d/%d/%d/%d'
' (pending/running/remaining/total).',
len(pending_tasks),
len(running_tasks),
len(tasks),
total_tasks,
)
ticker = time.time()
time.sleep(sleep_interval)
Loading

0 comments on commit 5a07a0e

Please sign in to comment.