Skip to content

Commit

Permalink
[FIRRTL] Add a new FIRRTL annotation to specify type lowering behavio…
Browse files Browse the repository at this point in the history
…r of module body

This allows more fine-grained control over how types are lowered in different contexts.

This also adds an "includeHierarchy" option to Convention annotations that
allows applying the convention to all modules in the hierarchy below the
annotated module.
  • Loading branch information
uenoku committed Nov 14, 2024
1 parent 823c948 commit 96b6abd
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 31 deletions.
37 changes: 31 additions & 6 deletions docs/Dialects/FIRRTL/FIRRTLAnnotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,12 @@ Example:

### Convention

| Property | Type | Description |
| ---------- | ------ | --------------------------------------- |
| class | string | `circt.ConventionAnnotation` |
| convention | string | `scalarized` |
| target | string | Reference target |
| Property | Type | Description |
| ---------------- | ------ | ---------------------------------------------------- |
| class | string | `circt.ConventionAnnotation` |
| convention | string | `scalarized` |
| target | string | Reference target |
| includeHierarchy | bool | Apply the convention to all modules in the hierarchy |

Specify the port convention for a module. The port convention controls how a
module's ports are transformed, and how that module can be instantiated, in the
Expand All @@ -341,7 +342,31 @@ The options are:
{
"class": "circt.ConventionAnnotation",
"convention": "scalarized",
"target": "~Foo|Bar/d:Baz"
"target": "~Foo|Bar",
"includeHierarchy": true
}
```

### BodyTypeLoweringAnnotation

| Property | Type | Description |
| ---------------- | ------ | ---------------------------------- |
| class | string | `circt.BodyTypeLoweringAnnotation` |
| convention | string | See `Convention` annotation |
| target | string | See `Convention` annotation |
| includeHierarchy | bool | See `Convention` annotation |

Specify the type lowering option for module internal signals.
This is similar to the `Convention` annotation, but for internal signals
rather than module ports. Refer to the `Convention` annotation for each
property description.

```json
{
"class": "circt.BodyTypeLoweringAnnotation",
"convention": "scalarized",
"target": "~Foo|Bar",
"includeHierarchy": true
}
```

Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/FIRRTL/AnnotationDetails.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ constexpr const char *rawAnnotations = "rawAnnotations";
//===----------------------------------------------------------------------===//

constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation";
constexpr const char *typeLoweringAnnoClass = "circt.BodyTypeLoweringAnnotation";
constexpr const char *dontTouchAnnoClass =
"firrtl.transforms.DontTouchAnnotation";
constexpr const char *enumComponentAnnoClass =
Expand Down
48 changes: 41 additions & 7 deletions lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,17 @@ static std::optional<Convention> parseConvention(llvm::StringRef str) {
.Default(std::nullopt);
}

static LogicalResult applyConventionAnno(const AnnoPathValue &target,
DictionaryAttr anno,
ApplyState &state) {
template <bool IsConventionAnno>
static LogicalResult
applyConventionOrTypeLoweringAnno(const AnnoPathValue &target,
DictionaryAttr anno, ApplyState &state) {
auto *op = target.ref.getOp();
auto loc = op->getLoc();
auto error = [&]() {
auto diag = mlir::emitError(loc);
diag << "circuit.ConventionAnnotation ";
diag << (IsConventionAnno ? "circuit.ConventionAnnotation "
: "circuit.TypeLoweringAnnotation ")
<< " ";
return diag;
};

Expand All @@ -305,13 +308,41 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,

auto convention = *conventionOpt;

if (convention == Convention::Internal)
// Convention is internal by default so there is nothing to change
return success();

auto includeHierarchy = anno.getAs<BoolAttr>("includeHierarchy");
auto conventionAttr = ConventionAttr::get(op->getContext(), convention);
auto setConvention = [&](Operation *moduleOp) {
TypeSwitch<Operation *>(moduleOp)
.Case<FModuleOp, FExtModuleOp>([&](auto moduleOp) {
if (IsConventionAnno)
moduleOp.setConventionAttr(conventionAttr);
else
moduleOp->setDiscardableAttr("body_type_lowering", conventionAttr);
})
.Default([](auto) {});
};

if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
moduleOp.setConvention(convention);
if (includeHierarchy && includeHierarchy.getValue()) {
// If includeHierarchy is true, update the convention for all modules in
// the hierarchy.
for (auto *node :
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
if (node && isa<FModuleOp, FExtModuleOp>(*node->getModule()))
setConvention(node->getModule());
}
} else {
// Update the convention.
setConvention(moduleOp);
}
return success();
}

if (auto extModuleOp = dyn_cast<FExtModuleOp>(op)) {
extModuleOp.setConvention(convention);
setConvention(extModuleOp);
return success();
}

Expand Down Expand Up @@ -563,7 +594,10 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
{omirTrackerAnnoClass, {stdResolve, applyWithoutTarget<true>}},
{omirFileAnnoClass, NoTargetAnnotation},
// Miscellaneous Annotations
{conventionAnnoClass, {stdResolve, applyConventionAnno}},
{conventionAnnoClass,
{stdResolve, applyConventionOrTypeLoweringAnno<true>}},
{typeLoweringAnnoClass,
{stdResolve, applyConventionOrTypeLoweringAnno<false>}},
{dontTouchAnnoClass,
{stdResolve, applyWithoutTarget<true, true, WireOp, NodeOp, RegOp,
RegResetOp, InstanceOp, MemOp, CombMemOp,
Expand Down
42 changes: 27 additions & 15 deletions lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {

TypeLoweringVisitor(
MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
Convention bodyConvention,
PreserveAggregate::PreserveMode memoryPreservationMode,
SymbolTable &symTbl, const AttrCache &cache,
const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
: context(context), aggregatePreservationMode(preserveAggregate),
: context(context), defaultAggregatePreservationMode(preserveAggregate),
memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
cache(cache), conventionTable(conventionTable) {}
cache(cache), conventionTable(conventionTable) {
bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
? PreserveAggregate::None
: defaultAggregatePreservationMode;
}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
Expand Down Expand Up @@ -429,7 +434,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
Location errorLoc);

PreserveAggregate::PreserveMode
getPreservationModeForModule(FModuleLike moduleLike);
getPreservationModeForPorts(FModuleLike moduleLike);
Value getSubWhatever(Value val, size_t index);

size_t uniqueIdx = 0;
Expand All @@ -441,7 +446,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
MLIRContext *context;

/// Aggregate preservation mode.
PreserveAggregate::PreserveMode aggregatePreservationMode;
PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
PreserveAggregate::PreserveMode memoryPreservationMode;

/// The builder is set and maintained in the main loop.
Expand All @@ -460,21 +466,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
};
} // namespace

/// Return aggregate preservation mode for the module. If the module has a
/// Return aggregate preservation mode for the module ports. If the module has a
/// scalarized linkage, then we may not preserve it's aggregate ports.
PreserveAggregate::PreserveMode
TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) {
TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
auto lookup = conventionTable.find(module);
if (lookup == conventionTable.end())
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
switch (lookup->second) {
case Convention::Scalarized:
return PreserveAggregate::None;
case Convention::Internal:
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}
llvm_unreachable("Unknown convention");
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}

Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
Expand Down Expand Up @@ -643,7 +649,7 @@ bool TypeLoweringVisitor::lowerProducer(
return false;
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;

if (!peelType(srcFType, fieldTypes, aggregatePreservationMode))
if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
return false;

SmallVector<Value> lowered;
Expand Down Expand Up @@ -809,7 +815,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
// Flatten any bundle types.
SmallVector<FlatBundleFieldEntry> fieldTypes;
auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi.type);
if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module)))
if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
return false;

// Ports with internalPath set cannot be lowered.
Expand Down Expand Up @@ -929,7 +935,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
// Attempt to get the bundle types.
SmallVector<FlatBundleFieldEntry> fields;

if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode))
if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
return false;

// Loop over the leaf aggregates.
Expand Down Expand Up @@ -1454,7 +1460,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
SmallVector<Direction> newDirs;
SmallVector<Attribute> newNames;
SmallVector<Attribute> newPortAnno;
PreserveAggregate::PreserveMode mode = getPreservationModeForModule(
PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
cast<FModuleLike>(op.getReferencedOperation(symTbl)));

endFields.push_back(0);
Expand Down Expand Up @@ -1667,9 +1673,15 @@ void LowerTypesPass::runOnOperation() {

// This lambda, executes in parallel for each Op within the circt.
auto lowerModules = [&](FModuleLike op) -> LogicalResult {
// Use body type lowering attribute if it exists, otherwise use internal.
Convention convention = Convention::Internal;
if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
op->getDiscardableAttr("body_type_lowering")))
convention = conventionAttr.getValue();

auto tl =
TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories,
symTbl, cache, conventionTable);
TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
preserveMemories, symTbl, cache, conventionTable);
tl.lowerModule(op);

return LogicalResult::failure(tl.isFailed());
Expand Down
2 changes: 1 addition & 1 deletion llvm
Submodule llvm updated 5164 files
23 changes: 21 additions & 2 deletions test/Dialect/FIRRTL/annotations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [
// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized"}
]} {
// CHECK: attributes {convention = #firrtl<convention scalarized>}
// CHECK: attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations = [
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}
]} {
// CHECK: @Test() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {
firrtl.instance child @Child()
}

// CHECK: @Child() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Child() attributes {convention = #firrtl<convention internal>} {}

// CHECK: @Child2() {
firrtl.module @Child2() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"},
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"},
Expand Down
39 changes: 39 additions & 0 deletions test/Dialect/FIRRTL/lower-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1404,3 +1404,42 @@ firrtl.circuit "UnrealizedConversion" {
firrtl.matchingconnect %w, %b : !firrtl.bundle<data: uint<64>, tag: uint<1>>
}
}

firrtl.circuit "Conventions1" {
// COMMON-LABEL: @Conventions1
// AGGREGATE-SAME: %input_0
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module public @Conventions1(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions2
// AGGREGATE-SAME: %input_0: !firrtl.uint<8>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions2(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions3
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module private @Conventions3(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions4
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions4(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
}

0 comments on commit 96b6abd

Please sign in to comment.