From 6f14df07061dc2f86921bf80f88d314907ad5461 Mon Sep 17 00:00:00 2001 From: Nathan Vieira Marcelino Date: Thu, 29 Aug 2024 18:24:43 -0300 Subject: [PATCH] fix: load bu number of clusters --- .../vameApi/app/services/project_service.py | 85 +++++++++++-------- 1 file changed, 49 insertions(+), 36 deletions(-) diff --git a/src/services/vameApi/app/services/project_service.py b/src/services/vameApi/app/services/project_service.py index 9db1c9f..c48b23b 100644 --- a/src/services/vameApi/app/services/project_service.py +++ b/src/services/vameApi/app/services/project_service.py @@ -120,33 +120,38 @@ def load_project(project_path: Path): symlink.symlink_to(project_path) states_path = Path(project_path) / "states" / "states.json" - states = json.load(open(states_path)) if os.path.exists(states_path) else None - config = yaml.safe_load(open(config_path, "r")) if config_path.exists() else None + # Extract parameters from the config + parametrizations = config["parametrizations"] # e.g., ["hmm", "kmeans"] + n_cluster = config["n_cluster"] # e.g., 28 + + # Create the visualization dictionary dynamically + visualization = { + param: get_visualization_images(project_path, f"{param}-{n_cluster}") + for param in parametrizations + } + images = dict( evaluation=get_evaluation_images(project_path), - visualization=dict( - hmm=get_visualization_images(project_path, 'hmm-15'), - kmeans=get_visualization_images(project_path, 'kmeans-15') - ) + visualization=visualization ) - videos = dict( - motif=dict( - hmm=get_motif_videos(project_path, 'hmm-15'), - kmeans=get_motif_videos(project_path, 'kmeans-15') - ), - community=dict( - hmm=get_community_videos(project_path, 'hmm-15'), - kmeans=get_community_videos(project_path, 'kmeans-15') - ) - ) + # Create the videos dictionary dynamically + videos = { + category: { + param: get_motif_videos(project_path, f"{param}-{n_cluster}") + for param in parametrizations + } + for category in ['motif', 'community'] + } has_latent_vector_files = False if config: - has_latent_vector_files = all(map(lambda video: (get_video_results_path(video, project_path) / f"latent_vector_{video}.npy").exists(), config["video_sets"])) + has_latent_vector_files = all( + map(lambda video: (get_video_results_path(video, project_path) / f"latent_vector_{video}.npy").exists(), config["video_sets"]) + ) has_communities = (project_path / 'cohort_community_label.npy').exists() @@ -154,32 +159,40 @@ def load_project(project_path: Path): original_csvs_location = original_videos_location / 'pose_estimation' # Get all files in the original data directory - original_videos = list(map(lambda file: str(file), get_files(original_videos_location))) - original_csvs = list(map(lambda file: str(file), get_files(original_csvs_location))) - - motif_videos_created_hmm=all(map(lambda videos: len(videos) > 0, videos["motif"]["hmm"].values())) - motif_videos_created_kmeans=all(map(lambda videos: len(videos) > 0, videos["motif"]["kmeans"].values())) - - community_videos_created_hmm = all(map(lambda videos: len(videos) > 0, videos["community"]["kmeans"].values())) - community_videos_created_kmeans = all(map(lambda videos: len(videos) > 0, videos["community"]["hmm"].values())) - - umaps_created_hmm = any(map(lambda videos: len(videos) > 0, images["visualization"]["hmm"].values())) - - umaps_created_kmeans = any(map(lambda videos: len(videos) > 0, images["visualization"]["kmeans"].values())) + original_videos = list(map(str, get_files(original_videos_location))) + original_csvs = list(map(str, get_files(original_csvs_location))) + + # Check if motif videos were created for each parametrization + motif_videos_created = { + param: all(map(lambda videos: len(videos) > 0, videos["motif"][param].values())) + for param in parametrizations + } + + # Check if community videos were created for each parametrization + community_videos_created = { + param: all(map(lambda videos: len(videos) > 0, videos["community"][param].values())) + for param in parametrizations + } + + # Check if UMAPs were created for each parametrization + umaps_created = { + param: any(map(lambda videos: len(videos) > 0, images["visualization"][param].values())) + for param in parametrizations + } pose_ref_index_description, ref_index_len = get_pose_ref_index_description(original_csvs[0]) # Provide project workflow status workflow = dict( - organized = (project_path / 'data' / 'train').exists(), + organized=(project_path / 'data' / 'train').exists(), pose_ref_index_description=pose_ref_index_description, ref_index_len=ref_index_len, - modeled = len(images["evaluation"]) > 0, - segmented = has_latent_vector_files, - motif_videos_created=(lambda: motif_videos_created_hmm or motif_videos_created_kmeans)(), + modeled=len(images["evaluation"]) > 0, + segmented=has_latent_vector_files, + motif_videos_created=any(motif_videos_created.values()), communities_created=has_communities, - community_videos_created = (lambda: community_videos_created_hmm or community_videos_created_kmeans)(), - umaps_created = (lambda: umaps_created_hmm or umaps_created_kmeans)(), + community_videos_created=any(community_videos_created.values()), + umaps_created=any(umaps_created.values()), ) return dict( @@ -193,4 +206,4 @@ def load_project(project_path: Path): csvs=original_csvs, workflow=workflow, states=states - ) \ No newline at end of file + )