Skip to content

Commit

Permalink
latents batch reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
koraypoyraz committed Dec 1, 2023
1 parent 5849efa commit fdb885b
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 24 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ This repo consist of an Image Compressor system using pretrained Vector Quantize
<img width='50%' src="/asset/astronaut.png">
</p>

<p align='center'>
<img width='50%' src="/asset/similar_latents.png">
</p>

## 🚀 Prerequisite
- install [miniforge](https://github.com/conda-forge/miniforge)
- create virtual env || conda
Expand Down
13 changes: 9 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from core.qdrant_service import QdrantService
from core.neural_service import NeuralService
Expand All @@ -17,9 +18,13 @@
neural_service = NeuralService()


@app.post("/latents/search/")
async def search(mu):
return await qdrant_service.search(mu)
class Latents(BaseModel):
mu: list


@app.post("/latents/search")
async def search(latents: Latents):
return await qdrant_service.search(latents.mu)


@app.get("/latents")
Expand Down Expand Up @@ -50,4 +55,4 @@ async def health_check():
if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
uvicorn.run(f"{__name__}:app", host="0.0.0.0", port=8000, log_level="info")
7 changes: 4 additions & 3 deletions app/models/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ def vector_ndarray(self, latent_tensor: torch.Tensor) -> np.ndarray:
def dimensionalize(self, latent_vector: List) -> torch.Tensor:
reshaped = np.array(latent_vector, dtype=np.float32).reshape(self._latent_dims)
dim_shift = to_tensor(reshaped).movedim(0, -1).unsqueeze(0)
return dim_shift.to(self._device)
return dim_shift

def decompress(self, latent_vector: List) -> Image.Image:
dim_shift_gpu = self.dimensionalize(latent_vector)
dim_shift_gpu = self.dimensionalize(latent_vector).to(self._device)
reconstructed = self._model.decoder(dim_shift_gpu).clamp(0, 1)
return to_pil_image(reconstructed[0])

def decompress_batch(self, latent_space_block: List) -> torch.Tensor:
dimensionalized_block = [self.dimensionalize(latent_vector) for latent_vector in latent_space_block]
dimensionalized_block = torch.stack(
[self.dimensionalize(latent_vector)[0] for latent_vector in latent_space_block]).to(self._device)
reconstructed_block = self._model.decoder(dimensionalized_block).clamp(0, 1)
return reconstructed_block

Expand Down
Binary file modified asset/astronaut.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/similar_latents.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 23 additions & 17 deletions webapp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import streamlit as st
from app.models import ImageCompressor
from torchvision.transforms.functional import to_pil_image
import requests
from timeit import default_timer as timer
from typing import List
Expand Down Expand Up @@ -37,12 +38,12 @@ def get_reconstructed_latents(idx: int):


def search(latents: List):
# TODO: fix type issue
response: requests.Response = requests.get(f'{BASE_URL}/latents/search', json={'mu': latents})
st.write(response.json())
# latent_space_block = [vector["vector"] for vector in response.json()]
# reconstructed_latents = image_compressor.decompress_batch(latent_space_block)
# st.write(reconstructed_latents.shape)
response: requests.Response = requests.post(f'{BASE_URL}/latents/search', json={'mu': latents})
latent_space_block = [vector["payload"]['latents'] for vector in response.json()]
scores = [vector["score"] for vector in response.json()]
reconstructed_latents = image_compressor.decompress_batch(latent_space_block)
reconstructed_images = [to_pil_image(tensor) for tensor in reconstructed_latents]
return reconstructed_images, scores


def web_app_prompting():
Expand Down Expand Up @@ -71,7 +72,6 @@ def web_app():
clamp=True)
s_o_s = timer()

st.write(f"Elapsed time: {duration.microseconds}μs")
st.write(f"Elapsed time: {duration.microseconds / 1e3}ms")
st.write(f"Reconstruction time: {s_r_s - r_s}s")
st.write(f'Overall time: {s_o_s - o_s}s')
Expand All @@ -81,21 +81,27 @@ def web_app_similar():
latent_idx = st.text_input('Latent idx')

if latent_idx and latent_idx.isdigit():
o_s = timer()
r_s = timer()
vector, image, depiction, duration = get_reconstructed_latents(latent_idx)

st.image([image, depiction], caption=[f'Reconstruction: {image.size}', f'Latents: {depiction.size}'],
clamp=True)

st.write(f"Elapsed time: {duration.microseconds}μs")
st.write(f"Elapsed time: {duration.microseconds / 1e3}ms")
st.write(f"Reconstruction time: {timer() - r_s}s")
st.write(f'Overall time: {timer() - o_s}s')

state = st.button('Search similarities')
state = st.button('Similar latents')
if state:
# top_k = search(vector)
search(vector)
r_s = timer()
top_k, scores = search(vector)
s_r_s = timer()

grid_size = len(top_k)
grid = st.columns(grid_size)
for idx in range(grid_size):
with grid[idx]:
reconstructed = top_k[idx]
score = scores[idx]
st.image(reconstructed, caption=f'Score: {score}')

st.write(f"Elapsed time: {duration.microseconds / 1e3}ms")
st.write(f"Reconstruction time: {s_r_s - r_s}s")


if __name__ == '__main__':
Expand Down

0 comments on commit fdb885b

Please sign in to comment.