From fb65cc24655d084f1f60b5edced6c12696022921 Mon Sep 17 00:00:00 2001 From: makslevental Date: Tue, 16 Dec 2025 19:41:23 -0800 Subject: [PATCH] [MLIR][TblGen] add AttrOrTypeCAPIGen --- mlir/include/mlir/IR/EnumAttr.td | 1 + mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp | 271 +++++++++++++++++++ mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 42 +-- mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h | 44 +++ mlir/tools/mlir-tblgen/CMakeLists.txt | 1 + 5 files changed, 320 insertions(+), 39 deletions(-) create mode 100644 mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td index ec57626ebde65..a3069d5837183 100644 --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -499,6 +499,7 @@ class EnumParameter !cast(enumInfo).parameterParser, ?); let printer = !if(!isa(enumInfo), !cast(enumInfo).parameterPrinter, ?); + string underlyingEnumName = enumInfo.className; } // An attribute backed by a C++ enum. The attribute contains a single diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp new file mode 100644 index 0000000000000..056e3f54b1c5a --- /dev/null +++ b/mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp @@ -0,0 +1,271 @@ +//===- AttrOrTypeCAPIGen.cpp - MLIR Attribute and Type CAPI generation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "AttrOrTypeFormatGen.h" +#include "CppGenUtilities.h" +#include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/Class.h" +#include "mlir/TableGen/EnumInfo.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/Pass.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-attr-or-type-capi-gen" + +using namespace mlir; +using namespace mlir::tblgen; +using llvm::formatv; +using llvm::Record; +using llvm::RecordKeeper; + +static llvm::cl::OptionCategory attrOrTypeCAPIDefGenCat( + "Options for -gen-attr-capi-* and -gen-typedef-capi-*"); +static llvm::cl::opt + capiDialect("attr-or-type-capi-dialect", + llvm::cl::desc("Generate C APIs for this dialect"), + llvm::cl::cat(attrOrTypeCAPIDefGenCat), + llvm::cl::CommaSeparated); +static llvm::cl::opt capiNamespacePrefix( + "attr-or-type-capi-namespace-prefix", + llvm::cl::desc("Generate C APIs with this namespace prefix"), + llvm::cl::cat(attrOrTypeCAPIDefGenCat)); + +static std::string makeIdentifier(StringRef str) { + if (!str.empty() && llvm::isDigit(static_cast(str.front()))) { + std::string newStr = std::string("_") + str.str(); + return newStr; + } + return str.str(); +} + +static std::string withCapitalFirstLetter(std::string name) { + name[0] = static_cast( + std::toupper(static_cast(name[0]))); + return name; +} + +static StringRef namespacePrefix() { + static const std::string prefix = [] { + if (!capiNamespacePrefix.empty()) + return "mlir" + capiNamespacePrefix; + return "mlir" + withCapitalFirstLetter(capiDialect.getValue()); + }(); + return prefix; +} + +static void emitEnums(const RecordKeeper &records, raw_ostream &os) { + for (const Record *it : + records.getAllDerivedDefinitionsIfDefined("EnumInfo")) { + EnumInfo enumInfo(*it); + os << "// " << enumInfo.getSummary() << "\n"; + os << "enum " << namespacePrefix() << enumInfo.getEnumClassName(); + + if (!enumInfo.getUnderlyingType().empty()) + os << " : " << enumInfo.getUnderlyingType(); + os << " {\n"; + + for (const EnumCase &enumerant : enumInfo.getAllCases()) { + auto symbol = makeIdentifier(enumerant.getSymbol()); + auto value = enumerant.getValue(); + if (value >= 0) + os << formatv(" {0} = {1},\n", symbol, value); + else + os << formatv(" {0},\n", symbol); + } + os << "};\n\n"; + } +} + +namespace { +struct CAPIDefGenerator : DefGenerator { + CAPIDefGenerator(const RecordKeeper &records, StringRef className, + raw_ostream &os, const StringRef &defType, + const StringRef &valueType, bool isAttrGenerator) + : DefGenerator(records.getAllDerivedDefinitionsIfDefined(className), os, + defType, valueType, isAttrGenerator), + records(records) {} + + bool emitDecls(StringRef selectedDialect) override; + const RecordKeeper &records; +}; +} // namespace + +static llvm::Twine mapParamTypeToCAPI(const AttrOrTypeParameter ¶m) { + if (const llvm::DefInit *defInit = dyn_cast(param.getDef())) { + if (defInit->getDef()->isSubClassOf("EnumParameter")) + return namespacePrefix() + + defInit->getDef()->getValueAsString("underlyingEnumName"); + } + StringRef cppType = param.getCppType(); + if (cppType == "Type") + return "MlirType"; + return cppType; +} + +static void emitGettorDecl(const AttrOrTypeDef &def, raw_ostream &os, + bool isAttrGenerator) { + os << "MLIR_CAPI_EXPORTED "; + if (isAttrGenerator) + os << "MlirAttribute "; + else + os << "MlirType "; + os << namespacePrefix() << def.getCppClassName() << "Get(MlirContext context"; + ArrayRef params = def.getParameters(); + if (!params.empty()) + os << ", "; + for (auto [i, param] : llvm::enumerate(params)) { + os << mapParamTypeToCAPI(param) << " " << param.getName() + << (i < params.size() - 1 ? ", " : ""); + } + os << ");\n"; +} + +static void emitAccessorDecls(const AttrOrTypeDef &def, raw_ostream &os, + bool isAttrGenerator) { + ArrayRef params = def.getParameters(); + if (params.empty()) + return; + for (auto param : params) { + os << "MLIR_CAPI_EXPORTED "; + std::string paramName = param.getName().str(); + os << mapParamTypeToCAPI(param) << " " << namespacePrefix() + << def.getCppClassName() << "Get" + << withCapitalFirstLetter(param.getName().str()); + if (isAttrGenerator) + os << "(MlirAttribute attr);"; + else + os << "(MlirType type);"; + os << "\n"; + } +} + +static void emitTypeIDDecl(const AttrOrTypeDef &def, raw_ostream &os) { + os << "MLIR_CAPI_EXPORTED MlirTypeID " << namespacePrefix() + << def.getCppClassName() << "GetTypeID();\n"; +} + +static void emitIsADecl(const AttrOrTypeDef &def, raw_ostream &os, + bool isAttrGenerator) { + os << "MLIR_CAPI_EXPORTED bool mlir"; + if (isAttrGenerator) + os << "Attribute"; + else + os << "Type"; + os << "IsA" << def.getCppClassName() << "("; + if (isAttrGenerator) + os << "(MlirAttribute attr);"; + else + os << "(MlirType type);"; + os << "\n"; +} + +static void emitAttrTypeHeader(const AttrOrTypeDef &def, raw_ostream &os) { + const char *const header = R"( +//===----------------------------------------------------------------------===// +// {0} +//===----------------------------------------------------------------------===// + +)"; + os << formatv(header, def.getCppClassName()); +} + +bool CAPIDefGenerator::emitDecls(StringRef selectedDialect) { + emitSourceFileHeader((defType + "Def C API Def Declarations").str(), os); + + SmallVector defs; + collectAllDefs(selectedDialect, defRecords, defs); + if (defs.empty()) + return false; + + for (const AttrOrTypeDef &def : defs) { + emitAttrTypeHeader(def, os); + emitGettorDecl(def, os, isAttrGenerator); + emitTypeIDDecl(def, os); + emitIsADecl(def, os, isAttrGenerator); + if (def.genAccessors()) + emitAccessorDecls(def, os, isAttrGenerator); + } + + os << "\n"; + + return false; +} + +namespace { +/// A specialized generator for AttrDefs. +struct CAPIAttrDefGenerator : public CAPIDefGenerator { + CAPIAttrDefGenerator(const RecordKeeper &records, raw_ostream &os) + : CAPIDefGenerator(records, "AttrDef", os, "Attr", "Attribute", + /*isAttrGenerator=*/true) {} +}; +/// A specialized generator for TypeDefs. +struct CAPITypeDefGenerator : public CAPIDefGenerator { + CAPITypeDefGenerator(const RecordKeeper &records, raw_ostream &os) + : CAPIDefGenerator(records, "TypeDef", os, "Type", "Type", + /*isAttrGenerator=*/false) {} +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GEN: Registration hooks +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// AttrDef +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genEnumDecls("gen-enum-capi-decls", "Generate Enum C API declarations", + [](const RecordKeeper &records, raw_ostream &os) { + emitSourceFileHeader("Enum C API Declarations", os); + emitEnums(records, os); + return false; + }); + +static mlir::GenRegistration + genAttrDecls("gen-attrdef-capi-decls", + "Generate AttrDef C API declarations", + [](const RecordKeeper &records, raw_ostream &os) { + CAPIAttrDefGenerator generator(records, os); + return generator.emitDecls(capiDialect); + }); + +// static mlir::GenRegistration +// genAttrDefs("gen-attrdef-capi-defs", "Generate AttrDef C API +// definitions", +// [](const RecordKeeper &records, raw_ostream &os) { +// CAPIAttrDefGenerator generator(records, os); +// return generator.emitDefs(attrDialect); +// }); + +//===----------------------------------------------------------------------===// +// TypeDef +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genTypeDecls("gen-typedef-capi-decls", + "Generate TypeDef C API declarations", + [](const RecordKeeper &records, raw_ostream &os) { + CAPITypeDefGenerator generator(records, os); + return generator.emitDecls(capiDialect); + }); + +// static mlir::GenRegistration +// genTypeDefs("gen-typedef-capi-defs", "Generate TypeDef C API +// definitions", +// [](const RecordKeeper &records, raw_ostream &os) { +// CAPITypeDefGenerator generator(records, os); +// return generator.emitDefs(capiDialect); +// }); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 2a513c3b8cc9b..c828dc6f67746 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -32,9 +32,9 @@ using llvm::RecordKeeper; /// Find all the AttrOrTypeDef for the specified dialect. If no dialect /// specified and can only find one dialect's defs, use that. -static void collectAllDefs(StringRef selectedDialect, - ArrayRef records, - SmallVectorImpl &resultDefs) { +void mlir::tblgen::collectAllDefs(StringRef selectedDialect, + ArrayRef records, + SmallVectorImpl &resultDefs) { // Nothing to do if no defs were found. if (records.empty()) return; @@ -804,42 +804,6 @@ void DefGen::emitStorageClass() { //===----------------------------------------------------------------------===// namespace { -/// This struct is the base generator used when processing tablegen interfaces. -class DefGenerator { -public: - bool emitDecls(StringRef selectedDialect); - bool emitDefs(StringRef selectedDialect); - -protected: - DefGenerator(ArrayRef defs, raw_ostream &os, - StringRef defType, StringRef valueType, bool isAttrGenerator) - : defRecords(defs), os(os), defType(defType), valueType(valueType), - isAttrGenerator(isAttrGenerator) { - // Sort by occurrence in file. - llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) { - return lhs->getID() < rhs->getID(); - }); - } - - /// Emit the list of def type names. - void emitTypeDefList(ArrayRef defs); - /// Emit the code to dispatch between different defs during parsing/printing. - void emitParsePrintDispatch(ArrayRef defs); - - /// The set of def records to emit. - std::vector defRecords; - /// The attribute or type class to emit. - /// The stream to emit to. - raw_ostream &os; - /// The prefix of the tablegen def name, e.g. Attr or Type. - StringRef defType; - /// The C++ base value type of the def, e.g. Attribute or Type. - StringRef valueType; - /// Flag indicating if this generator is for Attributes. False if the - /// generator is for types. - bool isAttrGenerator; -}; - /// A specialized generator for AttrDefs. struct AttrDefGenerator : public DefGenerator { AttrDefGenerator(const RecordKeeper &records, raw_ostream &os) diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h index d4711532a79bb..ca20fdba5ba96 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h @@ -20,6 +20,50 @@ class AttrOrTypeDef; void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser, MethodBody &printer); +/// Find all the AttrOrTypeDef for the specified dialect. If no dialect +/// specified and can only find one dialect's defs, use that. +void collectAllDefs(StringRef selectedDialect, + ArrayRef records, + SmallVectorImpl &resultDefs); + +/// This struct is the base generator used when processing tablegen interfaces. +class DefGenerator { +public: + virtual ~DefGenerator() = default; + virtual bool emitDecls(StringRef selectedDialect); + virtual bool emitDefs(StringRef selectedDialect); + +protected: + DefGenerator(ArrayRef defs, raw_ostream &os, + StringRef defType, StringRef valueType, bool isAttrGenerator) + : defRecords(defs), os(os), defType(defType), valueType(valueType), + isAttrGenerator(isAttrGenerator) { + // Sort by occurrence in file. + llvm::sort(defRecords, + [](const llvm::Record *lhs, const llvm::Record *rhs) { + return lhs->getID() < rhs->getID(); + }); + } + + /// Emit the list of def type names. + void emitTypeDefList(ArrayRef defs); + /// Emit the code to dispatch between different defs during parsing/printing. + void emitParsePrintDispatch(ArrayRef defs); + + /// The set of def records to emit. + std::vector defRecords; + /// The attribute or type class to emit. + /// The stream to emit to. + raw_ostream &os; + /// The prefix of the tablegen def name, e.g. Attr or Type. + StringRef defType; + /// The C++ base value type of the def, e.g. Attribute or Type. + StringRef valueType; + /// Flag indicating if this generator is for Attributes. False if the + /// generator is for types. + bool isAttrGenerator; +}; + } // namespace tblgen } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index d7087cba3c874..4256613ce4848 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS add_tablegen(mlir-tblgen MLIR DESTINATION "${MLIR_TOOLS_INSTALL_DIR}" EXPORT MLIR + AttrOrTypeCAPIGen.cpp AttrOrTypeDefGen.cpp AttrOrTypeFormatGen.cpp BytecodeDialectGen.cpp