Skip to content

Commit

Permalink
Merge pull request #180 from cloneofsimo/develop
Browse files Browse the repository at this point in the history
v0.1.7
  • Loading branch information
cloneofsimo authored Feb 13, 2023
2 parents 799c17a + e48cbbb commit bdd51b0
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 90 deletions.
175 changes: 125 additions & 50 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,31 @@ def get_models(
)


def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
@torch.no_grad()
def text2img_dataloader(
train_dataset,
train_batch_size,
tokenizer,
vae,
text_encoder,
cached_latents: bool = False,
):

if cached_latents:
cached_latents_dataset = []
for idx in tqdm(range(len(train_dataset))):
batch = train_dataset[idx]
# rint(batch)
latents = vae.encode(
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
).latent_dist.sample()
latents = latents * 0.18215
batch["instance_images"] = latents.squeeze(0)
cached_latents_dataset.append(batch)

def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if examples[0].get("class_prompt_ids", None) is not None:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]

pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

Expand All @@ -159,33 +173,60 @@ def collate_fn(examples):

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)
if cached_latents:

train_dataloader = torch.utils.data.DataLoader(
cached_latents_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)

print("PTI : Using cached latent.")

else:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)

return train_dataloader

def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):

def inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
):
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
mask_values = [example["instance_masks"] for example in examples]
masked_image_values = [example["instance_masked_images"] for example in examples]
masked_image_values = [
example["instance_masked_images"] for example in examples
]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if examples[0].get("class_prompt_ids", None) is not None:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
mask_values += [example["class_masks"] for example in examples]
masked_image_values += [example["class_masked_images"] for example in examples]
masked_image_values += [
example["class_masked_images"] for example in examples
]

pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
pixel_values = (
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
)
mask_values = (
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
)
masked_image_values = (
torch.stack(masked_image_values)
.to(memory_format=torch.contiguous_format)
.float()
)

input_ids = tokenizer.pad(
{"input_ids": input_ids},
Expand All @@ -198,7 +239,7 @@ def collate_fn(examples):
"input_ids": input_ids,
"pixel_values": pixel_values,
"mask_values": mask_values,
"masked_image_values": masked_image_values
"masked_image_values": masked_image_values,
}

if examples[0].get("mask", None) is not None:
Expand All @@ -215,6 +256,7 @@ def collate_fn(examples):

return train_dataloader


