Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 19, 2025

Description

This PR addresses several issues related to CPU offloading performance and compatibility.

1. CPU Overhead Reduction

This PR reduces CPU overhead through multiple optimizations:

  • Skip processing for non-offloaded layers: Tensors are no longer processed when the layer is known to be non-offloaded (the case for non-manual synchronization when it is known which layers are offloaded in advance). Manual synchronization overhead may be addressed in future work.
  • Remove expensive checks in __torch_function__ hook: Previously costly validation checks have been eliminated.
  • Skip offloading small tensors: Small tensors are now excluded from offloading to avoid overhead.

2. Out of Memory Error with Fused Optimizer and DTensor

PyTorch introduced JAX-like DTensor, and some workloads use our fused optimizer with this tensor type. The previous implementation used .empty_like, which works correctly for standard tensors but does not respect sharding for DTensor—resulting in full tensors being created on each device. This has been fixed by switching to .empty with explicit shape specification.

3. Synchronization Issues When Offloading Small Tensors

For grouped tensors, allocation is performed in bulk, requiring an all-or-nothing offloading approach. This meant small tensors like scales were also offloaded, which caused issues with comm-gemm overlap when CUDA_DEVICE_MAX_CONNECTIONS=1 was set. In these cases, tensors were small enough that SMs were used for copying instead of copy engines, leading to synchronization problems.

Fixes:

  • Added a minimum tensor size threshold for offloading to mitigate this issue.
  • Added an option to disable bulk allocation for grouped tensor offloading (enabled automatically when offloading is active).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pre-commit-ci bot and others added 4 commits December 19, 2025 14:22
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Comment on lines 453 to 455
# Only offload tensors with at least 256k elements (~1MB for float32)
if t.numel() < 256 * 1024:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand, this is the reason we need to expose an option to disable bulk allocation in split_quantize? Bulk-allocated tensors hold on to memory untill all are deallocated, but this condition means that some small tensor might keep a large memory block alive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And we cannot offload small tensors, because it causes the synchronization of compute/communication operations when CUDA_DEVICE_MAX_CONNECTIONS=1 is set - which is needed by the comm/gemm overlap.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL and others added 6 commits January 8, 2026 13:50
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review January 9, 2026 15:57
@pggPL
Copy link
Collaborator Author

pggPL commented Jan 9, 2026

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Summary

This PR addresses CPU offloading performance and compatibility issues through multiple coordinated changes:

Key Changes

  1. CPU Overhead Reduction: Added layer-level offloading skipping in DefaultOffloadSynchronizer.push_tensor() to avoid processing tensors when a layer won't be offloaded. Also conditionally guards mark_not_offload() calls.

  2. QuantizedTensor Offloading Support: Extended CPU offloading to handle QuantizedTensor types by decomposing them into component tensors, offloading each component recursively, and reconstructing them during reload.

  3. DTensor Compatibility: Changed from torch.empty(shape, device) to torch.empty_like(tensor, dtype) in FusedAdam to properly respect DTensor sharding annotations.

  4. Small Tensor Offloading Threshold: Added 256K element minimum threshold to prevent offloading of tiny tensors that would cause synchronization issues with CUDA_DEVICE_MAX_CONNECTIONS=1.

  5. Bulk Allocation Control: Added disable_bulk_allocation parameter to split_quantize() C++ function, enabled when CPU offloading is active to avoid grouping small tensors with large ones.

Files Modified

  • transformer_engine/pytorch/cpu_offload.py: Core offloading logic with QuantizedTensor support
  • transformer_engine/pytorch/optimizers/fused_adam.py: DTensor-aware state initialization
  • transformer_engine/pytorch/module/linear.py: Conditional mark_not_offload() guarding
  • transformer_engine/pytorch/module/grouped_linear.py: disable_bulk_allocation parameter passing
  • transformer_engine/pytorch/quantized_tensor.py: Removed CPU operation validation that blocked offloading
  • C++ files: Added disable_bulk_allocation parameter and logic
  • Tests: Updated tensor sizes to ensure components exceed 256K threshold

Issues Found

Critical Issue: The FusedAdam DTensor fix is incomplete. When a QuantizedTensor parameter wraps a DTensor, calling dequantize() creates a new plain tensor that loses DTensor sharding metadata. The fix should use the original parameter directly with .empty_like().

Type Annotation Issue: DefaultOffloadSynchronizer.push_tensor() return type annotation doesn't reflect actual return type (missing tuple[list, list]).

Behavioral Change: Default for retain_pinned_cpu_buffers changed from False to True, affecting memory usage patterns and performance characteristics. This change is not documented in the PR description.

