Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnx GridSample support for border padding mode #3819

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Ax9D
Copy link

@Ax9D Ax9D commented Oct 25, 2024

  • Adds support for padding_mode="border"

  • Clamps the grid coordinates between 0 and size - 1 like this when using this padding mode:

x_result = min(max(0, x), W - 1)
y_result = min(max(0, y), H - 1)
  • Added Lit tests for both TorchToLinalg and OnnxToTorch lowerings

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this support!

I have a couple of changes to request, and I'd also like you to sync your branch with head of main to avoid the failing CI.

@@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is simpler as:

Value paddingMode = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), paddingModeInt);

Comment on lines +2577 to +2578
Value zeroInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to keep this constant outside the lambda function, and just get it from the capture (otherwise you will re-write the constant multiple times in the IR).

Comment on lines +2579 to +2580
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
paddingMode, zeroInt);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The padding mode is always a constant (coming from ONNX, at least). We should be able to get the paddingModeInt by matchPattern(paddingMode, m_TorchConstantInt(&paddingModeInt)). This way, we can avoid writing IR to check conditions that can be determined statically at compile time. Since the select op and cmp op will likely get folded anyway, this is not hugely problematic, but on principal I'd prefer to generate the simplest possible IR from the get-go.

Comment on lines +2568 to +2575
auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
Value SizeSubOne) -> Value {
Value xMaxZero = b.create<arith::MaximumFOp>(loc, x, zeroFloat);
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
};

auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
Value x, Value SizeSubOne) -> Value {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x and sizeSubOne should be the only inputs for both of these lambda functions. The others can be gotten through the default capture. Just be sure to define these functions after you introduce the other arguments.

@Ax9D
Copy link
Author

Ax9D commented Nov 14, 2024

Thanks for the feedback! I'll work on the requested changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants