diff --git a/common.hpp b/common.hpp index 74b218ab7..b17c11e35 100644 --- a/common.hpp +++ b/common.hpp @@ -28,7 +28,7 @@ class DownSampleBlock : public GGMLBlock { if (vae_downsample) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); + x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); x = conv->forward(ctx, x); } else { auto conv = std::dynamic_pointer_cast(blocks["op"]); diff --git a/denoiser.hpp b/denoiser.hpp index fc5230d7b..7a8242e7d 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -366,18 +366,18 @@ struct KLOptimalScheduler : SigmaScheduler { for (uint32_t i = 0; i < n; ++i) { // t goes from 0.0 to 1.0 - float t = static_cast(i) / static_cast(n-1); + float t = static_cast(i) / static_cast(n - 1); // Interpolate in the angle domain float angle = t * alpha_min + (1.0f - t) * alpha_max; // Convert back to sigma sigmas.push_back(std::tan(angle)); - } + } // Append the final zero to sigma sigmas.push_back(0.0f); - + return sigmas; } }; diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 8c741fdc4..c4e0ba1d0 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -37,8 +37,9 @@ struct DiffusionModel { virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter){}; - virtual int64_t get_adm_in_channels() = 0; - virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual int64_t get_adm_in_channels() = 0; + virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; }; struct UNetModel : public DiffusionModel { @@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel { unet.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + unet.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel { mmdit.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + mmdit.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel { flux.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + flux.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel { wan.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + wan.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel { qwen_image.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + qwen_image.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel { z_image.set_flash_attention_enabled(enabled); } + void set_circular_axes(bool circular_x, bool circular_y) override { + z_image.set_circular_axes(circular_x, circular_y); + } + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, diff --git a/examples/common/common.hpp b/examples/common/common.hpp index b81dab784..5168730aa 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -449,6 +449,10 @@ struct SDContextParams { bool diffusion_conv_direct = false; bool vae_conv_direct = false; + bool circular = false; + bool circular_x = false; + bool circular_y = false; + bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; int chroma_t5_mask_pad = 1; @@ -605,6 +609,18 @@ struct SDContextParams { "--vae-conv-direct", "use ggml_conv2d_direct in the vae model", true, &vae_conv_direct}, + {"", + "--circular", + "enable circular padding for convolutions", + true, &circular}, + {"", + "--circularx", + "enable circular RoPE wrapping on x-axis (width) only", + true, &circular_x}, + {"", + "--circulary", + "enable circular RoPE wrapping on y-axis (height) only", + true, &circular_y}, {"", "--chroma-disable-dit-mask", "disable dit mask for chroma", @@ -868,6 +884,9 @@ struct SDContextParams { << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" + << " circular: " << (circular ? "true" : "false") << ",\n" + << " circular_x: " << (circular_x ? "true" : "false") << ",\n" + << " circular_y: " << (circular_y ? "true" : "false") << ",\n" << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" @@ -928,6 +947,8 @@ struct SDContextParams { taesd_preview, diffusion_conv_direct, vae_conv_direct, + circular || circular_x, + circular || circular_y, force_sdxl_vae_conv_scale, chroma_use_dit_mask, chroma_use_t5_mask, diff --git a/flux.hpp b/flux.hpp index 7ce263569..bcff04cfa 100644 --- a/flux.hpp +++ b/flux.hpp @@ -860,14 +860,14 @@ namespace Flux { } } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -893,11 +893,11 @@ namespace Flux { return x; } - struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* process_img(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + x = patchify(ctx->ggml_ctx, x); return x; } @@ -1076,7 +1076,7 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = pad_to_patch_size(ctx->ggml_ctx, x); + auto img = pad_to_patch_size(ctx, x); auto orig_img = img; if (params.chroma_radiance_params.use_patch_size_32) { @@ -1150,7 +1150,7 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = process_img(ctx->ggml_ctx, x); + auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; if (params.version == VERSION_FLUX_FILL) { @@ -1158,8 +1158,8 @@ namespace Flux { ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - masked = process_img(ctx->ggml_ctx, masked); - mask = process_img(ctx->ggml_ctx, mask); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); } else if (params.version == VERSION_FLEX_2) { @@ -1168,21 +1168,21 @@ namespace Flux { ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); - masked = process_img(ctx->ggml_ctx, masked); - mask = process_img(ctx->ggml_ctx, mask); - control = process_img(ctx->ggml_ctx, control); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); + control = process_img(ctx, control); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); } else if (params.version == VERSION_FLUX_CONTROLS) { GGML_ASSERT(c_concat != nullptr); - auto control = process_img(ctx->ggml_ctx, c_concat); + auto control = process_img(ctx, c_concat); img = ggml_concat(ctx->ggml_ctx, img, control, 0); } if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx->ggml_ctx, ref); + ref = process_img(ctx, ref); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -1472,6 +1472,8 @@ namespace Flux { increase_ref_index, flux_params.ref_index_scale, flux_params.theta, + circular_y_enabled, + circular_x_enabled, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); diff --git a/ggml b/ggml index f5425c0ee..3e9f2ba3b 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit f5425c0ee5e582a7d64411f06139870bff3e52e0 +Subproject commit 3e9f2ba3b934c20b26873b3c60dbf41b116978ff diff --git a/ggml_extend.hpp b/ggml_extend.hpp index fcaa92c9e..3849562f0 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -993,6 +994,48 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, return x; } +__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad_ext(struct ggml_context* ctx, + struct ggml_tensor* x, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3, + bool circular_x = false, + bool circular_y = false) { + if (circular_x && circular_y) { + return ggml_pad_ext_circular(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + } + + if (circular_x && (lp0 != 0 || rp0 != 0)) { + x = ggml_pad_ext_circular(ctx, x, lp0, rp0, 0, 0, 0, 0, 0, 0); + lp0 = rp0 = 0; + } + if (circular_y && (lp1 != 0 || rp1 != 0)) { + x = ggml_pad_ext_circular(ctx, x, 0, 0, lp1, rp1, 0, 0, 0, 0); + lp1 = rp1 = 0; + } + + if (lp0 != 0 || rp0 != 0 || lp1 != 0 || rp1 != 0 || lp2 != 0 || rp2 != 0 || lp3 != 0 || rp3 != 0) { + x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + } + return x; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad(struct ggml_context* ctx, + struct ggml_tensor* x, + int p0, + int p1, + int p2 = 0, + int p3 = 0, + bool circular_x = false, + bool circular_y = false) { + return ggml_ext_pad_ext(ctx, x, p0, p0, p1, p1, p2, p2, p3, p3, circular_x, circular_y); +} + // w: [OC,IC, KH, KW] // x: [N, IC, IH, IW] // b: [OC,] @@ -1001,20 +1044,29 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* w, struct ggml_tensor* b, - int s0 = 1, - int s1 = 1, - int p0 = 0, - int p1 = 0, - int d0 = 1, - int d1 = 1, - bool direct = false, - float scale = 1.f) { + int s0 = 1, + int s1 = 1, + int p0 = 0, + int p1 = 0, + int d0 = 1, + int d1 = 1, + bool direct = false, + bool circular_x = false, + bool circular_y = false, + float scale = 1.f) { if (scale != 1.f) { x = ggml_scale(ctx, x, scale); } if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); } + + if ((p0 != 0 || p1 != 0) && (circular_x || circular_y)) { + x = ggml_ext_pad(ctx, x, p0, p1, 0, 0, circular_x, circular_y); + p0 = 0; + p1 = 0; + } + if (direct) { x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1); } else { @@ -1521,14 +1573,16 @@ struct WeightAdapter { float scale = 1.f; } linear; struct { - int s0 = 1; - int s1 = 1; - int p0 = 0; - int p1 = 0; - int d0 = 1; - int d1 = 1; - bool direct = false; - float scale = 1.f; + int s0 = 1; + int s1 = 1; + int p0 = 0; + int p1 = 0; + int d0 = 1; + int d1 = 1; + bool direct = false; + bool circular_x = false; + bool circular_y = false; + float scale = 1.f; } conv2d; }; virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0; @@ -1546,6 +1600,8 @@ struct GGMLRunnerContext { ggml_context* ggml_ctx = nullptr; bool flash_attn_enabled = false; bool conv2d_direct_enabled = false; + bool circular_x_enabled = false; + bool circular_y_enabled = false; std::shared_ptr weight_adapter = nullptr; }; @@ -1582,6 +1638,8 @@ struct GGMLRunner { bool flash_attn_enabled = false; bool conv2d_direct_enabled = false; + bool circular_x_enabled = false; + bool circular_y_enabled = false; void alloc_params_ctx() { struct ggml_init_params params; @@ -1859,6 +1917,8 @@ struct GGMLRunner { runner_ctx.backend = runtime_backend; runner_ctx.flash_attn_enabled = flash_attn_enabled; runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled; + runner_ctx.circular_x_enabled = circular_x_enabled; + runner_ctx.circular_y_enabled = circular_y_enabled; runner_ctx.weight_adapter = weight_adapter; return runner_ctx; } @@ -2003,6 +2063,11 @@ struct GGMLRunner { conv2d_direct_enabled = enabled; } + void set_circular_axes(bool circular_x, bool circular_y) { + circular_x_enabled = circular_x; + circular_y_enabled = circular_y; + } + void set_weight_adapter(const std::shared_ptr& adapter) { weight_adapter = adapter; } @@ -2266,15 +2331,17 @@ class Conv2d : public UnaryBlock { } if (ctx->weight_adapter) { WeightAdapter::ForwardParams forward_params; - forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; - forward_params.conv2d.s0 = stride.second; - forward_params.conv2d.s1 = stride.first; - forward_params.conv2d.p0 = padding.second; - forward_params.conv2d.p1 = padding.first; - forward_params.conv2d.d0 = dilation.second; - forward_params.conv2d.d1 = dilation.first; - forward_params.conv2d.direct = ctx->conv2d_direct_enabled; - forward_params.conv2d.scale = scale; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; + forward_params.conv2d.s0 = stride.second; + forward_params.conv2d.s1 = stride.first; + forward_params.conv2d.p0 = padding.second; + forward_params.conv2d.p1 = padding.first; + forward_params.conv2d.d0 = dilation.second; + forward_params.conv2d.d1 = dilation.first; + forward_params.conv2d.direct = ctx->conv2d_direct_enabled; + forward_params.conv2d.circular_x = ctx->circular_x_enabled; + forward_params.conv2d.circular_y = ctx->circular_y_enabled; + forward_params.conv2d.scale = scale; return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); } return ggml_ext_conv_2d(ctx->ggml_ctx, @@ -2288,6 +2355,8 @@ class Conv2d : public UnaryBlock { dilation.second, dilation.first, ctx->conv2d_direct_enabled, + ctx->circular_x_enabled, + ctx->circular_y_enabled, scale); } }; diff --git a/lora.hpp b/lora.hpp index b847f044c..7d83ec5cd 100644 --- a/lora.hpp +++ b/lora.hpp @@ -599,6 +599,8 @@ struct LoraModel : public GGMLRunner { forward_params.conv2d.d0, forward_params.conv2d.d1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); if (lora_mid) { lx = ggml_ext_conv_2d(ctx, @@ -612,6 +614,8 @@ struct LoraModel : public GGMLRunner { 1, 1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); } lx = ggml_ext_conv_2d(ctx, @@ -625,6 +629,8 @@ struct LoraModel : public GGMLRunner { 1, 1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); } @@ -779,6 +785,8 @@ struct MultiLoraAdapter : public WeightAdapter { forward_params.conv2d.d0, forward_params.conv2d.d1, forward_params.conv2d.direct, + forward_params.conv2d.circular_x, + forward_params.conv2d.circular_y, forward_params.conv2d.scale); } for (auto& lora_model : lora_models) { diff --git a/mmdit.hpp b/mmdit.hpp index 38bdc2e74..eeb74a268 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -983,4 +983,4 @@ struct MMDiTRunner : public GGMLRunner { } }; -#endif \ No newline at end of file +#endif diff --git a/qwen_image.hpp b/qwen_image.hpp index eeb823d50..ed1c98308 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -354,14 +354,14 @@ namespace Qwen { blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels)); } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -387,10 +387,10 @@ namespace Qwen { return x; } - struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* process_img(GGMLRunnerContext* ctx, struct ggml_tensor* x) { x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + x = patchify(ctx->ggml_ctx, x); return x; } @@ -466,12 +466,12 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx->ggml_ctx, x); + auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx->ggml_ctx, ref); + ref = process_img(ctx, ref); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -565,6 +565,8 @@ namespace Qwen { ref_latents, increase_ref_index, qwen_image_params.theta, + circular_y_enabled, + circular_x_enabled, qwen_image_params.axes_dim); int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); @@ -684,4 +686,4 @@ namespace Qwen { } // namespace name -#endif // __QWEN_IMAGE_HPP__ \ No newline at end of file +#endif // __QWEN_IMAGE_HPP__ diff --git a/rope.hpp b/rope.hpp index 12047e3e9..4e6136c11 100644 --- a/rope.hpp +++ b/rope.hpp @@ -1,6 +1,8 @@ #ifndef __ROPE_HPP__ #define __ROPE_HPP__ +#include +#include #include #include "ggml_extend.hpp" @@ -39,7 +41,10 @@ namespace Rope { return flat_vec; } - __STATIC_INLINE__ std::vector> rope(const std::vector& pos, int dim, int theta) { + __STATIC_INLINE__ std::vector> rope(const std::vector& pos, + int dim, + int theta, + const std::vector& axis_wrap_dims = {}) { assert(dim % 2 == 0); int half_dim = dim / 2; @@ -47,14 +52,31 @@ namespace Rope { std::vector omega(half_dim); for (int i = 0; i < half_dim; ++i) { - omega[i] = 1.0 / std::pow(theta, scale[i]); + omega[i] = 1.0f / std::pow(theta, scale[i]); } int pos_size = pos.size(); std::vector> out(pos_size, std::vector(half_dim)); for (int i = 0; i < pos_size; ++i) { for (int j = 0; j < half_dim; ++j) { - out[i][j] = pos[i] * omega[j]; + float angle = pos[i] * omega[j]; + if (!axis_wrap_dims.empty()) { + size_t wrap_size = axis_wrap_dims.size(); + // mod batch size since we only store this for one item in the batch + size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0; + int wrap_dim = axis_wrap_dims[wrap_idx]; + if (wrap_dim > 0) { + constexpr float TWO_PI = 6.28318530717958647692f; + float cycles = omega[j] * wrap_dim / TWO_PI; + // closest periodic harmonic, necessary to ensure things neatly tile + // without this round, things don't tile at the boundaries and you end up + // with the model knowing what is "center" + float rounded = std::round(cycles); + angle = pos[i] * TWO_PI * rounded / wrap_dim; + } + } + + out[i][j] = angle; } } @@ -89,9 +111,9 @@ namespace Rope { int patch_size, int bs, int axes_dim_num, - int index = 0, - int h_offset = 0, - int w_offset = 0, + int index = 0, + int h_offset = 0, + int w_offset = 0, bool scale_rope = false) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; @@ -146,7 +168,8 @@ namespace Rope { __STATIC_INLINE__ std::vector embed_nd(const std::vector>& ids, int bs, int theta, - const std::vector& axes_dim) { + const std::vector& axes_dim, + const std::vector>& wrap_dims = {}) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; int num_axes = axes_dim.size(); @@ -161,7 +184,12 @@ namespace Rope { std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); int offset = 0; for (int i = 0; i < num_axes; ++i) { - std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + std::vector axis_wrap_dims; + if (!wrap_dims.empty() && i < (int)wrap_dims.size()) { + axis_wrap_dims = wrap_dims[i]; + } + std::vector> rope_emb = + rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] for (int b = 0; b < bs; ++b) { for (int j = 0; j < pos_len; ++j) { for (int k = 0; k < rope_emb[0].size(); ++k) { @@ -251,6 +279,8 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, int theta, + bool circular_h, + bool circular_w, const std::vector& axes_dim) { std::vector> ids = gen_flux_ids(h, w, @@ -262,7 +292,47 @@ namespace Rope { ref_latents, increase_ref_index, ref_index_scale); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; + if (h_len > 0 && w_len > 0) { + size_t pos_len = ids.size() / bs; + wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); + size_t cursor = context_len; // text first + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = w_len; + } + } + cursor += img_tokens; + // reference latents + for (ggml_tensor* ref : ref_latents) { + if (ref == nullptr) { + continue; + } + int ref_h = static_cast(ref->ne[1]); + int ref_w = static_cast(ref->ne[0]); + int ref_h_l = (ref_h + (patch_size / 2)) / patch_size; + int ref_w_l = (ref_w + (patch_size / 2)) / patch_size; + size_t ref_tokens = static_cast(ref_h_l) * static_cast(ref_w_l); + for (size_t token_i = 0; token_i < ref_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = ref_h_l; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = ref_w_l; + } + } + cursor += ref_tokens; + } + } + } + return embed_nd(ids, bs, theta, axes_dim, wrap_dims); } __STATIC_INLINE__ std::vector> gen_qwen_image_ids(int h, @@ -301,9 +371,57 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, int theta, + bool circular_h, + bool circular_w, const std::vector& axes_dim) { std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + // This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int pad_h = (patch_size - (h % patch_size)) % patch_size; + int pad_w = (patch_size - (w % patch_size)) % patch_size; + int h_len = (h + pad_h) / patch_size; + int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { + const size_t total_tokens = ids.size(); + // Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic. + wrap_dims.assign(axes_dim.size(), std::vector(total_tokens / bs, 0)); + size_t cursor = context_len; // ignore text tokens + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = w_len; + } + } + cursor += img_tokens; + // For each reference image, store wrap sizes as well + for (ggml_tensor* ref : ref_latents) { + if (ref == nullptr) { + continue; + } + int ref_h = static_cast(ref->ne[1]); + int ref_w = static_cast(ref->ne[0]); + int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size; + int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size; + int ref_h_len = (ref_h + ref_pad_h) / patch_size; + int ref_w_len = (ref_w + ref_pad_w) / patch_size; + size_t ref_n_tokens = static_cast(ref_h_len) * static_cast(ref_w_len); + for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = ref_h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = ref_w_len; + } + } + cursor += ref_n_tokens; + } + } + } + return embed_nd(ids, bs, theta, axes_dim, wrap_dims); } __STATIC_INLINE__ std::vector> gen_vid_ids(int t, @@ -440,9 +558,33 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, int theta, + bool circular_h, + bool circular_w, const std::vector& axes_dim) { std::vector> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int pad_h = (patch_size - (h % patch_size)) % patch_size; + int pad_w = (patch_size - (w % patch_size)) % patch_size; + int h_len = (h + pad_h) / patch_size; + int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { + size_t pos_len = ids.size() / bs; + wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); + size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding) + size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][cursor + token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][cursor + token_i] = w_len; + } + } + } + } + + return embed_nd(ids, bs, theta, axes_dim, wrap_dims); } __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 74519938d..24516a9bb 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -405,6 +405,10 @@ class StableDiffusionGGML { vae_decode_only = false; } + if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) { + LOG_INFO("Using circular padding for convolutions"); + } + bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; { @@ -705,6 +709,20 @@ class StableDiffusionGGML { } pmid_model->get_param_tensors(tensors, "pmid"); } + + diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); + if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); + } + if (control_net) { + control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); + } + if (first_stage_model) { + first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); + } + if (tae_first_stage) { + tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); + } } struct ggml_init_params params; @@ -1519,7 +1537,7 @@ class StableDiffusionGGML { } std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); - float cfg_scale = guidance.txt_cfg; + float cfg_scale = guidance.txt_cfg; if (cfg_scale < 1.f) { if (cfg_scale == 0.f) { // Diffusers follow the convention from the original paper @@ -2559,6 +2577,8 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false; sd_ctx_params->diffusion_flash_attn = false; + sd_ctx_params->circular_x = false; + sd_ctx_params->circular_y = false; sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_t5_mask_pad = 1; @@ -2598,6 +2618,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "keep_control_net_on_cpu: %s\n" "keep_vae_on_cpu: %s\n" "diffusion_flash_attn: %s\n" + "circular_x: %s\n" + "circular_y: %s\n" "chroma_use_dit_mask: %s\n" "chroma_use_t5_mask: %s\n" "chroma_t5_mask_pad: %d\n", @@ -2627,6 +2649,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->diffusion_flash_attn), + BOOL_STR(sd_ctx_params->circular_x), + BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->chroma_use_dit_mask), BOOL_STR(sd_ctx_params->chroma_use_t5_mask), sd_ctx_params->chroma_t5_mask_pad); diff --git a/stable-diffusion.h b/stable-diffusion.h index adb65a1d2..30583ea13 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -189,6 +189,8 @@ typedef struct { bool tae_preview_only; bool diffusion_conv_direct; bool vae_conv_direct; + bool circular_x; + bool circular_y; bool force_sdxl_vae_conv_scale; bool chroma_use_dit_mask; bool chroma_use_t5_mask; diff --git a/wan.hpp b/wan.hpp index 75333bfe1..31ecf33f7 100644 --- a/wan.hpp +++ b/wan.hpp @@ -75,7 +75,7 @@ namespace WAN { lp2 -= (int)cache_x->ne[2]; } - x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); + x = ggml_ext_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), 0, 0, 0, @@ -206,9 +206,9 @@ namespace WAN { } else if (mode == "upsample3d") { x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "downsample2d") { - x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); + x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); } else if (mode == "downsample3d") { - x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); + x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); } x = resample_1->forward(ctx, x); x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w) @@ -1826,7 +1826,7 @@ namespace WAN { } } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; @@ -1835,8 +1835,7 @@ namespace WAN { int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); - x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] - + ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -1986,14 +1985,14 @@ namespace WAN { int64_t T = x->ne[2]; int64_t C = x->ne[3]; - x = pad_to_patch_size(ctx->ggml_ctx, x); + x = pad_to_patch_size(ctx, x); int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); if (time_dim_concat != nullptr) { - time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat); + time_dim_concat = pad_to_patch_size(ctx, time_dim_concat); x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); } diff --git a/z_image.hpp b/z_image.hpp index bc554f177..af8d57e04 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -324,14 +324,14 @@ namespace ZImage { blocks["final_layer"] = std::make_shared(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); } - struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size; int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -357,10 +357,10 @@ namespace ZImage { return x; } - struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* process_img(GGMLRunnerContext* ctx, struct ggml_tensor* x) { x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + x = patchify(ctx->ggml_ctx, x); return x; } @@ -473,12 +473,12 @@ namespace ZImage { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx->ggml_ctx, x); + auto img = process_img(ctx, x); uint64_t n_img_token = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx->ggml_ctx, ref); + ref = process_img(ctx, ref); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -552,6 +552,8 @@ namespace ZImage { ref_latents, increase_ref_index, z_image_params.theta, + circular_y_enabled, + circular_x_enabled, z_image_params.axes_dim); int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len);