Skip to content

Commit

Permalink
Adds command line help. Prettier print. Nested configs
Browse files Browse the repository at this point in the history
  • Loading branch information
OlofHarrysson committed Mar 13, 2020
1 parent 90768be commit be92ee2
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 167 deletions.
161 changes: 73 additions & 88 deletions anyfig/anyfig_setup.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,59 @@
from io import StringIO
import fire
import inspect
import argparse
import sys
from . import figutils


def setup_config(default_config=None):
config_str = parse_args(default_config.__name__)
config = choose_config(config_str)
figutils.set_global(config)
return config
def setup_config(default_config):
assert default_config is not None
err_msg = "Expected default config to be a class definition but most likely got an object. E.g. expected ConfigClass but got ConfigClass() with parentheses."
assert default_config.__class__ == type.__class__, err_msg
err_msg = f"Expected default config to be an anyfig config class, was {default_config}"
assert figutils.is_config_class(default_config), err_msg

# Parse args and create config
args = parse_args()
config_str = default_config.__name__
if 'config' in args:
config_str = args.pop('config')
config = create_config(config_str)

# Print config help
if 'help' in args:
config_classes = list(figutils.get_registered_config_classes())
print(f"Available config classes {config_classes}",
"\nSet config with --config=OtherConfigClass\n")

print("Current config")
print(config)
sys.exit(0)

# Overwrite parameters via optional input flags
config = overwrite(config, args)

# Freezes config
config.frozen(freeze=True)

# Registers config with anyfig
figutils.register_globally(config)
return config

def parse_args(default_config):
p = argparse.ArgumentParser()

p.add_argument('--config',
type=str,
default=default_config,
help='What config class to choose')
def parse_args():
''' Parses input arguments '''
class NullIO(StringIO):
def write(self, txt):
pass

args, _ = p.parse_known_args()
return args.config
sys.stdout = NullIO()
args = fire.Fire(lambda **kwargs: kwargs)
sys.stdout = sys.__stdout__
return args


def choose_config(config_str):
# Create config object
err_msg = (
"Specify which config to use by either starting your python script with "
"the input argument --config=YourConfigClass or set "
"'default_config=YourConfigClass' in anyfigs 'setup_config' method")
if config_str is None:
raise RuntimeError(err_msg)
def create_config(config_str):
''' Instantiates a config class object '''

err_msg = ("There aren't any registered config classes. Decorate a class "
"with '@anyfig.config_class' and make sure that the class is "
Expand All @@ -51,30 +72,17 @@ def choose_config(config_str):
f"{list(registered_configs)}")
raise KeyError(err_msg) from e

# Overwrite parameters via optional input flags
config_obj = overwrite(config_obj)

# Freezes config
config_obj.frozen(freeze=True)
return config_obj


def overwrite(main_config_obj):
def overwrite(main_config_obj, args):
''' Overwrites parameters with input flags. Function is needed for the
convenience of specifying parameters via a combination of the config classes
and input flags. '''
class NullIO(StringIO):
def write(self, txt):
pass

sys.stdout = NullIO()
extra_arguments = fire.Fire(lambda **kwargs: kwargs)
sys.stdout = sys.__stdout__

for argument_key, val in extra_arguments.items():
if argument_key == 'config': # Argparse deals with this one
continue
config_classes = figutils.get_registered_config_classes()

for argument_key, val in args.items():
# Seperate nested keys into outer and inner
outer_keys = argument_key.split('.')
inner_key = outer_keys.pop(-1)
Expand All @@ -95,71 +103,48 @@ def write(self, txt):
"to overwrite attributes that exist in the active config class")
raise NotImplementedError(err_msg)

# Create object with new value
# Class definition
value_class = type(getattr(config_obj, inner_key))
try:
value_obj = value_class(val)
except Exception:
err_msg = f"Input argument '{argument_key}' with value {val} can't create an object of the expected type {value_class}"
raise RuntimeError(err_msg)

# Create new anyfig class object
if figutils.is_config_class(value_class):
value_class = config_classes[val]
value_obj = value_class()

# Create new object of previous value type with new value
else:
try:
value_obj = value_class(val)
except Exception:
err_msg = f"Input argument '{argument_key}' with value {val} can't create an object of the expected type {value_class}"
raise RuntimeError(err_msg)

# Overwrite old value
setattr(config_obj, inner_key, value_obj)

return main_config_obj


def config_class(func):
class_name = func.__name__
def config_class(class_def):
class_name = class_def.__name__

