-
Notifications
You must be signed in to change notification settings - Fork 507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add onnx GridSample support for border padding mode #3819
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this support!
I have a couple of changes to request, and I'd also like you to sync your branch with head of main to avoid the failing CI.
@@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |||
|
|||
Value paddingMode = rewriter.create<Torch::ConstantIntOp>( | |||
binder.getLoc(), rewriter.getType<Torch::IntType>(), | |||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); | |||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is simpler as:
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), paddingModeInt);
Value zeroInt = | ||
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to keep this constant outside the lambda function, and just get it from the capture (otherwise you will re-write the constant multiple times in the IR).
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | ||
paddingMode, zeroInt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding mode is always a constant (coming from ONNX, at least). We should be able to get the paddingModeInt
by matchPattern(paddingMode, m_TorchConstantInt(&paddingModeInt))
. This way, we can avoid writing IR to check conditions that can be determined statically at compile time. Since the select op and cmp op will likely get folded anyway, this is not hugely problematic, but on principal I'd prefer to generate the simplest possible IR from the get-go.
auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x, | ||
Value SizeSubOne) -> Value { | ||
Value xMaxZero = b.create<arith::MaximumFOp>(loc, x, zeroFloat); | ||
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne); | ||
}; | ||
|
||
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode, | ||
Value x, Value SizeSubOne) -> Value { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x
and sizeSubOne
should be the only inputs for both of these lambda functions. The others can be gotten through the default capture. Just be sure to define these functions after you introduce the other arguments.
Thanks for the feedback! I'll work on the requested changes. |
Adds support for padding_mode="border"
Clamps the grid coordinates between 0 and size - 1 like this when using this padding mode: