Skip to content

Commit aa680bc

Browse files
TST fix check_array_api_input device check (#31814)
Co-authored-by: Loïc Estève <[email protected]>
1 parent d80b0c7 commit aa680bc

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

sklearn/model_selection/tests/test_split.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,11 +1357,11 @@ def test_array_api_train_test_split(
13571357
assert get_namespace(y_train_xp)[0] == get_namespace(y_xp)[0]
13581358
assert get_namespace(y_test_xp)[0] == get_namespace(y_xp)[0]
13591359

1360-
# Check device and dtype is preserved on output
1361-
assert array_api_device(X_train_xp) == array_api_device(X_xp)
1362-
assert array_api_device(y_train_xp) == array_api_device(y_xp)
1363-
assert array_api_device(X_test_xp) == array_api_device(X_xp)
1364-
assert array_api_device(y_test_xp) == array_api_device(y_xp)
1360+
# Check device and dtype is preserved on output
1361+
assert array_api_device(X_train_xp) == array_api_device(X_xp)
1362+
assert array_api_device(y_train_xp) == array_api_device(y_xp)
1363+
assert array_api_device(X_test_xp) == array_api_device(X_xp)
1364+
assert array_api_device(y_test_xp) == array_api_device(y_xp)
13651365

13661366
assert X_train_xp.dtype == X_xp.dtype
13671367
assert y_train_xp.dtype == y_xp.dtype

sklearn/utils/estimator_checks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,8 @@ def check_array_api_input(
10931093
f"got {attribute_ns}"
10941094
)
10951095

1096-
assert array_device(est_xp_param) == array_device(X_xp)
1096+
with config_context(array_api_dispatch=True):
1097+
assert array_device(est_xp_param) == array_device(X_xp)
10971098

10981099
est_xp_param_np = _convert_to_numpy(est_xp_param, xp=xp)
10991100
if check_values:
@@ -1180,7 +1181,9 @@ def check_array_api_input(
11801181
f"got {result_ns}."
11811182
)
11821183

1183-
assert array_device(result_xp) == array_device(X_xp)
1184+
with config_context(array_api_dispatch=True):
1185+
assert array_device(result_xp) == array_device(X_xp)
1186+
11841187
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
11851188

11861189
if check_values:
@@ -1205,7 +1208,8 @@ def check_array_api_input(
12051208
f" {input_ns}, got {inverse_result_ns}."
12061209
)
12071210

1208-
assert array_device(invese_result_xp) == array_device(X_xp)
1211+
with config_context(array_api_dispatch=True):
1212+
assert array_device(invese_result_xp) == array_device(X_xp)
12091213

12101214
invese_result_xp_np = _convert_to_numpy(invese_result_xp, xp=xp)
12111215
if check_values:

sklearn/utils/tests/test_array_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ def test_average(
166166
with config_context(array_api_dispatch=True):
167167
result = _average(array_in, axis=axis, weights=weights, normalize=normalize)
168168

169-
if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
170-
# NumPy 2.0 has a problem with the device attribute of scalar arrays:
171-
# https://github.com/numpy/numpy/issues/26850
172-
assert device(array_in) == device(result)
169+
if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
170+
# NumPy 2.0 has a problem with the device attribute of scalar arrays:
171+
# https://github.com/numpy/numpy/issues/26850
172+
assert device(array_in) == device(result)
173173

174174
result = _convert_to_numpy(result, xp)
175175
assert_allclose(result, expected, atol=_atol_for_type(dtype_name))

0 commit comments

Comments
 (0)