Skip to content

Commit

Permalink
Merge pull request #45 from datarootsio/add-diff-command
Browse files Browse the repository at this point in the history
Add diff command
  • Loading branch information
murilo-cunha committed Nov 7, 2022
2 parents 020958b + f762832 commit 01b9a54
Show file tree
Hide file tree
Showing 16 changed files with 496 additions and 89 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ $ databooks show [OPTIONS] PATHS...

![databooks show demo](https://raw.githubusercontent.com/datarootsio/databooks/main/docs/images/databooks-show.gif)

### Show rich notebook diffs

Similar to git diff, but for notebooks! Show a rich diff of the notebooks in the
terminal. Works for comparing git index with the current working directory, comparing
branches or blobs.

```console
$ databooks diff [OPTIONS] [REF_BASE] [REF_REMOTE] [PATHS]...
```

![databooks diff demo](https://raw.githubusercontent.com/datarootsio/databooks/main/docs/images/databooks-diff.gif)

## License

This project is licensed under the terms of the MIT license.
84 changes: 76 additions & 8 deletions databooks/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main CLI application."""
from itertools import compress
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple

import tomli
from rich.progress import (
Expand All @@ -18,10 +18,11 @@
from databooks.common import expand_paths
from databooks.config import TOML_CONFIG_FILE, get_config
from databooks.conflicts import conflicts2nbs, path2conflicts
from databooks.git_utils import get_nb_diffs
from databooks.logging import get_logger
from databooks.metadata import clear_all
from databooks.recipes import Recipe
from databooks.tui import print_nbs
from databooks.tui import print_diffs, print_nbs
from databooks.version import __version__

logger = get_logger(__file__)
Expand All @@ -47,7 +48,7 @@ def _config_callback(ctx: Context, config_path: Optional[Path]) -> Optional[Path
"""Get config file and inject values into context to override default args."""
target_paths = expand_paths(
paths=[Path(p).resolve() for p in ctx.params.get("paths", ())]
)
) or [Path.cwd()]
config_path = (
get_config(
target_paths=target_paths,
Expand All @@ -57,7 +58,6 @@ def _config_callback(ctx: Context, config_path: Optional[Path]) -> Optional[Path
else config_path
)
logger.debug(f"Loading config file from: {config_path}")

if config_path is not None: # config may not be specified
with config_path.open("rb") as f:
conf = (
Expand All @@ -82,11 +82,27 @@ def _check_paths(paths: List[Path], ignore: List[str]) -> List[Path]:
)
nb_paths = expand_paths(paths=paths, ignore=ignore)
if not nb_paths:
logger.info(f"No notebooks found in {paths}. Nothing to do.")
logger.info(
f"No notebooks found in {[p.resolve() for p in paths]}. Nothing to do."
)
raise Exit()
return nb_paths


def _parse_paths(
*refs: Optional[str], paths: List[Path]
) -> Tuple[Tuple[Optional[str], ...], List[Path]]:
"""Detect paths from `refs` and add to `paths`."""
first, *rest = refs
if first is not None and Path(first).exists():
paths += [Path(first)]
first = None
if rest:
_refs, _paths = _parse_paths(*rest, paths=paths)
return (first, *_refs), _paths
return (first,), paths


@app.callback()
def callback( # noqa: D103
version: Optional[bool] = Option(
Expand Down Expand Up @@ -396,6 +412,58 @@ def show(


@app.command()
def diff() -> None:
"""Show differences between notebooks (not implemented)."""
raise NotImplementedError
def diff(
ref_base: Optional[str] = Argument(
None, help="Base reference (hash, branch, etc.), defaults to index"
),
ref_remote: Optional[str] = Argument(
None, help="Remote reference (hash, branch, etc.), defaults to working tree"
),
paths: List[Path] = Argument(
None, is_eager=True, help="Path(s) of notebook files to compare"
),
ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"),
pager: bool = Option(
False, "--pager", "-p", help="Use pager instead of printing to terminal"
),
verbose: bool = Option(
False, "--verbose", "-v", help="Increase verbosity for debugging"
),
multiple: bool = Option(False, "--yes", "-y", help="Show multiple files"),
config: Optional[Path] = Option(
None,
"--config",
"-c",
is_eager=True,
callback=_config_callback,
resolve_path=True,
exists=True,
help="Get CLI options from configuration file",
),
help: Optional[bool] = Option(
None,
"--help",
is_eager=True,
callback=_help_callback,
help="Show this message and exit",
),
) -> None:
"""
Show differences between notebooks.
This is similar to `git-diff`, but in practice it is a subset of `git-diff`
features - only exception is that we cannot compare diffs between local files. That
means we can compare files that are staged with other branches, hashes, etc., or
compare the current directory with the current index.
"""
(ref_base, ref_remote), paths = _parse_paths(ref_base, ref_remote, paths=paths)
diffs = get_nb_diffs(
ref_base=ref_base, ref_remote=ref_remote, paths=paths, verbose=verbose
)
if not diffs:
logger.info("No notebook diffs found. Nothing to do.")
raise Exit()
if len(diffs) > 1 and not multiple:
if not Confirm.ask(f"Show {len(diffs)} notebook diffs?"):
raise Exit()
print_diffs(diffs=diffs, use_pager=pager)
2 changes: 1 addition & 1 deletion databooks/data_models/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def remove_fields(
fields for cell type.
"""
# Ignore required `BaseCell` fields
cell_fields = self.__fields__ # required fields especified in class definition
cell_fields = BaseCell.__fields__ # required fields
if any(field in fields for field in cell_fields):
logger.debug(
"Ignoring removal of required fields "
Expand Down
3 changes: 3 additions & 0 deletions databooks/data_models/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def _rich(kernel: str) -> Text:

kernelspec = self.metadata.dict().get("kernelspec", {})
if isinstance(kernelspec, tuple): # check if this is a `DiffCells`
kernelspec = tuple(
ks or {"language": "text", "display_name": "null"} for ks in kernelspec
)
lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec)
nb_lang = lang_first if lang_first == lang_last else "text"
if any("display_name" in ks.keys() for ks in kernelspec):
Expand Down
154 changes: 140 additions & 14 deletions databooks/git_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
"""Git helper functions."""
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional, Sequence, Union, cast, overload

from git import Blob, Git, Repo # type: ignore
from git import Git
from git.diff import DiffIndex
from git.objects.blob import Blob
from git.objects.commit import Commit
from git.objects.tree import Tree
from git.repo import Repo

from databooks.common import find_obj
from databooks.logging import get_logger
from databooks.common import find_common_parent, find_obj
from databooks.logging import get_logger, set_verbose

logger = get_logger(name=__file__)

# https://github.com/python/mypy/issues/5317
ChangeType = Enum("ChangeType", [*DiffIndex.change_type, "U"]) # type: ignore[misc]


@dataclass
class UnmergedBlob:
Expand All @@ -30,14 +39,36 @@ class ConflictFile:
last_contents: str


def get_repo(path: Path = Path.cwd()) -> Repo:
"""Find git repo in current or parent directories."""
repo_dir = find_obj(
obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True
)
repo = Repo(path=repo_dir)
logger.debug(f"Repo found at: {repo.working_dir}")
return repo
@dataclass
class Contents:
"""Container for path of file versions."""

path: Optional[Path]
contents: Optional[str]


@dataclass
class DiffContents:
"""Container for path and different versions of conflicted notebooks."""

a: Contents
b: Contents
change_type: ChangeType


@overload
def blob2str(blob: None) -> None:
...


@overload
def blob2str(blob: Blob) -> str:
...


def blob2str(blob: Optional[Blob]) -> Optional[str]:
"""Get the blob contents if they exist (otherwise return `None`)."""
return blob.data_stream.read() if blob is not None else None


def blob2commit(blob: Blob, repo: Repo) -> str:
Expand All @@ -51,6 +82,43 @@ def blob2commit(blob: Blob, repo: Repo) -> str:
)


def diff2contents(
blob: Blob,
ref: Optional[Union[Tree, Commit, str]],
path: Path,
not_exists: bool = False,
) -> Optional[str]:
"""
Get the blob contents from the diff.
Depends on whether we are diffing against current working tree and if object exists
at diff time (added or deleted objects only exist at one side). If comparing
against working tree (`ref=None`) we return the current file contents.
:param blob: git diff blob
:param ref: git reference
:param path: path to object
:param not_exists: whether object exists at 'diff time' (added or removed objects
do not exist)
:return: blob contents as a string (if exists)
"""
if not_exists:
return None
elif ref is None:
return path.read_text()
else:
return blob2str(blob)


def get_repo(path: Path = Path.cwd()) -> Repo:
"""Find git repo in current or parent directories."""
repo_dir = find_obj(
obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True
)
repo = Repo(path=repo_dir)
logger.debug(f"Repo found at: {repo.working_dir}")
return repo


def get_conflict_blobs(repo: Repo) -> List[ConflictFile]:
"""Get the source files for conflicts."""
unmerged_blobs = repo.index.unmerged_blobs()
Expand All @@ -70,8 +138,66 @@ def get_conflict_blobs(repo: Repo) -> List[ConflictFile]:
filename=repo.working_dir / blob.filename,
first_log=blob2commit(blob=blob.stage[2], repo=repo),
last_log=blob2commit(blob=blob.stage[3], repo=repo),
first_contents=repo.git.show(blob.stage[2]),
last_contents=repo.git.show(blob.stage[3]),
first_contents=blob2str(blob.stage[2]),
last_contents=blob2str(blob.stage[3]),
)
for blob in blobs
]


