Skip to content

Commit

Permalink
DEV: Support multiple sources and sinks to fix #4
Browse files Browse the repository at this point in the history
  • Loading branch information
Vini2 committed Jun 18, 2024
1 parent 5f9fe52 commit d2146e4
Showing 1 changed file with 73 additions and 41 deletions.
114 changes: 73 additions & 41 deletions reneo/workflow/scripts/reneo.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,21 +882,10 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
)
kwargs["logger"].debug(f"Identified candidate sinks: {sink_candidates}")

if len(source_candidates) == 1 and len(sink_candidates) == 1:
kwargs["logger"].debug(f"Found source: {source_candidates[0]}")
kwargs["logger"].debug(f"Found sink: {sink_candidates[0]}")
if len(source_candidates) > 0 and len(sink_candidates) > 0:

source_node = kwargs["unitig_names_rev"][source_candidates[0][:-1]]
sink_node = kwargs["unitig_names_rev"][sink_candidates[0][:-1]]

candidate_nodes.remove(source_node)
candidate_nodes.insert(0, source_node)
candidate_nodes.remove(sink_node)
candidate_nodes.append(sink_node)

kwargs["logger"].debug(
f"Ordered candidate_nodes: {candidate_nodes}"
)
source_node_indices = [kwargs["unitig_names_rev"][x[:-1]] for x in source_candidates]
sink_node_indices = [kwargs["unitig_names_rev"][x[:-1]] for x in sink_candidates]

# Create refined directed graph for flow network
# ----------------------------------------------------------------------
Expand Down Expand Up @@ -974,8 +963,8 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
u_name = kwargs["unitig_names_rev"][u[:-1]]
v_name = kwargs["unitig_names_rev"][v[:-1]]

u_index = candidate_nodes.index(u_name)
v_index = candidate_nodes.index(v_name)
u_index = candidate_nodes.index(u_name) + 1
v_index = candidate_nodes.index(v_name) + 1

edge_list_indices[u_index] = u
edge_list_indices[v_index] = v
Expand All @@ -991,7 +980,7 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
cov_upper_bound = int(max_comp_cov * kwargs["alpha"])

kwargs["logger"].debug(
f"({v}, {u}), {juction_cov}, {cov_lower_bound}, {cov_upper_bound}"
f"({v}, {u}), ({u_index}, {v_index}), {juction_cov}, {cov_lower_bound}, {cov_upper_bound}"
)

if juction_cov == 0:
Expand Down Expand Up @@ -1027,13 +1016,13 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
u_pred_name = kwargs["unitig_names_rev"][
u_pred[:-1]
]
u_pred_index = candidate_nodes.index(u_pred_name)
u_pred_index = candidate_nodes.index(u_pred_name) + 1
u_pred_cov = kwargs["unitig_coverages"][u_pred[:-1]]
u_cov = kwargs["unitig_coverages"][u[:-1]]

if (
v_index != 0
and u_index != 0
(v_index - 1) not in source_node_indices
and (u_index - 1) not in source_node_indices
and u_pred_index != v_index
):
if (
Expand All @@ -1055,15 +1044,15 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
v_succ_name = kwargs["unitig_names_rev"][
v_succ[:-1]
]
v_succ_index = candidate_nodes.index(v_succ_name)
v_succ_index = candidate_nodes.index(v_succ_name) + 1
v_succ_cov = kwargs["unitig_coverages"][v_succ[:-1]]
v_cov = kwargs["unitig_coverages"][v[:-1]]

if (
v_succ_index != 0
and u_index != 0
and v_index != 0
and v_index != len(candidate_nodes)
(v_succ_index - 1) not in source_node_indices
and (u_index - 1) not in source_node_indices
and (v_index - 1) not in source_node_indices
and (v_index - 1) not in sink_node_indices
and v_succ_index != u_index
):
if (
Expand Down Expand Up @@ -1096,12 +1085,10 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
u_pred_name = kwargs["unitig_names_rev"][
u_pred[:-1]
]
u_pred_index = candidate_nodes.index(
u_pred_name
)
u_pred_index = candidate_nodes.index(u_pred_name) + 1
if (
v_index != 0
and u_index != 0
(v_index - 1) not in source_node_indices
and (u_index - 1) not in source_node_indices
and u_pred_index != v_index
):
subpaths[subpath_count] = [
Expand All @@ -1124,14 +1111,12 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
v_succ_name = kwargs["unitig_names_rev"][
v_succ[:-1]
]
v_succ_index = candidate_nodes.index(
v_succ_name
)
v_succ_index = candidate_nodes.index(v_succ_name) + 1
if (
v_succ_index != 0
and u_index != 0
and v_index != 0
and v_index != len(candidate_nodes)
(v_succ_index - 1) not in source_node_indices
and (u_index - 1) not in source_node_indices
and (v_index - 1) not in source_node_indices
and (v_index - 1) not in sink_node_indices
and v_succ_index != u_index
):
subpaths[subpath_count] = [
Expand All @@ -1144,6 +1129,49 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
)
subpath_count += 1

# Add common start to source links
for source_v in source_candidates:
source_node_index = (
candidate_nodes.index(kwargs["unitig_names_rev"][source_v[:-1]]) + 1
)
source_node_cov = kwargs["unitig_coverages"][source_v[:-1]]
cov_upper_bound = int(max_comp_cov * kwargs["alpha"])

network_edges.append(
(
0,
source_node_index,
source_node_cov,
cov_upper_bound,
)
)

subpaths[subpath_count] = [0, source_node_index]
subpath_count += 1

# Add common sink to end links
for sink_v in sink_candidates:
sink_node_index = (
candidate_nodes.index(kwargs["unitig_names_rev"][sink_v[:-1]]) + 1
)
sink_node_cov = kwargs["unitig_coverages"][sink_v[:-1]]
cov_upper_bound = int(max_comp_cov * kwargs["alpha"])

network_edges.append(
(
sink_node_index,
len(candidate_nodes) + 1,
sink_node_cov,
cov_upper_bound,
)
)

subpaths[subpath_count] = [
sink_node_index,
len(candidate_nodes) + 1,
]
subpath_count += 1

kwargs["logger"].debug(f"edge_list_indices: {edge_list_indices}")
kwargs["logger"].debug(f"subpaths: {subpaths}")

Expand Down Expand Up @@ -1198,7 +1226,7 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
try:
candidate_paths = list(
nx.all_simple_paths(
G_path, 0, len(candidate_nodes) - 1
G_path, 0, len(candidate_nodes) + 1
)
)

Expand All @@ -1210,9 +1238,13 @@ def worker_resolve_components(component_queue, results_queue, **kwargs):
# Get mapped unitigs in order from the flow network
path_order = []
for path_edge in candidate_paths[0]:
path_order.append(
edge_list_indices[path_edge]
)
if not (
path_edge == 0
or path_edge == len(candidate_nodes) + 1
):
path_order.append(
edge_list_indices[path_edge]
)

kwargs["logger"].debug(
f"path_order: {path_order}"
Expand Down

0 comments on commit d2146e4

Please sign in to comment.