Skip to content

Commit d81100f

Browse files
[TORCH] Added flex_attention hop function (#4366)
## Description - Added support for PyTorch's flex_attention Higher-Order Operator in torch-mlir. - Implemented `Torch_AtenFlexAttentionOp` with 6 operands (query, key, value, scale, return_lse, return_max_score) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references. - The FX importer (`_import_hop_flex_attention`) correctly extracts score/mask modification functions from `get_attr` nodes using module IDs, following the while_loop HOP pattern. - Includes TODO markers for `kernel_options` performance tuning parameters. > The call to `flex_attention_hop` internally in `torch.nn.attention.flex_attention` uses the `kernel_options` dict to pass the `return_lse` and `return_max_score` options (`OUTPUT_LOGSUMEXP` and `OUTPUT_MAX` key names respectively). While I've implemented that support, other fine grained controls (including blocking heuristics) are not supported yet. - Imports flex_attention from PyTorch FX graphs into valid MLIR. --------- Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent 7712b97 commit d81100f

File tree

5 files changed

+430
-1
lines changed

5 files changed

+430
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/TorchOps.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,4 +1442,71 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
14421442
let hasCustomAssemblyFormat = 1;
14431443
}
14441444

1445+
//===----------------------------------------------------------------------===//
1446+
// FlexAttention operation
1447+
1448+
// NOTE: This op is manually defined because flex_attention exists in
1449+
// PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet
1450+
// registered in PyTorch's JIT operator registry. The update_torch_ods.sh script
1451+
// validates against the JIT registry, so it cannot auto-generate this op.
1452+
// Once PyTorch adds flex_attention to the JIT registry, this can be moved to
1453+
// the auto-generated section.
1454+
//===----------------------------------------------------------------------===//
1455+
def Torch_HigherOrderFlexAttentionOp : Torch_Op<"hop_flex_attention", [
1456+
AllowsTypeRefinement,
1457+
HasValueSemantics,
1458+
ReadOnly
1459+
]> {
1460+
let summary = "Computes the flex_attention operation (1-1 with torch._higher_order_ops.flex_attention)";
1461+
let description = [{
1462+
FlexAttention operation with flexible block-sparse attention patterns.
1463+
1464+
Args:
1465+
query: Query tensor [B, H, M, K]
1466+
key: Key tensor [B, H, N, K]
1467+
value: Value tensor [B, H, N, Ev]
1468+
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
1469+
return_lse: Bool to return log-sum-exp values
1470+
1471+
Attributes:
1472+
score_mod_fn: Optional function symbol reference for score modification
1473+
mask_mod_fn: Optional function symbol reference for mask modification
1474+
1475+
# TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)
1476+
1477+
Returns:
1478+
output: Result tensor [B, H, M, Ev]
1479+
logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
1480+
max_scores: Optional max-scores tensor [B, H, M] (if return_max_scores=True)
1481+
}];
1482+
1483+
let arguments = (ins
1484+
AnyTorchTensorType:$query,
1485+
AnyTorchTensorType:$key,
1486+
AnyTorchTensorType:$value,
1487+
AnyTorchOptionalFloatType:$scale,
1488+
Torch_BoolType:$return_lse,
1489+
Torch_BoolType:$return_max_scores,
1490+
OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
1491+
OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
1492+
);
1493+
1494+
let results = (outs
1495+
AnyTorchTensorType:$output,
1496+
AnyTorchOptionalTensorType:$logsumexp,
1497+
AnyTorchOptionalTensorType:$max_scores
1498+
);
1499+
1500+
let hasCustomAssemblyFormat = 1;
1501+
let extraClassDefinition = [{
1502+
ParseResult HigherOrderFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
1503+
return parseDefaultTorchOp(parser, result, 6, 3);
1504+
}
1505+
void HigherOrderFlexAttentionOp::print(OpAsmPrinter &printer) {
1506+
printDefaultTorchOp(printer, *this, 6, 3);
1507+
}
1508+
}];
1509+
let hasVerifier = 1;
1510+
}
1511+
14451512
#endif // TORCH_OPS

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,46 @@ static Value getScalarFloatValue(Value input, Location loc,
235235
return nullptr;
236236
}
237237

