Skip to content

Commit

Permalink
feat(external_data_files): load data from other YAML files (a.k.a. in…
Browse files Browse the repository at this point in the history
…herit answers from other templates)

When composing templates, it's often needed to be able to load answers from other templates that you know are usually combined with yours. Or any other kind of external data.

@moduon MT-8282
  • Loading branch information
yajo committed Dec 10, 2024
1 parent 507cab3 commit 0651768
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 43 deletions.
4 changes: 4 additions & 0 deletions copier/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,7 @@ class DirtyLocalWarning(UserWarning, CopierWarning):

class ShallowCloneWarning(UserWarning, CopierWarning):
"""The template repository is a shallow clone."""


class MissingFileWarning(UserWarning, CopierWarning):
"""I still couldn't find what I'm looking for."""
78 changes: 54 additions & 24 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from jinja2.loaders import FileSystemLoader
from jinja2.sandbox import SandboxedEnvironment
from lazystuff.lazydict import lazydict
from pathspec import PathSpec
from plumbum import ProcessExecutionError, colors
from plumbum.cli.terminal import ask
Expand Down Expand Up @@ -60,11 +61,12 @@
from .types import (
MISSING,
AnyByStrDict,
AnyByStrMutableMapping,
JSONSerializable,
RelativePath,
StrOrPath,
)
from .user_data import DEFAULT_DATA, AnswersMap, Question
from .user_data import AnswersMap, Question, load_answersfile_data
from .vcs import get_git

_T = TypeVar("_T")
Expand Down Expand Up @@ -263,7 +265,27 @@ def _check_unsafe(self, mode: Literal["copy", "update"]) -> None:
if features:
raise UnsafeTemplateError(sorted(features))

def _external_data(self) -> Mapping[str, Any]:
"""Load external data lazily.
Result keys are used for rendering, and values are the parsed contents
of the YAML files specified in [external_data_files][].
Files will only be parsed lazily on 1st access. This helps avoiding
circular dependencies when the file name also comes from a variable.
"""
return lazydict(
{
name: lambda path=path: load_answersfile_data(
self.dst_path, self._render_string(path)
)
for name, path in self.template.external_data_files.items()
}
)

def _print_message(self, message: str) -> None:
# On first use, at least we need the system render context
self.answers.system = self._system_render_context()
if message and not self.quiet:
print(self._render_string(message), file=sys.stderr)

Expand Down Expand Up @@ -330,12 +352,18 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
with local.cwd(working_directory), local.env(**extra_env):
subprocess.run(task_cmd, shell=use_shell, check=True, env=local.env)

def _render_context(self) -> Mapping[str, Any]:
"""Produce render context for Jinja."""
def _system_render_context(self) -> AnyByStrMutableMapping:
"""System reserved render context.
Most keys start with `_` because they're reserved.
Resolution of computed values is deferred until used for the 1st time.
"""
# Backwards compatibility
# FIXME Remove it?
conf = asdict(self)
conf.pop("_cleanup_hooks")
conf.pop("answers")
conf.update(
{
"answers_file": self.answers_relpath,
Expand All @@ -345,12 +373,10 @@ def _render_context(self) -> Mapping[str, Any]:
"os": OS,
}
)

