forked from TomMoore515/material_stable_diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
136 lines (120 loc) · 4.88 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
from typing import Optional, List
import torch
import torch.nn as nn
from torch import autocast
from diffusers import PNDMScheduler, LMSDiscreteScheduler
from PIL import Image
from cog import BasePredictor, Input, Path
from image_to_image import (
StableDiffusionImg2ImgPipeline,
preprocess_init_image,
preprocess_mask,
)
def patch_conv(**patch):
cls = torch.nn.Conv2d
init = cls.__init__
def __init__(self, *args, **kwargs):
return init(self, *args, **kwargs, **patch)
cls.__init__ = __init__
patch_conv(padding_mode='circular')
MODEL_CACHE = "diffusers-cache"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading pipeline...")
scheduler = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=scheduler,
revision="fp16",
torch_dtype=torch.float16,
cache_dir=MODEL_CACHE,
local_files_only=True,
).to("cuda")
@torch.inference_mode()
@torch.cuda.amp.autocast()
def predict(
self,
prompt: str = Input(description="Input prompt", default=""),
width: int = Input(
description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
choices=[128, 256, 512, 768, 1024],
default=512,
),
height: int = Input(
description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
choices=[128, 256, 512, 768, 1024],
default=512,
),
init_image: Path = Input(
description="Inital image to generate variations of. Will be resized to the specified width and height",
default=None,
),
mask: Path = Input(
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7",
default=None,
),
prompt_strength: float = Input(
description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image",
default=0.8,
),
num_outputs: int = Input(
description="Number of images to output", choices=[1, 4], default=1
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> List[Path]:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
if width == height == 1024:
raise ValueError(
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height."
)
if init_image:
init_image = Image.open(init_image).convert("RGB")
init_image = preprocess_init_image(init_image, width, height).to("cuda")
# use PNDM with init images
scheduler = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
# use LMS without init images
scheduler = LMSDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
self.pipe.scheduler = scheduler
if mask:
mask = Image.open(mask).convert("RGB")
mask = preprocess_mask(mask, width, height).to("cuda")
generator = torch.Generator("cuda").manual_seed(seed)
output = self.pipe(
prompt=[prompt] * num_outputs if prompt is not None else None,
init_image=init_image,
mask=mask,
width=width,
height=height,
prompt_strength=prompt_strength,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
)
if any(output["nsfw_content_detected"]):
raise Exception("NSFW content detected, please try a different prompt")
output_paths = []
for i, sample in enumerate(output["sample"]):
output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
output_paths.append(Path(output_path))
return output_paths