Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Dec 18, 2024
1 parent 8e1d33f commit 1087e3c
Showing 1 changed file with 74 additions and 16 deletions.
90 changes: 74 additions & 16 deletions examples/tuning/get_important_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@
import seaborn as sns
from mlxtend.frequent_patterns import apriori
from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import association_rules
from networkx import parse_adjlist
from scipy import stats

metric_name = "acc"
ascending = False


def get_important_pattern(test_accs, vis=True, alpha=0.8, title=""):
def get_important_pattern(test_accs, vis=True, alpha=0.8, title="",test_acc_names=None):
medians = [np.median(group) for group in test_accs]
_, p_value = stats.kruskal(*test_accs)
if vis:
fig = plt.figure(figsize=(12, 4))
sns.boxplot(data=test_accs)
plt.xticks(list(range(len(test_accs))), [f"{i}" for i in range(len(test_accs))])
plt.xticks(list(range(len(test_accs))), ([f"{i}" for i in range(len(test_accs))] if test_acc_names is None else test_acc_names),rotation=45, fontsize=10)
plt.title(title)
plt.show()
if p_value < alpha:
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_com(step2_data, r=2, alpha=0.8, columns=None, vis=True):
test_accs_arrays.append({"name": g[0], metric_name: list(g[1][metric_name])})
test_accs = [i[metric_name] for i in test_accs_arrays]
test_acc_names = [i["name"] for i in test_accs_arrays]
final_ranks = get_important_pattern(test_accs, alpha=alpha, title=" ".join(list(com)), vis=vis)
final_ranks = get_important_pattern(test_accs, alpha=alpha, title=" ".join(list(com)), vis=vis,test_acc_names=[" ".join(test_acc_name) for test_acc_name in test_acc_names])
if len(final_ranks) > 0:
max_rank = max(final_ranks)
max_rank_count = final_ranks.count(max_rank)
Expand All @@ -82,9 +83,57 @@ def get_com(step2_data, r=2, alpha=0.8, columns=None, vis=True):
print(f"index={index},name={test_acc_name},rank={rank}")
ans.append(test_acc_name if isinstance(test_acc_name, tuple) else (test_acc_name, ))
return ans


def get_frequent_itemsets(step2_data, threshold_per=0.1):
def draw_graph(rules, rules_to_show):
import networkx as nx
G1 = nx.DiGraph()

color_map=[]
N = 50
colors = np.random.rand(N)
strs=['R0', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11']


for i in range (rules_to_show):
G1.add_nodes_from(["R"+str(i)])


for a in rules.iloc[i]['antecedents']:

G1.add_nodes_from([a])

G1.add_edge(a, "R"+str(i), color=colors[i] , weight = 2)

for c in rules.iloc[i]['consequents']:

G1.add_nodes_from([c])

G1.add_edge("R"+str(i), c, color=colors[i], weight=2)

for node in G1:
found_a_string = False
for item in strs:
if node==item:
found_a_string = True
if found_a_string:
color_map.append('yellow')
else:
color_map.append('green')



edges = G1.edges()
colors = [G1[u][v]['color'] for u,v in edges]
weights = [G1[u][v]['weight'] for u,v in edges]

pos = nx.spring_layout(G1, k=16, scale=1)
nx.draw(G1, pos, node_color = color_map, edge_color=colors, width=weights, font_size=16, with_labels=False)

for p in pos: # raise text positions
pos[p][1] += 0.07
nx.draw_networkx_labels(G1, pos)
plt.show()

def get_frequent_itemsets(step2_data, threshold_per=0.1,vis=False):
threshold = int(len(step2_data) * threshold_per)
df_sorted = step2_data.sort_values(metric_name, ascending=ascending)
top_10_percent = df_sorted.head(threshold)
Expand All @@ -95,7 +144,16 @@ def get_frequent_itemsets(step2_data, threshold_per=0.1):
df = pd.DataFrame(te_ary, columns=te.columns_)
frequent_itemsets = apriori(df, min_support=0.3, use_colnames=True)
# print(frequent_itemsets)
# rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.5)
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.5)
if vis:
# print(frequent_itemsets)
# print(frequent_itemsets)
# draw_graph(rules=rules,rules_to_show=10)
frequent_itemsets_copy=frequent_itemsets.copy()
frequent_itemsets_copy=frequent_itemsets_copy.sort_values(by="support")
frequent_itemsets_copy.plot(x="itemsets",y="support",kind="bar")
plt.xticks(rotation=30, fontsize=7)
# print(type(rules))
return [tuple(a) for a in frequent_itemsets["itemsets"]]


Expand All @@ -111,7 +169,7 @@ def summary_pattern(data_path, alpha=0.8, vis=False):
step2_origin_data = pd.read_csv(data_path)
step2_data = step2_origin_data.dropna()
com_ans = get_com_all(step2_data, vis=vis, alpha=alpha)
apr_ans = get_frequent_itemsets(step2_data)
apr_ans = get_frequent_itemsets(step2_data,vis=vis)
return list(set(com_ans) & set(apr_ans))


Expand All @@ -136,11 +194,11 @@ def list_files(directories, file_name="best_test_acc.csv", alpha=0.8, vis=False)


if __name__ == "__main__":
directories = []
for path in Path('/home/zyxing/dance/examples/tuning').iterdir():
if path.is_dir():
if str(path.name).startswith("cluster"):
directories.append(path)
list_files(directories)

# print(summary_pattern("/home/zyxing/dance/examples/tuning/cta_scdeepsort/328_138/results/pipeline/best_test_acc.csv",alpha=0.3,vis=True))
# directories = []
# for path in Path('/home/zyxing/dance/examples/tuning').iterdir():
# if path.is_dir():
# if str(path.name).startswith("cluster"):
# directories.append(path)
# list_files(directories)

print(summary_pattern("/home/zyxing/dance/examples/tuning/cluster_graphsc/mouse_ES_cell/results/pipeline/best_test_acc.csv",alpha=0.3,vis=False))

0 comments on commit 1087e3c

Please sign in to comment.