Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiri Otoupal committed Mar 14, 2024
1 parent 9668811 commit 8eeedda
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 94 deletions.
2 changes: 1 addition & 1 deletion abst/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"CLI Command making OCI Bastion and kubernetes usage simple and fast"
)

__version__ = "2.3.48"
__version__ = "2.3.49"
__author__ = "Jiri Otoupal"
__author_email__ = "jiri-otoupal@ips-database.eu"
__license__ = "MIT"
Expand Down
6 changes: 3 additions & 3 deletions abst/bastion_support/oci_bastion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, context_name=None, region=None, direct_json_path=None):
self.__mark_used__(direct_json_path)

self.lb = LocalBroadcast(broadcast_shm_name)
self.lb.store_json({context_name: {"region": self.region}})
self.lb.store_json(context_name, {"region": self.region})

def __mark_used__(self, path: Path | None = None):
if path is None:
Expand All @@ -67,7 +67,7 @@ def current_status(self):
@current_status.setter
def current_status(self, value):
self._current_status = value
self.lb.store_json({self.context_name: {"status": value}})
self.lb.store_json(self.context_name, {"status": value})

def get_bastion_state(self) -> dict:
session_id = self.response["id"]
Expand Down Expand Up @@ -209,7 +209,7 @@ def create_forward_loop(self, shell: bool = False, force: bool = False):
local_port = creds.get("local-port", 22)
username = creds.get("resource-os-username", None)
if username:
self.lb.store_json({self.context_name: {"port": local_port, "username": username}})
self.lb.store_json(self.context_name, {"port": local_port, "username": username})
else:
rich.print(
"[yellow]No username in context json, please "
Expand Down
14 changes: 9 additions & 5 deletions abst/cli_commands/ssh_cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def ssh_lin(port, name, debug):
setup_calls(debug)

lb = LocalBroadcast(broadcast_shm_name)
data = lb.retrieve_json()
# Refresh memory because of a system
lb._write_json(data)
data = lb.list_contents()

if len(data.keys()) == 0:
rich.print("[yellow]No connected sessions[/yellow]")
Expand All @@ -37,8 +35,14 @@ def ssh_lin(port, name, debug):
rich.print(f"[yellow]No alive contexts found[/yellow]")
return

questions = [{"name": f"{key.ljust(longest_key)} |= status: {data['status']}", "value": (key, data)} for key, data
in data.items()]
for key, value in data.items():
value["compatible"] = len({"username", "port"} - set(value.keys())) == 0

questions = [
{"name": f"{key.ljust(longest_key)} |= status: {data['status']} | Context compatible: {data['compatible']}",
"value": (key, data)} for
key, data
in data.items()]

context_name, context = inquirer.select("Select context to ssh to:", questions).execute()

Expand Down
1 change: 1 addition & 0 deletions abst/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

# Default config
default_shared_mem_path: Path = (Path().home().resolve() / ".abst" / "shared_mem")
default_stack_location: Path = (Path().home().resolve() / ".abst" / "stack.json")
default_conf_path: Path = (Path().home().resolve() / ".abst" / "config.json")
default_creds_path: Path = (Path().home().resolve() / ".abst" / "creds.json")
Expand Down
132 changes: 49 additions & 83 deletions abst/sharing/local_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,80 @@
import json
import logging
import struct
from json import JSONDecodeError
from multiprocessing import shared_memory

from deepmerge import always_merger

from abst.config import max_json_shared
from abst.config import default_shared_mem_path


class LocalBroadcast:
_instance = None
_base_dir = default_shared_mem_path

def __new__(cls, name: str, size: int = max_json_shared):
def __new__(cls, name: str):
if cls._instance is None:
cls._instance = super(LocalBroadcast, cls).__new__(cls)
cls._instance.__init_shared_memory(name[:14], size)
cls._instance.__init_files(name)
return cls._instance

def __init_shared_memory(self, name: str, size: int):
self._data_name = name
self._len_name = f"{name}_len"
self._size = size
def __init_files(self, name: str):
self._base_name = name
# Ensure base directory exists
self._base_dir.mkdir(parents=True, exist_ok=True)

try:
# Attempt to create the main shared memory block
self._data_shm = shared_memory.SharedMemory(name=self._data_name, create=False)
self._data_is_owner = False
except FileNotFoundError:
self._data_shm = shared_memory.SharedMemory(name=self._data_name, create=True, size=size)
self._data_is_owner = True
def _get_file_path(self, context):
# Generates a file path for a given context using a Path object
return self._base_dir / f"{self._base_name}_{context}.json"

try:
self._len_shm = shared_memory.SharedMemory(name=self._len_name, create=False)
self._len_shm.buf[:8] = struct.pack('Q', 0)
self._len_is_owner = False
except FileNotFoundError:
self._len_shm = shared_memory.SharedMemory(name=self._len_name, create=True, size=8)
self._len_is_owner = True

def store_json(self, data: dict) -> int:
def store_json(self, context: str, data: dict) -> int:
"""
Serialize and store JSON data in shared memory.
Serialize and store JSON data in a file named by context.
@return: Size of the serialized data in bytes
"""
file_path = self._get_file_path(context)
data_before = self.retrieve_json(context)

data_before = self.retrieve_json()
for key, value in data.items():
for s_key in value.keys():
if data_before.get(key, None) is not None and data_before.get(key, None).get(s_key,
None) is not None and type(
data_before[key][s_key]) == type(data[key][s_key]):
data_before[key].pop(s_key)

data_copy = always_merger.merge(data, data_before)
# Merge new data with existing data
for s_key in list(data_before.keys()):
if data_before.get(s_key, None) is not None and type(
data_before[s_key]) == type(data.get(s_key, None)):
data_before.pop(s_key)
data_merged = always_merger.merge(data_before, data)

serialized_data = self._write_json(data_copy)
return len(serialized_data)
with file_path.open('w', encoding='utf-8') as file:
json.dump(data_merged, file)

def _write_json(self, data: dict):
serialized_data = json.dumps(data).encode('utf-8')
if len(serialized_data) > self._size:
raise ValueError("Data exceeds allocated shared memory size.")
# Write the data length to the length shared memory
self._len_shm.buf[:8] = struct.pack('Q', len(serialized_data))
# Write data to the main shared memory
self._data_shm.buf[:len(serialized_data)] = serialized_data
return serialized_data
return file_path.stat().st_size

def delete_context(self, context: str):
data_before = self.retrieve_json()
data_before.pop(context, None)
self._write_json(data_before)
file_path = self._get_file_path(context)
try:
file_path.unlink()
except FileNotFoundError:
pass # Context file already deleted or never existed

def retrieve_json(self) -> dict:
def retrieve_json(self, context: str) -> dict:
"""
Retrieve and deserialize JSON data from shared memory.
Retrieve and deserialize JSON data from a file named by context.
"""
# Read the data length from the length shared memory
data_length = self.get_used_space()

if data_length == -1:
file_path = self._get_file_path(context)
if not file_path.exists():
return {}

# Read data from the main shared memory
logging.debug(f"Stored data length: {data_length}")
data = bytes(self._data_shm.buf[:data_length]).decode('utf-8')
try:
return json.loads(data)
except JSONDecodeError:
logging.debug(f"Failed to load {data}")
return {}
with file_path.open('r', encoding='utf-8') as file:
try:
return json.load(file)
except json.JSONDecodeError as e:
print(f"Failed to load JSON from {file_path}: {e}")
return {}

def get_used_space(self) -> int:
def list_contents(self) -> dict:
"""
Get the size of the shared memory
@return: Number of bytes used
List files in the base directory and return a dictionary where file names
are keys and the values are the contents of the files.
"""
if self._len_shm.buf is None:
return -1
return struct.unpack('Q', self._len_shm.buf[:8])[0]

def close(self):
"""Close and unlink the shared memory blocks."""
try:
self._data_shm.close()
self._len_shm.close()
if self._data_is_owner:
self._data_shm.unlink()
if self._len_is_owner:
self._len_shm.unlink()
except FileNotFoundError:
pass
content_dict = {}
for file_path in self._base_dir.iterdir():
if file_path.is_file() and file_path.suffix == '.json':
# Extract the context from the file name
context = file_path.stem.replace(f"{self._base_name}_", "")
content_dict[context] = self.retrieve_json(context)
return content_dict
4 changes: 2 additions & 2 deletions abst/sharing/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def tearDownClass(cls):

def test_store_and_retrieve_json(self):
test_data = {'key': 'value', 'integer': 42, 'float': 3.14, 'list': [1, 2, 3]}
self.lb.store_json(test_data)
retrieved_data = self.lb.retrieve_json()
self.lb.store_json("test", test_data)
retrieved_data = self.lb.retrieve_json("test")
self.assertEqual(test_data, retrieved_data)

def test_data_length_persistence(self):
Expand Down

0 comments on commit 8eeedda

Please sign in to comment.