Skip to content

Commit

Permalink
Add Parameter.cli_default() to supply default values for CLI only, th…
Browse files Browse the repository at this point in the history
…rough converters (#93)
  • Loading branch information
epsy authored Oct 8, 2022
1 parent ea00efa commit 4bad129
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 8 deletions.
38 changes: 33 additions & 5 deletions clize/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ class Parameter(object):
Mostly only useful for ``*args`` parameters. In other cases, simply don't
provide a default value."""

@attrs.define
class cli_default:
value: typing.Any
convert: bool = attrs.field(default=True, kw_only=True)

def value_after_conversion(self, converter):
if self.convert:
return converter(self.value)
else:
return self.value

required = False
"""Is this parameter required?"""
Expand Down Expand Up @@ -315,6 +325,7 @@ class ParameterWithValue(Parameter):
"""

def __init__(self, conv=identity, default=util.UNSET,
cli_default=util.UNSET,
**kwargs):
super(ParameterWithValue, self).__init__(**kwargs)
self.conv = conv
Expand All @@ -323,11 +334,15 @@ def __init__(self, conv=identity, default=util.UNSET,
self.default = default
"""The default value used for the parameter, or `.util.UNSET` if there
is no default value. Usually only used for displaying the help."""
self.cli_default = cli_default
"""The default value used for the parameter in the CLI,
or `.util.UNSET` if there is no default value.
Converted by ``self.conv`` before insertion."""

@property
def required(self):
"""Tells if the parameter has no default value."""
return self.default is util.UNSET
return self.default is util.UNSET and self.cli_default is util.UNSET

def read_argument(self, ba, i):
"""Uses `.get_value`, `.coerce_value` and `.set_value` to process
Expand Down Expand Up @@ -367,7 +382,10 @@ def get_value(self, ba, i):

def help_parens(self):
"""Shows the default value in the parameter description."""
if self.default is not util.UNSET and self.default is not None:
if self.cli_default is not util.UNSET:
if self.cli_default.value is not None:
yield 'default: ' + str(self.cli_default.value)
elif self.default is not util.UNSET and self.default is not None:
yield 'default: ' + str(self.default)

def post_parse(self, ba):
Expand All @@ -377,8 +395,13 @@ def post_parse(self, ba):
except AttributeError:
pass
else:
if self.default != util.UNSET and info['convert_default']:
if self in ba.not_provided:
if self in ba.not_provided:
if self.cli_default is not util.UNSET:
self.set_value(
ba,
self.cli_default.value_after_conversion(partial(self.coerce_value, ba=ba))
)
elif self.default is not util.UNSET and info['convert_default']:
self.set_value(ba, self.coerce_value(self.default, ba))


Expand Down Expand Up @@ -594,7 +617,9 @@ def set_value(self, ba, val):
return
else:
if arg is util.UNSET:
if param.default != util.UNSET:
if param.cli_default != util.UNSET:
ba.args.append(param.cli_default.value_after_conversion(partial(self.coerce_value, ba=ba)))
elif param.default != util.UNSET:
ba.args.append(param.default)
else:
raise ValueError(
Expand Down Expand Up @@ -910,6 +935,9 @@ def _use_class(pos_cls, varargs_cls, named_cls, varkwargs_cls, kwargs,
continue
if isinstance(thing, ParameterFlag):
continue
if isinstance(thing, Parameter.cli_default):
kwargs['cli_default'] = thing
continue
raise ValueError(
"Unknown annotation {!r}\n"
"If you intended for it to be a value or parameter converter, "
Expand Down
58 changes: 55 additions & 3 deletions clize/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from repeated_test import evaluated
from sigtools import support, modifiers, specifiers

from clize import parser, errors, util, Clize
from clize import parser, errors, util, Clize, Parameter
from clize.tests.util import Fixtures, SignatureFixtures


Expand Down Expand Up @@ -47,7 +47,7 @@ def s(sig_str, **inject):
def evaluate_sig(self, *, make_signature, **_):
pre_code = (
"import pathlib;"
"from clize import Clize, Parameter;"
"from clize import *;"
"import typing;"
"P = Parameter;"
)
Expand Down Expand Up @@ -338,6 +338,27 @@ def __call__(self, param, annotations):
parser.CliSignature.convert_parameter(param)


class ParamHelpTests(SignatureFixtures):
def _test(self, sig, expected, *, make_signature, desc="ds"):
param = list(sig.parameters.values())[0]
cparam = parser.CliSignature.convert_parameter(param)
f = util.Formatter()
with f.columns(indent=2) as cols:
actual = cparam.show_help(desc, (), util.Formatter(), cols)
self.assertEqual(actual, expected)

plain_param = s("param"), ("param", "ds")
param_default = s("param = 4"), ("param", "ds (type: INT, default: 4)")
param_cli_default = s("param: ann", ann=Parameter.cli_default('5')), ("param", "ds (default: 5)")
param_cli_default_overrides_default = s("param: ann = '2'", ann=Parameter.cli_default('5')), ("param", "ds (default: 5)")
param_cli_default_none = s("param: ann = '2'", ann=Parameter.cli_default(None, convert=False)), ("param", "ds")


@parser.value_converter()
def conv_not_default(arg):
return f'converted:{arg}'


class SigTests(SignatureFixtures):
def _test(self, sig, str_rep, args, posargs, kwargs, *, make_signature):
csig = parser.CliSignature.from_signature(sig)
Expand Down Expand Up @@ -513,6 +534,14 @@ def test_posparam_set_value_after_default(self):
param.set_value(ba, 'inserted')
self.assertEqual(ba.args, ['one', 'inserted'])

def test_posparam_set_value_after_cli_default(self):
param = parser.PositionalParameter(argument_name='two', display_name='two', default="two")
sig = support.s('one: ann1 = "src_default", two:ann2="two"', globals={'ann1': Parameter.cli_default("one"), 'ann2': param})
csig = parser.CliSignature.from_signature(sig)
ba = parser.CliBoundArguments(csig, [], 'func', args=[])
param.set_value(ba, 'inserted')
self.assertEqual(ba.args, ['one', 'inserted'])

def test_posparam_set_value_after_missing(self):
param = parser.PositionalParameter(argument_name='two', display_name='two')
sig = support.s('one, two:par', globals={'par': param})
Expand Down Expand Up @@ -561,6 +590,30 @@ def conv(arg):
sig = make_signature('first="otherdefault", par:conv="default"', globals={'conv': conv})
return (sig, '[first] [par]', (), ['otherdefault', 'converted'], {})

cli_default_no_src_default = (s(
'par: ann',
ann=(conv_not_default, Parameter.cli_default("cli_default"))
), "[par]", (), ["converted:cli_default"], {})

cli_default_src_default = (s(
'par: ann = "default"',
ann=(conv_not_default, Parameter.cli_default("cli_default"))
), "[par]", (), ["converted:cli_default"], {})

cli_default_dont_convert = (s(
'par: ann = "default"',
ann=(conv_not_default, Parameter.cli_default("cli_default", convert=False))
), "[par]", (), ["cli_default"], {})

cli_default_none = (s(
'par: ann = "default"',
ann=(conv_not_default, Parameter.cli_default(None, convert=False))
), "[par]", (), [None], {})

cli_default_after_pos = (s(
'first="otherdefault", par:conv="src_default"', conv=Parameter.cli_default("default")
), "[first] [par]", (), ["otherdefault", "default"], {})


class ExtraParamsTests(Fixtures):
def _test(self, sig_str, extra, args, posargs, kwargs, func):
Expand Down Expand Up @@ -641,7 +694,6 @@ def test_param_extras(self):
self.assertEqual('[-a] [-b] [-c] one', str(csig))



class SigErrorTests(Fixtures):
def _test(self, sig_str, args, exc_typ, message):
sig = support.s(sig_str)
Expand Down

0 comments on commit 4bad129

Please sign in to comment.