Duplicated from OmniGen2/OmniGen2
119e1fd
1
2
3
4
import torch.nn.functional as F def swiglu(x, y): return F.silu(x.float(), inplace=False).to(x.dtype) * y