def loss_step(
batch,
unet,
Expand All @@ -225,23 +267,30 @@ def loss_step(
t_mutliplier=1.0,
mixed_precision=False,
mask_temperature=1.0,
cached_latents: bool = False,
):
weight_dtype = torch.float32

latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
latents = latents * 0.18215

if train_inpainting:
masked_image_latents = vae.encode(
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
if not cached_latents:
latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
masked_image_latents = masked_image_latents * 0.18215
mask = F.interpolate(
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
scale_factor=1/8
)
latents = latents * 0.18215

if train_inpainting:
masked_image_latents = vae.encode(
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
masked_image_latents = masked_image_latents * 0.18215
mask = F.interpolate(
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
scale_factor=1 / 8,
)
else:
latents = batch["pixel_values"]

if train_inpainting:
masked_image_latents = batch["masked_image_latents"]
mask = batch["mask_values"]

noise = torch.randn_like(latents)
bsz = latents.shape[0]
Expand All @@ -257,7 +306,9 @@ def loss_step(
noisy_latents = scheduler.add_noise(latents, noise, timesteps)

if train_inpainting:
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
latent_model_input = torch.cat(
[noisy_latents, mask, masked_image_latents], dim=1
)
else:
latent_model_input = noisy_latents

Expand All @@ -268,7 +319,9 @@ def loss_step(
batch["input_ids"].to(text_encoder.device)
)[0]

model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
model_pred = unet(
latent_model_input, timesteps, encoder_hidden_states
).sample
else:

encoder_hidden_states = text_encoder(
Expand Down Expand Up @@ -308,7 +361,12 @@ def loss_step(

target = target * mask

loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss = (
F.mse_loss(model_pred.float(), target.float(), reduction="none")
.mean([1, 2, 3])
.mean()
)

return loss


Expand All @@ -328,6 +386,7 @@ def train_inversion(
tokenizer,
lr_scheduler,
test_image_path: str,
cached_latents: bool,
accum_iter: int = 1,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
Expand Down Expand Up @@ -367,6 +426,7 @@ def train_inversion(
scheduler,
train_inpainting=train_inpainting,
mixed_precision=mixed_precision,
cached_latents=cached_latents,
)
/ accum_iter
)
Expand All @@ -375,6 +435,13 @@ def train_inversion(
loss_sum += loss.detach().item()

if global_step % accum_iter == 0:
# print gradient of text encoder embedding
print(
text_encoder.get_input_embeddings()
.weight.grad[index_updates, :]
.norm(dim=-1)
.mean()
)
optimizer.step()
optimizer.zero_grad()

Expand Down Expand Up @@ -448,7 +515,11 @@ def train_inversion(
# open all images in test_image_path
images = []
for file in os.listdir(test_image_path):
if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"):
if (
file.lower().endswith(".png")
or file.lower().endswith(".jpg")
or file.lower().endswith(".jpeg")
):
images.append(
Image.open(os.path.join(test_image_path, file))
)
Expand Down Expand Up @@ -490,6 +561,7 @@ def perform_tuning(
out_name: str,
tokenizer,
test_image_path: str,
cached_latents: bool,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
Expand Down Expand Up @@ -526,6 +598,7 @@ def perform_tuning(
t_mutliplier=0.8,
mixed_precision=True,
mask_temperature=mask_temperature,
cached_latents=cached_latents,
)
loss_sum += loss.detach().item()

Expand Down Expand Up @@ -627,18 +700,12 @@ def train(
train_text_encoder: bool = True,
pretrained_vae_name_or_path: str = None,
revision: Optional[str] = None,
class_data_dir: Optional[str] = None,
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
train_inpainting: bool = False,
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
class_prompt: Optional[str] = None,
with_prior_preservation: bool = False,
prior_loss_weight: float = 1.0,
num_class_images: int = 100,
seed: int = 42,
resolution: int = 512,
color_jitter: bool = True,
Expand All @@ -649,7 +716,6 @@ def train(
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
mixed_precision="fp16",
lora_rank: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
Expand All @@ -663,6 +729,7 @@ def train(
continue_inversion: bool = False,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
cached_latents: bool = True,
use_mask_captioned_data: bool = False,
mask_temperature: float = 1.0,
scale_lr: bool = False,
Expand Down Expand Up @@ -773,11 +840,8 @@ def train(

train_dataset = PivotalTuningDatasetCapation(
instance_data_root=instance_data_dir,
stochastic_attribute=stochastic_attribute,
token_map=token_map,
use_template=use_template,
class_data_root=class_data_dir if with_prior_preservation else None,
class_prompt=class_prompt,
tokenizer=tokenizer,
size=resolution,
color_jitter=color_jitter,
Expand All @@ -789,12 +853,19 @@ def train(
train_dataset.blur_amount = 200

if train_inpainting:
assert not cached_latents, "Cached latents not supported for inpainting"

train_dataloader = inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
else:
train_dataloader = text2img_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
train_dataset,
train_batch_size,
tokenizer,
vae,
text_encoder,
cached_latents=cached_latents,
)

index_no_updates = torch.arange(len(tokenizer)) != -1
Expand All @@ -813,6 +884,8 @@ def train(
for param in params_to_freeze:
param.requires_grad = False

if cached_latents:
vae = None
# STEP 1 : Perform Inversion
if perform_inversion:
ti_optimizer = optim.AdamW(
Expand All @@ -836,6 +909,7 @@ def train(
text_encoder,
train_dataloader,
max_train_steps_ti,
cached_latents=cached_latents,
accum_iter=gradient_accumulation_steps,
scheduler=noise_scheduler,
index_no_updates=index_no_updates,
Expand Down Expand Up @@ -941,6 +1015,7 @@ def train(
text_encoder,
train_dataloader,
max_train_steps_tuning,
cached_latents=cached_latents,
scheduler=noise_scheduler,
optimizer=lora_optimizers,
save_steps=save_steps,
Expand Down
Loading

0 comments on commit bdd51b0

Please sign in to comment.