Skip to content

Commit

Permalink
Merge pull request #189 from asmeurer/2022-fix
Browse files Browse the repository at this point in the history
Some fixes for v2022.12
  • Loading branch information
asmeurer authored Jun 16, 2023
2 parents fb49802 + 9064d5d commit f82c7bc
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 13 deletions.
5 changes: 2 additions & 3 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,18 @@ def is_int_dtype(dtype):
return dtype in all_int_dtypes


def is_float_dtype(dtype):
def is_float_dtype(dtype, *, include_complex=True):
# None equals NumPy's xp.float64 object, so we specifically check it here.
# xp.float64 is in fact an alias of np.dtype('float64'), and its equality
# with None is meant to be deprecated at some point.
# See https://github.com/numpy/numpy/issues/18434
if dtype is None:
return False
valid_dtypes = real_float_dtypes
if api_version > "2021.12":
if api_version > "2021.12" and include_complex:
valid_dtypes += complex_dtypes
return dtype in valid_dtypes


def get_scalar_type(dtype: DataType) -> ScalarType:
if dtype in all_int_dtypes:
return int
Expand Down
4 changes: 2 additions & 2 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def assert_array_elements(
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
f"{f_func}"
)
_assert_float_element(at_out.real, at_expected.real, msg)
_assert_float_element(at_out.imag, at_expected.imag, msg)
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
else:
assert xp.all(
out == expected
Expand Down
2 changes: 2 additions & 0 deletions array_api_tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_finfo(dtype_name):
assert isinstance(
value, stype
), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}"
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
# TODO: test values


Expand All @@ -179,6 +180,7 @@ def test_iinfo(dtype_name):
assert isinstance(
value, int
), f"type(out.{attr})={type(value)!r}, but should be int {f_func}"
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
# TODO: test values


Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_take(x, data):
f_axis_idx = sh.fmt_idx("x", axis_idx)
for i in _indices:
f_take_idx = sh.fmt_idx(f_axis_idx, i)
indexed_x = x[axis_idx][i]
indexed_x = x[axis_idx][i, ...]
for at_idx in sh.ndindex(indexed_x.shape):
out_idx = next(out_indices)
ph.assert_0d_equals(
Expand Down
8 changes: 4 additions & 4 deletions array_api_tests/test_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_unique_all(x):

if dh.is_float_dtype(out.values.dtype):
assume(math.prod(x.shape) <= 128) # may not be representable
expected = sum(v for k, v in counts.items() if math.isnan(k))
expected = sum(v for k, v in counts.items() if cmath.isnan(k))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"


Expand All @@ -137,7 +137,7 @@ def test_unique_counts(x):
for idx in sh.ndindex(out.values.shape):
val = scalar_type(out.values[idx])
count = int(out.counts[idx])
if math.isnan(val):
if cmath.isnan(val):
nans += 1
assert count == 1, (
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
Expand All @@ -159,7 +159,7 @@ def test_unique_counts(x):
vals_idx[val] = idx
if dh.is_float_dtype(out.values.dtype):
assume(math.prod(x.shape) <= 128) # may not be representable
expected = sum(v for k, v in counts.items() if math.isnan(k))
expected = sum(v for k, v in counts.items() if cmath.isnan(k))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"


Expand Down Expand Up @@ -188,7 +188,7 @@ def test_unique_inverse(x):
nans = 0
for idx in sh.ndindex(out.values.shape):
val = scalar_type(out.values[idx])
if math.isnan(val):
if cmath.isnan(val):
nans += 1
else:
assert (
Expand Down
22 changes: 19 additions & 3 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from . import xps, api_version
from ._array_module import _UndefinedStub
from .typing import DataType

Expand Down Expand Up @@ -145,11 +145,19 @@ def test_prod(x, data):
_dtype = x.dtype
else:
_dtype = default_dtype
else:
elif dh.is_float_dtype(x.dtype, include_complex=False):
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
_dtype = x.dtype
else:
_dtype = dh.default_float
elif api_version > "2021.12":
# Complex dtype
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
_dtype = x.dtype
else:
_dtype = dh.default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype
if _dtype is None:
Expand Down Expand Up @@ -253,11 +261,19 @@ def test_sum(x, data):
_dtype = x.dtype
else:
_dtype = default_dtype
else:
elif dh.is_float_dtype(x.dtype, include_complex=False):
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
_dtype = x.dtype
else:
_dtype = dh.default_float
elif api_version > "2021.12":
# Complex dtype
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
_dtype = x.dtype
else:
_dtype = dh.default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype
if _dtype is None:
Expand Down

0 comments on commit f82c7bc

Please sign in to comment.