-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_index.py
30 lines (25 loc) · 899 Bytes
/
create_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import faiss
from pathlib import Path
import open_clip
from PIL import Image
import torch
import torch.nn.functional as F
import json
model_name = "ViT-B-16-SigLIP-512"
model, _, preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained="model/open_clip_pytorch_model.bin",
)
output_dim = model.text.output_dim
index = faiss.IndexFlatIP(output_dim)
image_paths = list(Path("images").rglob("*.jpg"))
images = [Image.open(image_path) for image_path in image_paths]
images = [preprocess(image) for image in images]
images = torch.stack(images)
with torch.no_grad():
image_embeddings = model.encode_image(images)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
index.add(image_embeddings)
faiss.write_index(index, "docker/index.faiss")
with Path("docker/mapping.json").open("w") as f:
json.dump([str(image_path) for image_path in image_paths], f)