From dfa3c0756dbdff1f0d07eebbb9ff64f63be2a184 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 24 Jan 2019 14:31:30 -1000 Subject: [PATCH] [ML-6097] Update Horovod Runner to include return value and multi-gpu info. (#183) * [ML-6012] add return value in oss, update document and add test (#138) In this PR, we: change the horovodrunner.run api to return values in horovod/runner_base.py update the documentation to reflect that change add a test to runner_base_test.py to test the api change * Update HorovodRunner docs to include multi-gpu support. * Revert "Update HorovodRunner docs to include multi-gpu support." This reverts commit 0172bbeb5147c6d895322d5dfb22e2f54cd12e90. * update np doc (#136) --- python/sparkdl/horovod/runner_base.py | 8 ++++++-- python/tests/horovod/runner_base_test.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/sparkdl/horovod/runner_base.py b/python/sparkdl/horovod/runner_base.py index d18e8a5b..18e0532e 100644 --- a/python/sparkdl/horovod/runner_base.py +++ b/python/sparkdl/horovod/runner_base.py @@ -43,6 +43,8 @@ def __init__(self, np): :param np: number of parallel processes to use for the Horovod job. This argument only takes effect on Databricks Runtime 5.0 ML and above. It is ignored in the open-source version. + On Databricks, each process will take an available task slot, + which maps to a GPU on a GPU cluster or a CPU core on a CPU cluster. Accepted values are: - If -1, this will spawn a subprocess on the driver node to run the Horovod job locally. @@ -79,7 +81,9 @@ def run(self, main, **kwargs): Avoid referencing large objects in the function, which might result large pickled data, making the job slow to start. :param kwargs: keyword arguments passed to the main function at invocation time. - :return: None + :return: return value of the main function. + With `np>=0`, this returns the value from the rank 0 process. Note that the returned + value should be serializable using cloudpickle. """ logger = logging.getLogger("HorovodRunner") logger.warning( @@ -87,4 +91,4 @@ def run(self, main, **kwargs): "It only does basic checks and invokes the main function, " "which is for local development only. " "Please use Databricks Runtime ML 5.0+ to distribute the job.") - main(**kwargs) + return main(**kwargs) diff --git a/python/tests/horovod/runner_base_test.py b/python/tests/horovod/runner_base_test.py index 82a7fb01..7e42ad06 100644 --- a/python/tests/horovod/runner_base_test.py +++ b/python/tests/horovod/runner_base_test.py @@ -51,3 +51,10 @@ def append(value): hr.run(append, value=1) self.assertEquals(data[0], 1) + + def test_return_value(self): + """Test that the return value is returned to the user.""" + hr = HorovodRunner(np=-1) + return_value = hr.run(lambda: 42) + self.assertEquals(return_value, 42) +