From aa6d387687fc737f74a9e500a53ff3cfc7297cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Fri, 28 Jun 2024 15:59:35 +0200 Subject: [PATCH] minimal order in dumped config, for better readability (#44) --- eole/config/__init__.py | 48 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/eole/config/__init__.py b/eole/config/__init__.py index b481a876..920110eb 100644 --- a/eole/config/__init__.py +++ b/eole/config/__init__.py @@ -1,4 +1,5 @@ import os +from collections import OrderedDict from pydantic import Field from eole.config.config import Config from eole.utils.logging import logger @@ -9,6 +10,47 @@ os.environ["EOLE_MODEL_DIR"] = os.getcwd() +def calculate_depth(value, current_depth=0): + if isinstance(value, (dict, Config)): + if isinstance(value, Config): + value = value.__dict__ + return ( + max(calculate_depth(v, current_depth + 1) for v in value.values()) + if value + else current_depth + ) + return current_depth + + +def reorder_fields(fields): + """ + Put non nested fields before nested ones in config json dump, + for better readability. + """ + ordered_fields = OrderedDict() + non_nested_fields = [] + nested_fields = [] + + for field, value in fields.items(): + if isinstance(value, dict): + nested_fields.append((field, value)) + else: + non_nested_fields.append((field, value)) + + # Add non-nested fields first + for field, value in non_nested_fields: + ordered_fields[field] = value + + # Calculate depths and sort nested fields + nested_fields.sort(key=lambda x: calculate_depth(x[1])) + + # Add nested fields + for field, value in nested_fields: + ordered_fields[field] = reorder_fields(value) + + return ordered_fields + + def recursive_model_fields_set(model): fields = {} if isinstance(model, dict): @@ -31,10 +73,12 @@ def recursive_model_fields_set(model): else: field_value = getattr(model, field, None) if isinstance(field_value, Config) or isinstance(field_value, dict): - fields[field] = recursive_model_fields_set(field_value) + _fields = recursive_model_fields_set(field_value) + if _fields != {}: + fields[field] = _fields else: fields[field] = field_value - return fields + return reorder_fields(fields) def recursive_update_dict(_dict, new_dict, defaults):