# Makes sure that nothing fishy is going on...
err_msg = (f"Can't decorate '{class_name}' of type {type(func)}. "
err_msg = (f"Can't decorate '{class_name}' of type {type(class_def)}. "
"Can only be used for classes")
assert inspect.isclass(func), err_msg
assert inspect.isclass(class_def), err_msg

# Config class functions
members = inspect.getmembers(func, inspect.isfunction)
members = {name: function for name, function in members}
functions = inspect.getmembers(class_def, inspect.isfunction)
functions = {name: function for name, function in functions}

# Transfers functions from MasterConfig to config class
for name, member in inspect.getmembers(figutils.MasterConfig,
inspect.isfunction):
if name not in members: # Only transfer not implemented functions
setattr(func, name, member)
for name, function in inspect.getmembers(figutils.MasterConfig,
inspect.isfunction):
if name not in functions: # Only transfer not implemented functions
setattr(class_def, name, function)

# Manually add attributes to config class
setattr(func, '_frozen', False)

# TODO: Is this needed? Maybe if we want to print the name of the print(config)
# setattr(func, '_config_class', class_name)

figutils.register_config_class(class_name, func)
return func


def print_source(func):
class_name = func.__name__
err_msg = (f"Can't decorate '{class_name}' of type {type(func)}. "
"Can only be used for classes")

assert inspect.isclass(func), err_msg

def __print_source__(self):
''' Get my source. Get my childrens sources '''

# Newline makes indention better
src = '\n' + inspect.getsource(self.__class__)

unique_classes = {v.__class__: v for k, v in vars(self).items()}
for key, val in unique_classes.items():
if hasattr(val, '__anyfig_print_source__'):
src += __print_source__(val)

# TODO: Source code can have different indention than \t
# Make it a config to anyfig? Check the first indention and add that
# Adds indention
src = src.replace('\n', '\n ')
return src
setattr(class_def, '_frozen', False)

setattr(func, '__anyfig_print_source__', __print_source__)
return func
figutils.register_config_class(class_name, class_def)
return class_def
82 changes: 36 additions & 46 deletions anyfig/figutils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC
import pprint
import inspect
from dataclasses import FrozenInstanceError
import pickle
Expand All @@ -10,7 +9,24 @@
global_configs = {}


def set_global(config):
def register_config_class(class_name, class_def):
''' Saves the config class name and definition '''
err_msg = (
f"The config class '{class_name}' has already been registered. "
"Duplicated names aren't allowed. Either change the name or avoid "
"importing the duplicated classes at the same time. "
f"The registered classes are '{registered_config_classes}'")
assert class_name not in registered_config_classes, err_msg

registered_config_classes[class_name] = class_def


def get_registered_config_classes():
return registered_config_classes


def register_globally(config):
''' Registers the config with anyfig to be accessible anywhere '''
assert is_config_class(config), "Can only register anyfig config object"
global_configs[type(config).__name__] = config

Expand All @@ -33,9 +49,10 @@ def cfg():


def is_config_class(obj):
''' Returns True if object is a config class that is registered with anyfig '''
return inspect.isclass(
type(obj)) and type(obj).__name__ in registered_config_classes
''' Returns True if config class definition registered with anyfig '''
if obj.__class__ != type.__class__:
obj = type(obj)
return inspect.isclass(obj) and obj.__name__ in registered_config_classes


def load_config(path):
Expand All @@ -58,54 +75,27 @@ def save_config(obj, path):
f.write(str(obj))


def register_config_class(class_name, class_def):
err_msg = (
f"The config class '{class_name}' has already been registered. "
"Duplicated names aren't allowed. Either change the name or avoid "
"importing the duplicated classes at the same time. "
f"The registered classes are '{registered_config_classes}'")
assert class_name not in registered_config_classes, err_msg

registered_config_classes[class_name] = class_def


def get_registered_config_classes():
return registered_config_classes


# Class which is used to define functions that goes into every config class
class MasterConfig(ABC):
def frozen(self, freeze=True):
self._frozen = freeze
''' Freeze/unfreeze config '''
self.__class__._frozen = freeze
return self

def get_parameters(self):
params = copy.deepcopy(self.__dict__)
params.pop('_frozen')
return params
return copy.deepcopy(self.__dict__)

def __str__(self):
str_ = ""
params = vars(self)
for key, val in params.items():
if key == '_frozen': # Dont print frozen
continue
if hasattr(val, '__anyfig_print_source__'):
cls_str = val.__anyfig_print_source__()
s = f"'{key}':\n{cls_str}"

# Recursively print anyfig classes
elif is_config_class(val):
inner_config_string = '\n' + str(val)
inner_config_string = inner_config_string.replace('\n', '\n\t')
s = f"'{key}':{inner_config_string}"

else:
s = pprint.pformat({key: str(val)})
# Prettyprint adds some extra wings that I dont like
s = s.lstrip('{').rstrip('}')
str_ += s + '\n'

return str_
''' Prettyprints the config '''
lines = [self.__class__.__name__ + ':']

for key, val in self.__dict__.items():
val_str = f'{val}'
if is_config_class(val): # Remove class name info
val_str = f'{val}'.replace(f'{val.__class__.__name__}:', '')
lines += f'{key} ({val.__class__.__name__}): {val_str}'.split('\n')

return '\n '.join(lines) + '\n'

def __setattr__(self, name, value):
# Raise error if frozen unless we're trying to unfreeze the config
Expand Down
25 changes: 0 additions & 25 deletions example/3_eval_input.py

This file was deleted.

Loading

0 comments on commit be92ee2

Please sign in to comment.