Skip to content

Commit

Permalink
close figures in plot testing
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Apr 4, 2024
1 parent 4b87cfa commit 60623e4
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion tests/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def test_plot_molecules_per_cell_and_gene():
assert ax.get_xlabel() == "Molecules per gene (log10 scale)"

assert ax.get_ylabel() == "Frequency"
plt.close()


def test_cell_types_default_colors(mock_tsne, mock_clusters):
Expand Down Expand Up @@ -274,30 +275,35 @@ def test_plot_tsne_by_cell_sizes(mock_data, mock_tsne):
0.2,
0.8,
), "Color limits should be set to vmin and vmax"
plt.close()


def test_plot_gene_expression(mock_gene_data, mock_tsne):
genes = ["gene_0", "gene_1"]
fig, axs = plot_gene_expression(mock_gene_data, mock_tsne, genes, plot_scale=True)
assert isinstance(fig, plt.Figure)
plt.close()


def test_plot_gene_expression_missing_genes(mock_gene_data, mock_tsne):
genes = ["gene_0", "nonexistent_gene"]
fig, axs = plot_gene_expression(mock_gene_data, mock_tsne, genes)
assert isinstance(fig, plt.Figure) # Expect a warning but still a plot
plt.close()


def test_plot_gene_expression_no_genes(mock_gene_data, mock_tsne):
with pytest.raises(ValueError):
plot_gene_expression(mock_gene_data, mock_tsne, ["nonexistent_gene"])
plt.close()


def test_plot_diffusion_components_with_anndata(mock_anndata, mock_dm_res):
fig, axs = plot_diffusion_components(mock_anndata)
assert isinstance(fig, plt.Figure)
for ax in axs.values():
assert isinstance(ax, plt.Axes)
plt.close()


def test_plot_diffusion_components_with_dataframe(mock_tsne, mock_dm_res):
Expand All @@ -306,16 +312,19 @@ def test_plot_diffusion_components_with_dataframe(mock_tsne, mock_dm_res):
assert isinstance(fig, plt.Figure)
for ax in axs.values():
assert isinstance(ax, plt.Axes)
plt.close()


def test_plot_diffusion_components_key_error_embedding(mock_anndata):
with pytest.raises(KeyError):
plot_diffusion_components(mock_anndata, embedding_basis="NonexistentKey")
plt.close()


def test_plot_diffusion_components_key_error_dm_res(mock_anndata):
with pytest.raises(KeyError):
plot_diffusion_components(mock_anndata, dm_res="NonexistentKey")
plt.close()


def test_plot_diffusion_components_default_args(mock_anndata):
Expand All @@ -324,38 +333,44 @@ def test_plot_diffusion_components_default_args(mock_anndata):
assert (
ax.collections[0].get_array().data.shape[0] == 100
) # Checking data points
plt.close()


def test_plot_diffusion_components_custom_args(mock_anndata):
fig, axs = plot_diffusion_components(mock_anndata, s=10, edgecolors="r")
for ax in axs.values():
assert ax.collections[0].get_edgecolors().all() == np.array([1, 0, 0, 1]).all()
assert ax.collections[0].get_sizes()[0] == 10
plt.close()


# Test with AnnData and all keys available
def test_plot_palantir_results_anndata(mock_anndata):
fig = plot_palantir_results(mock_anndata)
assert isinstance(fig, plt.Figure)
plt.close()


# Test with DataFrame and PResults
def test_plot_palantir_results_dataframe(mock_tsne, mock_presults):
fig = plot_palantir_results(mock_tsne, pr_res=mock_presults)
assert isinstance(fig, plt.Figure)
plt.close()


# Test KeyError for missing embedding_basis
def test_plot_palantir_results_key_error_embedding(mock_anndata):
with pytest.raises(KeyError):
plot_palantir_results(mock_anndata, embedding_basis="NonexistentKey")
plt.close()


# Test KeyError for missing Palantir results in AnnData
def test_plot_palantir_results_key_error_palantir(mock_anndata):
mock_anndata.obs = pd.DataFrame(index=mock_anndata.obs_names) # Clearing obs
with pytest.raises(KeyError):
plot_palantir_results(mock_anndata)
plt.close()


# Test plotting with custom arguments
Expand All @@ -364,44 +379,52 @@ def test_plot_palantir_results_custom_args(mock_anndata):
ax = fig.axes[0] # Assuming first subplot holds the first scatter plot
assert np.all(ax.collections[0].get_edgecolors() == [1, 0, 0, 1])
assert ax.collections[0].get_sizes()[0] == 10
plt.close()


# Test with AnnData and all keys available
def test_plot_terminal_state_probs_anndata(mock_anndata, mock_cells):
fig = plot_terminal_state_probs(mock_anndata, mock_cells)
assert isinstance(fig, plt.Figure)
plt.close()


# Test with DataFrame and PResults
def test_plot_terminal_state_probs_dataframe(mock_data, mock_presults, mock_cells):
fig = plot_terminal_state_probs(mock_data, mock_cells, pr_res=mock_presults)
assert isinstance(fig, plt.Figure)
plt.close()


# Test ValueError for missing pr_res in DataFrame input
def test_plot_terminal_state_probs_value_error(mock_data, mock_cells):
with pytest.raises(ValueError):
plot_terminal_state_probs(mock_data, mock_cells)
plt.close()


