Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements to config.json readability #44

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions eole/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections import OrderedDict
from pydantic import Field
from eole.config.config import Config
from eole.utils.logging import logger
Expand All @@ -9,6 +10,47 @@
os.environ["EOLE_MODEL_DIR"] = os.getcwd()


def calculate_depth(value, current_depth=0):
if isinstance(value, (dict, Config)):
if isinstance(value, Config):
value = value.__dict__
return (
max(calculate_depth(v, current_depth + 1) for v in value.values())
if value
else current_depth
)
return current_depth


def reorder_fields(fields):
"""
Put non nested fields before nested ones in config json dump,
for better readability.
"""
ordered_fields = OrderedDict()
non_nested_fields = []
nested_fields = []

for field, value in fields.items():
if isinstance(value, dict):
nested_fields.append((field, value))
else:
non_nested_fields.append((field, value))

# Add non-nested fields first
for field, value in non_nested_fields:
ordered_fields[field] = value

# Calculate depths and sort nested fields
nested_fields.sort(key=lambda x: calculate_depth(x[1]))

# Add nested fields
for field, value in nested_fields:
ordered_fields[field] = reorder_fields(value)

return ordered_fields


def recursive_model_fields_set(model):
fields = {}
if isinstance(model, dict):
Expand All @@ -31,10 +73,12 @@ def recursive_model_fields_set(model):
else:
field_value = getattr(model, field, None)
if isinstance(field_value, Config) or isinstance(field_value, dict):
fields[field] = recursive_model_fields_set(field_value)
_fields = recursive_model_fields_set(field_value)
if _fields != {}:
fields[field] = _fields
else:
fields[field] = field_value
return fields
return reorder_fields(fields)


def recursive_update_dict(_dict, new_dict, defaults):
Expand Down
Loading