diff --git a/UpgradingACA-Py.md b/UpgradingACA-Py.md index a829bc4f7f..c1ef822527 100644 --- a/UpgradingACA-Py.md +++ b/UpgradingACA-Py.md @@ -76,7 +76,19 @@ In case, running multiple tags [say test1 & test2]: ./scripts/run_docker upgrade --force-upgrade --named-tag test1 --named-tag test2 ``` +## Subwallet upgrades +With multitenant enabled, there is a subwallet associated with each tenant profile, so there is a need to upgrade those sub wallets in addition to the base wallet associated with root profile. +There are 2 options to perform such upgrades: + - `--upgrade-all-subwallets` + + This will apply the upgrade steps to all sub wallets [tenant profiles] and the base wallet [root profiles]. + + - `--upgrade-subwallet` + + This will apply the upgrade steps to specified sub wallets [identified by wallet id] and the base wallet. + + Note: multiple specification allowed ## Exceptions diff --git a/aries_cloudagent/commands/tests/test_upgrade.py b/aries_cloudagent/commands/tests/test_upgrade.py index d6c5c66e08..55451387c0 100644 --- a/aries_cloudagent/commands/tests/test_upgrade.py +++ b/aries_cloudagent/commands/tests/test_upgrade.py @@ -4,9 +4,11 @@ from ...core.in_memory import InMemoryProfile from ...connections.models.conn_record import ConnRecord -from ...storage.base import BaseStorage +from ...storage.base import BaseStorage, BaseStorageSearch +from ...storage.in_memory import InMemoryStorage from ...storage.record import StorageRecord from ...version import __version__ +from ...wallet.models.wallet_record import WalletRecord from .. import upgrade as test_module from ..upgrade import UpgradeError @@ -16,12 +18,33 @@ class TestUpgrade(AsyncTestCase): async def setUp(self): self.session = InMemoryProfile.test_session() self.profile = self.session.profile + self.profile.context.injector.bind_instance( + BaseStorageSearch, InMemoryStorage(self.profile) + ) self.storage = self.session.inject(BaseStorage) record = StorageRecord( "acapy_version", "v0.7.2", ) await self.storage.add_record(record) + recs = [ + WalletRecord( + key_management_mode=[ + WalletRecord.MODE_UNMANAGED, + WalletRecord.MODE_MANAGED, + ][i], + settings={ + "wallet.name": f"my-wallet-{i}", + "wallet.type": "indy", + "wallet.key": f"dummy-wallet-key-{i}", + }, + wallet_name=f"my-wallet-{i}", + ) + for i in range(2) + ] + async with self.profile.session() as session: + for rec in recs: + await rec.save(session) def test_bad_calls(self): with self.assertRaises(SystemExit): @@ -85,6 +108,63 @@ async def test_upgrade_from_version(self): profile=self.profile, ) + async def test_upgrade_all_subwallets(self): + self.profile.settings.extend( + { + "upgrade.from_version": "v0.7.2", + "upgrade.upgrade_all_subwallets": True, + "upgrade.force_upgrade": True, + "upgrade.page_size": 1, + } + ) + with async_mock.patch.object( + ConnRecord, + "query", + async_mock.CoroutineMock(return_value=[ConnRecord()]), + ), async_mock.patch.object(ConnRecord, "save", async_mock.CoroutineMock()): + await test_module.upgrade( + profile=self.profile, + ) + + async def test_upgrade_specified_subwallets(self): + wallet_ids = [] + async with self.profile.session() as session: + wallet_recs = await WalletRecord.query(session, tag_filter={}) + for wallet_rec in wallet_recs: + wallet_ids.append(wallet_rec.wallet_id) + self.profile.settings.extend( + { + "upgrade.named_tags": "fix_issue_rev_reg", + "upgrade.upgrade_subwallets": [wallet_ids[0]], + "upgrade.force_upgrade": True, + } + ) + with async_mock.patch.object( + ConnRecord, + "query", + async_mock.CoroutineMock(return_value=[ConnRecord()]), + ), async_mock.patch.object(ConnRecord, "save", async_mock.CoroutineMock()): + await test_module.upgrade( + profile=self.profile, + ) + + self.profile.settings.extend( + { + "upgrade.named_tags": "fix_issue_rev_reg", + "upgrade.upgrade_subwallets": wallet_ids, + "upgrade.force_upgrade": True, + "upgrade.page_size": 1, + } + ) + with async_mock.patch.object( + ConnRecord, + "query", + async_mock.CoroutineMock(return_value=[ConnRecord()]), + ), async_mock.patch.object(ConnRecord, "save", async_mock.CoroutineMock()): + await test_module.upgrade( + profile=self.profile, + ) + async def test_upgrade_callable(self): version_storage_record = await self.storage.find_record( type_filter="acapy_version", tag_query={} @@ -412,7 +492,7 @@ async def test_upgrade_x_invalid_config(self): async_mock.MagicMock(return_value={}), ): with self.assertRaises(UpgradeError) as ctx: - await test_module.upgrade(settings={}) + await test_module.upgrade(profile=self.profile) assert "No version configs found in" in str(ctx.exception) async def test_upgrade_x_params(self): diff --git a/aries_cloudagent/commands/upgrade.py b/aries_cloudagent/commands/upgrade.py index bf2d0dc95a..5d81db13fe 100644 --- a/aries_cloudagent/commands/upgrade.py +++ b/aries_cloudagent/commands/upgrade.py @@ -22,23 +22,26 @@ from ..core.profile import Profile, ProfileSession from ..config import argparse as arg +from ..config.injection_context import InjectionContext from ..config.default_context import DefaultContextBuilder from ..config.base import BaseError, BaseSettings from ..config.util import common_config from ..config.wallet import wallet_config from ..messaging.models.base import BaseModelError from ..messaging.models.base_record import BaseRecord, RecordType -from ..storage.base import BaseStorage +from ..storage.base import BaseStorage, BaseStorageSearch from ..storage.error import StorageNotFoundError from ..storage.record import StorageRecord from ..revocation.models.issuer_rev_reg_record import IssuerRevRegRecord from ..utils.classloader import ClassLoader, ClassNotFoundError from ..version import __version__, RECORD_TYPE_ACAPY_VERSION +from ..wallet.models.wallet_record import WalletRecord from . import PROG DEFAULT_UPGRADE_CONFIG_FILE_NAME = "default_version_upgrade_config.yml" LOGGER = logging.getLogger(__name__) +BATCH_SIZE = 25 class ExplicitUpgradeOption(Enum): @@ -239,21 +242,129 @@ def _perform_upgrade( return resave_record_path_sets, executables_call_set +def get_webhook_urls( + base_context: InjectionContext, + wallet_record: WalletRecord, +) -> list: + """Get the webhook urls according to dispatch_type.""" + wallet_id = wallet_record.wallet_id + dispatch_type = wallet_record.wallet_dispatch_type + subwallet_webhook_urls = wallet_record.wallet_webhook_urls or [] + base_webhook_urls = base_context.settings.get("admin.webhook_urls", []) + + if dispatch_type == "both": + webhook_urls = list(set(base_webhook_urls) | set(subwallet_webhook_urls)) + if not webhook_urls: + LOGGER.warning( + "No webhook URLs in context configuration " + f"nor wallet record {wallet_id}, but wallet record " + f"configures dispatch type {dispatch_type}" + ) + elif dispatch_type == "default": + webhook_urls = subwallet_webhook_urls + if not webhook_urls: + LOGGER.warning( + f"No webhook URLs in nor wallet record {wallet_id}, but " + f"wallet record configures dispatch type {dispatch_type}" + ) + else: + webhook_urls = base_webhook_urls + return webhook_urls + + +async def get_wallet_profile( + base_context: InjectionContext, + wallet_record: WalletRecord, + extra_settings: dict = {}, +) -> Profile: + """Get profile for a wallet record.""" + context = base_context.copy() + reset_settings = { + "wallet.recreate": False, + "wallet.seed": None, + "wallet.rekey": None, + "wallet.name": None, + "wallet.type": None, + "mediation.open": None, + "mediation.invite": None, + "mediation.default_id": None, + "mediation.clear": None, + } + extra_settings["admin.webhook_urls"] = get_webhook_urls(base_context, wallet_record) + + context.settings = ( + context.settings.extend(reset_settings) + .extend(wallet_record.settings) + .extend(extra_settings) + ) + + profile, _ = await wallet_config(context, provision=False) + return profile + + async def upgrade( settings: Optional[Union[Mapping[str, Any], BaseSettings]] = None, profile: Optional[Profile] = None, +): + """Invoke upgradation process for each applicable profile.""" + profiles_to_upgrade = [] + if settings: + batch_size = settings.get("upgrade.page_size", BATCH_SIZE) + else: + batch_size = BATCH_SIZE + if profile and (settings or settings == {}): + raise UpgradeError("upgrade requires either profile or settings, not both.") + if profile: + root_profile = profile + settings = profile.settings + else: + context_builder = DefaultContextBuilder(settings) + context = await context_builder.build_context() + root_profile, _ = await wallet_config(context) + profiles_to_upgrade.append(root_profile) + base_storage_search_inst = root_profile.inject(BaseStorageSearch) + if "upgrade.upgrade_all_subwallets" in settings and settings.get( + "upgrade.upgrade_all_subwallets" + ): + search_session = base_storage_search_inst.search_records( + type_filter=WalletRecord.RECORD_TYPE, page_size=batch_size + ) + while search_session._done is False: + wallet_storage_records = await search_session.fetch() + for wallet_storage_record in wallet_storage_records: + wallet_record = WalletRecord.from_storage( + wallet_storage_record.id, + json.loads(wallet_storage_record.value), + ) + wallet_profile = await get_wallet_profile( + base_context=root_profile.context, wallet_record=wallet_record + ) + profiles_to_upgrade.append(wallet_profile) + del settings["upgrade.upgrade_all_subwallets"] + if ( + "upgrade.upgrade_subwallets" in settings + and len(settings.get("upgrade.upgrade_subwallets")) >= 1 + ): + for _wallet_id in settings.get("upgrade.upgrade_subwallets"): + async with root_profile.session() as session: + wallet_record = await WalletRecord.retrieve_by_id( + session, record_id=_wallet_id + ) + wallet_profile = await get_wallet_profile( + base_context=root_profile.context, wallet_record=wallet_record + ) + profiles_to_upgrade.append(wallet_profile) + del settings["upgrade.upgrade_subwallets"] + for _profile in profiles_to_upgrade: + await upgrade_per_profile(profile=_profile, settings=settings) + + +async def upgrade_per_profile( + profile: Profile, + settings: Optional[Union[Mapping[str, Any], BaseSettings]] = None, ): """Perform upgradation steps.""" try: - if profile and (settings or settings == {}): - raise UpgradeError("upgrade requires either profile or settings, not both.") - if profile: - root_profile = profile - settings = profile.settings - else: - context_builder = DefaultContextBuilder(settings) - context = await context_builder.build_context() - root_profile, _ = await wallet_config(context) version_upgrade_config_inst = VersionUpgradeConfig( settings.get("upgrade.config_path") ) @@ -273,7 +384,7 @@ async def upgrade( upgrade_from_version_storage = None upgrade_from_version_config = None upgrade_from_version = None - async with root_profile.session() as session: + async with profile.session() as session: storage = session.inject(BaseStorage) try: version_storage_record = await storage.find_record( @@ -391,8 +502,24 @@ async def upgrade( raise UpgradeError( f"Only BaseRecord can be resaved, found: {str(rec_type)}" ) - async with root_profile.session() as session: - all_records = await rec_type.query(session) + all_records = [] + if settings: + batch_size = settings.get("upgrade.page_size", BATCH_SIZE) + else: + batch_size = BATCH_SIZE + base_storage_search_inst = profile.inject(BaseStorageSearch) + search_session = base_storage_search_inst.search_records( + type_filter=rec_type.RECORD_TYPE, page_size=batch_size + ) + while search_session._done is False: + storage_records = await search_session.fetch() + for storage_record in storage_records: + _record = rec_type.from_storage( + storage_record.id, + json.loads(storage_record.value), + ) + all_records.append(_record) + async with profile.session() as session: for record in all_records: await record.save( session, @@ -406,11 +533,11 @@ async def upgrade( _callable = version_upgrade_config_inst.get_callable(callable_name) if not _callable: raise UpgradeError(f"No function specified for {callable_name}") - await _callable(root_profile) + await _callable(profile) # Update storage version if to_update_flag: - async with root_profile.session() as session: + async with profile.session() as session: storage = session.inject(BaseStorage) if not version_storage_record: await storage.add_record( @@ -428,7 +555,7 @@ async def upgrade( f"set to {upgrade_to_version}" ) if not profile: - await root_profile.close() + await profile.close() except BaseError as e: raise UpgradeError(f"Error during upgrade: {e}") diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index 9be7794a77..4247ffa725 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -2148,6 +2148,33 @@ def add_arguments(self, parser: ArgumentParser): help=("Runs upgrade steps associated with tags provided in the config"), ) + parser.add_argument( + "--upgrade-all-subwallets", + action="store_true", + env_var="ACAPY_UPGRADE_ALL_SUBWALLETS", + help="Apply upgrade to all subwallets and the base wallet", + ) + + parser.add_argument( + "--upgrade-subwallet", + action="append", + env_var="ACAPY_UPGRADE_SUBWALLETS", + help=( + "Apply upgrade to specified subwallets (identified by wallet id)" + " and the base wallet" + ), + ) + + parser.add_argument( + "--upgrade-page-size", + type=str, + env_var="ACAPY_UPGRADE_PAGE_SIZE", + help=( + "Specify page/batch size to process BaseRecords, " + "this provides a way to prevent out-of-memory issues." + ), + ) + def get_settings(self, args: Namespace) -> dict: """Extract ACA-Py upgrade process settings.""" settings = {} @@ -2161,4 +2188,15 @@ def get_settings(self, args: Namespace) -> dict: settings["upgrade.named_tags"] = ( list(args.named_tag) if args.named_tag else [] ) + if args.upgrade_all_subwallets: + settings["upgrade.upgrade_all_subwallets"] = args.upgrade_all_subwallets + if args.upgrade_subwallet: + settings["upgrade.upgrade_subwallets"] = ( + list(args.upgrade_subwallet) if args.upgrade_subwallet else [] + ) + if args.upgrade_page_size: + try: + settings["upgrade.page_size"] = int(args.upgrade_page_size) + except ValueError: + raise ArgsParseError("Parameter --upgrade-page-size must be an integer") return settings diff --git a/aries_cloudagent/config/tests/test_argparse.py b/aries_cloudagent/config/tests/test_argparse.py index 32f9de4ebf..0ddc0eec81 100644 --- a/aries_cloudagent/config/tests/test_argparse.py +++ b/aries_cloudagent/config/tests/test_argparse.py @@ -155,6 +155,75 @@ async def test_upgrade_config(self): == "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml" ) + result = parser.parse_args( + [ + "--named-tag", + "test_tag_1", + "--named-tag", + "test_tag_2", + "--force-upgrade", + ] + ) + + assert result.named_tag == ["test_tag_1", "test_tag_2"] + assert result.force_upgrade is True + + settings = group.get_settings(result) + + assert settings.get("upgrade.named_tags") == ["test_tag_1", "test_tag_2"] + assert settings.get("upgrade.force_upgrade") is True + + result = parser.parse_args( + [ + "--upgrade-config-path", + "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml", + "--from-version", + "v0.7.2", + "--upgrade-all-subwallets", + "--force-upgrade", + ] + ) + + assert ( + result.upgrade_config_path + == "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml" + ) + assert result.force_upgrade is True + assert result.upgrade_all_subwallets is True + + settings = group.get_settings(result) + + assert ( + settings.get("upgrade.config_path") + == "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml" + ) + assert settings.get("upgrade.force_upgrade") is True + assert settings.get("upgrade.upgrade_all_subwallets") is True + + result = parser.parse_args( + [ + "--named-tag", + "fix_issue_rev_reg", + "--upgrade-subwallet", + "test_wallet_id_1", + "--upgrade-subwallet", + "test_wallet_id_2", + "--force-upgrade", + ] + ) + + assert result.named_tag == ["fix_issue_rev_reg"] + assert result.force_upgrade is True + assert result.upgrade_subwallet == ["test_wallet_id_1", "test_wallet_id_2"] + + settings = group.get_settings(result) + assert settings.get("upgrade.named_tags") == ["fix_issue_rev_reg"] + assert settings.get("upgrade.force_upgrade") is True + assert settings.get("upgrade.upgrade_subwallets") == [ + "test_wallet_id_1", + "test_wallet_id_2", + ] + async def test_outbound_is_required(self): """Test that either -ot or -oq are required""" parser = argparse.create_argument_parser() diff --git a/aries_cloudagent/storage/in_memory.py b/aries_cloudagent/storage/in_memory.py index 84a9ed241f..d2b0671f4f 100644 --- a/aries_cloudagent/storage/in_memory.py +++ b/aries_cloudagent/storage/in_memory.py @@ -255,6 +255,7 @@ def __init__( self.page_size = page_size or DEFAULT_PAGE_SIZE self.tag_query = tag_query self.type_filter = type_filter + self._done = False async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]: """Fetch the next list of results from the store. @@ -270,7 +271,7 @@ async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]: StorageSearchError: If the search query has not been opened """ - if self._cache is None: + if self._cache is None and self._done: raise StorageSearchError("Search query is complete") ret = [] @@ -291,9 +292,11 @@ async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]: if not ret: self._cache = None + self._done = True return ret async def close(self): """Dispose of the search query.""" self._cache = None + self._done = True