forked from oshadura/dask-custom-docker
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DistributedEnvironmentPlugin.py
117 lines (100 loc) · 3.96 KB
/
DistributedEnvironmentPlugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import zipfile
import logging
from distributed.diagnostics.plugin import NannyPlugin
from dask.utils import tmpfile
logger = logging.getLogger(__name__)
class DistributedEnvironmentPlugin(NannyPlugin):
"""A NannyPlugin to upload a local pip installable package to workers.
Parameters
----------
path: str
A path to the pip installable directory to upload
Examples
--------
>>> from distributed.diagnostics.plugin import DistributedEnvironmentPlugin
>>> client.register_worker_plugin(DistributedEnvironmentPlugin("/path/to/directory"), nanny=True) # doctest: +SKIP
"""
def __init__(
self,
path,
pip_options=None,
restart=False,
update_path=False,
skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"),
skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",),
extra_inputs=[],
):
"""
Initialize the plugin by reading in the data from the given file.
"""
path = os.path.expanduser(path)
self.package = os.path.split(path)[-1]
self.restart = restart
self.update_path = update_path
if pip_options is None:
pip_options = []
self.pip_options = pip_options
self.extra_inputs = [os.path.split(x)[-1] for x in extra_inputs]
self.name = "upload-directory-" + self.package
with tmpfile(extension="zip") as fn:
with zipfile.ZipFile(fn, "w", zipfile.ZIP_DEFLATED) as z:
for root, dirs, files in os.walk(path):
for file in files:
filename = os.path.join(root, file)
if any(predicate(filename) for predicate in skip):
continue
dirs = filename.split(os.sep)
if any(word in dirs for word in skip_words):
continue
archive_name = os.path.relpath(
os.path.join(root, file), os.path.join(path, "..")
)
z.write(filename, archive_name)
for fpath in extra_inputs:
fname = os.path.split(fpath)[-1]
z.write(fpath,fname)
with open(fn, "rb") as f:
self.data = f.read()
def setup(self, nanny):
import uuid
import subprocess
import socket
from dask.distributed.lock import Lock
# Copy the package to the worker machine
logger.info("Entering plugin setup")
fn = os.path.join(nanny.local_directory, f"tmp-{str(uuid.uuid4())}.zip")
with open(fn, "wb") as f:
f.write(self.data)
with zipfile.ZipFile(fn) as z:
z.extractall(path=nanny.local_directory)
if self.update_path:
path = os.path.join(nanny.local_directory, self.package)
if path not in sys.path:
sys.path.insert(0, path)
# Now try to pip install the package
package_path = os.path.join(nanny.local_directory,self.package)
logger.info("Installing the package: %s",self.package);
proc = subprocess.Popen(
[sys.executable, "-m", "pip", "install"]
+ self.pip_options
+ [package_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = proc.communicate()
returncode = proc.wait()
if returncode:
logger.error("Pip install failed with '%s'",stderr.decode().strip())
return
for fname in self.extra_inputs:
if nanny.worker_dir is None:
raise RuntimeError
src = os.path.join(nanny.local_directory,fname)
dst = os.path.join(nanny.worker_dir,fname)
os.rename(src,dst)
# Cleanup the zip file
logger.info("Cleaning up temporary directory: %s",fn)
os.remove(fn)
return