-
Notifications
You must be signed in to change notification settings - Fork 794
Add PointwiseAffineRewritePass for XNNPACK #16436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16436
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 2 Unrelated FailuresAs of commit ae870b3 with merge base 09b5bdb ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a new optimization pass that converts pointwise Linear/MatMul operations to Conv2d(1x1) or optimized MatMul operations, reducing transpose overhead in vision and transformer models. The implementation takes a conservative approach with multiple safety measures to prevent miscompilation.
Key Changes:
- New
PointwiseAffineRewritePassthat pattern-matches and rewrites pointwise affine operations - NCHW tensors (rank 4, channel axis 1) are converted to Conv2d(1x1), eliminating all transposes
- Other patterns are converted to optimized MatMul with explicit layout operations
- Comprehensive test suite covering positive cases, negative cases, and numerical equivalence
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
| backends/xnnpack/_passes/pointwise_affine_pass.py | Core implementation with matcher and lowering logic, includes safety measures for unique channel axis detection, single consumer validation, and explicit op allowlists |
| backends/xnnpack/test/passes/test_pointwise_affine.py | Test suite covering NCHW/NHWC patterns, transformer patterns, bias handling, negative cases, and numerical equivalence validation |
| backends/xnnpack/_passes/init.py | Integration of the new pass into the XNNPACK pass pipeline, positioned after ConvertToLinearPass and before FuseBatchNormPass |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Future implementation would: | ||
| - Detect if reference came from a DQ node | ||
| - Insert matching Q/DQ around the new node | ||
| - Preserve quantization parameters | ||
| """ | ||
| # TODO: Implement when quantization support is needed | ||
| # if _is_dq_node(reference): | ||
| # ... insert Q/DQ pair around node ... |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment references future quantization support but provides incomplete implementation guidance. Consider either removing this placeholder comment or providing a more concrete specification of what the implementation should look like, including which Q/DQ ops to handle and how to preserve quantization parameters.
| Future implementation would: | |
| - Detect if reference came from a DQ node | |
| - Insert matching Q/DQ around the new node | |
| - Preserve quantization parameters | |
| """ | |
| # TODO: Implement when quantization support is needed | |
| # if _is_dq_node(reference): | |
| # ... insert Q/DQ pair around node ... | |
| Expected future implementation: | |
| - Detect whether ``reference`` is the output of a dequantize (DQ) op | |
| or has an immediately preceding DQ op, e.g.: | |
| * ``torch.ops.aten.dequantize.default`` | |
| * backend-specific variants such as | |
| ``torch.ops.quantized_decomposed.dequantize_per_tensor.default`` | |
| or ``dequantize_per_channel``. | |
| - Detect whether ``reference``'s users include a quantize (Q) op, e.g.: | |
| * ``torch.ops.quantized_decomposed.quantize_per_tensor.default`` | |
| * ``quantize_per_channel`` or other backend-specific Q ops. | |
| - Recreate an equivalent Q/DQ pattern around ``node`` so that: | |
| * The input to ``node`` is the output of a cloned DQ op whose | |
| scale, zero_point, and dtype are copied from the original DQ. | |
| * The outputs of ``node`` feed into cloned Q ops with quantization | |
| parameters preserved from the original Q nodes. | |
| - Preserve node ordering and metadata so that observer / calibration | |
| information remains valid and downstream passes see an equivalent | |
| quantization structure. | |
| A sketch of the intended logic (for future implementers) could be: | |
| .. code-block:: python | |
| dq_node = _get_preceding_dq_node(reference) | |
| q_users = _get_q_users(reference) | |
| if dq_node is not None: | |
| # Clone DQ with same quantization parameters | |
| new_dq = graph.call_function( | |
| dq_node.target, | |
| dq_node.args, | |
| dq_node.kwargs, | |
| ) | |
| new_dq.meta = copy.deepcopy(dq_node.meta) | |
| # Wire new_dq to feed node instead of reference | |
| node.replace_input_with(reference, new_dq) | |
| for q in q_users: | |
| new_q = graph.call_function( | |
| q.target, | |
| (node,), | |
| q.kwargs, | |
| ) | |
| new_q.meta = copy.deepcopy(q.meta) | |
| q.replace_all_uses_with(new_q) | |
| The helper functions (``_get_preceding_dq_node``, ``_get_q_users``) | |
| are intentionally not implemented here; they should encapsulate | |
| backend-specific knowledge of which ops are treated as Q/DQ and how | |
| to extract their quantization parameters. | |
| """ | |
| # Quantization-aware behavior is intentionally not implemented yet. | |
| # When quantization support is added, this method should follow the | |
| # algorithm sketched above to rebuild Q/DQ adjacency around ``node``. |
f0e406c to
a99644f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 13 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Create a node with properly copied metadata. | ||
| """ |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _create_node method creates nodes with metadata but doesn't validate that new_val has compatible shape/dtype with the operation being created. If the computed new_val shape is incorrect due to a bug in the lowering logic, the metadata would be wrong but the node would still be created. Consider adding assertions to validate that the created node's expected output matches the provided new_val.
| Create a node with properly copied metadata. | |
| """ | |
| Create a node with properly copied metadata. | |
| We additionally validate that the provided ``new_val`` is compatible with | |
| the reference node's expected output (when such information is available | |
| in ``meta_like.meta["val"]``). This helps catch bugs in the lowering | |
| logic where the computed value has an unexpected shape or dtype. | |
| """ | |
| # Validate shape/dtype compatibility with reference metadata, if present. | |
| ref_val = meta_like.meta.get("val", None) | |
| if isinstance(ref_val, torch.Tensor): | |
| # Ensure we are propagating tensor metadata and that the lowered | |
| # tensor agrees in shape and dtype with the reference. | |
| assert isinstance( | |
| new_val, torch.Tensor | |
| ), f"_create_node expected new_val to be a Tensor, got {type(new_val)}" | |
| assert ( | |
| new_val.shape == ref_val.shape | |
| ), f"_create_node shape mismatch: new_val.shape={tuple(new_val.shape)} != ref_val.shape={tuple(ref_val.shape)}" | |
| assert ( | |
| new_val.dtype == ref_val.dtype | |
| ), f"_create_node dtype mismatch: new_val.dtype={new_val.dtype} != ref_val.dtype={ref_val.dtype}" |
a99644f to
308edd0
Compare
GregoryComer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this up. I have a few concerns with the soundness of the matching logic, as it currently stands. I'm good to merge it if we can add tests to verify the cases called out in the comments.
As a nit, there are a lot of duplicated utility functions in this pass. It would be good to consolidate on the existing utils, where possible. I've left comments on most of them to mention the existing util.
Do you have performance numbers for a motivating case? I believe that this should be a net benefit, mostly just curious on the magnitude. In the long-term, zero cost views + transpose optimizations should give us this "for free" without this pass, but I'm happy to merge this for now.
| return tensor, b | ||
| return None, None | ||
|
|
||
| def _extract_add_bias(self, add_node: fx.Node, pred: fx.Node, cout: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't hurt to have this code, but is this a pattern that we actually see in models? Having an explicit bias add node instead of using linear/addmm?
| return result | ||
|
|
||
|
|
||
| def _is_dq_node(node: fx.Node) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Duplicates existing is_dequant function in backends/xnnpack/utils/quant_utils.py.
| return node.op == "call_function" and _op_in(node.target, DQ_OPS) | ||
|
|
||
|
|
||
| def _get_dq_source(node: fx.Node) -> Optional[fx.Node]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function seems a little odd. It's only handling dq nodes that take a direct graph input (placeholder node)? The quantized source, is, in general, the output of some op node and not a placeholder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove quantization logic for now.
| return t.reshape(cout), bias_arg | ||
| return None, None | ||
|
|
||
| def _get_param(self, node: fx.Node) -> Optional[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Duplicates existing get_param_tensor function in backends/xnnpack/utils/utils.py.
| .export() | ||
| .to_edge() | ||
| .run_passes(self.PassStage) | ||
| .check_node_count({MM_OP: 1}) # mm exists - pass worked! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an E2E test to verify that the linear/mm is delegated post-transform for this case? Also, what is the intent of this re-write?
| def _dtype(node: fx.Node) -> torch.dtype: | ||
| """Get dtype from node.""" | ||
| val = node.meta.get("val") | ||
| return val.dtype if val is not None and hasattr(val, "dtype") else torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This logic seems suspect - it defaults to float32? If it's actually a valid tensor, it should always have a dtype. Falling back to float32 silently seems error-prone.
| if val is None or not hasattr(val, "shape"): | ||
| return None | ||
| try: | ||
| return tuple(int(s) for s in val.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: We can use is_shape_dynamic in exir/backend/utils.py to check for symbolic dimensions, instead of relying on it throwing on dynamic dims, which is slow.
| """ | ||
| return torch.empty(shape, dtype=dtype) | ||
|
|
||
| def _maybe_rewrap_qdq(self, node: fx.Node, reference: fx.Node) -> fx.Node: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we either fully implement handling for quantized subgraphs (w/ tests) or clean up this logic? I'm hesitant to include a half-implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, it would be good to at a minimum, add some tests to verify that the pass doesn't break the lowering when it receives a quantized graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remove the quantization code, and also added a test that quantized graph lowering is not broken (the pass doesn't recognize).
|
Also, are we confident that the LLM-related failures are unrelated to these changes? I see some are non-CPU, but there are some CPU failures. The Phi3 job, for example, shows wrong outputs. It could be a flaky job, but it would be good to confirm this. |
308edd0 to
f066580
Compare
|
Addressing @GregoryComer 's feedbacks
|
f066580 to
70d9f32
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| logger.debug( | ||
| "PointwiseAffineRewritePass: _trace_forward hit depth limit %d at node %s", | ||
| _MAX_TRACE_DEPTH, | ||
| start.name, | ||
| ) | ||
| return None, set(), None, None, None |
Copilot
AI
Jan 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same depth limit logging pattern appears in both _trace_back and _trace_forward. Consider extracting this into a helper method to reduce code duplication and ensure consistent logging behavior.
Conservative implementation that converts pointwise Linear/MatMul to Conv2d(1x1) or optimized MatMul. Designed to avoid false positives at the cost of potentially missing some valid patterns. Safety measures to prevent miscompiles: 1. Unique channel axis - bail if multiple axes could match cin 2. Single consumer path - require linear has exactly one user 3. Explicit op allowlist - no substring matching for layout ops 4. Symbolic shape rejection - only accept concrete integer shapes 5. Edge op handling - check underlying _op for edge dialect Test results: 9/9 functional tests pass - Pattern matching (NCHW, NHWC, transformer, separate bias) - Negative cases (gather, spatial mixing, broken restore) - Numerical equivalence
70d9f32 to
c522126
Compare
|
@mergennachin has imported this pull request. If you are a Meta employee, you can view this in D90197421. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _trace_back(self, start: fx.Node) -> Tuple[Optional[fx.Node], Set[fx.Node]]: | ||
| """Trace backward through layout ops.""" | ||
| visited, cur = set(), start | ||
| for _ in range(_MAX_TRACE_DEPTH): |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'for' statement has a redundant 'else' as no 'break' is present in the body.
| activation = None | ||
| seen_layout_op = False | ||
|
|
||
| for _ in range(_MAX_TRACE_DEPTH): |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'for' statement has a redundant 'else' as no 'break' is present in the body.
Converts pointwise Linear/MatMul to Conv2d(1x1) to eliminate transpose overhead in vision and transformer models. Pattern matches through layout ops (permute, reshape) and replaces entire chains with optimized ops.
How it works
Key features