Skip to content

Commit

Permalink
fix(registry): push to remote automatically only on cloned repos (#417)
Browse files Browse the repository at this point in the history
Fixes #405 

We've changed the semantic of the operation during the recent migration
to `scmrepo`. The key thing is the `has_remote(reg.scm)` calls and
implementation. I think the original intention was to push automatically
when we run GTO operations on a remote repo (means we are cloning it
into a temp dir, do some op, and push the result back).

## TODO

- [x] Tests
- [x] Review CLI option description
- [x] Review docs

## Docs

Relevant docs update is here
iterative/dvc.org#4879
  • Loading branch information
shcheklein authored Sep 23, 2023
1 parent 1b5c727 commit 7da10ee
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 29 deletions.
11 changes: 5 additions & 6 deletions gto/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
parse_shortcut,
)
from gto.exceptions import NoRepo, NotImplementedInGTO, RefNotFound, WrongArgs
from gto.git_utils import has_remote
from gto.index import Artifact, RepoIndexManager
from gto.registry import GitRegistry
from gto.tag import parse_name as parse_tag_name
Expand Down Expand Up @@ -97,7 +96,7 @@ def register(
bump_major=bump_major,
bump_minor=bump_minor,
bump_patch=bump_patch,
push=push or has_remote(reg.scm),
push=push,
stdout=stdout,
author=author,
author_email=author_email,
Expand Down Expand Up @@ -131,7 +130,7 @@ def assign(
message=message,
simple=simple,
force=force,
push=push or has_remote(reg.scm),
push=push,
skip_registration=skip_registration,
stdout=stdout,
author=author,
Expand Down Expand Up @@ -165,7 +164,7 @@ def unassign(
simple=simple if simple is not None else False,
force=force,
delete=delete,
push=push or has_remote(reg.scm),
push=push,
author=author,
author_email=author_email,
)
Expand Down Expand Up @@ -195,7 +194,7 @@ def deregister(
simple=simple if simple is not None else True,
force=force,
delete=delete,
push=push or has_remote(reg.scm),
push=push,
author=author,
author_email=author_email,
)
Expand Down Expand Up @@ -223,7 +222,7 @@ def deprecate(
simple=simple,
force=force,
delete=delete,
push=push or has_remote(reg.scm),
push=push,
author=author,
author_email=author_email,
)
Expand Down
2 changes: 1 addition & 1 deletion gto/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def callback_sort( # pylint: disable=inconsistent-return-statements
False,
"--push",
is_flag=True,
help="Push created git tag to `origin` (done automatically for remote repo)",
help="Push created git tag to `origin` (ignored if `repo` option is a remote URL)",
)
option_commit = Option(
False,
Expand Down
25 changes: 13 additions & 12 deletions gto/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tempfile import TemporaryDirectory
from typing import Optional, Union

from scmrepo.exceptions import InvalidRemote, SCMError
from scmrepo.exceptions import SCMError
from scmrepo.git import Git, SyncStatus

from gto.config import RegistryConfig
Expand All @@ -17,7 +17,17 @@
class RemoteRepoMixin:
@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: bool = False,
config: Optional[RegistryConfig] = None,
):
"""
`cloned` - scm is a remote repo that was cloned locally into a tmp
directory to be used for the duration of the context manager.
Means that we push tags and changes back to the remote repo.
"""
raise NotImplementedError()

@classmethod
Expand Down Expand Up @@ -51,7 +61,7 @@ def from_url(
with cloned_git_repo(url_or_scm) as scm:
if branch:
scm.checkout(branch)
with cls.from_scm(scm=scm, config=config) as obj:
with cls.from_scm(scm=scm, config=config, cloned=True) as obj:
yield obj

def _call_commit_push(
Expand Down Expand Up @@ -152,12 +162,3 @@ def git_add_and_commit_all_changes(scm: Git, message: str) -> None:
def _reset_repo_to_head(scm: Git) -> None:
if scm.stash.push(include_untracked=True):
scm.stash.drop()


def has_remote(scm: Git, remote: str = "origin") -> bool:
try:
scm.validate_git_remote(remote)
return True
except InvalidRemote:
pass
return False
23 changes: 17 additions & 6 deletions gto/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,22 @@ def artifact_centric_representation(self):

class RepoIndexManager(FileIndexManager, RemoteRepoMixin):
scm: Git
cloned: bool

def __init__(self, scm: Git, config):
super().__init__(scm=scm, config=config) # type: ignore[call-arg]
def __init__(self, scm: Git, cloned: bool, config):
super().__init__(scm=scm, cloned=cloned, config=config) # type: ignore[call-arg]

@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: bool = False,
config: Optional[RegistryConfig] = None,
):
if config is None:
config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME))
yield cls(scm=scm, config=config)
yield cls(scm=scm, cloned=cloned, config=config)

def add(
self,
Expand All @@ -351,7 +357,7 @@ def add(
commit=commit,
commit_message=commit_message
or generate_annotate_commit_message(name=name, type=type, path=path),
push=push,
push=push or self.cloned,
stdout=stdout,
name=name,
type=type,
Expand Down Expand Up @@ -458,7 +464,12 @@ class EnrichmentManager(BaseManager, RemoteRepoMixin):

@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: Optional[bool] = False,
config: Optional[RegistryConfig] = None,
):
if config is None:
config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME))
yield cls(scm=scm, config=config)
Expand Down
11 changes: 9 additions & 2 deletions gto/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

class GitRegistry(BaseModel, RemoteRepoMixin):
scm: Git
cloned: bool
artifact_manager: TagArtifactManager
version_manager: TagVersionManager
stage_manager: TagStageManager
Expand All @@ -54,12 +55,18 @@ class Config:

@classmethod
@contextmanager
def from_scm(cls, scm: Git, config: Optional[RegistryConfig] = None):
def from_scm(
cls,
scm: Git,
cloned: bool = False,
config: Optional[RegistryConfig] = None,
):
if config is None:
config = read_registry_config(os.path.join(scm.root_dir, CONFIG_FILE_NAME))

yield cls(
scm=scm,
cloned=cloned,
config=config,
artifact_manager=TagArtifactManager(scm=scm, config=config),
version_manager=TagVersionManager(scm=scm, config=config),
Expand Down Expand Up @@ -572,7 +579,7 @@ def get_stages(self, allowed: bool = False, used: bool = False):
def _push_tag_or_echo_reminder(
self, tag_name: str, push: bool, stdout: bool, delete: bool = False
) -> None:
if push:
if push or self.cloned:
if stdout:
echo(
f"Running `git push{' --delete ' if delete else ' '}origin {tag_name}`"
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"pytest-mock",
"pytest-test-utils",
"pylint==2.17.5",
# we use this to suppress pytest-related false positives in our tests.
"pylint-pytest",
# we use this to suppress some messages in tests, eg: foo/bar naming,
# and, protected method calls in our tests
"pylint-plugin-utils",
Expand Down
15 changes: 15 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

import pytest
from freezegun import freeze_time
from pytest_mock import MockFixture
from pytest_test_utils import TmpDir
from scmrepo.git import Git

import gto
import tests.resources
from gto.api import show
from gto.exceptions import RefNotFound, WrongArgs
from gto.git_utils import cloned_git_repo
from gto.index import RepoIndexManager
from gto.tag import find
from gto.versions import SemVer
Expand Down Expand Up @@ -590,3 +592,16 @@ def test_if_unassign_with_remote_repo_then_invoke_git_push_tag(tmp_dir: TmpDir):
tag_name="churn#staging!#3",
delete=False,
)


def test_action_doesnt_push_even_if_repo_has_remotes_set(mocker: MockFixture):
# test for https://github.com/iterative/gto/issues/405
with cloned_git_repo(tests.resources.SAMPLE_REMOTE_REPO_URL) as scm:
mocked_git_push_tag = mocker.patch("gto.registry.git_push_tag")
gto.api.unassign(
repo=scm,
name="churn",
stage="staging",
version="v3.1.0",
)
mocked_git_push_tag.assert_not_called()
15 changes: 15 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Sequence

import pytest
from pytest_mock import MockFixture
from pytest_test_utils import TmpDir
from scmrepo.git import Git

Expand Down Expand Up @@ -111,3 +112,17 @@ def test_check_existence_no_repo(tmp_dir: TmpDir):
tmp_dir.gen("m1.txt", "some content")
assert check_if_path_exists(tmp_dir / "m1.txt")
assert not check_if_path_exists(tmp_dir / "not" / "exists")


def test_from_url_sets_cloned_property(tmp_dir: TmpDir, scm: Git, mocker: MockFixture):
with RepoIndexManager.from_url(tmp_dir) as idx:
assert idx.cloned is False

with RepoIndexManager.from_url(scm) as idx:
assert idx.cloned is False

cloned_git_repo_mock = mocker.patch("gto.git_utils.cloned_git_repo")
cloned_git_repo_mock.return_value.__enter__.return_value = scm

with RepoIndexManager.from_url("https://github.com/iterative/gto") as idx:
assert idx.cloned is True
49 changes: 49 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Dict, List

import pytest
from pytest_mock import MockFixture
from pytest_test_utils import TmpDir
from scmrepo.git import Git

from gto.registry import GitRegistry

Expand Down Expand Up @@ -443,3 +445,50 @@ def test_registry_state_tag_tag(tmp_dir: TmpDir):
check_obj(
appeared["stages"][key], expected["stages"][key], exclude["stages"]
)


def test_from_url_sets_cloned_property(tmp_dir: TmpDir, scm: Git, mocker: MockFixture):
with GitRegistry.from_url(tmp_dir) as reg:
assert reg.cloned is False

with GitRegistry.from_url(scm) as reg:
assert reg.cloned is False

cloned_git_repo_mock = mocker.patch("gto.git_utils.cloned_git_repo")
cloned_git_repo_mock.return_value.__enter__.return_value = scm

with GitRegistry.from_url("https://github.com/iterative/gto") as reg:
assert reg.cloned is True


# Some method parameters (model names, versions, revs, etc) depend and set by
# the `showcase` fixture setup in the conftest.py.
@pytest.mark.parametrize(
"method,args,kwargs",
[
("register", ["new_model", "HEAD"], {}),
("deregister", ["nn"], {"version": "v0.0.1"}),
("assign", ["nn", "new_stage"], {"version": "v0.0.1"}),
("unassign", ["nn", "staging"], {"version": "v0.0.1"}),
("deprecate", ["nn"], {}),
],
)
@pytest.mark.usefixtures("showcase")
def test_tag_is_pushed_if_cloned_is_set(
tmp_dir: TmpDir,
mocker: MockFixture,
method,
args,
kwargs,
):
with GitRegistry.from_url(tmp_dir) as reg:
# imitate that we are doing actions on the remote repo
assert reg.cloned is False
reg.cloned = True

# check that it attempts to push tag to a remote repo, even if
# push=False is set in a call. `cloned` overrides it in this case
git_push_tag_mock = mocker.patch("gto.registry.git_push_tag")
kwargs["push"] = False
getattr(reg, method)(*args, **kwargs)
git_push_tag_mock.assert_called_once()

0 comments on commit 7da10ee

Please sign in to comment.