diff --git a/examples/tuning/get_important_pattern.py b/examples/tuning/get_important_pattern.py index 39fcfb2f..c93d7114 100644 --- a/examples/tuning/get_important_pattern.py +++ b/examples/tuning/get_important_pattern.py @@ -10,6 +10,7 @@ import seaborn as sns from mlxtend.frequent_patterns import apriori from mlxtend.preprocessing import TransactionEncoder +from mlxtend.frequent_patterns import association_rules from networkx import parse_adjlist from scipy import stats @@ -17,13 +18,13 @@ ascending = False -def get_important_pattern(test_accs, vis=True, alpha=0.8, title=""): +def get_important_pattern(test_accs, vis=True, alpha=0.8, title="",test_acc_names=None): medians = [np.median(group) for group in test_accs] _, p_value = stats.kruskal(*test_accs) if vis: fig = plt.figure(figsize=(12, 4)) sns.boxplot(data=test_accs) - plt.xticks(list(range(len(test_accs))), [f"{i}" for i in range(len(test_accs))]) + plt.xticks(list(range(len(test_accs))), ([f"{i}" for i in range(len(test_accs))] if test_acc_names is None else test_acc_names),rotation=45, fontsize=10) plt.title(title) plt.show() if p_value < alpha: @@ -71,7 +72,7 @@ def get_com(step2_data, r=2, alpha=0.8, columns=None, vis=True): test_accs_arrays.append({"name": g[0], metric_name: list(g[1][metric_name])}) test_accs = [i[metric_name] for i in test_accs_arrays] test_acc_names = [i["name"] for i in test_accs_arrays] - final_ranks = get_important_pattern(test_accs, alpha=alpha, title=" ".join(list(com)), vis=vis) + final_ranks = get_important_pattern(test_accs, alpha=alpha, title=" ".join(list(com)), vis=vis,test_acc_names=[" ".join(test_acc_name) for test_acc_name in test_acc_names]) if len(final_ranks) > 0: max_rank = max(final_ranks) max_rank_count = final_ranks.count(max_rank) @@ -82,9 +83,57 @@ def get_com(step2_data, r=2, alpha=0.8, columns=None, vis=True): print(f"index={index},name={test_acc_name},rank={rank}") ans.append(test_acc_name if isinstance(test_acc_name, tuple) else (test_acc_name, )) return ans - - -def get_frequent_itemsets(step2_data, threshold_per=0.1): +def draw_graph(rules, rules_to_show): + import networkx as nx + G1 = nx.DiGraph() + + color_map=[] + N = 50 + colors = np.random.rand(N) + strs=['R0', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11'] + + + for i in range (rules_to_show): + G1.add_nodes_from(["R"+str(i)]) + + + for a in rules.iloc[i]['antecedents']: + + G1.add_nodes_from([a]) + + G1.add_edge(a, "R"+str(i), color=colors[i] , weight = 2) + + for c in rules.iloc[i]['consequents']: + + G1.add_nodes_from([c]) + + G1.add_edge("R"+str(i), c, color=colors[i], weight=2) + + for node in G1: + found_a_string = False + for item in strs: + if node==item: + found_a_string = True + if found_a_string: + color_map.append('yellow') + else: + color_map.append('green') + + + + edges = G1.edges() + colors = [G1[u][v]['color'] for u,v in edges] + weights = [G1[u][v]['weight'] for u,v in edges] + + pos = nx.spring_layout(G1, k=16, scale=1) + nx.draw(G1, pos, node_color = color_map, edge_color=colors, width=weights, font_size=16, with_labels=False) + + for p in pos: # raise text positions + pos[p][1] += 0.07 + nx.draw_networkx_labels(G1, pos) + plt.show() + +def get_frequent_itemsets(step2_data, threshold_per=0.1,vis=False): threshold = int(len(step2_data) * threshold_per) df_sorted = step2_data.sort_values(metric_name, ascending=ascending) top_10_percent = df_sorted.head(threshold) @@ -95,7 +144,16 @@ def get_frequent_itemsets(step2_data, threshold_per=0.1): df = pd.DataFrame(te_ary, columns=te.columns_) frequent_itemsets = apriori(df, min_support=0.3, use_colnames=True) # print(frequent_itemsets) - # rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.5) + rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.5) + if vis: + # print(frequent_itemsets) + # print(frequent_itemsets) + # draw_graph(rules=rules,rules_to_show=10) + frequent_itemsets_copy=frequent_itemsets.copy() + frequent_itemsets_copy=frequent_itemsets_copy.sort_values(by="support") + frequent_itemsets_copy.plot(x="itemsets",y="support",kind="bar") + plt.xticks(rotation=30, fontsize=7) + # print(type(rules)) return [tuple(a) for a in frequent_itemsets["itemsets"]] @@ -111,7 +169,7 @@ def summary_pattern(data_path, alpha=0.8, vis=False): step2_origin_data = pd.read_csv(data_path) step2_data = step2_origin_data.dropna() com_ans = get_com_all(step2_data, vis=vis, alpha=alpha) - apr_ans = get_frequent_itemsets(step2_data) + apr_ans = get_frequent_itemsets(step2_data,vis=vis) return list(set(com_ans) & set(apr_ans)) @@ -136,11 +194,11 @@ def list_files(directories, file_name="best_test_acc.csv", alpha=0.8, vis=False) if __name__ == "__main__": - directories = [] - for path in Path('/home/zyxing/dance/examples/tuning').iterdir(): - if path.is_dir(): - if str(path.name).startswith("cluster"): - directories.append(path) - list_files(directories) - - # print(summary_pattern("/home/zyxing/dance/examples/tuning/cta_scdeepsort/328_138/results/pipeline/best_test_acc.csv",alpha=0.3,vis=True)) + # directories = [] + # for path in Path('/home/zyxing/dance/examples/tuning').iterdir(): + # if path.is_dir(): + # if str(path.name).startswith("cluster"): + # directories.append(path) + # list_files(directories) + + print(summary_pattern("/home/zyxing/dance/examples/tuning/cluster_graphsc/mouse_ES_cell/results/pipeline/best_test_acc.csv",alpha=0.3,vis=False))