Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seeking alternatives for setting array values with integer indexing #62

Closed
fcogidi opened this issue Oct 11, 2023 · 3 comments
Closed

Comments

@fcogidi
Copy link

fcogidi commented Oct 11, 2023

On a few occasions while using this library, I've bumped against the issue having to set array values using advanced indexing. Here is an example:

import numpy as np
import numpy.array_api as anp

def one_hot_np(array, num_classes):
    n = array.shape[0]
    categorical = np.zeros((n, num_classes))
    categorical[np.arange(n), array] = 1
    return categorical


def one_hot_anp(array, num_classes):
    one_hot = anp.zeros((array.shape[0], num_classes))
    indices = anp.stack(
        (anp.arange(array.shape[0]), anp.reshape(array, (-1,))), axis=-1
    )
    indices = anp.reshape(indices, shape=(-1, indices.shape[-1]))

    for idx in range(indices.shape[0]):
        one_hot[tuple(indices[idx, ...])] = 1

    return one_hot

I'm using the numpy.array_api namespace because it follows the API standard closely.

Is there a different (better) way of setting values of an array using integer (array) indices that adhere to the 2021.12 version of the array API standard?

For the example I gave, I'm aware that I can do something like this (but not with numpy.array_api namespace, as it only supports v2021.12):

import numpy as np
import numpy.array_api as anp

def one_hot(array, num_classes):
    id_arr = anp.eye(num_classes)
    return np.take(id_arr, array, axis=0)

But I have other cases in my codebase that follow the first pattern - looping through array indices and using basic indexing to set array values. For example, using the indices from xp.argsort to mark the top-k values. Is there a better way than looping through the indices?

@asmeurer
Copy link
Member

There's a plan to add a guide for this sort of thing to the standard data-apis/array-api#668, although there's nothing there yet for alternatives to integer indexing. Most likely your best bet is to just manually use put or integer indexing for libraries that you know have that functionality.

See also data-apis/array-api#177 and data-apis/array-api#629

@seberg
Copy link

seberg commented Oct 12, 2023

That doesn't mean that partial advanced-indexing isn't very useful in other cases.

But for one-hot, there is a conceptual rewrite available that is probably likely faster anyway as long as num_classes is relatively small.

def one_hot(array, classes):
    classes = np.arange(num_classes, dtype=array.dtype)
    return classes == array[..., np.newaxis]

(In NumPy, you could use out=... to force whichever dtype you want on the result, which should be faster when arrays are large, but also adds a fair bit of overhead.)

@fcogidi
Copy link
Author

fcogidi commented Oct 16, 2023

Thank you @asmeurer and @seberg for your quick and helpful reply!

@fcogidi fcogidi closed this as completed Oct 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants