|
33 | 33 | CadenceWithSoftmaxQuantizer, |
34 | 34 | qconfig_A16, |
35 | 35 | qconfig_A8W8, |
| 36 | + qconfig_A8W8sym, |
36 | 37 | ) |
37 | 38 | from executorch.exir.pass_base import NodeMetadata |
38 | 39 | from parameterized import parameterized |
|
53 | 54 | # Quantizers intentionally excluded from annotation testing. |
54 | 55 | # These should be explicitly justified when added. |
55 | 56 | EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { |
56 | | - CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage |
57 | 57 | CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage |
58 | 58 | CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything |
59 | 59 | CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage |
|
137 | 137 | # For add: both inputs are activations |
138 | 138 | [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
139 | 139 | ), |
| 140 | + # CadenceDefaultQuantizer test cases |
| 141 | + ( |
| 142 | + "default_matmul_A8W8", |
| 143 | + lambda self: self._build_matmul_graph(), |
| 144 | + CadenceDefaultQuantizer(), |
| 145 | + torch.ops.aten.matmul.default, |
| 146 | + qconfig_A8W8.output_activation, |
| 147 | + # For matmul: both inputs are activations |
| 148 | + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
| 149 | + ), |
| 150 | + ( |
| 151 | + "default_linear_A8W8", |
| 152 | + lambda self: self._build_linear_graph(), |
| 153 | + CadenceDefaultQuantizer(), |
| 154 | + torch.ops.aten.linear.default, |
| 155 | + qconfig_A8W8.output_activation, |
| 156 | + # For linear: [input_activation, weight] |
| 157 | + [qconfig_A8W8.input_activation, qconfig_A8W8.weight], |
| 158 | + ), |
| 159 | + ( |
| 160 | + "default_conv1d_A8W8sym", |
| 161 | + lambda self: self._build_conv1d_graph(), |
| 162 | + CadenceDefaultQuantizer(), |
| 163 | + torch.ops.aten.conv1d.default, |
| 164 | + qconfig_A8W8sym.output_activation, |
| 165 | + # For conv1d: [input_activation, weight] with symmetric weights |
| 166 | + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], |
| 167 | + ), |
| 168 | + ( |
| 169 | + "default_conv2d_A8W8sym", |
| 170 | + lambda self: self._build_conv2d_graph(), |
| 171 | + CadenceDefaultQuantizer(), |
| 172 | + torch.ops.aten.conv2d.default, |
| 173 | + qconfig_A8W8sym.output_activation, |
| 174 | + # For conv2d: [input_activation, weight] with symmetric weights |
| 175 | + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], |
| 176 | + ), |
| 177 | + ( |
| 178 | + "default_bmm_A8W8", |
| 179 | + lambda self: self._build_bmm_graph(), |
| 180 | + CadenceDefaultQuantizer(), |
| 181 | + torch.ops.aten.bmm.default, |
| 182 | + qconfig_A8W8.output_activation, |
| 183 | + # For bmm: both inputs are activations |
| 184 | + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
| 185 | + ), |
| 186 | + ( |
| 187 | + "default_relu_A8W8", |
| 188 | + lambda self: self._build_relu_graph(), |
| 189 | + CadenceDefaultQuantizer(), |
| 190 | + torch.ops.aten.relu.default, |
| 191 | + qconfig_A8W8.output_activation, |
| 192 | + # For relu: only input_activation |
| 193 | + [qconfig_A8W8.input_activation], |
| 194 | + ), |
140 | 195 | ] |
141 | 196 |
|
142 | 197 | # Derive the set of tested quantizer classes from the test cases. |
@@ -309,6 +364,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
309 | 364 | self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") |
310 | 365 | return gm, add_nodes[0] |
311 | 366 |
|
| 367 | + def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 368 | + """Build a simple graph with a bmm (batch matrix multiply) operation.""" |
| 369 | + builder = GraphBuilder() |
| 370 | + # BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p) |
| 371 | + x = builder.placeholder("x", torch.randn(2, 4, 8)) |
| 372 | + y = builder.placeholder("y", torch.randn(2, 8, 4)) |
| 373 | + bmm = builder.call_operator( |
| 374 | + op=torch.ops.aten.bmm.default, |
| 375 | + args=(x, y), |
| 376 | + meta=NodeMetadata( |
| 377 | + {"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]} |
| 378 | + ), |
| 379 | + ) |
| 380 | + builder.output([bmm]) |
| 381 | + gm = builder.get_graph_module() |
| 382 | + |
| 383 | + bmm_nodes = gm.graph.find_nodes( |
| 384 | + op="call_function", |
| 385 | + target=torch.ops.aten.bmm.default, |
| 386 | + ) |
| 387 | + self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node") |
| 388 | + return gm, bmm_nodes[0] |
| 389 | + |
| 390 | + def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 391 | + """Build a simple graph with a relu operation.""" |
| 392 | + builder = GraphBuilder() |
| 393 | + x = builder.placeholder("x", torch.randn(1, 10)) |
| 394 | + relu = builder.call_operator( |
| 395 | + op=torch.ops.aten.relu.default, |
| 396 | + args=(x,), |
| 397 | + meta=NodeMetadata( |
| 398 | + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} |
| 399 | + ), |
| 400 | + ) |
| 401 | + builder.output([relu]) |
| 402 | + gm = builder.get_graph_module() |
| 403 | + |
| 404 | + relu_nodes = gm.graph.find_nodes( |
| 405 | + op="call_function", |
| 406 | + target=torch.ops.aten.relu.default, |
| 407 | + ) |
| 408 | + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") |
| 409 | + return gm, relu_nodes[0] |
| 410 | + |
312 | 411 | @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) |
313 | 412 | def test_quantizer_annotation( |
314 | 413 | self, |
|
0 commit comments