238+
//===----------------------------------------------------------------------===//
239+
// HigherOrderFlexAttentionOp
240+
//===----------------------------------------------------------------------===//
241+
242+
LogicalResult HigherOrderFlexAttentionOp::verify() {
243+
static constexpr int kAttentionRank = 4;
244+
Value query = getQuery();
245+
Value key = getKey();
246+
Value value = getValue();
247+
248+
if (!isa<Torch::BoolType>(getReturnLse().getType())) {
249+
return emitError() << "expected return_lse to be a bool type";
250+
}
251+
if (!isa<Torch::BoolType>(getReturnMaxScores().getType())) {
252+
return emitError() << "expected return_max_scores to be a bool type";
253+
}
254+
255+
auto queryType = dyn_cast<ValueTensorType>(query.getType());
256+
auto keyType = dyn_cast<ValueTensorType>(key.getType());
257+
auto valueType = dyn_cast<ValueTensorType>(value.getType());
258+
259+
if (!queryType || !keyType || !valueType || !queryType.hasSizes() ||
260+
!keyType.hasSizes() || !valueType.hasSizes()) {
261+
return emitError() << "expected input(s) types having sizes";
262+
}
263+
264+
ArrayRef<int64_t> queryShape = queryType.getSizes();
265+
266+
// Query shape: [B, H, M, E].
267+
if (queryShape.size() != kAttentionRank) {
268+
return emitError() << "expected 4D query tensor";
269+
}
270+
// Check if the element type is a float.
271+
if (!isa<mlir::FloatType>(queryType.getDtype())) {
272+
return emitError() << "expected float element type";
273+
}
274+
275+
return success();
276+
}
277+
238278
//===----------------------------------------------------------------------===//
239279
// MethodOp
240280
//===----------------------------------------------------------------------===//

