Skip to content

Conversation

@mergennachin
Copy link
Contributor

@mergennachin mergennachin commented Jan 5, 2026

Converts pointwise Linear/MatMul to Conv2d(1x1) to eliminate transpose overhead in vision and transformer models. Pattern matches through layout ops (permute, reshape) and replaces entire chains with optimized ops.

How it works

  • NCHW patterns (permute→reshape→linear→reshape→permute) → Conv2d(1x1)
  • Other patterns (NHWC, transformer) → optimized MatMul with minimal permutes
  • Safety-first: rejects ambiguous patterns rather than risk miscompilation

Key features

  • Inherits from ExportPass for pipeline compatibility
  • Unique channel axis validation (reject if multiple candidates)
  • Single consumer path (avoid breaking shared subgraphs)
  • Bias vs residual detection (only accept provably bias adds)
  • Collision-proof naming for new constants
  • Q/DQ extension points for future quantization support

Copilot AI review requested due to automatic review settings January 5, 2026 18:07
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16436

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 2 Unrelated Failures

As of commit ae870b3 with merge base 09b5bdb (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 5, 2026
@github-actions
Copy link

github-actions bot commented Jan 5, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new optimization pass that converts pointwise Linear/MatMul operations to Conv2d(1x1) or optimized MatMul operations, reducing transpose overhead in vision and transformer models. The implementation takes a conservative approach with multiple safety measures to prevent miscompilation.

Key Changes:

  • New PointwiseAffineRewritePass that pattern-matches and rewrites pointwise affine operations
  • NCHW tensors (rank 4, channel axis 1) are converted to Conv2d(1x1), eliminating all transposes
  • Other patterns are converted to optimized MatMul with explicit layout operations
  • Comprehensive test suite covering positive cases, negative cases, and numerical equivalence

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.

File Description
backends/xnnpack/_passes/pointwise_affine_pass.py Core implementation with matcher and lowering logic, includes safety measures for unique channel axis detection, single consumer validation, and explicit op allowlists
backends/xnnpack/test/passes/test_pointwise_affine.py Test suite covering NCHW/NHWC patterns, transformer patterns, bias handling, negative cases, and numerical equivalence validation
backends/xnnpack/_passes/init.py Integration of the new pass into the XNNPACK pass pipeline, positioned after ConvertToLinearPass and before FuseBatchNormPass

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 759 to 834
Future implementation would:
- Detect if reference came from a DQ node
- Insert matching Q/DQ around the new node
- Preserve quantization parameters
"""
# TODO: Implement when quantization support is needed
# if _is_dq_node(reference):
# ... insert Q/DQ pair around node ...
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment references future quantization support but provides incomplete implementation guidance. Consider either removing this placeholder comment or providing a more concrete specification of what the implementation should look like, including which Q/DQ ops to handle and how to preserve quantization parameters.

Suggested change
Future implementation would:
- Detect if reference came from a DQ node
- Insert matching Q/DQ around the new node
- Preserve quantization parameters
"""
# TODO: Implement when quantization support is needed
# if _is_dq_node(reference):
# ... insert Q/DQ pair around node ...
Expected future implementation:
- Detect whether ``reference`` is the output of a dequantize (DQ) op
or has an immediately preceding DQ op, e.g.:
* ``torch.ops.aten.dequantize.default``
* backend-specific variants such as
``torch.ops.quantized_decomposed.dequantize_per_tensor.default``
or ``dequantize_per_channel``.
- Detect whether ``reference``'s users include a quantize (Q) op, e.g.:
* ``torch.ops.quantized_decomposed.quantize_per_tensor.default``
* ``quantize_per_channel`` or other backend-specific Q ops.
- Recreate an equivalent Q/DQ pattern around ``node`` so that:
* The input to ``node`` is the output of a cloned DQ op whose
scale, zero_point, and dtype are copied from the original DQ.
* The outputs of ``node`` feed into cloned Q ops with quantization
parameters preserved from the original Q nodes.
- Preserve node ordering and metadata so that observer / calibration
information remains valid and downstream passes see an equivalent
quantization structure.
A sketch of the intended logic (for future implementers) could be:
.. code-block:: python
dq_node = _get_preceding_dq_node(reference)
q_users = _get_q_users(reference)
if dq_node is not None:
# Clone DQ with same quantization parameters
new_dq = graph.call_function(
dq_node.target,
dq_node.args,
dq_node.kwargs,
)
new_dq.meta = copy.deepcopy(dq_node.meta)
# Wire new_dq to feed node instead of reference
node.replace_input_with(reference, new_dq)
for q in q_users:
new_q = graph.call_function(
q.target,
(node,),
q.kwargs,
)
new_q.meta = copy.deepcopy(q.meta)
q.replace_all_uses_with(new_q)
The helper functions (``_get_preceding_dq_node``, ``_get_q_users``)
are intentionally not implemented here; they should encapsulate
backend-specific knowledge of which ops are treated as Q/DQ and how
to extract their quantization parameters.
"""
# Quantization-aware behavior is intentionally not implemented yet.
# When quantization support is added, this method should follow the
# algorithm sketched above to rebuild Q/DQ adjacency around ``node``.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin marked this pull request as draft January 5, 2026 18:14
@mergennachin mergennachin force-pushed the pointwise-affine-simplified branch 3 times, most recently from f0e406c to a99644f Compare January 5, 2026 19:55
@mergennachin mergennachin marked this pull request as ready for review January 5, 2026 19:56
Copilot AI review requested due to automatic review settings January 5, 2026 19:56
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 13 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +763 to +790
Create a node with properly copied metadata.
"""
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _create_node method creates nodes with metadata but doesn't validate that new_val has compatible shape/dtype with the operation being created. If the computed new_val shape is incorrect due to a bug in the lowering logic, the metadata would be wrong but the node would still be created. Consider adding assertions to validate that the created node's expected output matches the provided new_val.

Suggested change
Create a node with properly copied metadata.
"""
Create a node with properly copied metadata.
We additionally validate that the provided ``new_val`` is compatible with
the reference node's expected output (when such information is available
in ``meta_like.meta["val"]``). This helps catch bugs in the lowering
logic where the computed value has an unexpected shape or dtype.
"""
# Validate shape/dtype compatibility with reference metadata, if present.
ref_val = meta_like.meta.get("val", None)
if isinstance(ref_val, torch.Tensor):
# Ensure we are propagating tensor metadata and that the lowered
# tensor agrees in shape and dtype with the reference.
assert isinstance(
new_val, torch.Tensor
), f"_create_node expected new_val to be a Tensor, got {type(new_val)}"
assert (
new_val.shape == ref_val.shape
), f"_create_node shape mismatch: new_val.shape={tuple(new_val.shape)} != ref_val.shape={tuple(ref_val.shape)}"
assert (
new_val.dtype == ref_val.dtype
), f"_create_node dtype mismatch: new_val.dtype={new_val.dtype} != ref_val.dtype={ref_val.dtype}"

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin force-pushed the pointwise-affine-simplified branch from a99644f to 308edd0 Compare January 5, 2026 20:22
Copy link
Member

@GregoryComer GregoryComer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this up. I have a few concerns with the soundness of the matching logic, as it currently stands. I'm good to merge it if we can add tests to verify the cases called out in the comments.

As a nit, there are a lot of duplicated utility functions in this pass. It would be good to consolidate on the existing utils, where possible. I've left comments on most of them to mention the existing util.

Do you have performance numbers for a motivating case? I believe that this should be a net benefit, mostly just curious on the magnitude. In the long-term, zero cost views + transpose optimizations should give us this "for free" without this pass, but I'm happy to merge this for now.

return tensor, b
return None, None

def _extract_add_bias(self, add_node: fx.Node, pred: fx.Node, cout: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't hurt to have this code, but is this a pattern that we actually see in models? Having an explicit bias add node instead of using linear/addmm?

return result


def _is_dq_node(node: fx.Node) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Duplicates existing is_dequant function in backends/xnnpack/utils/quant_utils.py.

return node.op == "call_function" and _op_in(node.target, DQ_OPS)


def _get_dq_source(node: fx.Node) -> Optional[fx.Node]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems a little odd. It's only handling dq nodes that take a direct graph input (placeholder node)? The quantized source, is, in general, the output of some op node and not a placeholder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove quantization logic for now.

return t.reshape(cout), bias_arg
return None, None

def _get_param(self, node: fx.Node) -> Optional[torch.Tensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Duplicates existing get_param_tensor function in backends/xnnpack/utils/utils.py.

.export()
.to_edge()
.run_passes(self.PassStage)
.check_node_count({MM_OP: 1}) # mm exists - pass worked!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an E2E test to verify that the linear/mm is delegated post-transform for this case? Also, what is the intent of this re-write?

def _dtype(node: fx.Node) -> torch.dtype:
"""Get dtype from node."""
val = node.meta.get("val")
return val.dtype if val is not None and hasattr(val, "dtype") else torch.float32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This logic seems suspect - it defaults to float32? If it's actually a valid tensor, it should always have a dtype. Falling back to float32 silently seems error-prone.

if val is None or not hasattr(val, "shape"):
return None
try:
return tuple(int(s) for s in val.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We can use is_shape_dynamic in exir/backend/utils.py to check for symbolic dimensions, instead of relying on it throwing on dynamic dims, which is slow.

"""
return torch.empty(shape, dtype=dtype)

def _maybe_rewrap_qdq(self, node: fx.Node, reference: fx.Node) -> fx.Node:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we either fully implement handling for quantized subgraphs (w/ tests) or clean up this logic? I'm hesitant to include a half-implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, it would be good to at a minimum, add some tests to verify that the pass doesn't break the lowering when it receives a quantized graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remove the quantization code, and also added a test that quantized graph lowering is not broken (the pass doesn't recognize).

@GregoryComer
Copy link
Member

Also, are we confident that the LLM-related failures are unrelated to these changes? I see some are non-CPU, but there are some CPU failures. The Phi3 job, for example, shows wrong outputs. It could be a flaky job, but it would be good to confirm this.

Copilot AI review requested due to automatic review settings January 6, 2026 17:24
@mergennachin mergennachin force-pushed the pointwise-affine-simplified branch from 308edd0 to f066580 Compare January 6, 2026 17:24
@mergennachin
Copy link
Contributor Author

Addressing @GregoryComer 's feedbacks

  1. Re-using utility functions

  2. Removed Quantization-Related Code

  • Removed is_dequant import from quant_utils
  • Removed _get_dq_source function (unused, had incorrect placeholder-only logic)
  • Removed _maybe_rewrap_qdq method (was a no-op with TODO comment)
  • Simplified _create_node to return node directly without wrapper call
  1. Added Test for Reviewer's Example (lines 570-609)
  • Added test_channel_axis_not_at_last_position_should_not_match in TestMatcherNegative
  • Tests the specific pattern: [1, 2, 3, 4] -> reshape(12, 2) -> linear -> reshape(1, 4, 3, 4)
  • Verifies the pass correctly rejects patterns where channel axis is not at last position without a permute
  1. Added Quantization Compatibility Tests (lines 898-1048)
  • New TestQuantizationCompatibility class with 4 tests:
    • test_quantized_nchw_linear_not_matched - Explicitly verifies pass does NOT create conv2d/mm for quantized graphs
    • test_quantized_nhwc_linear_not_matched - Same for NHWC pattern
    • test_quantized_nchw_linear_full_pipeline - Verifies full pipeline still works with quantization
    • test_quantized_linear_relu_full_pipeline - Verifies linear+relu with quantization
  1. Fixed _dtype Function (lines 141-148)
  • Changed from silently defaulting to float32 to raising ValueError if node lacks valid tensor metadata
  • Prevents masking bugs from silent fallback
  1. Added NHWC E2E Delegation Test (lines 732-782)
  • Added test_nhwc_delegates_to_xnnpack in TestXNNPACKIntegration
  • Verifies linear -> mm rewrite and XNNPACK delegation for NHWC patterns
  • Includes detailed docstring explaining the intent of the NHWC rewrite
  1. Added Weight Cleanup Verification Test (lines 784-850)
  • Added test_original_weights_not_in_transformed_graph in TestXNNPACKIntegration
  • Verifies original linear weight/bias placeholders are removed after transformation
  • Confirms new conv weight placeholders are present instead

@mergennachin mergennachin force-pushed the pointwise-affine-simplified branch from f066580 to 70d9f32 Compare January 6, 2026 17:27
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 591 to 596
logger.debug(
"PointwiseAffineRewritePass: _trace_forward hit depth limit %d at node %s",
_MAX_TRACE_DEPTH,
start.name,
)
return None, set(), None, None, None
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same depth limit logging pattern appears in both _trace_back and _trace_forward. Consider extracting this into a helper method to reduce code duplication and ensure consistent logging behavior.

Copilot uses AI. Check for mistakes.
Conservative implementation that converts pointwise Linear/MatMul
to Conv2d(1x1) or optimized MatMul. Designed to avoid false positives
at the cost of potentially missing some valid patterns.

Safety measures to prevent miscompiles:
1. Unique channel axis - bail if multiple axes could match cin
2. Single consumer path - require linear has exactly one user
3. Explicit op allowlist - no substring matching for layout ops
4. Symbolic shape rejection - only accept concrete integer shapes
5. Edge op handling - check underlying _op for edge dialect

Test results: 9/9 functional tests pass
- Pattern matching (NCHW, NHWC, transformer, separate bias)
- Negative cases (gather, spatial mixing, broken restore)
- Numerical equivalence
@mergennachin mergennachin force-pushed the pointwise-affine-simplified branch from 70d9f32 to c522126 Compare January 6, 2026 18:24
@meta-codesync
Copy link

meta-codesync bot commented Jan 6, 2026

@mergennachin has imported this pull request. If you are a Meta employee, you can view this in D90197421.

Copilot AI review requested due to automatic review settings January 7, 2026 16:15
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def _trace_back(self, start: fx.Node) -> Tuple[Optional[fx.Node], Set[fx.Node]]:
"""Trace backward through layout ops."""
visited, cur = set(), start
for _ in range(_MAX_TRACE_DEPTH):
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'for' statement has a redundant 'else' as no 'break' is present in the body.

Copilot uses AI. Check for mistakes.
activation = None
seen_layout_op = False

for _ in range(_MAX_TRACE_DEPTH):
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'for' statement has a redundant 'else' as no 'break' is present in the body.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants