Skip to content

Commit 50ff966

Browse files
PhylliidaPhylliidaleejet
authored
feat: add seamless texture generation support (#914)
* global bool * reworked circular to global flag * cleaner implementation of tiling support in sd cpp * cleaned rope * working simplified but still need wraps * Further clean of rope * resolve flux conflict * switch to pad op circular only * Set ggml to most recent * Revert ggml temp * Update ggml to most recent * Revert unneded flux change * move circular flag to the GGMLRunnerContext * Pass through circular param in all places where conv is called * fix of constant and minor cleanup * Added back --circular option * Conv2d circular in vae and various models * Fix temporal padding for qwen image and other vaes * Z Image circular tiling * x and y axis seamless only * First attempt at chroma seamless x and y * refactor into pure x and y, almost there * Fix crash on chroma * Refactor into cleaner variable choices * Removed redundant set_circular_enabled * Sync ggml * simplify circular parameter * format code * no need to perform circular pad on the clip * simplify circular_axes setting * unify function naming * remove unnecessary member variables * simplify rope --------- Co-authored-by: Phylliida <phylliidadev@gmail.com> Co-authored-by: leejet <leejet714@gmail.com>
1 parent 88ec9d3 commit 50ff966

File tree

15 files changed

+375
-79
lines changed

15 files changed

+375
-79
lines changed

common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class DownSampleBlock : public GGMLBlock {
2828
if (vae_downsample) {
2929
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
3030

31-
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
31+
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
3232
x = conv->forward(ctx, x);
3333
} else {
3434
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);

denoiser.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,18 +366,18 @@ struct KLOptimalScheduler : SigmaScheduler {
366366

367367
for (uint32_t i = 0; i < n; ++i) {
368368
// t goes from 0.0 to 1.0
369-
float t = static_cast<float>(i) / static_cast<float>(n-1);
369+
float t = static_cast<float>(i) / static_cast<float>(n - 1);
370370

371371
// Interpolate in the angle domain
372372
float angle = t * alpha_min + (1.0f - t) * alpha_max;
373373

374374
// Convert back to sigma
375375
sigmas.push_back(std::tan(angle));
376-
}
376+
}
377377

378378
// Append the final zero to sigma
379379
sigmas.push_back(0.0f);
380-
380+
381381
return sigmas;
382382
}
383383
};

diffusion_model.hpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ struct DiffusionModel {
3737
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
3838
virtual size_t get_params_buffer_size() = 0;
3939
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
40-
virtual int64_t get_adm_in_channels() = 0;
41-
virtual void set_flash_attn_enabled(bool enabled) = 0;
40+
virtual int64_t get_adm_in_channels() = 0;
41+
virtual void set_flash_attn_enabled(bool enabled) = 0;
42+
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
4243
};
4344