python/torch_mlir/extras/fx_importer.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,150 @@ def _import_hop_auto_functionalized(
19051905
for i, value in enumerate(operation.results):
19061906
self.bind_node_value(node, value, i + bind_none)
19071907

1908+
def _import_hop_flex_attention(
1909+
self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator
1910+
):
1911+
"""Imports the torch._higher_order_ops.flex_attention HOP.
1912+
1913+
Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...)
1914+
- query, key, value: Attention input tensors
1915+
- score_mod: Optional submodule/callable for score modification (imported as function)
1916+
- block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors
1917+
- scale: Optional float for attention score scaling
1918+
- kernel_options: Optional Dict of performance tuning options:
1919+
- return_lse: Boolean for whether to return the log-sum-exp tensor
1920+
1921+
This creates a call to hop_flex_attention with function symbol references for
1922+
score_mod and mask_mod.
1923+
"""
1924+
# flex_attention HOP args from PyTorch:
1925+
# (query, key, value, score_mod, block_mask, scale, kernel_options, ...)
1926+
(
1927+
query_arg,
1928+
key_arg,
1929+
value_arg,
1930+
score_mod_arg,
1931+
block_mask_arg,
1932+
scale_arg,
1933+
kernel_options,
1934+
) = node.args[:7]
1935+
1936+
# Import Q, K, V tensors
1937+
query = self._import_argument(loc, query_arg, None)
1938+
key = self._import_argument(loc, key_arg, None)
1939+
value = self._import_argument(loc, value_arg, None)
1940+
1941+
score_mod_ref = None
1942+
if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node):
1943+
assert (
1944+
score_mod_arg.op == "get_attr"
1945+
), f"Expected get_attr for score_mod, got {score_mod_arg.op}"
1946+
root_module = node.graph.owning_module
1947+
score_mod_module = getattr(root_module, score_mod_arg.target, None)
1948+
if score_mod_module is not None:
1949+
score_mod_func_name = self.fx_importer._graph_module_to_func_name[
1950+
id(score_mod_module)
1951+
]
1952+
score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name)
1953+
1954+
# Handle block_mask: extract only mask_mod function reference
1955+
# Note: BlockMask contains runtime tensors (kv_num_blocks, kv_indices, etc.)
1956+
# that are materialized by evaluating mask_mod(b, h, q_idx, kv_idx).
1957+
mask_mod_ref = None
1958+
if block_mask_arg is not None and isinstance(block_mask_arg, tuple):
1959+
root_module = node.graph.owning_module
1960+
# The mask_mod function is the last element in the BlockMask tuple
1961+
mask_mod_arg = block_mask_arg[-1]
1962+
if mask_mod_arg is not None and isinstance(mask_mod_arg, torch_fx.Node):
1963+
assert (
1964+
mask_mod_arg.op == "get_attr"
1965+
), f"Expected get_attr for mask_mod, got {mask_mod_arg.op}"
1966+
mask_mod_module = getattr(root_module, mask_mod_arg.target, None)
1967+
if mask_mod_module is not None:
1968+
mask_mod_func_name = self.fx_importer._graph_module_to_func_name[
1969+
id(mask_mod_module)
1970+
]
1971+
mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name)
1972+
1973+
# Import scale (float or None)
1974+
if scale_arg is None:
1975+
scale = Operation.create(
1976+
"torch.constant.none",
1977+
results=[self._cc.torch_none_type],
1978+
loc=loc,
1979+
).result
1980+
elif isinstance(scale_arg, (int, float)):
1981+
with loc:
1982+
scale = _make_constant_op(
1983+
"torch.constant.float",
1984+
FloatAttr.get_f64(float(scale_arg)),
1985+
self._cc.torch_float_type,
1986+
).result
1987+
else:
1988+
scale = self._import_argument(loc, scale_arg, None)
1989+
1990+
# Determine result types from node metadata
1991+
node_val = node.meta.get("val")
1992+
if isinstance(node_val, (list, tuple)) and len(node_val) >= 2:
1993+
# flex_attention returns (output, logsumexp)
1994+
result_types = [self._cc.value_info_to_type(v) for v in node_val]
1995+
self._multi_result_nodes.add(node)
1996+
else:
1997+
# Single output
1998+
result_types = [self._cc.node_val_to_type(node)]
1999+
2000+
# Extract OUTPUT_LOGSUMEXP and OUTPUT_MAX from kernel_options
2001+
with loc:
2002+
return_lse = _make_constant_op(
2003+
"torch.constant.bool",
2004+
self._cc.integer_attr(
2005+
bool(kernel_options.get("OUTPUT_LOGSUMEXP", 0)), 1
2006+
),
2007+
self._cc.torch_bool_type,
2008+
).result
2009+
return_max_scores = _make_constant_op(
2010+
"torch.constant.bool",
2011+
self._cc.integer_attr(bool(kernel_options.get("OUTPUT_MAX", 0)), 1),
2012+
self._cc.torch_bool_type,
2013+
).result
2014+
2015+
# Build operands for aten.flex_attention.
2016+
# Op expects exactly 6 operands: query, key, value, scale, return_lse, return_max_scores.
2017+
# Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands.
2018+
# Note: block_mask tensors are handled by mask_mod_fn, not passed as operands.
2019+
2020+
flat_operands = [
2021+
query,
2022+
key,
2023+
value,
2024+
scale,
2025+
return_lse,
2026+
return_max_scores,
2027+
]
2028+
2029+
# Build attributes with function references
2030+
# Only include attributes if they're not None (OptionalAttr in TableGen)
2031+
attributes = {}
2032+
if score_mod_ref is not None:
2033+
attributes["score_mod_fn"] = score_mod_ref
2034+
if mask_mod_ref is not None:
2035+
attributes["mask_mod_fn"] = mask_mod_ref
2036+
2037+
operation = Operation.create(
2038+
"torch.hop_flex_attention",
2039+
results=result_types,
2040+
operands=flat_operands,
2041+
attributes=attributes if attributes else None,
2042+
loc=loc,
2043+
)
2044+
# Bind results
2045+
if len(result_types) > 1:
2046+
self._multi_result_nodes.add(node)
2047+
for i, value in enumerate(operation.results):
2048+
self.bind_node_value(node, value, i)
2049+
else:
2050+
self.bind_node_value(node, operation.results[0])
2051+
19082052
def _import_torch_op_overload(
19092053
self,
19102054
loc: Location,
@@ -1932,7 +2076,7 @@ def _import_torch_op_overload(
19322076
# torch dynamo where it emits the Tensor variant of ops even when processing
19332077
# scalar arguments, therefore we retrieve the schema as well so that we
19342078
# consume the correct typing information when subsequently importing the
1935-
# function arguments and result types
2079+
# function arguments and result types.
19362080
# i.e. the code below is basically doing `schema = torch.ops.aten.my_op.Scalar._schema`
19372081
op_attrs = mlir_op_name.split(".")
19382082
op_overload = getattr(torch, "ops")

test/Dialect/Torch/ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,28 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to
205205
%1 = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %arg0, %arg1, %arg2, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
206206
return %1 : !torch.vtensor<[3,3],f32>
207207
}
208+
209+
// Round trip test for flex_attention.
210+
func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> {
211+
%5 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
212+
return %5 : !torch.vtensor<[],f32>
213+
}
214+
215+
func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> {
216+
%0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1>
217+
return %0 : !torch.vtensor<[],i1>
218+
}
219+
220+
// CHECK-LABEL: func.func @torch.hop_flex_attention
221+
func.func @torch.hop_flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
222+
%float1.0 = torch.constant.float 1.000000e+00
223+
%false_0 = torch.constant.bool false
224+
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
225+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
226+
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
227+
// CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
228+
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
229+
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
230+
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
231+
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
232+
}

0 commit comments

Comments
 (0)