return dict(
DEFAULT_DATA,
**self.answers.combined,
_copier_answers=self._answers_to_remember(),
_copier_conf=conf,
_ext=self._external_data(),
_folder_name=self.subproject.local_abspath.name,
_copier_python=sys.executable,
)
Expand Down Expand Up @@ -455,41 +481,42 @@ def _render_allowed(

def _ask(self) -> None: # noqa: C901
"""Ask the questions of the questionnaire and record their answers."""
result = AnswersMap(
self.answers = AnswersMap(
user_defaults=self.user_defaults,
init=self.data,
last=self.subproject.last_answers,
metadata=self.template.metadata,
system=self._system_render_context(),
)

for var_name, details in self.template.questions_data.items():
question = Question(
answers=result,
answers=self.answers,
jinja_env=self.jinja_env,
var_name=var_name,
**details,
)
# Delete last answer if it cannot be parsed or validated, so a new
# valid answer can be provided.
if var_name in result.last:
if var_name in self.answers.last:
try:
answer = question.parse_answer(result.last[var_name])
answer = question.parse_answer(self.answers.last[var_name])
except Exception:
del result.last[var_name]
del self.answers.last[var_name]
else:
if question.validate_answer(answer):
del result.last[var_name]
del self.answers.last[var_name]

Check warning on line 508 in copier/main.py

View check run for this annotation

Codecov / codecov/patch

copier/main.py#L508

Added line #L508 was not covered by tests
# Skip a question when the skip condition is met.
if not question.get_when():
# Omit its answer from the answers file.
result.hide(var_name)
self.answers.hide(var_name)
# Skip immediately to the next question when it has no default
# value.
if question.default is MISSING:
continue
if var_name in result.init:
if var_name in self.answers.init:
# Try to parse the answer value.
answer = question.parse_answer(result.init[var_name])
answer = question.parse_answer(self.answers.init[var_name])
# Try to validate the answer value if the question has a
# validator.
if err_msg := question.validate_answer(answer):
Expand All @@ -498,10 +525,10 @@ def _ask(self) -> None: # noqa: C901
)
# At this point, the answer value is valid. Do not ask the
# question again, but set answer as the user's answer instead.
result.user[var_name] = answer
self.answers.user[var_name] = answer
continue
# Skip a question when the user already answered it.
if self.skip_answered and var_name in result.last:
if self.skip_answered and var_name in self.answers.last:
continue

# Display TUI and ask user interactively only without --defaults
Expand All @@ -516,10 +543,12 @@ def _ask(self) -> None: # noqa: C901
answers={question.var_name: question.get_default()},
)[question.var_name]
except KeyboardInterrupt as err:
raise CopierAnswersInterrupt(result, question, self.template) from err
result.user[var_name] = new_answer

self.answers = result
raise CopierAnswersInterrupt(
self.answers, question, self.template
) from err
self.answers.user[var_name] = new_answer
# Update system render context, which may depend on answers
self.answers.system = self._system_render_context()

