This repository has been archived by the owner on Mar 11, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 561
/
mask_flags.py
138 lines (111 loc) · 4.68 KB
/
mask_flags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Filters flagfile to only pass in flags that are defined.
Having one big flagfile is great for seeing all the configuration at a glance.
However, absl.flags will throw an error if you pass an undefined flag.
To solve this problem, we filter the global flagfile by running
python3 some_module.py --helpfull
to generate a list of all flags that some_module.py accepts. Then, we pass in
only those flags that are accepted by some_module.py and run as a subprocess
Usage example:
import mask_flags
mask_flags.run(['python3', 'train.py', '--custom_flag', '--flagfile=flags'])
# will be transformed into
subprocess.run(['python3', 'train.py', '--custom_flag',
'--train_only_flag=...', '--more_train_only=...''])
Command line usage example:
python3 -m mask_flags train.py --custom_flag --flagfile=flags
"""
import re
import subprocess
import sys
import time
from absl import flags
# Matches both
# --some_flag: Flag description
# --[no]bool_flag: Flag description
FLAG_HELP_RE_PY = re.compile(r'--((\[no\])?)([\w_-]+):')
FLAG_HELP_RE_CC = re.compile(r'-((\[no\])?)([\w_-]+) \(')
FLAG_RE = re.compile(r'--[\w_-]+')
def extract_valid_flags(subprocess_cmd):
"""Extracts the valid flags from a command by running it with --helpfull.
Args:
subprocess_cmd: List[str], what would be passed into subprocess.call()
i.e. ['python', 'train.py', '--flagfile=flags']
Returns:
['--foo=blah', '--more_flags']
"""
help_cmd = subprocess_cmd + ['--helpfull']
help_output = subprocess.run(help_cmd, stdout=subprocess.PIPE).stdout
help_output = help_output.decode('ascii')
if 'python' in subprocess_cmd[0]:
valid_flags = parse_helpfull_output(help_output)
else:
valid_flags = parse_helpfull_output(help_output, regex=FLAG_HELP_RE_CC)
return valid_flags
def parse_helpfull_output(help_output, regex=FLAG_HELP_RE_PY):
"""Parses the output of --helpfull.
Args:
help_output: str, the full output of --helpfull.
Returns:
A set of flags that are valid flags.
"""
valid_flags = set()
for _, no_prefix, flag_name in regex.findall(help_output):
valid_flags.add('--' + flag_name)
if no_prefix:
valid_flags.add('--no' + flag_name)
return valid_flags
def filter_flags(parsed_flags, valid_flags):
"""Return the subset of `parsed_flags` that are found in the list `valid_flags`"""
def valid_argv(argv):
"""Figures out if a flag parsed from the flagfile matches a flag in
the command about to be run."""
flagname_match = FLAG_RE.match(argv)
if not flagname_match:
return True
flagname = flagname_match.group()
return flagname in valid_flags
return list(filter(valid_argv, parsed_flags))
def prepare_subprocess_cmd(subprocess_cmd):
"""Prepares a subprocess command by running --helpfull and masking flags.
Args:
subprocess_cmd: List[str], what would be passed into subprocess.call()
i.e. ['python', 'train.py', '--flagfile=flags']
Returns:
['python', 'train.py', '--train_flag=blah', '--more_flags']
"""
valid_flags = extract_valid_flags(subprocess_cmd)
parsed_flags = flags.FlagValues().read_flags_from_files(subprocess_cmd[1:])
filtered_flags = filter_flags(parsed_flags, valid_flags)
return [subprocess_cmd[0]] + filtered_flags
def run(cmd):
"""Prepare and run a subprocess cmd, returning a CompletedProcess."""
print("Preparing the following cmd:")
cmd = prepare_subprocess_cmd(cmd)
print("Running the following cmd:")
print('\n'.join(cmd))
return subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stderr)
def checked_run(cmd):
"""Prepare and run a subprocess cmd, checking for successful completion."""
completed_process = run(cmd)
if completed_process.returncode > 0:
print("Command failed! Hanging around in case someone needs a "
"docker connection. (Ctrl-C to quit now)")
time.sleep(300)
raise RuntimeError
return completed_process
if __name__ == '__main__':
sys.argv.pop(0)
checked_run(sys.argv)