-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add z-image-omni-base implementation #12857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for the PR! I left some comments, mainly I'm just trying to simplify the code in the transfomer as much as possible by removing unused code path etc
let me know what you think:)
| SEQ_MULTI_OF = 32 | ||
|
|
||
|
|
||
| class TimestepEmbedder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a #Coped from ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by merging into one transformer_z_image.
| return t_emb | ||
|
|
||
|
|
||
| class ZSingleStreamAttnProcessor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a Copied from ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before.
|
|
||
|
|
||
| @maybe_allow_in_graph | ||
| class ZImageTransformerBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class ZImageTransformerBlock(nn.Module): | |
| class ZOmniImageTransformerBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignored due to merging into one.
| adaln_clean: Optional[torch.Tensor] = None, | ||
| ): | ||
| if self.modulation: | ||
| if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current codebase in 4c14cf3, it would be needed. But could be optimized by re-design in next pr.
| else: | ||
| # Original global modulation | ||
| assert adaln_input is not None | ||
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) | ||
| gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() | ||
| scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| else: | |
| # Original global modulation | |
| assert adaln_input is not None | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) | |
| gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() | |
| scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp |
can we remove this code path if it is not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When merging into one, it would be needed.
| patch_size=2, | ||
| f_patch_size=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| patch_size=2, | |
| f_patch_size=1, |
I don't think these two arguments are used in the pipeline, can we remove them? could simplify the code a lot I think -> it can help remove the ModuleDict pattern too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #12857 (comment)
| assert patch_size in self.all_patch_size | ||
| assert f_patch_size in self.all_f_patch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert patch_size in self.all_patch_size | |
| assert f_patch_size in self.all_f_patch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #12857 (comment)
| cap_noise_mask, | ||
| siglip_noise_mask | ||
| ) = self.patchify_and_embed( | ||
| x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask | |
| x, cap_feats, siglip_feats, image_noise_mask |
| grids = torch.meshgrid(axes, indexing="ij") | ||
| return torch.stack(grids, dim=-1) | ||
|
|
||
| def patchify_and_embed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method is really hard to follow here, do you think it's possible to break it into 3?
like
for x, cap_feat, siglip_feat in zip(all_x, all_cap_feats, all_siglip_feats):
cap_item_cu_len = 1
cap_padded, ..., cap_item_cu_len = self.patchify_and_embed_cap(...)
all_cap_padded.append(cap_padded)
x_padded, ..., cap_item_cu_len = self.patchify_and_embed_x(..., cap_item_cu_len)
all_x_padded.append(x_padded)
...
siglip_padded, ..., cap_item_cu_len = self.patchify_and_embed_siglip(...,cap_item_cu_len )
all_siglip_padded.append(siglip_padded)| assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) | ||
| x_max_item_seqlen = max(x_item_seqlens) | ||
|
|
||
| x = torch.cat(x, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hopefully we can simplify to x = self.x_embedder(x) here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #12857 (comment)
|
this gets forgotten all the time lol diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index db0268a2a..2c36ce36b 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -119,7 +119,7 @@ from .stable_diffusion_xl import (
)
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
-from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
+from .z_image import ZImageImg2ImgPipeline, ZImageOmniPipeline, ZImagePipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -164,6 +164,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
("z-image", ZImagePipeline),
+ ("z-image-omni", ZImageOmniPipeline),
]
)
|
Thanks for useful comments yiyi, I would review these and fix these modifications today ~ 😊 |
|
Hi @yiyixuxu, this branch is ready to merge 😊. This would solve most of your concerns before (including copied xxx, cond_latents xxx, auto_pipeline, styling) by merging into one transformer model and incorporating new feats of main branch upon the start point. More feature updates and code cleanify would be update in another pr, you could review current status and leave some comments, and I would updates more asap ~ Thanks !!! |
Thanks!! Fixed in 4c14cf3 ~ |
- Add select_per_token function for per-token value selection - Separate adaptive modulation logic - Cleanify t_noisy/clean variable naming - Move image_noise_mask handler from forward to pipeline
70bc2c8 to
5bc676c
Compare
|
Ready, let's merge it for 732c527 ~ 😊 |
What does this PR do?
This PR adds support for the Z-Image-Omni-Base model. Z-Image-Omni-Base is a foundation model designed for easy fine-tuning, unifying core capabilities in both image generation and editing to empower the community to explore custom development and innovative applications.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@yiyixuxu @apolinario @JerryWu-code