def get_nb_diffs(
ref_base: Optional[str] = None,
ref_remote: Optional[str] = None,
paths: Sequence[Path] = (),
*,
repo: Optional[Repo] = None,
verbose: bool = False,
) -> List[DiffContents]:
"""
Get the noteebook(s) git diff(s).
By default, diffs are compared with the current working direcotory. That is, staged
files will still show up in the diffs. Only return the diffs for notebook files.
"""
if verbose:
set_verbose(logger)

common_path = find_common_parent(paths or [Path.cwd()])
repo = get_repo(path=common_path) if repo is None else repo
if repo.working_dir is None:
raise ValueError("No repo found - cannot compute diffs.")

ref_base = repo.index if ref_base is None else repo.tree(ref_base)
ref_remote = ref_remote if ref_remote is None else repo.tree(ref_remote)

logger.debug(
f"Looking for diffs on path(s) {[p.resolve() for p in paths]}.\n"
f"Comparing `{ref_base}` and `{ref_remote}`."
)
repo_root_dir = Path(repo.working_dir)
return [
DiffContents(
a=Contents(
path=Path(d.a_path),
contents=diff2contents(
blob=cast(Blob, d.a_blob),
ref=ref_base,
path=repo_root_dir / d.a_path,
not_exists=d.change_type is ChangeType.A, # type: ignore
),
),
b=Contents(
path=Path(d.b_path),
contents=diff2contents(
blob=cast(Blob, d.b_blob),
ref=ref_remote,
path=repo_root_dir / d.b_path,
not_exists=d.change_type is ChangeType.D, # type: ignore
),
),
change_type=ChangeType[d.change_type],
)
for d in ref_base.diff(
other=ref_remote, paths=list(paths) or list(repo_root_dir.rglob("*.ipynb"))
)
]
8 changes: 4 additions & 4 deletions databooks/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def clear(

if nb_equals or check:
msg = (
"only check (unwanted metadata found)."
if not nb_equals
else "no metadata to remove."
"no metadata to remove."
if nb_equals
else "only check (unwanted metadata found)."
)
logger.debug(f"No action taken for {read_path} - " + msg)
logger.debug(f"No action taken for {read_path} - {msg}")
else:
notebook.write(path=write_path, overwrite=overwrite)
logger.debug(f"Removed metadata from {read_path}, saved as {write_path}")
Expand Down
2 changes: 1 addition & 1 deletion databooks/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def _recipes(cls) -> Dict[str, str]:


# https://github.com/python/mypy/issues/5317
Recipe = Enum("Recipe", CookBook._recipes()) # type: ignore
Recipe = Enum("Recipe", CookBook._recipes()) # type: ignore[misc]
Loading

0 comments on commit 01b9a54

Please sign in to comment.