4445
struct UNetModel : public DiffusionModel {
@@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel {
8788
unet.set_flash_attention_enabled(enabled);
8889
}
8990

91+
void set_circular_axes(bool circular_x, bool circular_y) override {
92+
unet.set_circular_axes(circular_x, circular_y);
93+
}
94+
9095
bool compute(int n_threads,
9196
DiffusionParams diffusion_params,
9297
struct ggml_tensor** output = nullptr,
@@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel {
148153
mmdit.set_flash_attention_enabled(enabled);
149154
}
150155

156+
void set_circular_axes(bool circular_x, bool circular_y) override {
157+
mmdit.set_circular_axes(circular_x, circular_y);
158+
}
159+
151160
bool compute(int n_threads,
152161
DiffusionParams diffusion_params,
153162
struct ggml_tensor** output = nullptr,
@@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel {
210219
flux.set_flash_attention_enabled(enabled);
211220
}
212221

222+
void set_circular_axes(bool circular_x, bool circular_y) override {
223+
flux.set_circular_axes(circular_x, circular_y);
224+
}
225+
213226
bool compute(int n_threads,
214227
DiffusionParams diffusion_params,
215228
struct ggml_tensor** output = nullptr,
@@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel {
277290
wan.set_flash_attention_enabled(enabled);
278291
}
279292

293+
void set_circular_axes(bool circular_x, bool circular_y) override {
294+
wan.set_circular_axes(circular_x, circular_y);
295+
}
296+
280297
bool compute(int n_threads,
281298
DiffusionParams diffusion_params,
282299
struct ggml_tensor** output = nullptr,
@@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel {
343360
qwen_image.set_flash_attention_enabled(enabled);
344361
}
345362

363+
void set_circular_axes(bool circular_x, bool circular_y) override {
364+
qwen_image.set_circular_axes(circular_x, circular_y);
365+
}
366+
346367
bool compute(int n_threads,
347368
DiffusionParams diffusion_params,
348369
struct ggml_tensor** output = nullptr,
@@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel {
406427
z_image.set_flash_attention_enabled(enabled);
407428
}
408429

430+
void set_circular_axes(bool circular_x, bool circular_y) override {
431+
z_image.set_circular_axes(circular_x, circular_y);
432+
}
433+
409434
bool compute(int n_threads,
410435
DiffusionParams diffusion_params,
411436
struct ggml_tensor** output = nullptr,

examples/common/common.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,10 @@ struct SDContextParams {
449449
bool diffusion_conv_direct = false;
450450
bool vae_conv_direct = false;
451451

452+
bool circular = false;
453+
bool circular_x = false;
454+
bool circular_y = false;
455+
452456
bool chroma_use_dit_mask = true;
453457
bool chroma_use_t5_mask = false;
454458
int chroma_t5_mask_pad = 1;
@@ -605,6 +609,18 @@ struct SDContextParams {
605609
"--vae-conv-direct",
606610
"use ggml_conv2d_direct in the vae model",
607611
true, &vae_conv_direct},
612+
{"",
613+
"--circular",
614+
"enable circular padding for convolutions",
615+
true, &circular},
616+
{"",
617+
"--circularx",
618+
"enable circular RoPE wrapping on x-axis (width) only",
619+
true, &circular_x},
620+
{"",
621+
"--circulary",
622+
"enable circular RoPE wrapping on y-axis (height) only",
623+
true, &circular_y},
608624
{"",
609625
"--chroma-disable-dit-mask",
610626
"disable dit mask for chroma",
@@ -868,6 +884,9 @@ struct SDContextParams {
868884
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
869885
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
870886
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
887+
<< " circular: " << (circular ? "true" : "false") << ",\n"
888+
<< " circular_x: " << (circular_x ? "true" : "false") << ",\n"
889+
<< " circular_y: " << (circular_y ? "true" : "false") << ",\n"
871890
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
872891
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
873892
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
@@ -928,6 +947,8 @@ struct SDContextParams {
928947
taesd_preview,
929948
diffusion_conv_direct,
930949
vae_conv_direct,
950+
circular || circular_x,
951+
circular || circular_y,
931952
force_sdxl_vae_conv_scale,
932953
chroma_use_dit_mask,
933954
chroma_use_t5_mask,

flux.hpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -860,14 +860,14 @@ namespace Flux {
860860
}
861861
}
862862

863-
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
863+
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
864864
struct ggml_tensor* x) {
865865
int64_t W = x->ne[0];
866866
int64_t H = x->ne[1];
867867

868868
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
869869
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
870-
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
870+
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
871871
return x;
872872
}
873873

@@ -893,11 +893,11 @@ namespace Flux {
893893
return x;
894894
}
895895

896-
struct ggml_tensor* process_img(struct ggml_context* ctx,
896+
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
897897
struct ggml_tensor* x) {
898898
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
899899
x = pad_to_patch_size(ctx, x);
900-
x = patchify(ctx, x);
900+
x = patchify(ctx->ggml_ctx, x);
901901
return x;
902902
}
903903

@@ -1076,7 +1076,7 @@ namespace Flux {
10761076
int pad_h = (patch_size - H % patch_size) % patch_size;
10771077
int pad_w = (patch_size - W % patch_size) % patch_size;
10781078

1079-
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
1079+
auto img = pad_to_patch_size(ctx, x);
10801080
auto orig_img = img;
10811081

10821082
if (params.chroma_radiance_params.use_patch_size_32) {
@@ -1150,16 +1150,16 @@ namespace Flux {
11501150
int pad_h = (patch_size - H % patch_size) % patch_size;
11511151
int pad_w = (patch_size - W % patch_size) % patch_size;
11521152

1153-
auto img = process_img(ctx->ggml_ctx, x);
1153+
auto img = process_img(ctx, x);
11541154
uint64_t img_tokens = img->ne[1];
11551155

11561156
if (params.version == VERSION_FLUX_FILL) {
11571157
GGML_ASSERT(c_concat != nullptr);
11581158
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
11591159
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
11601160

1161-
masked = process_img(ctx->ggml_ctx, masked);
1162-
mask = process_img(ctx->ggml_ctx, mask);
1161+
masked = process_img(ctx, masked);
1162+
mask = process_img(ctx, mask);
11631163

11641164
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
11651165
} else if (params.version == VERSION_FLEX_2) {
@@ -1168,21 +1168,21 @@ namespace Flux {
11681168
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
11691169
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
11701170

1171-
masked = process_img(ctx->ggml_ctx, masked);
1172-
mask = process_img(ctx->ggml_ctx, mask);
1173-
control = process_img(ctx->ggml_ctx, control);
1171+
masked = process_img(ctx, masked);
1172+
mask = process_img(ctx, mask);
1173+
control = process_img(ctx, control);
11741174

11751175
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
11761176
} else if (params.version == VERSION_FLUX_CONTROLS) {
11771177
GGML_ASSERT(c_concat != nullptr);
11781178

1179-
auto control = process_img(ctx->ggml_ctx, c_concat);
1179+
auto control = process_img(ctx, c_concat);
11801180
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
11811181
}
11821182

11831183
if (ref_latents.size() > 0) {
11841184
for (ggml_tensor* ref : ref_latents) {
1185-
ref = process_img(ctx->ggml_ctx, ref);
1185+
ref = process_img(ctx, ref);
11861186
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
11871187
}
11881188
}
@@ -1472,6 +1472,8 @@ namespace Flux {
14721472
increase_ref_index,
14731473
flux_params.ref_index_scale,
14741474
flux_params.theta,
1475+
circular_y_enabled,
1476+
circular_x_enabled,
14751477
flux_params.axes_dim);
14761478
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
14771479
// LOG_DEBUG("pos_len %d", pos_len);

0 commit comments

Comments
 (0)