Skip to content

Commit

Permalink
fix: correct vits assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 23, 2024
1 parent 3d3861d commit d35ae96
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,16 @@ def cluster_vits(
# Fetch the cached embeddings
debug('Fetching embeddings ...')
image_emb = []
image_predictions = []
image_scores = []
for filename in images:
emb, label, score = fetch_embedding(model, filename)
if len(emb) == 0:
# If the embeddings are zero, then the extraction failed; add a zero array
image_emb.append(np.zeros(384, dtype=np.float32))
else:
image_emb.append(emb)
image_predictions.append(label)
image_scores.append(score)
if use_vits:
df_dets.loc[df_dets['crop_path'] == filename, 'class'] = label[0]
df_dets.loc[df_dets['crop_path'] == filename, 'score'] = score[0]

# If the embeddings are zero, then the extraction failed
num_failed = [i for i, e in enumerate(image_emb) if np.all(e == 0)]
Expand Down Expand Up @@ -452,16 +451,6 @@ def cluster_vits(
debug(f'Adding {images[idx]} to cluster id {cluster_id} ')
df_dets.loc[df_dets['crop_path'] == images[idx], 'cluster'] = cluster_id

# If use_vits is true, then assign the class to each detection
if use_vits:
for idx, row in df_dets.iterrows():
# If the idx is our of range, then skip
if idx >= len(image_predictions):
continue
predictions, scores = image_predictions[idx], image_scores[idx]
df_dets.loc[idx, 'class'] = predictions[0] # Use the top prediction
df_dets.loc[idx, 'score'] = scores[0]

# For each cluster let's create a grid of the images to check the quality of the clustering results
num_processes = min(multiprocessing.cpu_count(), len(unique_clusters))
info(f'Using {num_processes} processes to visualize the {len(unique_clusters)} clusters')
Expand Down

0 comments on commit d35ae96

Please sign in to comment.