Confidence Score: 2/5

  • This PR has a critical bug that breaks DTensor parameter handling in FusedAdam, and incomplete type annotations. The DTensor fix is fundamentally broken for QuantizedTensor parameters.
  • The PR contains one critical logic bug that makes the DTensor fix incomplete/incorrect. The FusedAdam change dequantizes QuantizedTensor parameters, which destroys DTensor sharding information that the empty_like() call is meant to preserve. Additionally, return type annotations are incomplete, and an undocumented behavioral default change (retain_pinned_cpu_buffers) could affect existing users. While the core CPU offloading improvements are sound, these issues need resolution before merging.
  • transformer_engine/pytorch/optimizers/fused_adam.py (critical DTensor bug), transformer_engine/pytorch/cpu_offload.py (type annotation and default value)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/optimizers/fused_adam.py 2/5 FusedAdam state initialization broken for QuantizedTensor parameters with DTensor sharding. Calling dequantize() loses DTensor metadata that should be preserved with empty_like().
transformer_engine/pytorch/cpu_offload.py 3/5 Multiple changes: QuantizedTensor offloading support added, CPU overhead reduced with layer skipping optimization, but return type annotation mismatch and behavioral default change. Default retain_pinned_cpu_buffers changed from False to True.
transformer_engine/pytorch/quantized_tensor.py 4/5 Removed CPU operation validation checks. This change is safe as it enables QuantizedTensor component tensors to be offloaded to CPU while preserving the original tensor type checks elsewhere.
transformer_engine/pytorch/module/grouped_linear.py 5/5 Disables bulk allocation in split_quantize when CPU offloading is active, addressing synchronization issues with small tensors and CUDA_DEVICE_MAX_CONNECTIONS=1.
transformer_engine/pytorch/csrc/extensions/cast.cpp 5/5 Correctly implements disable_bulk_allocation feature by wrapping bulk allocation logic in conditional, allowing fallback to unfused allocation when needed.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant CPU_Offload as CPU Offload System
    participant QuantTensor as QuantizedTensor
    participant OffloadState as OffloadableLayerState
    
    User->>Linear: Forward pass with QuantizedTensor
    Linear->>OffloadState: push_tensor(quantized_tensor)
    alt QuantizedTensor Case
        OffloadState->>QuantTensor: detach()
        QuantTensor->>QuantTensor: prepare_for_saving()
        Note over QuantTensor: Decompose into [data, transpose, scale_inv]
        loop For each component
            OffloadState->>OffloadState: push_tensor(component)
            OffloadState->>OffloadState: Check if numel() > 256K
        end
        OffloadState-->>Linear: Return (push_results, tensor_objs)
    else Regular Tensor Case
        OffloadState->>OffloadState: Check if offload needed
        OffloadState-->>Linear: Return tensor_id or tensor
    end
    
    Note over OffloadState: Backward pass
    Linear->>OffloadState: pop_tensor(tensor_id_or_tuple)
    alt QuantizedTensor Tuple Case
        OffloadState->>OffloadState: pop_tensor(each_component)
        OffloadState->>QuantTensor: restore_from_saved(components)
        OffloadState-->>Linear: Return reconstructed QuantizedTensor
    else Regular Case
        OffloadState-->>Linear: Return tensor
    end
Loading

Comment on lines 375 to +381
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantization of QuantizedTensor parameters here breaks DTensor sharding preservation. When param is a QuantizedTensor wrapping a DTensor, calling dequantize() creates a new plain tensor that loses the DTensor sharding metadata. This defeats the purpose of using .empty_like() to preserve DTensor sharding.

The fix should use the original parameter directly without dequantization, since .empty_like() respects the sharding annotations of the input tensor regardless of whether it's quantized:

Suggested change
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
data = torch.zeros_like(param, dtype=torch.int16)
...
data = torch.empty_like(param, dtype=dtype)

Alternatively, if dequantization is necessary for some reason, the sharding information from the original parameter should be explicitly preserved.

Comment on lines +631 to +636
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation for DefaultOffloadSynchronizer.push_tensor() is incomplete and doesn't match the base class signature. This method can return a tuple (when processing QuantizedTensors), but the annotation only specifies int | torch.Tensor.

Suggested change
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)

num_layers: Optional[int] = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value for retain_pinned_cpu_buffers has been changed from False to True. While this may improve performance by reusing CPU buffers across iterations, this is a significant behavioral change that affects memory usage patterns and is not mentioned in the PR description. Consider documenting this change in the commit message and PR description, as it could impact existing users' performance characteristics.

If backward compatibility is important, consider keeping the default as False or provide a migration path for existing code.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants