Skip to content

Commit bbbe43a

Browse files
RahulC7facebook-github-bot
authored andcommitted
Adding Tests for CadenceDefaultQuantizer (pytorch#16364)
Summary: As title. Note in the next diff, we finish off testing this, we don't test all the patterns in this diff, because some changes need to be made. Reviewed By: zonglinpeng, hsharma35 Differential Revision: D88899457
1 parent 63f41ea commit bbbe43a

File tree

1 file changed

+100
-1
lines changed

1 file changed

+100
-1
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
CadenceWithSoftmaxQuantizer,
3434
qconfig_A16,
3535
qconfig_A8W8,
36+
qconfig_A8W8sym,
3637
)
3738
from executorch.exir.pass_base import NodeMetadata
3839
from parameterized import parameterized
@@ -53,7 +54,6 @@
5354
# Quantizers intentionally excluded from annotation testing.
5455
# These should be explicitly justified when added.
5556
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
56-
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
5757
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
5858
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
5959
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
@@ -137,6 +137,61 @@
137137
# For add: both inputs are activations
138138
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
139139
),
140+
# CadenceDefaultQuantizer test cases
141+
(
142+
"default_matmul_A8W8",
143+
lambda self: self._build_matmul_graph(),
144+
CadenceDefaultQuantizer(),
145+
torch.ops.aten.matmul.default,
146+
qconfig_A8W8.output_activation,
147+
# For matmul: both inputs are activations
148+
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
149+
),
150+
(
151+
"default_linear_A8W8",
152+
lambda self: self._build_linear_graph(),
153+
CadenceDefaultQuantizer(),
154+
torch.ops.aten.linear.default,
155+
qconfig_A8W8.output_activation,
156+
# For linear: [input_activation, weight]
157+
[qconfig_A8W8.input_activation, qconfig_A8W8.weight],
158+
),
159+
(
160+
"default_conv1d_A8W8sym",
161+
lambda self: self._build_conv1d_graph(),
162+
CadenceDefaultQuantizer(),
163+
torch.ops.aten.conv1d.default,
164+
qconfig_A8W8sym.output_activation,
165+
# For conv1d: [input_activation, weight] with symmetric weights
166+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
167+
),
168+
(
169+
"default_conv2d_A8W8sym",
170+
lambda self: self._build_conv2d_graph(),
171+
CadenceDefaultQuantizer(),
172+
torch.ops.aten.conv2d.default,
173+
qconfig_A8W8sym.output_activation,
174+
# For conv2d: [input_activation, weight] with symmetric weights
175+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
176+
),
177+
(
178+
"default_bmm_A8W8",
179+
lambda self: self._build_bmm_graph(),
180+
CadenceDefaultQuantizer(),
181+
torch.ops.aten.bmm.default,
182+
qconfig_A8W8.output_activation,
183+
# For bmm: both inputs are activations
184+
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
185+
),
186+
(
187+
"default_relu_A8W8",
188+
lambda self: self._build_relu_graph(),
189+
CadenceDefaultQuantizer(),
190+
torch.ops.aten.relu.default,
191+
qconfig_A8W8.output_activation,
192+
# For relu: only input_activation
193+
[qconfig_A8W8.input_activation],
194+
),
140195
]
141196

142197
# Derive the set of tested quantizer classes from the test cases.
@@ -309,6 +364,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
309364
self.assertEqual(len(add_nodes), 1, "Should find exactly one add node")
310365
return gm, add_nodes[0]
311366

367+
def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
368+
"""Build a simple graph with a bmm (batch matrix multiply) operation."""
369+
builder = GraphBuilder()
370+
# BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p)
371+
x = builder.placeholder("x", torch.randn(2, 4, 8))
372+
y = builder.placeholder("y", torch.randn(2, 8, 4))
373+
bmm = builder.call_operator(
374+
op=torch.ops.aten.bmm.default,
375+
args=(x, y),
376+
meta=NodeMetadata(
377+
{"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]}
378+
),
379+
)
380+
builder.output([bmm])
381+
gm = builder.get_graph_module()
382+
383+
bmm_nodes = gm.graph.find_nodes(
384+
op="call_function",
385+
target=torch.ops.aten.bmm.default,
386+
)
387+
self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node")
388+
return gm, bmm_nodes[0]
389+
390+
def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
391+
"""Build a simple graph with a relu operation."""
392+
builder = GraphBuilder()
393+
x = builder.placeholder("x", torch.randn(1, 10))
394+
relu = builder.call_operator(
395+
op=torch.ops.aten.relu.default,
396+
args=(x,),
397+
meta=NodeMetadata(
398+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
399+
),
400+
)
401+
builder.output([relu])
402+
gm = builder.get_graph_module()
403+
404+
relu_nodes = gm.graph.find_nodes(
405+
op="call_function",
406+
target=torch.ops.aten.relu.default,
407+
)
408+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
409+
return gm, relu_nodes[0]
410+
312411
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
313412
def test_quantizer_annotation(
314413
self,

0 commit comments

Comments
 (0)