Commit d81100f
authored
[TORCH] Added flex_attention hop function (#4366)
## Description
- Added support for PyTorch's flex_attention Higher-Order Operator in
torch-mlir.
- Implemented `Torch_AtenFlexAttentionOp` with 6 operands (query, key,
value, scale, return_lse, return_max_score) and 2 optional attributes
(score_mod_fn, mask_mod_fn) for function references.
- The FX importer (`_import_hop_flex_attention`) correctly extracts
score/mask modification functions from `get_attr` nodes using module
IDs, following the while_loop HOP pattern.
- Includes TODO markers for `kernel_options` performance tuning
parameters.
> The call to `flex_attention_hop` internally in
`torch.nn.attention.flex_attention` uses the `kernel_options` dict to
pass the `return_lse` and `return_max_score` options (`OUTPUT_LOGSUMEXP`
and `OUTPUT_MAX` key names respectively). While I've implemented that
support, other fine grained controls (including blocking heuristics) are
not supported yet.
- Imports flex_attention from PyTorch FX graphs into valid MLIR.
---------
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>1 parent 7712b97 commit d81100f
File tree
5 files changed
+430
-1
lines changed- include/torch-mlir/Dialect/Torch/IR
- lib/Dialect/Torch/IR
- python/torch_mlir/extras
- test
- Dialect/Torch
- python/fx_importer
5 files changed
+430
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1442 | 1442 | | |
1443 | 1443 | | |
1444 | 1444 | | |
| 1445 | + | |
| 1446 | + | |
| 1447 | + | |
| 1448 | + | |
| 1449 | + | |
| 1450 | + | |
| 1451 | + | |
| 1452 | + | |
| 1453 | + | |
| 1454 | + | |
| 1455 | + | |
| 1456 | + | |
| 1457 | + | |
| 1458 | + | |
| 1459 | + | |
| 1460 | + | |
| 1461 | + | |
| 1462 | + | |
| 1463 | + | |
| 1464 | + | |
| 1465 | + | |
| 1466 | + | |
| 1467 | + | |
| 1468 | + | |
| 1469 | + | |
| 1470 | + | |
| 1471 | + | |
| 1472 | + | |
| 1473 | + | |
| 1474 | + | |
| 1475 | + | |
| 1476 | + | |
| 1477 | + | |
| 1478 | + | |
| 1479 | + | |
| 1480 | + | |
| 1481 | + | |
| 1482 | + | |
| 1483 | + | |
| 1484 | + | |
| 1485 | + | |
| 1486 | + | |
| 1487 | + | |
| 1488 | + | |
| 1489 | + | |
| 1490 | + | |
| 1491 | + | |
| 1492 | + | |
| 1493 | + | |
| 1494 | + | |
| 1495 | + | |
| 1496 | + | |
| 1497 | + | |
| 1498 | + | |
| 1499 | + | |
| 1500 | + | |
| 1501 | + | |
| 1502 | + | |
| 1503 | + | |
| 1504 | + | |
| 1505 | + | |
| 1506 | + | |
| 1507 | + | |
| 1508 | + | |
| 1509 | + | |
| 1510 | + | |
| 1511 | + | |
1445 | 1512 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
235 | 235 | | |
236 | 236 | | |
237 | 237 | | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
238 | 278 | | |
239 | 279 | | |
240 | 280 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1905 | 1905 | | |
1906 | 1906 | | |
1907 | 1907 | | |
| 1908 | + | |
| 1909 | + | |
| 1910 | + | |
| 1911 | + | |
| 1912 | + | |
| 1913 | + | |
| 1914 | + | |
| 1915 | + | |
| 1916 | + | |
| 1917 | + | |
| 1918 | + | |
| 1919 | + | |
| 1920 | + | |
| 1921 | + | |
| 1922 | + | |
| 1923 | + | |
| 1924 | + | |
| 1925 | + | |
| 1926 | + | |
| 1927 | + | |
| 1928 | + | |
| 1929 | + | |
| 1930 | + | |
| 1931 | + | |
| 1932 | + | |
| 1933 | + | |
| 1934 | + | |
| 1935 | + | |
| 1936 | + | |
| 1937 | + | |
| 1938 | + | |
| 1939 | + | |
| 1940 | + | |
| 1941 | + | |
| 1942 | + | |
| 1943 | + | |
| 1944 | + | |
| 1945 | + | |
| 1946 | + | |
| 1947 | + | |
| 1948 | + | |
| 1949 | + | |
| 1950 | + | |
| 1951 | + | |
| 1952 | + | |
| 1953 | + | |
| 1954 | + | |
| 1955 | + | |
| 1956 | + | |
| 1957 | + | |
| 1958 | + | |
| 1959 | + | |
| 1960 | + | |
| 1961 | + | |
| 1962 | + | |
| 1963 | + | |
| 1964 | + | |
| 1965 | + | |
| 1966 | + | |
| 1967 | + | |
| 1968 | + | |
| 1969 | + | |
| 1970 | + | |
| 1971 | + | |
| 1972 | + | |
| 1973 | + | |
| 1974 | + | |
| 1975 | + | |
| 1976 | + | |
| 1977 | + | |
| 1978 | + | |
| 1979 | + | |
| 1980 | + | |
| 1981 | + | |
| 1982 | + | |
| 1983 | + | |
| 1984 | + | |
| 1985 | + | |
| 1986 | + | |
| 1987 | + | |
| 1988 | + | |
| 1989 | + | |
| 1990 | + | |
| 1991 | + | |
| 1992 | + | |
| 1993 | + | |
| 1994 | + | |
| 1995 | + | |
| 1996 | + | |
| 1997 | + | |
| 1998 | + | |
| 1999 | + | |
| 2000 | + | |
| 2001 | + | |
| 2002 | + | |
| 2003 | + | |
| 2004 | + | |
| 2005 | + | |
| 2006 | + | |
| 2007 | + | |
| 2008 | + | |
| 2009 | + | |
| 2010 | + | |
| 2011 | + | |
| 2012 | + | |
| 2013 | + | |
| 2014 | + | |
| 2015 | + | |
| 2016 | + | |
| 2017 | + | |
| 2018 | + | |
| 2019 | + | |
| 2020 | + | |
| 2021 | + | |
| 2022 | + | |
| 2023 | + | |
| 2024 | + | |
| 2025 | + | |
| 2026 | + | |
| 2027 | + | |
| 2028 | + | |
| 2029 | + | |
| 2030 | + | |
| 2031 | + | |
| 2032 | + | |
| 2033 | + | |
| 2034 | + | |
| 2035 | + | |
| 2036 | + | |
| 2037 | + | |
| 2038 | + | |
| 2039 | + | |
| 2040 | + | |
| 2041 | + | |
| 2042 | + | |
| 2043 | + | |
| 2044 | + | |
| 2045 | + | |
| 2046 | + | |
| 2047 | + | |
| 2048 | + | |
| 2049 | + | |
| 2050 | + | |
| 2051 | + | |
1908 | 2052 | | |
1909 | 2053 | | |
1910 | 2054 | | |
| |||
1932 | 2076 | | |
1933 | 2077 | | |
1934 | 2078 | | |
1935 | | - | |
| 2079 | + | |
1936 | 2080 | | |
1937 | 2081 | | |
1938 | 2082 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
205 | 205 | | |
206 | 206 | | |
207 | 207 | | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
0 commit comments