Skip to content

Commit

Permalink
- update readme
Browse files Browse the repository at this point in the history
- new dream-shape model
- fix gen synthetic image
  • Loading branch information
koraypoyraz committed Feb 22, 2024
1 parent 00a3b30 commit 5975e02
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 30 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ In a world of compression without storing original images, latent space represen
- VQVAE: pretraining (see notebooks)
- VAE Tiny: madebyollin/taesd
- model size: 2.4M params
- Stable Diffusion model: Lykon/dreamshaper-8
- for generating synthetic data
- model size: > 1B params
- Flavour
- 8bit latent space
- Similarity
Expand Down
3 changes: 2 additions & 1 deletion app/core/neural_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def load_models(self):
def prompt_inference(self, prompt: str) -> Tuple[np.ndarray, torch.Size, bool]:
image, is_nsfw = self.generator.inference(prompt)
latents, shape = self.compressor.compress(image)
vectorized = self.compressor.vector_ndarray(latents[0])
scaled_latents = self.compressor.scale_latents(latents)
vectorized = self.compressor.vector_ndarray(scaled_latents[0])
return vectorized, shape, is_nsfw

async def prompt_inference_async(self, prompt: str) -> Tuple[np.ndarray, torch.Size, bool]:
Expand Down
13 changes: 6 additions & 7 deletions app/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ImageGenerator(BaseModel):

def __init__(self):
super().__init__()
self.model_id = f'{MODEL_WEIGHTS_DIR}/dreamshaper-7'
self.model_id = f'{MODEL_WEIGHTS_DIR}/dreamshaper-8'

def load_model(self):
self._model = DiffusionPipeline.from_pretrained(self.model_id,
Expand All @@ -20,20 +20,19 @@ def load_model(self):
self._optimize()

def _optimize(self):
self._model.scheduler = LCMScheduler.from_config(self._model.scheduler.config)
self._model.enable_attention_slicing()
# self._model.scheduler = LCMScheduler.from_config(self._model.scheduler.config)
# self._model.enable_attention_slicing()

self._model = self._model.to(self.device)

self._model.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
self._model.fuse_lora()
# self._model.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
# self._model.fuse_lora()

@torch.inference_mode()
def inference(self, prompt: str) -> Tuple[List[Image.Image], bool]:
results = self._model(
prompt=prompt,
num_inference_steps=4,
guidance_scale=0.0,
num_inference_steps=4
)

is_nsfw: bool = results.nsfw_content_detected[0]
Expand Down
34 changes: 27 additions & 7 deletions notebooks/vqvae_tiny.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2023-12-08T22:48:31.465466Z",
"start_time": "2023-12-08T22:48:29.907234Z"
"end_time": "2024-02-22T12:01:14.058242Z",
"start_time": "2024-02-22T12:01:12.017842Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -101,20 +101,40 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"outputs": [],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-12-08T22:48:34.428290Z",
"start_time": "2023-12-08T22:48:34.414839Z"
"end_time": "2024-02-22T12:01:16.476156Z",
"start_time": "2024-02-22T12:01:16.471331Z"
}
},
"id": "54623a740cde50e9"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"def calc_params(pars:list):\n",
" to_return = 0\n",
" for par in pars:\n",
" to_return += sum(param.numel() for param in par.parameters())\n",
" return to_return"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-02-22T12:21:19.684922Z",
"start_time": "2024-02-22T12:21:19.612462Z"
}
},
"id": "2e28b6a84623926e",
"execution_count": 41
},
{
"cell_type": "code",
"execution_count": 3,
Expand Down Expand Up @@ -146,7 +166,7 @@
}
],
"source": [
"total_params = sum(param.numel() for param in vae.parameters())\n",
"total_params = calc_params([vae])\n",
"total_params"
],
"metadata": {
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ diffusers==0.23
pillow==10.1.0
accelerate
qdrant-client
streamlit==1.29.0
streamlit==1.29.0
umap-learn[plot]
32 changes: 18 additions & 14 deletions webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
def inference(prompt: str):
response: requests.Response = requests.get(f'{BASE_URL}/inference?prompt="{prompt}"')
body = response.json()
st.write(body.keys())
latents = body['latents']

reconstruction = image_compressor.decompress(latents)
Expand All @@ -31,6 +32,22 @@ def inference(prompt: str):
return reconstruction, depiction, response.elapsed


def web_app_prompting():
example_prompt = 'portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors'
prompt = st.text_input('Prompt', placeholder=example_prompt)
st.markdown(example_prompt)

if prompt:
start = timer()
image, depiction, elapsed_time = inference(prompt)
end = timer()
st.image([image, depiction], caption=[f'Reconstruction: {image.size}', f'Latents: {depiction.size}'],
clamp=True)

st.write(f"Elapsed time (ms)", elapsed_time.total_seconds() * 1e3)
st.write(f'Overall time (s)', end - start)


def get_reconstructed_latents(idx: int):
response: requests.Response = requests.get(f'{BASE_URL}/latents/{idx}')
body = response.json()[0]
Expand Down Expand Up @@ -74,20 +91,6 @@ def search_reference(latents: List):
return reconstructed_images, scores, response.elapsed


def web_app_prompting():
prompt = st.text_input('Prompt')

if prompt:
start = timer()
image, depiction, elapsed_time = inference(prompt)
end = timer()
st.image([image, depiction], caption=[f'Reconstruction: {image.size}', f'Latents: {depiction.size}'],
clamp=True)

st.write(f"Elapsed time (ms)", elapsed_time.total_seconds() * 1e3)
st.write(f'Overall time (s)', end - start)


def web_app():
latent_idx = st.text_input('Latent idx')

Expand Down Expand Up @@ -197,6 +200,7 @@ def web_app_file_store_similar():


if __name__ == '__main__':
# web_app_prompting()
# web_app()
# web_app_file_store()
# web_app_similar()
Expand Down

0 comments on commit 5975e02

Please sign in to comment.