# Test plotting with custom arguments
def test_plot_terminal_state_probs_custom_args(mock_anndata, mock_cells):
fig = plot_terminal_state_probs(mock_anndata, mock_cells, linewidth=2.0)
ax = fig.axes[0] # Assuming first subplot holds the first bar plot
assert ax.patches[0].get_linewidth() == 2.0
plt.close()


# Test if the function uses the correct keys and raises appropriate errors
def test_plot_branch_selection_keys(mock_anndata):
# This will depend on how your mock_anndata is structured
with pytest.raises(KeyError):
plot_branch_selection(mock_anndata, pseudo_time_key="invalid_key")
plt.close()

with pytest.raises(KeyError):
plot_branch_selection(mock_anndata, fate_prob_key="invalid_key")
plt.close()

with pytest.raises(KeyError):
plot_branch_selection(mock_anndata, embedding_basis="invalid_basis")
plt.close()


# Test the scatter custom arguments
Expand All @@ -417,6 +440,7 @@ def test_plot_branch_selection_custom_args(mock_anndata):

alpha1 = scatter1.get_alpha()
assert alpha1 == 0.5
plt.close()


# Test 1: Basic functionality
Expand All @@ -425,7 +449,7 @@ def test_plot_gene_trends_legacy_basic(mock_gene_trends):
axes = fig.axes
# Check if the number of subplots matches the number of genes
assert len(axes) == 2
# Perform additional checks on axes content if needed
plt.close()


# Test 2: Custom gene list
Expand All @@ -436,6 +460,7 @@ def test_plot_gene_trends_legacy_custom_genes(mock_gene_trends):
assert len(axes) == 1
# Check if the title of the subplot matches the custom gene
assert axes[0].get_title() == "Gene1"
plt.close()


# Test 3: Color consistency
Expand All @@ -446,20 +471,23 @@ def test_plot_gene_trends_legacy_color_consistency(mock_gene_trends):
colors_2 = [line.get_color() for line in axes[1].lines]
# Check if the colors are consistent across different genes
assert colors_1 == colors_2
plt.close()


# Test 1: Basic Functionality with AnnData
def test_plot_gene_trends_basic_anndata(mock_anndata):
fig = plot_gene_trends(mock_anndata)
axes = fig.axes
assert len(axes) == mock_anndata.n_vars
plt.close()


# Test 2: Basic Functionality with Dictionary
def test_plot_gene_trends_basic_dict(mock_gene_trends):
fig = plot_gene_trends(mock_gene_trends)
axes = fig.axes
assert len(axes) == 2 # Mock data contains 2 genes
plt.close()


# Test 3: Custom Genes
Expand All @@ -468,19 +496,22 @@ def test_plot_gene_trends_custom_genes(mock_anndata):
axes = fig.axes
assert len(axes) == 1
assert axes[0].get_title() == "gene_1"
plt.close()


# Test 4: Custom Branch Names
def test_plot_gene_trends_custom_branch_names(mock_anndata):
fig = plot_gene_trends(mock_anndata, branch_names=["a", "b"])
axes = fig.axes
assert len(axes) == mock_anndata.n_vars
plt.close()


# Test 5: Error Handling - Invalid Data Type
def test_plot_gene_trends_invalid_data_type():
with pytest.raises(ValueError):
plot_gene_trends("invalid_data_type")
plt.close()


# Test 6: Error Handling - Missing Key
Expand All @@ -489,12 +520,14 @@ def test_plot_gene_trends_missing_key(mock_anndata):
plot_gene_trends(
mock_anndata, gene_trend_key="missing_key", branch_names="missing_branch"
)
plt.close()


@pytest.mark.parametrize("wrong_type", [123, True, 1.23, "unknown_key"])
def test_plot_stats_key_errors(mock_anndata, wrong_type):
with pytest.raises(KeyError):
plot_stats(mock_anndata, x=wrong_type, y="palantir_pseudotime")
plt.close()


def test_plot_stats_basic(mock_anndata):
Expand All @@ -510,6 +543,7 @@ def test_plot_stats_optional_parameters(mock_anndata):
y="palantir_entropy",
color="palantir_entropy",
)
plt.close()


def test_plot_stats_masking(mock_anndata):
Expand All @@ -522,6 +556,7 @@ def test_plot_stats_masking(mock_anndata):
y="palantir_entropy",
masks_key="branch_masks",
)
plt.close()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -555,19 +590,23 @@ def test_plot_branch_functionality(mock_anndata):
def test_plot_trend_type_validation(mock_anndata):
with pytest.raises(TypeError):
plot_trend("string_instead_of_anndata", "a", "gene_1")
plt.close()
with pytest.raises(TypeError):
plot_trend(mock_anndata, 123, "gene_1")
plt.close()


def test_plot_trend_value_validation(mock_anndata):
with pytest.raises((ValueError, KeyError)):
plot_trend(mock_anndata, "nonexistent_branch", "gene_1")
plt.close()


def test_plot_trend_plotting(mock_anndata):
fig, ax = plot_trend(mock_anndata, "a", "gene_1")
assert isinstance(fig, plt.Figure)
assert isinstance(ax, plt.Axes)
plt.close()


def test_plot_gene_trend_heatmaps(mock_anndata):
Expand Down

0 comments on commit 60623e4

Please sign in to comment.