Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autoinstall: Don't use snap env when invoking early and late commands #1811

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion subiquity/server/controllers/cmdlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import os
import shlex
import shutil
from typing import List, Sequence, Union

import attr
Expand All @@ -25,7 +26,7 @@
from subiquity.server.controller import NonInteractiveController
from subiquitycore.async_helpers import run_bg_task
from subiquitycore.context import with_context
from subiquitycore.utils import arun_command
from subiquitycore.utils import arun_command, orig_environ


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -78,6 +79,15 @@ async def run(self, context):
env = self.env()
for i, cmd in enumerate(tuple(self.builtin_cmds) + tuple(self.cmds)):
desc = cmd.desc()

# If the path to the command isn't found on the snap we should
# drop the snap specific environment variables.
command = shlex.split(desc)[0]
path = shutil.which(command)
if path is not None:
if not path.startswith("/snap"):
env = orig_environ(env)

with context.child("command_{}".format(i), desc):
args = cmd.as_args_list()
if self.syslog_id:
Expand Down
84 changes: 84 additions & 0 deletions subiquity/server/controllers/tests/test_cmdlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2023 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from unittest import IsolatedAsyncioTestCase, mock

from subiquity.server.controllers import cmdlist
from subiquity.server.controllers.cmdlist import (
CmdListController,
Command,
EarlyController,
LateController,
)
from subiquitycore.tests.mocks import make_app
from subiquitycore.utils import orig_environ


@mock.patch.object(cmdlist, "orig_environ", side_effect=orig_environ)
@mock.patch.object(cmdlist, "arun_command")
class TestCmdListController(IsolatedAsyncioTestCase):
controller_type = CmdListController

def setUp(self):
self.controller = self.controller_type(make_app())
self.controller.cmds = [Command(args="some-command", check=False)]
snap_env = {
"LD_LIBRARY_PATH": "/var/lib/snapd/lib/gl",
}
self.mocked_os_environ = mock.patch.dict("os.environ", snap_env)

@mock.patch("shutil.which", return_value="/usr/bin/path/to/bin")
async def test_no_snap_env_on_call(
self,
mocked_shutil,
mocked_arun,
mocked_orig_environ,
):
with self.mocked_os_environ:
await self.controller.run()
args, kwargs = mocked_arun.call_args
call_env = kwargs["env"]

mocked_orig_environ.assert_called()
self.assertNotIn("LD_LIBRARY_PATH", call_env)

@mock.patch("shutil.which", return_value="/snap/path/to/bin")
async def test_with_snap_env_on_call(
self,
mocked_shutil,
mocked_arun,
mocked_orig_environ,
):
with self.mocked_os_environ:
await self.controller.run()
args, kwargs = mocked_arun.call_args
call_env = kwargs["env"]

mocked_orig_environ.assert_not_called()
self.assertIn("LD_LIBRARY_PATH", call_env)


class TestEarlyController(TestCmdListController):
controller_type = EarlyController

def setUp(self):
super().setUp()


class TestLateController(TestCmdListController):
controller_type = LateController

def setUp(self):
super().setUp()
4 changes: 0 additions & 4 deletions subiquity/server/controllers/tests/test_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def setUp(self):
self.app = make_app()
self.app.opts.bootloader = "UEFI"
self.app.command_runner = mock.AsyncMock()
self.app.report_start_event = mock.Mock()
self.app.report_finish_event = mock.Mock()
self.app.prober = mock.Mock()
self.app.prober.get_storage = mock.AsyncMock()
self.app.block_log_dir = "/inexistent"
Expand Down Expand Up @@ -1177,8 +1175,6 @@ def setUp(self):
self.app = make_app()
self.app.command_runner = mock.AsyncMock()
self.app.opts.bootloader = "UEFI"
self.app.report_start_event = mock.Mock()
self.app.report_finish_event = mock.Mock()
self.app.prober = mock.Mock()
self.app.prober.get_storage = mock.AsyncMock()
self.app.snapdapi = snapdapi.make_api_client(AsyncSnapd(get_fake_connection()))
Expand Down
4 changes: 0 additions & 4 deletions subiquity/server/controllers/tests/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def setUp(self):
self.controller = InstallController(make_app())
self.controller.write_config = unittest.mock.Mock()
self.controller.app.note_file_for_apport = Mock()
self.controller.app.report_start_event = Mock()
self.controller.app.report_finish_event = Mock()

self.controller.model.target = "/target"

Expand Down Expand Up @@ -199,8 +197,6 @@ def test_generic_config(self):
class TestInstallController(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.controller = InstallController(make_app())
self.controller.app.report_start_event = Mock()
self.controller.app.report_finish_event = Mock()
self.controller.model.target = tempfile.mkdtemp()
os.makedirs(os.path.join(self.controller.model.target, "etc/grub.d"))
self.addCleanup(shutil.rmtree, self.controller.model.target)
Expand Down
2 changes: 0 additions & 2 deletions subiquity/server/controllers/tests/test_refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
class TestRefreshController(SubiTestCase):
def setUp(self):
self.app = make_app()
self.app.report_start_event = mock.Mock()
self.app.report_finish_event = mock.Mock()
self.app.note_data_for_apport = mock.Mock()
self.app.prober = mock.Mock()
self.app.snapdapi = snapdapi.make_api_client(AsyncSnapd(get_fake_connection()))
Expand Down
4 changes: 1 addition & 3 deletions subiquity/server/controllers/tests/test_snaplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import unittest
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock

import requests

Expand All @@ -31,8 +31,6 @@ def setUp(self):
self.model = SnapListModel()
self.app = make_app()
self.app.snapd = AsyncMock()
self.app.report_start_event = Mock()
self.app.report_finish_event = Mock()

self.loader = SnapdSnapInfoLoader(
self.model, self.app.snapd, "server", self.app.context
Expand Down
5 changes: 5 additions & 0 deletions subiquitycore/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@ def make_app(model=None):
app.opts = mock.Mock()
app.opts.dry_run = True
app.scale_factor = 1000
app.echo_syslog_id = None
app.log_syslog_id = None
app.report_start_event = mock.Mock()
app.report_finish_event = mock.Mock()
Chris-Peterson444 marked this conversation as resolved.
Show resolved Hide resolved

return app
13 changes: 8 additions & 5 deletions subiquitycore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import random
import subprocess
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Optional, Sequence

log = logging.getLogger("subiquitycore.utils")

Expand All @@ -35,15 +35,18 @@ def _clean_env(env, *, locale=True):
return env


def orig_environ(env):
def orig_environ(env: Optional[Dict[str, str]]) -> Dict[str, str]:
"""Generate an environment dict that is suitable for use for running
programs that live outside the snap."""

if env is None:
env = os.environ
ret = env.copy()
env: Dict[str, str] = os.environ

ret: Dict[str, str] = env.copy()

for key, val in env.items():
if key.endswith("_ORIG"):
key_to_restore = key[: -len("_ORIG")]
key_to_restore: str = key[: -len("_ORIG")]
if val:
ret[key_to_restore] = val
else:
Expand Down