@property
def answers_relpath(self) -> Path:
Expand Down Expand Up @@ -644,7 +673,7 @@ def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None:
# suffix is empty, fallback to copy
new_content = src_abspath.read_bytes()
else:
new_content = tpl.render(**self._render_context()).encode()
new_content = tpl.render(**self.answers.combined).encode()
else:
new_content = src_abspath.read_bytes()
dst_abspath = self.subproject.local_abspath / dst_relpath
Expand Down Expand Up @@ -766,7 +795,7 @@ def _render_string(
Additional variables to use for rendering the template.
"""
tpl = self.jinja_env.from_string(string)
return tpl.render(**self._render_context(), **(extra_context or {}))
return tpl.render(**self.answers.combined, **(extra_context or {}))

def _render_value(
self, value: _T, extra_context: AnyByStrDict | None = None
Expand Down Expand Up @@ -984,7 +1013,7 @@ def _apply_update(self) -> None: # noqa: C901
)
# Clear last answers cache to load possible answers migration, if skip_answered flag is not set
if self.skip_answered is False:
self.answers = AnswersMap()
self.answers = AnswersMap(system=self._system_render_context())
with suppress(AttributeError):
del self.subproject.last_answers
# Do a normal update in final destination
Expand All @@ -1000,6 +1029,7 @@ def _apply_update(self) -> None: # noqa: C901
) as current_worker:
current_worker.run_copy()
self.answers = current_worker.answers
self.answers.system = self._system_render_context()
# Render with the same answers in an empty dir to avoid pollution
with replace(
self,
Expand Down
10 changes: 9 additions & 1 deletion copier/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import cached_property
from pathlib import Path, PurePosixPath
from shutil import rmtree
from typing import Any, Literal, Mapping, Sequence
from typing import Any, Dict, Literal, Mapping, Sequence
from warnings import warn

import dunamai
Expand Down Expand Up @@ -329,6 +329,14 @@ def exclude(self) -> tuple[str, ...]:
)
)

@cached_property
def external_data_files(self) -> Dict[str, str]:
"""Get external data files specified in the template.
See [external_data_files][].
"""
return self.config_data.get("external_data_files", {})

@cached_property
def jinja_extensions(self) -> tuple[str, ...]:
"""Get Jinja2 extensions specified in the template, or `()`.
Expand Down
2 changes: 2 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
Literal,
Mapping,
MutableMapping,
NewType,
Optional,
Sequence,
Expand All @@ -19,6 +20,7 @@
# simple types
StrOrPath = Union[str, Path]
AnyByStrDict = Dict[str, Any]
AnyByStrMutableMapping = MutableMapping[str, Any]

# sequences
IntSeq = Sequence[int]
Expand Down
42 changes: 28 additions & 14 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from pygments.lexers.data import JsonLexer, YamlLexer
from questionary.prompts.common import Choice

from .errors import InvalidTypeError, UserMessageError
from .errors import InvalidTypeError, MissingFileWarning, UserMessageError
from .tools import cast_to_bool, cast_to_str, force_str_end
from .types import MISSING, AnyByStrDict, MissingType, OptStrOrPath, StrOrPath
from .types import MISSING, AnyByStrDict, AnyByStrMutableMapping, MissingType, StrOrPath


# TODO Remove these two functions as well as DEFAULT_DATA in a future release
Expand Down Expand Up @@ -83,17 +83,21 @@ class AnswersMap:
Default data from the user e.g. previously completed and restored data.
See [copier.main.Worker][].
system:
Automatic context generated by the [Worker][copier.main.Worker].
"""

# Private
hidden: set[str] = field(default_factory=set, init=False)

# Public
user: AnyByStrDict = field(default_factory=dict)
init: AnyByStrDict = field(default_factory=dict)
metadata: AnyByStrDict = field(default_factory=dict)
last: AnyByStrDict = field(default_factory=dict)
user_defaults: AnyByStrDict = field(default_factory=dict)
user: AnyByStrMutableMapping = field(default_factory=dict)
init: AnyByStrMutableMapping = field(default_factory=dict)
metadata: AnyByStrMutableMapping = field(default_factory=dict)
last: AnyByStrMutableMapping = field(default_factory=dict)
user_defaults: AnyByStrMutableMapping = field(default_factory=dict)
system: AnyByStrMutableMapping = field(default_factory=dict)

@property
def combined(self) -> Mapping[str, Any]:
Expand All @@ -105,6 +109,7 @@ def combined(self) -> Mapping[str, Any]:
self.metadata,
self.last,
self.user_defaults,
self.system,
DEFAULT_DATA,
)
)
Expand All @@ -125,6 +130,15 @@ class Question:
All attributes are init kwargs.
Attributes:
var_name:
Question name in the answers dict.
answers:
A map containing the answers provided by the user.
jinja_env:
The Jinja environment used to rendering answers.
choices:
Selections available for the user if the question requires them.
Can be templated.
Expand Down Expand Up @@ -155,13 +169,10 @@ class Question:
If the question type is str, it will hide user input on the screen
by displaying asterisks: `****`.
type_name:
type:
The type of question. Affects the rendering, validation and filtering.
Can be templated.
var_name:
Question name in the answers dict.
validator:
Jinja template with which to validate the user input. This template
will be rendered with the combined answers as variables; it should
Expand Down Expand Up @@ -487,13 +498,16 @@ def parse_yaml_string(string: str) -> Any:

def load_answersfile_data(
dst_path: StrOrPath,
answers_file: OptStrOrPath = None,
answers_file: StrOrPath = ".copier-answers.yml",
) -> AnyByStrDict:
"""Load answers data from a `$dst_path/$answers_file` file if it exists."""
try:
with open(Path(dst_path) / (answers_file or ".copier-answers.yml")) as fd:
with open(Path(dst_path) / answers_file) as fd:
return yaml.safe_load(fd)
except FileNotFoundError:
except (FileNotFoundError, IsADirectoryError):
warnings.warn(

Check warning on line 508 in copier/user_data.py

View check run for this annotation

Codecov / codecov/patch

copier/user_data.py#L507-L508

Added lines #L507 - L508 were not covered by tests
f"File not found; returning empty dict: {answers_file}", MissingFileWarning
)
return {}


Expand Down
Loading

0 comments on commit 0651768

Please sign in to comment.