Skip to content

Commit

Permalink
fix: load bu number of clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-vm committed Aug 29, 2024
1 parent b6a4f7b commit 6f14df0
Showing 1 changed file with 49 additions and 36 deletions.
85 changes: 49 additions & 36 deletions src/services/vameApi/app/services/project_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,66 +120,79 @@ def load_project(project_path: Path):
symlink.symlink_to(project_path)

states_path = Path(project_path) / "states" / "states.json"

states = json.load(open(states_path)) if os.path.exists(states_path) else None

config = yaml.safe_load(open(config_path, "r")) if config_path.exists() else None

# Extract parameters from the config
parametrizations = config["parametrizations"] # e.g., ["hmm", "kmeans"]
n_cluster = config["n_cluster"] # e.g., 28

# Create the visualization dictionary dynamically
visualization = {
param: get_visualization_images(project_path, f"{param}-{n_cluster}")
for param in parametrizations
}

images = dict(
evaluation=get_evaluation_images(project_path),
visualization=dict(
hmm=get_visualization_images(project_path, 'hmm-15'),
kmeans=get_visualization_images(project_path, 'kmeans-15')
)
visualization=visualization
)

videos = dict(
motif=dict(
hmm=get_motif_videos(project_path, 'hmm-15'),
kmeans=get_motif_videos(project_path, 'kmeans-15')
),
community=dict(
hmm=get_community_videos(project_path, 'hmm-15'),
kmeans=get_community_videos(project_path, 'kmeans-15')
)
)
# Create the videos dictionary dynamically
videos = {
category: {
param: get_motif_videos(project_path, f"{param}-{n_cluster}")
for param in parametrizations
}
for category in ['motif', 'community']
}

has_latent_vector_files = False
if config:
has_latent_vector_files = all(map(lambda video: (get_video_results_path(video, project_path) / f"latent_vector_{video}.npy").exists(), config["video_sets"]))
has_latent_vector_files = all(
map(lambda video: (get_video_results_path(video, project_path) / f"latent_vector_{video}.npy").exists(), config["video_sets"])
)

has_communities = (project_path / 'cohort_community_label.npy').exists()

original_videos_location = project_path / 'videos'
original_csvs_location = original_videos_location / 'pose_estimation'

# Get all files in the original data directory
original_videos = list(map(lambda file: str(file), get_files(original_videos_location)))
original_csvs = list(map(lambda file: str(file), get_files(original_csvs_location)))

motif_videos_created_hmm=all(map(lambda videos: len(videos) > 0, videos["motif"]["hmm"].values()))
motif_videos_created_kmeans=all(map(lambda videos: len(videos) > 0, videos["motif"]["kmeans"].values()))

community_videos_created_hmm = all(map(lambda videos: len(videos) > 0, videos["community"]["kmeans"].values()))
community_videos_created_kmeans = all(map(lambda videos: len(videos) > 0, videos["community"]["hmm"].values()))

umaps_created_hmm = any(map(lambda videos: len(videos) > 0, images["visualization"]["hmm"].values()))

umaps_created_kmeans = any(map(lambda videos: len(videos) > 0, images["visualization"]["kmeans"].values()))
original_videos = list(map(str, get_files(original_videos_location)))
original_csvs = list(map(str, get_files(original_csvs_location)))

# Check if motif videos were created for each parametrization
motif_videos_created = {
param: all(map(lambda videos: len(videos) > 0, videos["motif"][param].values()))
for param in parametrizations
}

# Check if community videos were created for each parametrization
community_videos_created = {
param: all(map(lambda videos: len(videos) > 0, videos["community"][param].values()))
for param in parametrizations
}

# Check if UMAPs were created for each parametrization
umaps_created = {
param: any(map(lambda videos: len(videos) > 0, images["visualization"][param].values()))
for param in parametrizations
}

pose_ref_index_description, ref_index_len = get_pose_ref_index_description(original_csvs[0])

# Provide project workflow status
workflow = dict(
organized = (project_path / 'data' / 'train').exists(),
organized=(project_path / 'data' / 'train').exists(),
pose_ref_index_description=pose_ref_index_description,
ref_index_len=ref_index_len,
modeled = len(images["evaluation"]) > 0,
segmented = has_latent_vector_files,
motif_videos_created=(lambda: motif_videos_created_hmm or motif_videos_created_kmeans)(),
modeled=len(images["evaluation"]) > 0,
segmented=has_latent_vector_files,
motif_videos_created=any(motif_videos_created.values()),
communities_created=has_communities,
community_videos_created = (lambda: community_videos_created_hmm or community_videos_created_kmeans)(),
umaps_created = (lambda: umaps_created_hmm or umaps_created_kmeans)(),
community_videos_created=any(community_videos_created.values()),
umaps_created=any(umaps_created.values()),
)

return dict(
Expand All @@ -193,4 +206,4 @@ def load_project(project_path: Path):
csvs=original_csvs,
workflow=workflow,
states=states
)
)

0 comments on commit 6f14df0

Please sign in to comment.