diff --git a/examples/allegro/sample.py b/examples/allegro/sample.py new file mode 100644 index 00000000..ac1a8e5c --- /dev/null +++ b/examples/allegro/sample.py @@ -0,0 +1,125 @@ +from videosys import AllegroConfig, AllegroPABConfig, VideoSysEngine +import io +import imageio +import torch +def run_base(): + # num frames: 65 or 221 + # change num_gpus for multi-gpu inference + config = AllegroConfig(model_path="rhymes-ai/Allegro", + cpu_offload=False, + num_gpus=4) + engine = VideoSysEngine(config) + + positive_prompt = """ +(masterpiece), (best quality), (ultra-detailed), (unwatermarked), +{} +emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, +sharp focus, high budget, cinemascope, moody, epic, gorgeous +""" + + negative_prompt = """ +nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, +low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. +""" + + user_prompt = "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats." + num_step, cfg_scale, rand_seed = 100, 7.5, 42 + input_prompt = positive_prompt.format(user_prompt.lower().strip()) + + + video = engine.generate( + input_prompt, + negative_prompt=negative_prompt, + num_frames=88, + height=720, + width=1280, + num_inference_steps=num_step, + guidance_scale=cfg_scale, + max_sequence_length=512, + seed=rand_seed + ).video[0] + + engine.save_video(video, f"./outputs/{user_prompt}.mp4") + + +def run_low_mem(): + + # change num_gpus for multi-gpu inference + config = AllegroConfig(model_path="rhymes-ai/Allegro", + cpu_offload=True) + engine = VideoSysEngine(config) + + + positive_prompt = """ +(masterpiece), (best quality), (ultra-detailed), (unwatermarked), +{} +emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, +sharp focus, high budget, cinemascope, moody, epic, gorgeous +""" + + negative_prompt = """ +nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, +low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. +""" + + + user_prompt = "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats." + num_step, cfg_scale, rand_seed = 100, 7.5, 42 + input_prompt = positive_prompt.format(user_prompt.lower().strip()) + + + video = engine.generate( + input_prompt, + negative_prompt=negative_prompt, + num_frames=88, + height=720, + width=1280, + num_inference_steps=num_step, + guidance_scale=cfg_scale, + max_sequence_length=512, + seed=rand_seed + ).video[0] + engine.save_video(video, f"./outputs/{user_prompt}.mp4") + + +def run_pab(): + + config = AllegroConfig(model_path="rhymes-ai/Allegro", + cpu_offload=False, enable_tiling=True, num_gpus=4, enable_pab=True) + engine = VideoSysEngine(config) + + positive_prompt = """ +(masterpiece), (best quality), (ultra-detailed), (unwatermarked), +{} +emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, +sharp focus, high budget, cinemascope, moody, epic, gorgeous +""" + + negative_prompt = """ +nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, +low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. +""" + + + user_prompt = "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats." + num_step, cfg_scale, rand_seed = 100, 7.5, 42 + input_prompt = positive_prompt.format(user_prompt.lower().strip()) + + + video = engine.generate( + input_prompt, + negative_prompt=negative_prompt, + num_frames=88, + height=720, + width=1280, + num_inference_steps=num_step, + guidance_scale=cfg_scale, + max_sequence_length=512, + seed=rand_seed + ).video[0] + engine.save_video(video, f"./outputs/{user_prompt}.mp4") + +if __name__ == "__main__": + run_base() + # run_pab() + # run_low_mem() diff --git a/videosys/__init__.py b/videosys/__init__.py index 6c539c0f..343f72b9 100644 --- a/videosys/__init__.py +++ b/videosys/__init__.py @@ -10,6 +10,7 @@ OpenSoraPlanV120PABConfig, ) from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline +from .pipelines.allegro import AllegroConfig, AllegroPABConfig, AllegroPipeline __all__ = [ "initialize", @@ -18,5 +19,6 @@ "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig", "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig", "CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig", - "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig" + "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig", + "AllegroPipeline", "AllegroConfig", "AllegroPABConfig" ] # fmt: skip diff --git a/videosys/models/autoencoders/autoencoder_kl_allegro.py b/videosys/models/autoencoders/autoencoder_kl_allegro.py new file mode 100644 index 00000000..a70baa72 --- /dev/null +++ b/videosys/models/autoencoders/autoencoder_kl_allegro.py @@ -0,0 +1,978 @@ +import math +from dataclasses import dataclass +import os +from typing import Dict, Optional, Tuple, Union +from einops import rearrange + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.attention_processor import Attention +from diffusers.models.resnet import ResnetBlock2D +from diffusers.models.upsampling import Upsample2D +from diffusers.models.downsampling import Downsample2D +from diffusers.models.attention_processor import SpatialNorm + + +class TemporalConvBlock(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + spa_pad = int((spa_stride-1)*0.5) + temp_pad = 0 + self.temp_pad = temp_pad + + if down_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) + ) + elif up_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) + ) + else: + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + self.down_sample = down_sample + self.up_sample = up_sample + + + def forward(self, hidden_states): + identity = hidden_states + + if self.down_sample: + identity = identity[:,:,::2] + elif self.up_sample: + hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) + hidden_states_new[:, :, 0::2] = hidden_states + hidden_states_new[:, :, 1::2] = hidden_states + identity = hidden_states_new + del hidden_states_new + + if self.down_sample or self.up_sample: + hidden_states = self.conv1(hidden_states) + else: + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv1(hidden_states) + + + if self.up_sample: + hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) + + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv2(hidden_states) + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv3(hidden_states) + hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + + return hidden_states + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + add_temp_downsample=False, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_temp_downsample: + self.temp_convs_down = TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + down_sample=True, + spa_stride=3 + ) + self.add_temp_downsample = add_temp_downsample + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def _set_partial_grad(self): + for temp_conv in self.temp_convs: + temp_conv.requires_grad_(True) + if self.downsamplers: + for down_layer in self.downsamplers: + down_layer.requires_grad_(True) + + def forward(self, hidden_states): + bz = hidden_states.shape[0] + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = resnet(hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = temp_conv(hidden_states) + if self.add_temp_downsample: + hidden_states = self.temp_convs_down(hidden_states) + + if self.downsamplers is not None: + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + for upsampler in self.downsamplers: + hidden_states = upsampler(hidden_states) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + add_temp_upsample=False, + temb_channels=None, + ): + super().__init__() + self.add_upsample = add_upsample + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + self.add_temp_upsample = add_temp_upsample + if add_temp_upsample: + self.temp_conv_up = TemporalConvBlock( + out_channels, + out_channels, + dropout=0.1, + up_sample=True, + spa_stride=3 + ) + + + if self.add_upsample: + # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)]) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def _set_partial_grad(self): + for temp_conv in self.temp_convs: + temp_conv.requires_grad_(True) + if self.add_upsample: + self.upsamplers.requires_grad_(True) + + def forward(self, hidden_states): + bz = hidden_states.shape[0] + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + hidden_states = resnet(hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = temp_conv(hidden_states) + if self.add_temp_upsample: + hidden_states = self.temp_conv_up(hidden_states) + + if self.upsamplers is not None: + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + return hidden_states + + +class UNetMidBlock3DConv(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvBlock( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + temp_convs.append( + TemporalConvBlock( + in_channels, + in_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + def _set_partial_grad(self): + for temp_conv in self.temp_convs: + temp_conv.requires_grad_(True) + + def forward( + self, + hidden_states, + ): + bz = hidden_states.shape[0] + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + + hidden_states = self.resnets[0](hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = self.temp_convs[0](hidden_states) + hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') + + for attn, resnet, temp_conv in zip( + self.attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) + hidden_states = temp_conv(hidden_states) + return hidden_states + + +class Encoder3D(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + num_blocks=4, + blocks_temp_li=[False, False, False, False], + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.blocks_temp_li = blocks_temp_li + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + block_out_channels[0], + block_out_channels[0], + (3,1,1), + padding = (1, 0, 0) + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i in range(num_blocks): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + add_temp_downsample=blocks_temp_li[i], + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + + self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0)) + + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + nn.init.zeros_(self.temp_conv_in.weight) + nn.init.zeros_(self.temp_conv_in.bias) + nn.init.zeros_(self.temp_conv_out.weight) + nn.init.zeros_(self.temp_conv_out.bias) + + self.gradient_checkpointing = False + + def forward(self, x): + ''' + x: [b, c, (tb f), h, w] + ''' + bz = x.shape[0] + sample = rearrange(x, 'b c n h w -> (b n) c h w') + sample = self.conv_in(sample) + + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + temp_sample = sample + sample = self.temp_conv_in(sample) + sample = sample+temp_sample + # down + for b_id, down_block in enumerate(self.down_blocks): + sample = down_block(sample) + # middle + sample = self.mid_block(sample) + + # post-process + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + + temp_sample = sample + sample = self.temp_conv_out(sample) + sample = sample+temp_sample + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + + sample = self.conv_out(sample) + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + return sample + +class Decoder3D(nn.Module): + def __init__( + self, + in_channels=4, + out_channels=3, + num_blocks=4, + blocks_temp_li=[False, False, False, False], + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + norm_type="group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + self.blocks_temp_li = blocks_temp_li + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + block_out_channels[-1], + block_out_channels[-1], + (3,1,1), + padding = (1, 0, 0) + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(num_blocks): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + add_temp_upsample=blocks_temp_li[i], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + nn.init.zeros_(self.temp_conv_in.weight) + nn.init.zeros_(self.temp_conv_in.bias) + nn.init.zeros_(self.temp_conv_out.weight) + nn.init.zeros_(self.temp_conv_out.bias) + + self.gradient_checkpointing = False + + def forward(self, z): + bz = z.shape[0] + sample = rearrange(z, 'b c n h w -> (b n) c h w') + sample = self.conv_in(sample) + + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + temp_sample = sample + sample = self.temp_conv_in(sample) + sample = sample+temp_sample + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # up + for b_id, up_block in enumerate(self.up_blocks): + sample = up_block(sample) + + # post-process + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + temp_sample = sample + sample = self.temp_conv_out(sample) + sample = sample+temp_sample + sample = rearrange(sample, 'b c n h w -> (b n) c h w') + + sample = self.conv_out(sample) + sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) + return sample + + + +class AllegroAutoencoderKL3D(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size. + tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width) + chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size. + t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling + scaling_factor (`float`, *optional*, defaults to 0.13235): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling. + blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling. + load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_num: int = 4, + up_block_num: int = 4, + block_out_channels: Tuple[int] = (128,256,512,512), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 320, + tile_overlap: tuple = (120, 80), + force_upcast: bool = True, + chunk_len: int = 24, + t_over: int = 8, + scale_factor: float = 0.13235, + blocks_tempdown_li=[True, True, False, False], + blocks_tempup_li=[False, True, True, False], + load_mode = 'full', + ): + super().__init__() + + self.blocks_tempdown_li = blocks_tempdown_li + self.blocks_tempup_li = blocks_tempup_li + # pass init params to Encoder + self.load_mode = load_mode + if load_mode in ['full', 'encoder_only']: + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=down_block_num, + blocks_temp_li=blocks_tempdown_li, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + if load_mode in ['full', 'decoder_only']: + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=up_block_num, + blocks_temp_li=blocks_tempup_li, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + + # only relevant if vae tiling is enabled + sample_size = ( + sample_size[0] + if isinstance(sample_size, (list, tuple)) + else sample_size + ) + self.tile_overlap = tile_overlap + self.vae_scale_factor=[4, 8, 8] + self.scale_factor = scale_factor + self.sample_size = sample_size + self.chunk_len = chunk_len + self.t_over = t_over + + self.latent_chunk_len = self.chunk_len//4 + self.latent_t_over = self.t_over//4 + self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256) + self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192) + + + def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + KERNEL = self.kernel + STRIDE = self.stride + LOCAL_BS = local_batch_size + OUT_C = 8 + + B, C, N, H, W = input_imgs.shape + + + out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 + out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 + out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1 + + ## cut video into overlapped small cubes and batch forward + num = 0 + + out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype) + vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) + + for i in range(out_n): + for j in range(out_h): + for k in range(out_w): + n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] + h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] + w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] + video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[num%LOCAL_BS] = video_cube + + if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: + latent = self.encoder(vae_batch_input) + + if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: + out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] + else: + out_latent[num-LOCAL_BS+1:num+1] = latent + vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) + num+=1 + + ## flatten the batched out latent to videos and supress the overlapped parts + B, C, N, H, W = input_imgs.shape + + out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype) + OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 + OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 + OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2] + + for i in range(out_n): + n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0] + for j in range(out_h): + h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1] + for k in range(out_w): + w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2] + latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0)) + out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend + + ## final conv + out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w') + out_video_cube = self.quant_conv(out_video_cube) + out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B) + + posterior = DiagonalGaussianDistribution(out_video_cube) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + + def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]: + KERNEL = self.kernel + STRIDE = self.stride + + LOCAL_BS = local_batch_size + OUT_C = 3 + IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 + IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 + + B, C, N, H, W = input_latents.shape + + ## post quant conv (a mapping) + input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w') + input_latents = self.post_quant_conv(input_latents) + input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B) + + ## out tensor shape + out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 + out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1 + out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1 + + ## cut latent into overlapped small cubes and batch forward + num = 0 + decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + for i in range(out_n): + for j in range(out_h): + for k in range(out_w): + n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0] + h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1] + w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2] + latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[num%LOCAL_BS] = latent_cube + if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: + + latent = self.decoder(vae_batch_input) + + if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: + decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] + else: + decoded_cube[num-LOCAL_BS+1:num+1] = latent + vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) + num+=1 + B, C, N, H, W = input_latents.shape + + out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype) + OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2] + for i in range(out_n): + n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] + for j in range(out_h): + h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] + for k in range(out_w): + w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] + out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0)) + out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend + + out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous() + + decoded = out_video + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + encoder_local_batch_size: int = 2, + decoder_local_batch_size: int = 2, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + PyTorch random number generator. + encoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the encoder's batch inference. + decoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the decoder's batch inference. + """ + x = sample + posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + kwargs["torch_type"] = torch.float32 + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + +def prepare_for_blend(n_param, h_param, w_param, x): + n, n_max, overlap_n = n_param + h, h_max, overlap_h = h_param + w, w_max, overlap_w = w_param + if overlap_n > 0: + if n > 0: # the head overlap part decays from 0 to 1 + x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) + if n < n_max-1: # the tail overlap part decays from 1 to 0 + x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) + if h > 0: + x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) + if h < h_max-1: + x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) + if w > 0: + x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w) + if w < w_max-1: + x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w) + return x diff --git a/videosys/models/transformers/allegro_transformer_3d.py b/videosys/models/transformers/allegro_transformer_3d.py new file mode 100644 index 00000000..78870ded --- /dev/null +++ b/videosys/models/transformers/allegro_transformer_3d.py @@ -0,0 +1,2013 @@ +# Adapted from Open-Sora-Plan + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan +# -------------------------------------------------------- + + +import json +import os +from dataclasses import dataclass +from functools import partial +from importlib import import_module +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +import collections +import torch.distributed +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + SlicedAttnProcessor, + SpatialNorm, + XFormersAttnAddedKVProcessor, + XFormersAttnProcessor, +) +from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange, repeat +from torch import nn +from diffusers.models.embeddings import PixArtAlphaTextProjection + +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence, all_to_all_comm +from videosys.core.pab_mgr import ( + enable_pab, + get_mlp_output, + if_broadcast_cross, + if_broadcast_mlp, + if_broadcast_spatial, + if_broadcast_temporal, + save_mlp_output, +) +from videosys.core.parallel_mgr import ParallelManager +from videosys.utils.logging import logger +from videosys.utils.utils import batch_func + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +SPATIAL_LIST = [] +TEMPROAL_LIST = [] +CROSS_LIST = [] + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + +class CombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.use_additional_conditions = True + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): + if size.ndim == 1: + size = size[:, None] + + if size.shape[0] != batch_size: + size = size.repeat(batch_size // size.shape[0], 1) + if size.shape[0] != batch_size: + raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") + + current_batch_size, dims = size.shape[0], size.shape[1] + size = size.reshape(-1) + size_freq = self.additional_condition_proj(size).to(size.dtype) + + size_emb = embedder(size_freq) + size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) + return size_emb + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) + aspect_ratio = self.apply_condition( + aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder + ) + conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + +class PositionGetter3D(object): + """ return positions of patches """ + + def __init__(self, ): + self.cache_positions = {} + + def __call__(self, b, t, h, w, device): + if not (b, t,h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + z = torch.arange(t, device=device) + pos = torch.cartesian_prod(z, y, x) + + pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() + poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) + max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) + + self.cache_positions[b, t, h, w] = (poses, max_poses) + pos = self.cache_positions[b, t, h, w] + + return pos + + +class RoPE3D(torch.nn.Module): + + def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): + super().__init__() + self.base = freq + self.F0 = F0 + self.interpolation_scale_t = interpolation_scale_thw[0] + self.interpolation_scale_h = interpolation_scale_thw[1] + self.interpolation_scale_w = interpolation_scale_thw[2] + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + + # for (batch_size x ntokens x nheads x dim) + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 3 (t, y and x position of each token) + output: + * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) + """ + assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" + D = tokens.size(3) // 3 + poses, max_poses = positions + assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3 + cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) + cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) + cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) + # split features into three along the feature dimension, and apply rope1d on each half + t, y, x = tokens.chunk(3, dim=-1) + t = self.apply_rope1d(t, poses[0], cos_t, sin_t) + y = self.apply_rope1d(y, poses[1], cos_y, sin_y) + x = self.apply_rope1d(x, poses[2], cos_x, sin_x) + tokens = torch.cat((t, y, x), dim=-1) + return tokens + +class PatchEmbed2D(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=False, + ): + super().__init__() + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + + def forward(self, latent): + b, _, _, _, _ = latent.shape + video_latent = None + + latent = rearrange(latent, 'b c t h w -> (b t) c h w') + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + if self.layer_norm: + latent = self.norm(latent) + + latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) + video_latent = latent + + return video_latent + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + attention_mode: str = "xformers", + use_rope: bool = False, + interpolation_scale_thw=None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + if USE_PEFT_BACKEND: + linear_cls = nn.Linear + else: + linear_cls = LoRACompatibleLinear + + + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0( + attention_mode, + use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_lora = hasattr(self, "processor") + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + _remove_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to remove LoRA layers from the model. + """ + if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: + deprecate( + "set_processor to offload LoRA", + "0.26.0", + "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + ) + # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete + # We need to remove all LoRA layers + # Don't forget to remove ALL `_remove_lora` from the codebase + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False): + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = head_size if head_size is not None else self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + def _init_compress(self): + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(self.inner_dim) + + +class AttnProcessor2_0(nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=None): + super().__init__() + self.attention_mode = attention_mode + self.use_rope = use_rope + self.interpolation_scale_thw = interpolation_scale_thw + + self.parallel_manager: ParallelManager = None + + if self.use_rope: + self._init_rope(interpolation_scale_thw) + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def _init_rope(self, interpolation_scale_thw): + self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) + self.position_getter = PositionGetter3D() + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + frame: int = 8, + height: int = 16, + width: int = 16, + ) -> torch.FloatTensor: + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None and self.attention_mode == 'xformers': + attention_heads = attn.heads + if self.parallel_manager.sp_size > 1: + attention_heads = attn.heads // self.parallel_manager.sp_size + sequence_length = sequence_length * self.parallel_manager.sp_size + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, head_size=attention_heads) + attention_mask = attention_mask.view(batch_size, attention_heads, -1, attention_mask.shape[-1]) + else: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # b, num_heads, sequence//sp, head_dim + if self.parallel_manager.sp_size > 1: + # query = self.dynamic_switch(query, scatter_dim=2, gather_dim=1) + # key = self.dynamic_switch(key, scatter_dim=2, gather_dim=1) + # value = self.dynamic_switch(value, scatter_dim=2, gather_dim=1) + query, key, value = map( + lambda x: all_to_all_comm(x, self.parallel_manager.sp_group, scatter_dim=2, gather_dim=1), + [query, key, value], + ) + attn_heads = attn.heads // self.parallel_manager.sp_size + else: + attn_heads = attn.heads + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn_heads + + query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + + + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + if self.attention_mode == 'flash': + # assert attention_mask is None, 'flash-attn do not support attention_mask' + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + elif self.attention_mode == 'xformers': + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if self.parallel_manager.sp_size > 1: + hidden_states = all_to_all_comm(hidden_states, self.parallel_manager.sp_group, scatter_dim=1, gather_dim=2) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def dynamic_switch(self, x, scatter_dim, gather_dim): + + scatter_pad = 0 + gather_pad = 0 + + x = all_to_all_with_pad( + x, + self.parallel_manager.sp_group, + scatter_dim=scatter_dim, + gather_dim=gather_dim, + scatter_pad=scatter_pad, + gather_pad=gather_pad, + ) + return x + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + sa_attention_mode: str = "flash", + ca_attention_mode: str = "xformers", + use_rope: bool = False, + interpolation_scale_thw: Tuple[int] = (1, 1, 1), + block_idx: Optional[int] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=sa_attention_mode, + use_rope=use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + attention_mode=ca_attention_mode, # only xformers support attention_mask + use_rope=False, # do not position in cross attention + interpolation_scale_thw=interpolation_scale_thw, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + + if not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # pab + self.cross_last = None + self.cross_count = 0 + self.spatial_last = None + self.spatial_count = 0 + self.block_idx = block_idx + self.spatila_mlp_count = 0 + + self.parallel_manager: ParallelManager = None + + def set_cross_last(self, last_out: torch.Tensor): + self.cross_last = last_out + + def set_spatial_last(self, last_out: torch.Tensor): + self.spatial_last = last_out + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + org_timestep: Optional[torch.LongTensor] = None, + all_timesteps=None, + frame: int = None, + height: int = None, + width: int = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count) + if broadcast_spatial: + attn_output = self.spatial_last + assert self.use_ada_layer_norm_single + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + else: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + frame=frame, + height=height, + width=width, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + if enable_pab(): + self.set_spatial_last(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1. Cross-Attention + if self.attn2 is not None: + broadcast_cross, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count) + if broadcast_cross: + hidden_states = hidden_states + self.cross_last + else: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + if enable_pab(): + self.set_cross_last(attn_output) + + if enable_pab(): + broadcast_mlp, self.spatila_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp( + int(org_timestep[0]), + self.spatila_mlp_count, + self.block_idx, + all_timesteps.tolist(), + is_temporal=False, + ) + + if enable_pab() and broadcast_mlp: + ff_output = get_mlp_output( + broadcast_range, + timestep=int(org_timestep[0]), + block_idx=self.block_idx, + is_temporal=False, + ) + else: + # 2. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + if enable_pab() and broadcast_next: + save_mlp_output( + timestep=int(org_timestep[0]), + block_idx=self.block_idx, + ff_output=ff_output, + is_temporal=False, + ) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + def dynamic_switch(self, x, to_spatial_shard: bool): + if to_spatial_shard: + scatter_dim, gather_dim = 1, 2 + scatter_pad = get_pad("spatial") + gather_pad = 0 + else: + scatter_dim, gather_dim = 2, 1 + scatter_pad = 0 + gather_pad = get_pad("spatial") + x = all_to_all_with_pad( + x, + self.parallel_manager.sp_group, + scatter_dim=scatter_dim, + gather_dim=gather_dim, + scatter_pad=scatter_pad, + gather_pad=gather_pad, + ) + return x + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = CombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb( + timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None + ) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class AllegroTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + sample_size_t: Optional[int] = None, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = 1, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = 1000, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "ada_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + interpolation_scale_h: float = None, + interpolation_scale_w: float = None, + interpolation_scale_t: float = None, + use_additional_conditions: Optional[bool] = None, + sa_attention_mode: str = "flash", + ca_attention_mode: str = 'xformers', + downsampler: str = None, + use_rope: bool = True, + model_max_length: int = 300, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.interpolation_scale_t = interpolation_scale_t + self.interpolation_scale_h = interpolation_scale_h + self.interpolation_scale_w = interpolation_scale_w + self.downsampler = downsampler + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_rope = use_rope + self.model_max_length = model_max_length + self.num_layers = num_layers + self.config.hidden_size = inner_dim + + + # 1. Transformer3DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + assert in_channels is not None and patch_size is not None + + # 2. Initialize the right blocks. + # Initialize the output blocks and other projection blocks when necessary. + + assert self.config.sample_size_t is not None, "AllegroTransformer3DModel over patched input must provide sample_size_t" + assert self.config.sample_size is not None, "AllegroTransformer3DModel over patched input must provide sample_size" + #assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim" + + self.num_frames = self.config.sample_size_t + self.config.sample_size = to_2tuple(self.config.sample_size) + self.height = self.config.sample_size[0] + self.width = self.config.sample_size[1] + self.patch_size_t = self.config.patch_size_t + self.patch_size = self.config.patch_size + interpolation_scale_t = ((self.config.sample_size_t - 1) // 16 + 1) if self.config.sample_size_t % 2 == 1 else self.config.sample_size_t / 16 + interpolation_scale_t = ( + self.config.interpolation_scale_t if self.config.interpolation_scale_t is not None else interpolation_scale_t + ) + interpolation_scale = ( + self.config.interpolation_scale_h if self.config.interpolation_scale_h is not None else self.config.sample_size[0] / 30, + self.config.interpolation_scale_w if self.config.interpolation_scale_w is not None else self.config.sample_size[1] / 40, + ) + self.pos_embed = PatchEmbed2D( + num_frames=self.config.sample_size_t, + height=self.config.sample_size[0], + width=self.config.sample_size[1], + patch_size_t=self.config.patch_size_t, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + interpolation_scale_t=interpolation_scale_t, + use_abs_pos=not self.config.use_rope, + ) + interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale) + + # 3. Define transformers blocks, spatial attention + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + sa_attention_mode=sa_attention_mode, + ca_attention_mode=ca_attention_mode, + use_rope=use_rope, + interpolation_scale_thw=interpolation_scale_thw, + block_idx=d, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + + if norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + # parallel + self.parallel_manager: ParallelManager = None + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size) + + for _, module in self.named_modules(): + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + all_timesteps: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + added_cond_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AdaLayerNormSingle` + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 0. Split batch + if self.parallel_manager.cp_size > 1: + ( + hidden_states, + timestep, + encoder_hidden_states, + class_labels, + attention_mask, + encoder_attention_mask, + ) = batch_func( + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), + hidden_states, + timestep, + encoder_hidden_states, + class_labels, + attention_mask, + encoder_attention_mask, + ) + batch_size, c, frame, h, w = hidden_states.shape + org_timestep = timestep + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame+use_image_num, h, w -> a video with images + # b, 1, h, w -> only images + attention_mask = attention_mask.to(self.dtype) + attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w + + if attention_mask_vid.numel() > 0: + attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w + attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.patch_size_t, self.patch_size, self.patch_size), + stride=(self.patch_size_t, self.patch_size, self.patch_size)) + attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)') + + attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1+use_image_num, l -> a video with images + # b, 1, l -> only images + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None + + # 1. Input + frame = frame // self.patch_size_t # patchfy + # print('frame', frame) + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs + hidden_states, encoder_hidden_states_vid, \ + timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, + ) + + if self.parallel_manager.sp_size > 1: + num_patches = hidden_states.shape[1] + set_pad("spatial", num_patches, self.parallel_manager.sp_group) + hidden_states = self.split_from_second_dim(hidden_states, batch_size) + encoder_hidden_states_vid = self.split_from_second_dim(encoder_hidden_states_vid, batch_size) + # timestep_vid = repeat(timestep_vid, "b d -> (b p) d", p=self.parallel_manager.sp_size).contiguous() + # timestep_vid = self.split_from_second_dim(timestep_vid, batch_size) + # attention_mask_vid = self.split_from_second_dim(attention_mask_vid, batch_size) + # attention_mask_compress = self.split_from_second_dim(attention_mask_compress, batch_size) + + + for _, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask_vid, + encoder_hidden_states_vid, + encoder_attention_mask_vid, + timestep_vid, + cross_attention_kwargs, + class_labels, + frame=frame, + height=height, + width=width, + use_reentrant=False, + ) + + else: + + hidden_states = block( + hidden_states, + attention_mask_vid, + encoder_hidden_states_vid, + encoder_attention_mask_vid, + timestep_vid, + cross_attention_kwargs, + class_labels, + org_timestep=org_timestep, + all_timesteps=all_timesteps, + frame=frame, + height=height, + width=width, + ) + + + if self.parallel_manager.sp_size > 1: + hidden_states = self.gather_from_second_dim(hidden_states, batch_size) + + # 3. Output + output = None + if hidden_states is not None: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep_vid, + class_labels=class_labels, + embedded_timestep=embedded_timestep_vid, + num_frames=frame, + height=height, + width=width, + ) # b c t h w + + + # 3. Gather batch for data parallelism + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + def split_from_second_dim(self, x, batch_size): + # x = x.view(batch_size, -1, *x.shape[1:]) + x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("spatial")) + x = x.reshape(-1, *x.shape[1:]) + return x + + def gather_from_second_dim(self, x, batch_size): + # x = x.view(batch_size, -1, *x.shape[1:]) + x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("spatial")) + x = x.reshape(-1, *x.shape[1:]) + return x + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size): + # batch_size = hidden_states.shape[0] + hidden_states_vid = self.pos_embed(hidden_states.to(self.dtype)) + timestep_vid = None + embedded_timestep_vid = None + encoder_hidden_states_vid = None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + ) # b 6d, b d + + timestep_vid = timestep + embedded_timestep_vid = embedded_timestep + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d + encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d') + + return hidden_states_vid, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None + ): + # import ipdb;ipdb.set_trace() + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=self.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, num_frames, height, width, self.patch_size_t, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, num_frames * self.patch_size_t, height * self.patch_size, width * self.patch_size) + ) + return output diff --git a/videosys/pipelines/allegro/__init__.py b/videosys/pipelines/allegro/__init__.py new file mode 100644 index 00000000..5cc46742 --- /dev/null +++ b/videosys/pipelines/allegro/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_allegro import AllegroConfig, AllegroPABConfig, AllegroPipeline + +__all__ = ["AllegroConfig", "AllegroPipeline", "AllegroPABConfig"] diff --git a/videosys/pipelines/allegro/pipeline_allegro.py b/videosys/pipelines/allegro/pipeline_allegro.py new file mode 100644 index 00000000..cbba9d85 --- /dev/null +++ b/videosys/pipelines/allegro/pipeline_allegro.py @@ -0,0 +1,1024 @@ +# Adapted from Open-Sora-Plan + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan +# -------------------------------------------------------- + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union +import os +import random +import torch.nn.functional as F +from einops import rearrange +import imageio +import ftfy +import torch +import torch.distributed as dist +import torch.distributed +import tqdm +from bs4 import BeautifulSoup +from diffusers.utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from transformers import T5EncoderModel, T5Tokenizer + +from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps +from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput +from videosys.utils.logging import logger +from videosys.utils.utils import save_video, set_seed +from diffusers.schedulers import EulerAncestralDiscreteScheduler +from videosys.models.autoencoders.autoencoder_kl_allegro import AllegroAutoencoderKL3D +from ...models.transformers.allegro_transformer_3d import AllegroTransformer3DModel + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> # You can replace the your_path_to_model with your own path. + >>> pipe = AllegroPipeline.from_pretrained(your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True) + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).video[0] + ``` +""" + + + +class AllegroPABConfig(PABConfig): + def __init__( + self, + steps: int = 150, + spatial_broadcast: bool = True, + spatial_threshold: list = [100, 850], + spatial_range: int = 2, + temporal_broadcast: bool = True, + temporal_threshold: list = [100, 850], + temporal_range: int = 4, + cross_broadcast: bool = True, + cross_threshold: list = [100, 850], + cross_range: int = 6, + mlp_broadcast: bool = True, + mlp_spatial_broadcast_config: dict = { + 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 666: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 642: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 618: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 594: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 570: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 546: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 522: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 498: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 474: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + }, + mlp_temporal_broadcast_config: dict = { + 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 666: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 642: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 618: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 594: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 570: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 546: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 522: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 498: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 474: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + 426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2}, + }, + ): + super().__init__( + spatial_broadcast=spatial_broadcast, + spatial_threshold=spatial_threshold, + spatial_range=spatial_range, + temporal_broadcast=temporal_broadcast, + temporal_threshold=temporal_threshold, + temporal_range=temporal_range, + cross_broadcast=cross_broadcast, + cross_threshold=cross_threshold, + cross_range=cross_range, + mlp_broadcast=mlp_broadcast, + mlp_spatial_broadcast_config=mlp_spatial_broadcast_config, + mlp_temporal_broadcast_config=mlp_temporal_broadcast_config, + ) + + + +class AllegroConfig: + """ + This config is to instantiate a `AllegroPipeline` class for video generation. + + To be specific, this config will be passed to engine by `VideoSysEngine(config)`. + In the engine, it will be used to instantiate the corresponding pipeline class. + And the engine will call the `generate` function of the pipeline to generate the video. + If you want to explore the detail of generation, please refer to the pipeline class below. + + Args: + transformer (str): + The transformer model to use. Defaults to "LanguageBind/Open-Sora-Plan-v1.1.0". + ae (str): + The Autoencoder model to use. Defaults to "CausalVAEModel_4x8x8". + text_encoder (str): + The text encoder model to use. Defaults to "DeepFloyd/t5-v1_1-xxl". + num_frames (int): + The number of frames to generate. Must be one of [65, 221]. + num_gpus (int): + The number of GPUs to use. Defaults to 1. + enable_tiling (bool): + Whether to enable tiling. Defaults to True. + tile_overlap_factor (float): + The overlap factor for tiling. Defaults to 0.25. + enable_pab (bool): + Whether to enable Pyramid Attention Broadcast. Defaults to False. + pab_config (CogVideoXPABConfig): + The configuration for Pyramid Attention Broadcast. Defaults to `LattePABConfig()`. + + Examples: + ```python + from videosys import AllegroConfig, VideoSysEngine + + # num frames: 65 or 221 + # change num_gpus for multi-gpu inference + config = AllegroConfig(num_frames=65, num_gpus=1) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + video = engine.generate( + prompt=prompt, + guidance_scale=7.5, + num_inference_steps=150, + ).video[0] + engine.save_video(video, f"./outputs/{prompt}.mp4") + ``` + """ + + def __init__( + self, + model_path: str = "rhymes-ai/Allegro", + # ======= distributed ======== + num_gpus: int = 1, + # ======= memory ======= + cpu_offload: bool = False, + enable_tiling: bool = True, + tile_overlap_factor: float = 0.25, + # ======= pab ======== + enable_pab: bool = False, + pab_config: PABConfig = AllegroPABConfig(), + ): + self.model_path = model_path + self.pipeline_cls = AllegroPipeline + + # ======= distributed ======== + self.num_gpus = num_gpus + # ======= memory ======== + self.cpu_offload = cpu_offload + self.enable_tiling = enable_tiling + self.tile_overlap_factor = tile_overlap_factor + # ======= pab ======== + self.enable_pab = enable_pab + self.pab_config = pab_config + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AllegroPipeline(VideoSysPipeline): + r""" + Pipeline for text-to-image generation using Allegro. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AllegroAutoEncoderKL3D`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`AllegroTransformer3DModel`]): + A text conditioned `AllegroTransformer3DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + config: AllegroConfig, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + vae: Optional[AllegroAutoencoderKL3D] = None, + transformer: Optional[AllegroTransformer3DModel] = None, + scheduler: Optional[EulerAncestralDiscreteScheduler] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self._config = config + # init + if tokenizer is None: + tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer") + if text_encoder is None: + text_encoder = T5EncoderModel.from_pretrained(config.model_path, subfolder="text_encoder", torch_dtype=dtype) + if vae is None: + vae = AllegroAutoencoderKL3D.from_pretrained(config.model_path, subfolder="vae", torch_dtype=torch.float32) + if transformer is None: + transformer = AllegroTransformer3DModel.from_pretrained(config.model_path, subfolder="transformer", torch_dtype=dtype) + if scheduler is None: + scheduler = EulerAncestralDiscreteScheduler() + + # set eval and device + self.set_eval_and_device(device, vae, transformer) + + # pab + if config.enable_pab: + set_pab_manager(config.pab_config) + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + # cpu offload + if config.cpu_offload: + self.enable_model_cpu_offload() + else: + self.set_eval_and_device(device, text_encoder) + + # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # parallel + self._set_parallel() + + def _set_seed(self, seed): + if dist.get_world_size() == 1: + set_seed(seed) + else: + set_seed(seed, self.transformer.parallel_manager.dp_rank) + + def _set_parallel( + self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False + ): + # init sequence parallel + if sp_size is None: + sp_size = dist.get_world_size() + dp_size = 1 + else: + assert ( + dist.get_world_size() % sp_size == 0 + ), f"world_size {dist.get_world_size()} must be divisible by sp_size" + dp_size = dist.get_world_size() // sp_size + + # transformer parallel + self.transformer.enable_parallel(dp_size, sp_size, enable_cp) + + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 + else: + masked_feature = emb * mask[:, None, :, None] # 1 120 4096 + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + (math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1) + if int(num_frames) % 2 == 1 + else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]), + math.ceil(int(height) / self.vae.vae_scale_factor[1]), + math.ceil(int(width) / self.vae.vae_scale_factor[2]), + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + + return latents + + @torch.no_grad() + def generate( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + seed: int = -1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + max_sequence_length: int = 512, + verbose: bool = True, + ) -> Union[VideoSysPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_frames: (`int`, *optional*, defaults to 88): + The number controls the generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0] + height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] + width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] + update_steps(num_inference_steps) + self.check_inputs( + prompt, + num_frames, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._set_seed(seed) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x) + for i, t in progress_wrap(list(enumerate(timesteps))): + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + if prompt_attention_mask.ndim == 2: + prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + # prepare attention_mask. + # b c t h w -> b t h w + attention_mask = torch.ones_like(latent_model_input)[:, 0] + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + all_timesteps=timesteps, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latents": + video = self.decode_latents(latents) + video = video[:, :num_frames, :height, :width] + else: + video = latents + return VideoSysPipelineOutput(video=video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return VideoSysPipelineOutput(video=video) + + def decode_latents(self, latents): + video = self.vae.decode(latents.to(self.vae.dtype) / self.vae.scale_factor).sample + # b t c h w -> b t h w c + video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() + return video + + def save_video(self, video, output_path): + # save_video(video, output_path, fps=15) + if dist.is_initialized() and dist.get_rank() != 0: + return + os.makedirs(os.path.dirname(output_path), exist_ok=True) + imageio.mimwrite(output_path, video, fps=15, quality=5) \ No newline at end of file