From 8eeedda62f524e845f48e884402091a4c0c0978c Mon Sep 17 00:00:00 2001 From: Jiri Otoupal Date: Thu, 14 Mar 2024 13:31:10 +0100 Subject: [PATCH] bug fix --- abst/__version__.py | 2 +- abst/bastion_support/oci_bastion.py | 6 +- abst/cli_commands/ssh_cli/commands.py | 14 ++- abst/config.py | 1 + abst/sharing/local_broadcast.py | 132 +++++++++-------------- abst/sharing/tests/test_functionality.py | 4 +- 6 files changed, 65 insertions(+), 94 deletions(-) diff --git a/abst/__version__.py b/abst/__version__.py index 537f879..29f7866 100644 --- a/abst/__version__.py +++ b/abst/__version__.py @@ -10,7 +10,7 @@ "CLI Command making OCI Bastion and kubernetes usage simple and fast" ) -__version__ = "2.3.48" +__version__ = "2.3.49" __author__ = "Jiri Otoupal" __author_email__ = "jiri-otoupal@ips-database.eu" __license__ = "MIT" diff --git a/abst/bastion_support/oci_bastion.py b/abst/bastion_support/oci_bastion.py index add07dd..31b41c2 100644 --- a/abst/bastion_support/oci_bastion.py +++ b/abst/bastion_support/oci_bastion.py @@ -48,7 +48,7 @@ def __init__(self, context_name=None, region=None, direct_json_path=None): self.__mark_used__(direct_json_path) self.lb = LocalBroadcast(broadcast_shm_name) - self.lb.store_json({context_name: {"region": self.region}}) + self.lb.store_json(context_name, {"region": self.region}) def __mark_used__(self, path: Path | None = None): if path is None: @@ -67,7 +67,7 @@ def current_status(self): @current_status.setter def current_status(self, value): self._current_status = value - self.lb.store_json({self.context_name: {"status": value}}) + self.lb.store_json(self.context_name, {"status": value}) def get_bastion_state(self) -> dict: session_id = self.response["id"] @@ -209,7 +209,7 @@ def create_forward_loop(self, shell: bool = False, force: bool = False): local_port = creds.get("local-port", 22) username = creds.get("resource-os-username", None) if username: - self.lb.store_json({self.context_name: {"port": local_port, "username": username}}) + self.lb.store_json(self.context_name, {"port": local_port, "username": username}) else: rich.print( "[yellow]No username in context json, please " diff --git a/abst/cli_commands/ssh_cli/commands.py b/abst/cli_commands/ssh_cli/commands.py index 391f619..befdc49 100644 --- a/abst/cli_commands/ssh_cli/commands.py +++ b/abst/cli_commands/ssh_cli/commands.py @@ -16,9 +16,7 @@ def ssh_lin(port, name, debug): setup_calls(debug) lb = LocalBroadcast(broadcast_shm_name) - data = lb.retrieve_json() - # Refresh memory because of a system - lb._write_json(data) + data = lb.list_contents() if len(data.keys()) == 0: rich.print("[yellow]No connected sessions[/yellow]") @@ -37,8 +35,14 @@ def ssh_lin(port, name, debug): rich.print(f"[yellow]No alive contexts found[/yellow]") return - questions = [{"name": f"{key.ljust(longest_key)} |= status: {data['status']}", "value": (key, data)} for key, data - in data.items()] + for key, value in data.items(): + value["compatible"] = len({"username", "port"} - set(value.keys())) == 0 + + questions = [ + {"name": f"{key.ljust(longest_key)} |= status: {data['status']} | Context compatible: {data['compatible']}", + "value": (key, data)} for + key, data + in data.items()] context_name, context = inquirer.select("Select context to ssh to:", questions).execute() diff --git a/abst/config.py b/abst/config.py index 56136d8..10ea54f 100644 --- a/abst/config.py +++ b/abst/config.py @@ -1,6 +1,7 @@ from pathlib import Path # Default config +default_shared_mem_path: Path = (Path().home().resolve() / ".abst" / "shared_mem") default_stack_location: Path = (Path().home().resolve() / ".abst" / "stack.json") default_conf_path: Path = (Path().home().resolve() / ".abst" / "config.json") default_creds_path: Path = (Path().home().resolve() / ".abst" / "creds.json") diff --git a/abst/sharing/local_broadcast.py b/abst/sharing/local_broadcast.py index 68ce8f7..415fda4 100644 --- a/abst/sharing/local_broadcast.py +++ b/abst/sharing/local_broadcast.py @@ -1,114 +1,80 @@ import json -import logging -import struct -from json import JSONDecodeError -from multiprocessing import shared_memory from deepmerge import always_merger -from abst.config import max_json_shared +from abst.config import default_shared_mem_path class LocalBroadcast: _instance = None + _base_dir = default_shared_mem_path - def __new__(cls, name: str, size: int = max_json_shared): + def __new__(cls, name: str): if cls._instance is None: cls._instance = super(LocalBroadcast, cls).__new__(cls) - cls._instance.__init_shared_memory(name[:14], size) + cls._instance.__init_files(name) return cls._instance - def __init_shared_memory(self, name: str, size: int): - self._data_name = name - self._len_name = f"{name}_len" - self._size = size + def __init_files(self, name: str): + self._base_name = name + # Ensure base directory exists + self._base_dir.mkdir(parents=True, exist_ok=True) - try: - # Attempt to create the main shared memory block - self._data_shm = shared_memory.SharedMemory(name=self._data_name, create=False) - self._data_is_owner = False - except FileNotFoundError: - self._data_shm = shared_memory.SharedMemory(name=self._data_name, create=True, size=size) - self._data_is_owner = True + def _get_file_path(self, context): + # Generates a file path for a given context using a Path object + return self._base_dir / f"{self._base_name}_{context}.json" - try: - self._len_shm = shared_memory.SharedMemory(name=self._len_name, create=False) - self._len_shm.buf[:8] = struct.pack('Q', 0) - self._len_is_owner = False - except FileNotFoundError: - self._len_shm = shared_memory.SharedMemory(name=self._len_name, create=True, size=8) - self._len_is_owner = True - - def store_json(self, data: dict) -> int: + def store_json(self, context: str, data: dict) -> int: """ - Serialize and store JSON data in shared memory. + Serialize and store JSON data in a file named by context. @return: Size of the serialized data in bytes """ + file_path = self._get_file_path(context) + data_before = self.retrieve_json(context) - data_before = self.retrieve_json() - for key, value in data.items(): - for s_key in value.keys(): - if data_before.get(key, None) is not None and data_before.get(key, None).get(s_key, - None) is not None and type( - data_before[key][s_key]) == type(data[key][s_key]): - data_before[key].pop(s_key) - - data_copy = always_merger.merge(data, data_before) + # Merge new data with existing data + for s_key in list(data_before.keys()): + if data_before.get(s_key, None) is not None and type( + data_before[s_key]) == type(data.get(s_key, None)): + data_before.pop(s_key) + data_merged = always_merger.merge(data_before, data) - serialized_data = self._write_json(data_copy) - return len(serialized_data) + with file_path.open('w', encoding='utf-8') as file: + json.dump(data_merged, file) - def _write_json(self, data: dict): - serialized_data = json.dumps(data).encode('utf-8') - if len(serialized_data) > self._size: - raise ValueError("Data exceeds allocated shared memory size.") - # Write the data length to the length shared memory - self._len_shm.buf[:8] = struct.pack('Q', len(serialized_data)) - # Write data to the main shared memory - self._data_shm.buf[:len(serialized_data)] = serialized_data - return serialized_data + return file_path.stat().st_size def delete_context(self, context: str): - data_before = self.retrieve_json() - data_before.pop(context, None) - self._write_json(data_before) + file_path = self._get_file_path(context) + try: + file_path.unlink() + except FileNotFoundError: + pass # Context file already deleted or never existed - def retrieve_json(self) -> dict: + def retrieve_json(self, context: str) -> dict: """ - Retrieve and deserialize JSON data from shared memory. + Retrieve and deserialize JSON data from a file named by context. """ - # Read the data length from the length shared memory - data_length = self.get_used_space() - - if data_length == -1: + file_path = self._get_file_path(context) + if not file_path.exists(): return {} - # Read data from the main shared memory - logging.debug(f"Stored data length: {data_length}") - data = bytes(self._data_shm.buf[:data_length]).decode('utf-8') - try: - return json.loads(data) - except JSONDecodeError: - logging.debug(f"Failed to load {data}") - return {} + with file_path.open('r', encoding='utf-8') as file: + try: + return json.load(file) + except json.JSONDecodeError as e: + print(f"Failed to load JSON from {file_path}: {e}") + return {} - def get_used_space(self) -> int: + def list_contents(self) -> dict: """ - Get the size of the shared memory - @return: Number of bytes used + List files in the base directory and return a dictionary where file names + are keys and the values are the contents of the files. """ - if self._len_shm.buf is None: - return -1 - return struct.unpack('Q', self._len_shm.buf[:8])[0] - - def close(self): - """Close and unlink the shared memory blocks.""" - try: - self._data_shm.close() - self._len_shm.close() - if self._data_is_owner: - self._data_shm.unlink() - if self._len_is_owner: - self._len_shm.unlink() - except FileNotFoundError: - pass + content_dict = {} + for file_path in self._base_dir.iterdir(): + if file_path.is_file() and file_path.suffix == '.json': + # Extract the context from the file name + context = file_path.stem.replace(f"{self._base_name}_", "") + content_dict[context] = self.retrieve_json(context) + return content_dict diff --git a/abst/sharing/tests/test_functionality.py b/abst/sharing/tests/test_functionality.py index a351ef5..334b05d 100644 --- a/abst/sharing/tests/test_functionality.py +++ b/abst/sharing/tests/test_functionality.py @@ -22,8 +22,8 @@ def tearDownClass(cls): def test_store_and_retrieve_json(self): test_data = {'key': 'value', 'integer': 42, 'float': 3.14, 'list': [1, 2, 3]} - self.lb.store_json(test_data) - retrieved_data = self.lb.retrieve_json() + self.lb.store_json("test", test_data) + retrieved_data = self.lb.retrieve_json("test") self.assertEqual(test_data, retrieved_data) def test_data_length_persistence(self):