Skip to content

Commit

Permalink
gpu dtype float16
Browse files Browse the repository at this point in the history
  • Loading branch information
koraypoyraz committed Dec 1, 2023
1 parent fdb885b commit c5950c6
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

[EXPERIMENTAL]

This repo consist of an Image Compressor system using pretrained Vector Quantized Variational Autoencoder (VQVAE) developed with Tensorflow (see notebooks) and hosted within the framework of FASTapi.
This repo consist of an Image Compressor system using pretrained Vector Quantized Variational Autoencoder (VQVAE) developed with Tensorflow (see notebooks).

**Stack**
- Models
Expand Down
2 changes: 1 addition & 1 deletion app/core/populate_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_collection(client: QdrantClient, compressor: ImageCompressor):
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(
size=compressor.get_latent_flat_size(),
size=compressor.get_latent_mu_size(),
distance=models.Distance.COSINE,
on_disk=True
)
Expand Down
4 changes: 1 addition & 3 deletions app/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,4 @@ def __init__(self):
self._model = None
self.model_id = None
self._device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

if self._device in ['cuda', 'cpu']:
self.d_type = torch.float16
self.d_type = torch.float16
22 changes: 12 additions & 10 deletions app/models/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ class ImageCompressor(BaseModel):
def __init__(self):
super().__init__()

if self._device in ['mps']:
self.d_type = torch.float32

self._latent_dims = (4, 64, 64)
self._max_dim = 512
self._np_d_type = np.float16
self.model_id = 'madebyollin/taesd'

def get_latent_dims(self):
return self._latent_dims

def get_latent_flat_size(self):
def get_latent_mu_size(self):
return math.prod(self._latent_dims[:2])

def load_model(self):
Expand All @@ -34,29 +32,33 @@ def get_model(self):
return self._model

def compress(self, raw_images: List[Image.Image]) -> Tuple[torch.Tensor, torch.Size]:
tensor_block = torch.stack([to_tensor(self.preprocess(img)) for img in raw_images]).to(self._device)
tensor_block = (torch.stack([to_tensor(self.preprocess(img)) for img in raw_images])
.to(self._device, dtype=self.d_type))

latent_space = self._model.encoder(tensor_block)
return latent_space, latent_space.shape

def preprocess(self, raw_image: Image.Image) -> Image.Image:
return resize(raw_image, self._max_dim)

def vector_ndarray(self, latent_tensor: torch.Tensor) -> np.ndarray:
return latent_tensor.flatten().numpy(force=True).astype(np.float32)
return latent_tensor.flatten().numpy(force=True).astype(self._np_d_type)

def dimensionalize(self, latent_vector: List) -> torch.Tensor:
reshaped = np.array(latent_vector, dtype=np.float32).reshape(self._latent_dims)
reshaped = np.array(latent_vector, dtype=self._np_d_type).reshape(self._latent_dims)
dim_shift = to_tensor(reshaped).movedim(0, -1).unsqueeze(0)
return dim_shift

def decompress(self, latent_vector: List) -> Image.Image:
dim_shift_gpu = self.dimensionalize(latent_vector).to(self._device)
dim_shift_gpu = self.dimensionalize(latent_vector).to(self._device, dtype=self.d_type)
reconstructed = self._model.decoder(dim_shift_gpu).clamp(0, 1)
return to_pil_image(reconstructed[0])

def decompress_batch(self, latent_space_block: List) -> torch.Tensor:
dimensionalized_block = torch.stack(
[self.dimensionalize(latent_vector)[0] for latent_vector in latent_space_block]).to(self._device)
dimensionalized_block = (torch.stack(
[self.dimensionalize(latent_vector)[0] for latent_vector in latent_space_block])
.to(self._device, dtype=self.d_type))

reconstructed_block = self._model.decoder(dimensionalized_block).clamp(0, 1)
return reconstructed_block

Expand Down
Binary file modified asset/similar_latents.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ requests==2.31.0
datasets==2.14.6
tqdm==4.66.1
transformers==4.35.1
setuptools==68.2.2
setuptools==69.0.2
diffusers==0.23
pillow==10.1.0
accelerate
qdrant-client
streamlit
streamlit==1.29.0
3 changes: 2 additions & 1 deletion webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def web_app_similar():

st.write(f"Elapsed time: {duration.microseconds / 1e3}ms")
st.write(f"Reconstruction time: {s_r_s - r_s}s")
st.write(f"Reconstruction dimensions: 512x512")


if __name__ == '__main__':
web_app()
web_app_similar()

0 comments on commit c5950c6

Please sign in to comment.