-
Notifications
You must be signed in to change notification settings - Fork 56
[Feature] Add LoRA Inference Support for WAN Models via Flax NNX #308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Does LoRA support the I2V pipelines as well? |
9290a6e to
e1b7221
Compare
|
Added examples of I2V support |
entrpn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this implementation load multiple loras at once?
| return pipeline | ||
|
|
||
|
|
||
| class Wan2_2NnxLoraLoader(LoRABaseMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - make NNX all upper case.
| from . import lora_conversion_utils | ||
|
|
||
|
|
||
| class Wan2_1NnxLoraLoader(LoRABaseMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - make NNX all upper case.
| # ----------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _to_jax_array(v): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should dtype be set as well here based on the dtypes set in the config?
| return jnp.array(v) | ||
|
|
||
|
|
||
| def parse_lora_dict(state_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know which lora formats are supported by this function? There are a couple lora trainers out there, might want to specify in a comment or readme which ones we're specifically targeting (diffusers, or others).
| num_layers = module.kernel.shape[0] | ||
| in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] | ||
| else: | ||
| # Should not happen based on is_scanned logic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put a warning message here in case it does happen or throw an error.
| is_scanned = module.kernel.ndim == 5 | ||
|
|
||
| # If layer is not scanned, merge it using single-layer logic | ||
| if not is_scanned: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use continue instead of if else? Let's make it explicit for readability.
Summary
This PR introduces full Low-Rank Adaptation (LoRA) inference support for the WAN family of models in MaxDiffusion.
Unlike previous implementations in this codebase that rely on
flax.linen, this implementation leveragesflax.nnx. This allows for a more Pythonic, object-oriented approach to weight injection, enabling us to modify thetransformer modelin-place.Key Features
1. Transition to
flax.nnxWAN models in MaxDiffusion are implemented using
flax.nnx. To support LoRA, we implemented a native NNX loader rather than wrappinglinenmodules.nnx.iter_graph) to identify target layers (nnx.Linear,nnx.Conv,nnx.Embed,nnx.LayerNorm) and merge LoRA weights directly into the kernel values.2. Robust Weight Merging Strategy
This implementation solves several critical distributed training/inference challenges:
jax.jit): To avoidShardingMismatchandDeviceArrayerrors that occur when mixing sharded TPU weights with CPU-based LoRA weights, all merge computations (kernel + delta) are performed within JIT-compiled functions (_compute_and_add_*_jit). This ensures weight updates occur efficiently on-device across the TPU mesh.jax.dlpackwhere possible to efficiently move PyTorch tensors to JAX arrays without unnecessary memory overhead.3. Advanced LoRA Support
Beyond standard
Linearrank reduction, this PR supports:diffweights before device-side merging.diff,diff_b): Supports checkpoints that include full-parameter fine-tuning offsets (difference injections) and bias tuning, which are common in high-fidelity WAN fine-tunes.text_embedding,time_embedding, andLayerNorm/RMSNormscales and biases.4. Scanned vs. Unscanned Layers
MaxDiffusion supports enabling
jax.scanfor transformer layers via thescan_layers: Trueconfiguration flag. This improves training memory efficiency by stacking weights of repeated layers (e.g., Attention, FFN) along a new leading dimension. Since users may run inference with or without this flag enabled, this LoRA implementation is designed to transparently support both modes.The loader distinguishes between:
merge_lora()function is used, which iterates through each layer and merges weights individually via efficient, on-device JIT calls (_compute_and_add_single_jit).merge_lora_for_scanned()function is used. It detects which parameters are stacked (e.g.,kernel.ndim > 2) and which are not._compute_and_add_scanned_jit. This updates all layers in the stack at once on-device, which is significantly more efficient than merging layer-by-layer.embeddings,proj_out): It merges them individually using the single-layer JIT logic.This dual approach ensures correct weight injection whether or not layers are scanned, while maximizing performance in scanned mode through batching.
Files Added / Modified
src/maxdiffusion/models/lora_nnx.py: [NEW] Core logic. Contains the JIT merge functions,parse_lora_dict, and the graph traversal logic (merge_lora,merge_lora_for_scanned) to inject weights into NNX modules.src/maxdiffusion/loaders/wan_lora_nnx_loader.py: [NEW] Orchestrates the loading process. Handles the download of safetensors, conversion of keys, and delegation to the merge functions.src/maxdiffusion/generate_wan.py: Updated the generation pipeline to identify iflorais enabled and trigger the loading sequence before inference.src/maxdiffusion/lora_conversion_utils.py: Updatedtranslate_wan_nnx_path_to_diffusers_lorato accurately map NNX paths (including embeddings and time projections) to Diffusers-style keys.Testing