Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support proper custom class reflexive operator applied to xarray objects #9944

Open
Li9htmare opened this issue Jan 13, 2025 · 2 comments
Open

Comments

@Li9htmare
Copy link

Is your feature request related to a problem?

I would like to implement reflexive operator on a custom class applied to xarray objects.

Following is a demo snippet:

import numpy as np
import xarray as xr


class DemoObj:
    def __add__(self, other):
        print(f'__add__ call: type={other.__class__}, value={other}')
        return other

    def __radd__(self, other):
        print(f'__radd__ call: type={other.__class__}, value={other}')
        return other


obj = DemoObj()
da = xr.DataArray(np.arange(8))

print('#### Test __add__ ####')
obj + da
print('\n')

print('#### Test __radd__ ####')
da + obj

Actual Output:

#### Test __add__ ####
__add__ call: type=<class 'xarray.core.dataarray.DataArray'>, value=<xarray.DataArray (dim_0: 8)>
array([0, 1, 2, 3, 4, 5, 6, 7])
Dimensions without coordinates: dim_0

#### Test __radd__ ####
__radd__ call: type=<class 'int'>, value=0
__radd__ call: type=<class 'int'>, value=1
__radd__ call: type=<class 'int'>, value=2
__radd__ call: type=<class 'int'>, value=3
__radd__ call: type=<class 'int'>, value=4
__radd__ call: type=<class 'int'>, value=5
__radd__ call: type=<class 'int'>, value=6
__radd__ call: type=<class 'int'>, value=7

We can see __add__ got called once and received xr.DataArray obj but __radd__ got called 8 times and received ints. This causes 2 problems;

  • Performance issue on large xr.DataArray
  • No access to xr.DataArray coords which is needed in a more realistic use case

Describe the solution you'd like

I would like to have a mechanism so that DemoObj.__radd__ got called only once and received xr.DataArray instance in the above example.

Describe alternatives you've considered

Option 1:

The most naive approach to workaround this is to call obj.__radd__(da) to achieve da + obj which defeats the purpose of implementing the reflexive operator and not offer good readability.

Option 2:

As xr.DataArray._binary_op replies on numpy's operator resolving mechanism under the hood, I could improve the situation by setting __array_ufunc__ = None on my class, e.g.:

class DemoObj:
    __array_ufunc__ = None

    def __add__(self, other):
        ...

    def __radd__(self, other):
        ...

This will make __radd__ get called once with np.ndarray instead of 8 times with ints. This solves the potential perf concern, however, it still doesn't cover the case if xr.Dataarray.coords is needed.

Additional context

Considering xr.DataArray._binary_op has already returned NoImplemented for a list of classes:
https://github.com/pydata/xarray/blob/v2025.01.1/xarray/core/dataarray.py#L4808-L4809

I'm wondering whether we should do the same for classes has __array_ufunc__ = None, i.e.:

def _binary_op(
    self: T_DataArray,
    other: Any,
    f: Callable,
    reflexive: bool = False,
) -> T_DataArray:
    if hasattr(other, '__array_ufunc__') and other.__array_ufunc__ is None:
        return NotImplementd
    ...

I'm happy with a similar property if you prefer to make it xarray specific. I'm happy to make the PR as well once you confirmed the mechanism / property name you preferred.

Many thanks in advance!

Copy link

welcome bot commented Jan 13, 2025

Thanks for opening your first issue here at xarray! Be sure to follow the issue template!
If you have an idea for a solution, we would really welcome a Pull Request with proposed changes.
See the Contributing Guide for more.
It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better.
Thank you!

@shoyer
Copy link
Member

shoyer commented Jan 17, 2025

Indeed, currently Xarray very aggressively attempts to take control of all binary arithmetic operations (by applying them to the wrapped .data of the xarray object). I agree that this is definitely not ideal.

Xarray should only attempt to do this for objects with an API that work like multi-dimensional arrays. I see at least two ways to determine this:

  1. As you suggest, we could use __array_ufunc__ = None like NumPy to indicate that an object explicitly does not have an API like NumPy arrays.
  2. Alternatively, we return NotImplemented except for types that explicitly indicate that they do work like NumPy arrays, which in principle should be the same set of types that are valid when wrapped inside xarray objects, because they implement one of two generations of NumPy compatibility APIs (__array_ufunc__/__array_function__ or __array_namespace). Here is where the current code to check for compatibility with these objects lives:
    hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")

My inclination would be to try the second solution first (I think it's a little cleaner / more comprehensive) but if that doesn't work I would be OK to fall back to the first one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants