|
from diffusers import DDPMScheduler |
|
import torch |
|
|
|
class HookedNoiseScheduler: |
|
scheduler: DDPMScheduler |
|
pre_hooks: list |
|
post_hooks: list |
|
|
|
def __init__(self, scheduler): |
|
object.__setattr__(self, 'scheduler', scheduler) |
|
object.__setattr__(self, 'pre_hooks', []) |
|
object.__setattr__(self, 'post_hooks', []) |
|
|
|
def step( |
|
self, |
|
model_output, timestep, sample, generator, return_dict |
|
): |
|
assert return_dict == False, "return_dict == True is not implemented" |
|
for hook in self.pre_hooks: |
|
hook_output = hook(model_output, timestep, sample, generator) |
|
if hook_output is not None: |
|
model_output, timestep, sample, generator = hook_output |
|
|
|
(pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict) |
|
|
|
for hook in self.post_hooks: |
|
hook_output = hook(pred_prev_sample) |
|
if hook_output is not None: |
|
pred_prev_sample = hook_output |
|
|
|
return (pred_prev_sample, ) |
|
|
|
def __getattr__(self, name): |
|
return getattr(self.scheduler, name) |
|
|
|
def __setattr__(self, name, value): |
|
if name in {'scheduler', 'pre_hooks', 'post_hooks'}: |
|
object.__setattr__(self, name, value) |
|
else: |
|
setattr(self.scheduler, name, value) |