Skip to content

Commit

Permalink
Avoid opening too many SSH sessions (#229)
Browse files Browse the repository at this point in the history
Sometimes, a race condition caused the `SSHConnector` to open more than
the maximum allowed sessions (specified through the
`maxConcurrentSessions` parameter). This led to broken connections and
errors in workflow executions.

This commit prvents this error by refactoring the `SSHConnector`
internals. It also adds a no-regression test that transfers 20 files
simultaneously to and from a remote SSH location.
  • Loading branch information
GlassOfWhiskey authored Sep 10, 2023
1 parent 7b451be commit 51faa3e
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 224 deletions.
8 changes: 4 additions & 4 deletions streamflow/deployment/connector/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import yaml
from cachetools import Cache, TTLCache
from kubernetes_asyncio import client
from kubernetes_asyncio.client import ApiClient, Configuration, V1Container
from kubernetes_asyncio.client import ApiClient, Configuration, V1Container, V1PodList
from kubernetes_asyncio.config import (
ConfigException,
load_incluster_config,
Expand Down Expand Up @@ -425,7 +425,7 @@ def _get_run_command(
)

@abstractmethod
async def _get_running_pods(self) -> MutableSequence[Any]:
async def _get_running_pods(self) -> V1PodList:
...

def _get_stream_reader(self, location: Location, src: str) -> StreamWrapperContext:
Expand Down Expand Up @@ -633,7 +633,7 @@ def _get_api(self, k8s_object: Any) -> Any:
fcn_to_call = f"{group}{version.capitalize()}Api"
return getattr(client, fcn_to_call)(self.client.api_client)

async def _get_running_pods(self) -> MutableSequence[Any]:
async def _get_running_pods(self) -> V1PodList:
return await self.client.list_namespaced_pod(
namespace=self.namespace or "default",
field_selector="status.phase=Running",
Expand Down Expand Up @@ -992,7 +992,7 @@ def _get_base_command(self) -> str:
f"{get_option('repository-config', self.repositoryConfig)}"
)

async def _get_running_pods(self) -> MutableSequence[Any]:
async def _get_running_pods(self) -> V1PodList:
return await self.client.list_namespaced_pod(
namespace=self.namespace or "default",
label_selector=f"app.kubernetes.io/instance={self.releaseName}",
Expand Down
142 changes: 95 additions & 47 deletions streamflow/deployment/connector/occam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from streamflow.core import utils
from streamflow.core.deployment import Location
from streamflow.core.exception import WorkflowExecutionException
from streamflow.core.scheduling import AvailableLocation
from streamflow.core.utils import get_option
from streamflow.deployment.connector.ssh import SSHConnector
Expand Down Expand Up @@ -98,8 +99,19 @@ async def _get_tmpdir(self, location: str):
temp_dir = posixpath.join(
scratch_home, "streamflow", "".join(utils.random_name())
)
async with self._get_ssh_client(location) as ssh_client:
await ssh_client.run(f"mkdir -p {temp_dir}")
async with self._get_ssh_client_process(
location=location,
command=f"mkdir -p {temp_dir}",
stdout=asyncio.subprocess.STDOUT,
stderr=asyncio.subprocess.STDOUT,
) as proc:
result = await proc.wait()
if result.returncode == 0:
return temp_dir
else:
raise WorkflowExecutionException(
f"Error while creating directory {temp_dir}: {result.stdout.strip()}"
)
return temp_dir

def _get_volumes(self, location: str) -> MutableSequence[str]:
Expand Down Expand Up @@ -148,8 +160,17 @@ async def _copy_remote_to_remote(
# If a temporary location was created, delete it
if temp_dir is not None:
for location in effective_locations:
async with self._get_ssh_client(location.name) as ssh_client:
await ssh_client.run(f"rm -rf {temp_dir}")
async with self._get_ssh_client_process(
location=location.name,
command=f"rm -rf {temp_dir}",
stdout=asyncio.subprocess.STDOUT,
stderr=asyncio.subprocess.STDOUT,
) as proc:
result = await proc.wait()
if result.returncode != 0:
raise WorkflowExecutionException(
f"Error while removing directory {temp_dir}: {result.stdout.strip()}"
)

async def _copy_remote_to_remote_single(
self,
Expand Down Expand Up @@ -181,7 +202,7 @@ async def _copy_local_to_remote(
shared_path = self._get_shared_path(location.name, dst)
if shared_path is None:
temp_dir = await self._get_tmpdir(location.name)
async with self._get_ssh_client(location.name) as ssh_client:
async with self._get_ssh_client_process(location.name) as ssh_client:
await asyncssh.scp(
src, (ssh_client, temp_dir), preserve=True, recurse=True
)
Expand All @@ -203,9 +224,13 @@ async def _copy_local_to_remote(
await asyncio.gather(*copy_tasks)
# If a temporary location was created, delete it
if temp_dir is not None:
for location in effective_locations:
async with self._get_ssh_client(location.name) as ssh_client:
await ssh_client.run(f"rm -rf {temp_dir}")
delete_command = ["rm", "-rf", temp_dir]
await asyncio.gather(
*(
asyncio.create_task(self.run(location, delete_command))
for location in effective_locations
)
)

async def _copy_local_to_remote_single(
self,
Expand All @@ -217,7 +242,7 @@ async def _copy_local_to_remote_single(
) -> None:
shared_path = self._get_shared_path(location.name, dst)
if shared_path is not None:
async with self._get_ssh_client(location.name) as ssh_client:
async with self._get_ssh_client_process(location.name) as ssh_client:
await asyncssh.scp(
src, (ssh_client, shared_path), preserve=True, recurse=True
)
Expand All @@ -230,7 +255,7 @@ async def _copy_remote_to_local(
) -> None:
shared_path = self._get_shared_path(location.name, src)
if shared_path is not None:
async with self._get_ssh_client(location.name) as ssh_client:
async with self._get_ssh_client_process(location.name) as ssh_client:
await asyncssh.scp(
(ssh_client, shared_path), dst, preserve=True, recurse=True
)
Expand All @@ -256,7 +281,7 @@ async def _copy_remote_to_local(
contents, _ = await self.run(location, copy_command, capture_output=True)
contents = contents.split()
scp_tasks = []
async with self._get_ssh_client(location.name) as ssh_client:
async with self._get_ssh_client_process(location.name) as ssh_client:
for content in contents:
scp_tasks.append(
asyncio.create_task(
Expand Down Expand Up @@ -291,27 +316,47 @@ async def _deploy_node(
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"EXECUTING {deploy_command}")
async with self._get_ssh_client(name) as ssh_client:
result = await ssh_client.run(deploy_command)
output = result.stdout
search_result = re.findall(f"({node}-[0-9]+).*", output, re.MULTILINE)
if search_result:
if name not in self.jobs_table:
self.jobs_table[name] = []
self.jobs_table[name].append(search_result[0])
if logger.isEnabledFor(logging.INFO):
logger.info(f"Deployed {name} on {search_result[0]}")
else:
raise Exception
async with self._get_ssh_client_process(
location=name,
command=deploy_command,
stdout=asyncio.subprocess.STDOUT,
stderr=asyncio.subprocess.STDOUT,
) as proc:
result = await proc.wait()
output = result.stdout.strip()
if result.returncode == 0:
search_result = re.findall(f"({node}-[0-9]+).*", output, re.MULTILINE)
if search_result:
if name not in self.jobs_table:
self.jobs_table[name] = []
self.jobs_table[name].append(search_result[0])
if logger.isEnabledFor(logging.INFO):
logger.info(f"Deployed {name} on {search_result[0]}")
else:
raise WorkflowExecutionException(
f"Failed to deploy {name}: {output}"
)
else:
raise WorkflowExecutionException(f"Failed to deploy {name}: {output}")

async def _undeploy_node(self, name: str, job_id: str):
undeploy_command = f"occam-kill {job_id}"
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"EXECUTING {undeploy_command}")
async with self._get_ssh_client(name) as ssh_client:
await ssh_client.run(undeploy_command)
if logger.isEnabledFor(logging.INFO):
logger.info(f"Killed {job_id}")
async with self._get_ssh_client_process(
location=name,
command=undeploy_command,
stdout=asyncio.subprocess.STDOUT,
stderr=asyncio.subprocess.STDOUT,
) as proc:
result = await proc.wait()
if result.returncode == 0:
if logger.isEnabledFor(logging.INFO):
logger.info(f"Killed {job_id}")
else:
raise WorkflowExecutionException(
f"Failed to undeploy {name}: {result.stdout.strip()}"
)

async def deploy(self, external: bool) -> None:
await super().deploy(external)
Expand Down Expand Up @@ -391,23 +436,26 @@ async def run(
job_name=job_name,
)
occam_command = f"occam-exec {location} sh -c '{command}'"
async with self._get_ssh_client(location.name) as ssh_client:
result = await asyncio.wait_for(
ssh_client.run(occam_command), timeout=timeout
)
if capture_output:
lines = (line for line in result.stdout.split("\n"))
out = ""
for line in lines:
if line.startswith("Trying to exec commands into container"):
break
try:
line = next(lines)
out = line.strip(" \r\t")
except StopIteration:
return out
for line in lines:
out = "\n".join([out, line.strip(" \r\t")])
return out, result.returncode
else:
return None
async with self._get_ssh_client_process(
location=location.name,
command=occam_command,
stdout=asyncio.subprocess.STDOUT,
stderr=asyncio.subprocess.STDOUT,
) as proc:
result = await proc.wait(timeout=timeout)
if capture_output:
lines = (line for line in result.stdout.split("\n"))
out = ""
for line in lines:
if line.startswith("Trying to exec commands into container"):
break
try:
line = next(lines)
out = line.strip(" \r\t")
except StopIteration:
return out, result.returncode
for line in lines:
out = "\n".join([out, line.strip(" \r\t")])
return out, result.returncode
else:
return None
3 changes: 1 addition & 2 deletions streamflow/deployment/connector/schemas/ssh.json
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@
}
},
"required": [
"nodes",
"username"
"nodes"
],
"additionalProperties": false
}
Loading

0 comments on commit 51faa3e

Please sign in to comment.