Skip to content

Commit

Permalink
transform scf index switch to if-else
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed Nov 8, 2024
1 parent 12785bb commit 577722d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 67 deletions.
1 change: 1 addition & 0 deletions include/circt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ std::unique_ptr<mlir::Pass> createStripDebugInfoWithPredPass(
std::unique_ptr<mlir::Pass> createMaximizeSSAPass();
std::unique_ptr<mlir::Pass> createInsertMergeBlocksPass();
std::unique_ptr<mlir::Pass> createPrintOpCountPass();
std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass();

//===----------------------------------------------------------------------===//
// Utility functions.
Expand Down
7 changes: 7 additions & 0 deletions include/circt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,11 @@ def PrintOpCount : Pass<"print-op-count", "::mlir::ModuleOp"> {
];
}

def IndexSwitchToIf : Pass<"switch-to-if", "::mlir::ModuleOp"> {
let summary = "Index switch to if";
let description = [{ SCF index switch to if-else. }];
let constructor = "circt::createIndexSwitchToIfPass()";
let dependentDialects = ["mlir::scf::SCFDialect"];
}

#endif // CIRCT_TRANSFORMS_PASSES
67 changes: 0 additions & 67 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,70 +1481,6 @@ class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern {
}
};

class BuildSwitchGroups : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &rewriter) const override {
LogicalResult res = success();
funcOp.walk([&](Operation *op) {
if (!isa<scf::IndexSwitchOp>(op))
return WalkResult::advance();

auto switchOp = cast<scf::IndexSwitchOp>(op);
auto loc = switchOp.getLoc();

Region &defaultRegion = switchOp.getDefaultRegion();
Operation *yieldOp = defaultRegion.front().getTerminator();
Value defaultResult = yieldOp->getOperand(0);

Value finalResult = defaultResult;
scf::IfOp prevIfOp = nullptr;

rewriter.setInsertionPointAfter(switchOp);
for (size_t i = 0; i < switchOp.getCases().size(); i++) {
auto caseValueInt = switchOp.getCases()[i];
if (prevIfOp)
rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());

Value caseValue = rewriter.create<ConstantIndexOp>(loc, caseValueInt);
Value cond = rewriter.create<CmpIOp>(
loc, CmpIPredicate::eq, *switchOp.getODSOperands(0).begin(),
caseValue);

auto ifOp = rewriter.create<scf::IfOp>(loc, switchOp.getResultTypes(),
cond, /*hasElseRegion=*/true);

Region &caseRegion = switchOp.getCaseRegions()[i];
IRMapping mapping;
Block &emptyThenBlock = ifOp.getThenRegion().front();
emptyThenBlock.erase();
caseRegion.cloneInto(&ifOp.getThenRegion(), mapping);

if (i == switchOp.getCases().size() - 1) {
rewriter.setInsertionPointToEnd(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, defaultResult);
}

if (prevIfOp) {
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
}

if (i == 0)
finalResult = ifOp.getResult(0);
prevIfOp = ifOp;
}

rewriter.replaceOp(switchOp, finalResult);

return WalkResult::advance();
});
return res;
}
};

/// Builds a control schedule by traversing the CFG of the function and
/// associating this with the previously created groups.
/// For simplicity, the generated control flow is expanded for all possible
Expand Down Expand Up @@ -2305,9 +2241,6 @@ void SCFToCalyxPass::runOnOperation() {
/// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);

addOncePattern<BuildSwitchGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts all index typed values to an i32 integer.
addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
funcMap, *loweringState);
Expand Down
2 changes: 2 additions & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_circt_library(CIRCTTransforms
MaximizeSSA.cpp
InsertMergeBlocks.cpp
PrintOpCount.cpp
IndexSwitchToIf.cpp

ADDITIONAL_HEADER_DIRS
${CIRCT_MAIN_INCLUDE_DIR}/circt/Transforms
Expand All @@ -18,6 +19,7 @@ add_circt_library(CIRCTTransforms
MLIRFuncDialect
MLIRIR
MLIRMemRefDialect
MLIRSCFDialect
MLIRSupport
MLIRTransforms

Expand Down
117 changes: 117 additions & 0 deletions lib/Transforms/IndexSwitchToIf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//===- IndexSwitchToIf.cpp - Index switch to if-else pass ---*-C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains the definitions of the SCF IndexSwitch to If-Else pass.
//
//===----------------------------------------------------------------------===//

#include "circt/Transforms/Passes.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace circt {
#define GEN_PASS_DEF_INDEXSWITCHTOIF
#include "circt/Transforms/Passes.h.inc"
} // namespace circt

using namespace mlir;
using namespace circt;

struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = switchOp.getLoc();

Region &defaultRegion = switchOp.getDefaultRegion();

Value finalResult;
scf::IfOp prevIfOp = nullptr;

rewriter.setInsertionPointAfter(switchOp);
for (size_t i = 0; i < switchOp.getCases().size(); i++) {
auto caseValueInt = switchOp.getCases()[i];
if (prevIfOp)
rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());

Value caseValue =
rewriter.create<arith::ConstantIndexOp>(loc, caseValueInt);
Value cond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, switchOp.getOperand(), caseValue);

auto ifOp = rewriter.create<scf::IfOp>(loc, switchOp.getResultTypes(),
cond, /*hasElseRegion=*/true);

Region &caseRegion = switchOp.getCaseRegions()[i];
IRMapping mapping;
Block &emptyThenBlock = ifOp.getThenRegion().front();
emptyThenBlock.erase();
caseRegion.cloneInto(&ifOp.getThenRegion(), mapping);

if (i == switchOp.getCases().size() - 1) {
rewriter.setInsertionPointToEnd(&ifOp.getElseRegion().front());
Block &elseBlock = ifOp.getElseRegion().front();
elseBlock.erase();
defaultRegion.cloneInto(&ifOp.getElseRegion(), mapping);
}

if (prevIfOp) {
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
}

if (i == 0)
finalResult = ifOp.getResult(0);
prevIfOp = ifOp;
}

rewriter.replaceOp(switchOp, finalResult);

return success();
}
};

namespace {

struct IndexSwitchToIfPass
: public circt::impl::IndexSwitchToIfBase<IndexSwitchToIfPass> {
public:
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
ConversionTarget target(*ctx);

target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
target.addIllegalOp<scf::IndexSwitchOp>();

patterns.add<SwitchToIfConversion>(ctx);

if (applyPartialConversion(getOperation(), target, std::move(patterns))
.failed()) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace circt {
std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass() {
return std::make_unique<IndexSwitchToIfPass>();
}
} // namespace circt

0 comments on commit 577722d

Please sign in to comment.