-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
only train and infer stage 2 model #119
Comments
Hi :)
class CustomPipeline(Pipeline):
def __init__(self, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
super().__init__(None, cldm, diffusion, cond_fn, device)
@count_vram_usage
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
# In our experiments, the output of restoration module (a.k.a the condition for subsequent IRControlNet)
# will be resized to a resolution >= 512, since both pretrained SD2 and IRControlNet are trained on 512x512.
# Here we directly use the lq as condition.
if min(lq.shape[2:]) < 512:
clean = resize_short_edge_to(lq, size=512)
return clean
class CustomInferenceLoop(InferenceLoop):
@count_vram_usage
def init_stage1_model(self) -> None:
# Nothing to do.
pass
def init_pipeline(self) -> None:
# Instantiate our custom pipeline.
self.pipeline = CustomPipeline(self.cldm, self.diffusion, self.cond_fn, self.args.device)
def main():
args = parse_args()
args.device = check_device(args.device)
set_seed(args.seed)
if args.version == "v1":
V1InferenceLoop(args).run()
else:
supported_tasks = {
"sr": BSRInferenceLoop,
"dn": BIDInferenceLoop,
"fr": BFRInferenceLoop,
"fr_bg": UnAlignedBFRInferenceLoop,
"custom": CustomInferenceLoop
}
supported_tasks[args.task](args).run()
print("done!") The provided code is directly wrote in the GitHub comment console and has not been tested, but I think it will work. As for finetuning stage2 model, there are two ways to do it:
The second way is more recommended. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, @0x3f3f3f3fun
I want to train & inference only stage 2 model(stage 1 will implement alone by another model and generate restorationed image), if you could provide me a command structures?(how to only fine-tune stage 2 model and save it, and how to inference only stage 2 model) Thanks!
BRs,
tzayuan
The text was updated successfully, but these errors were encountered: