Skip to content

Commit 11a6186

Browse files
committed
modify the scheme to put pyarrow array in
1 parent a06771c commit 11a6186

File tree

3 files changed

+293
-132
lines changed

3 files changed

+293
-132
lines changed

sklearn/utils/_testing.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -974,14 +974,14 @@ def _convert_container(
974974
container,
975975
constructor_type,
976976
dtype=None,
977-
sparse_container=None,
978-
sparse_format="csr",
979-
constructor_lib="pandas",
977+
constructor_lib=None,
980978
minversion=None,
979+
sparse_container="matrix",
980+
sparse_format="csr",
981981
column_names=None,
982982
categorical_feature_names=None,
983983
):
984-
"""Convert a given container to a specific array-like with a dtype.
984+
"""Convert a given container to another specified container type.
985985
986986
Parameters
987987
----------
@@ -1001,31 +1001,35 @@ def _convert_container(
10011001
Force the dtype of the container. Does not apply when `constructor_type` is
10021002
"slice".
10031003
1004-
sparse_container : {"matrix", "array"}, default=None
1005-
The sparse container to use. Only applies when `constructor_type` is "array"`.
1006-
Note that if this parameter is not `None` and `container` is 1D of length n,
1007-
then the converted container will be 2D of shape (1, n).
1008-
1009-
- `None` returns a dense numpy array
1010-
- "matrix" returns a sparse matrix
1011-
- "array" returns a sparse array and skip test if scipy < 1.8
1012-
1013-
sparse_format : {"csc", "csr"}, default="csr"
1014-
The sparse format to use. Only applies when `constructor_type` is "array" and
1015-
`sparse_container` is not `None`.
1016-
1017-
constructor_lib : {"pandas", "polars", "pyarrow"}, default="pandas"
1018-
The library to use. Only applies when `constructor_type` is one of "dataframe",
1019-
"series", and "index". Skip the test if the specified library is not available.
1004+
constructor_lib : {"numpy", "scipy", "pandas", "polars", "pyarrow"}, default=None
1005+
The library to use. Only applies when `constructor_type` is one of "array",
1006+
"dataframe", "series", and "index". Skip the test if the specified library is
1007+
not available.
10201008
1021-
- "pandas" supports `constructor_type`s "dataframe", "series", and "index"
1009+
- "numpy" supports "array"
1010+
- "scipy" supports "array"
1011+
- "pandas" supports "dataframe", "series", and "index"
10221012
- "polars" supports "dataframe" and "series"
1023-
- "pyarrow" supports "dataframe" (i.e., Table)
1013+
- "pyarrow" supports "array" and "dataframe"
10241014
10251015
minversion : str, default=None
10261016
Minimum version for package to install. Only applies when `constructor_lib` is
10271017
applicable.
10281018
1019+
sparse_container : {"matrix", "array"}, default="matrix"
1020+
The sparse container to use. Only applies when `constructor_type` is "array"
1021+
and `constructor_lib` is "scipy". Note that if this parameter is not `None` and
1022+
`container` is 1D of length n, then the converted container will be 2D of shape
1023+
(1, n).
1024+
1025+
- `None` returns a dense numpy array
1026+
- "matrix" returns a sparse matrix
1027+
- "array" returns a sparse array and skip test if scipy < 1.8
1028+
1029+
sparse_format : {"csr", "csc"}, default="csr"
1030+
The sparse format to use. Only applies when `constructor_type` is "array" and
1031+
`constructor_lib` is "scipy".
1032+
10291033
column_names : index or array-like, default=None
10301034
The column names of the container. Only applies when `constructor_type` is
10311035
"dataframe".
@@ -1039,7 +1043,7 @@ def _convert_container(
10391043
converted_container
10401044
"""
10411045
if sp.sparse.issparse(container) and not (
1042-
constructor_type == "array" and sparse_container is not None
1046+
constructor_type == "array" and constructor_lib == "scipy"
10431047
):
10441048
# If the container is sparse but the target is not sparse, convert to dense
10451049
# array in the first place; otherwise `np.asarray` may complain
@@ -1064,6 +1068,8 @@ def _convert_container(
10641068
return _convert_container_to_array(
10651069
container,
10661070
dtype,
1071+
constructor_lib or "numpy",
1072+
minversion,
10671073
sparse_container,
10681074
sparse_format,
10691075
)
@@ -1072,7 +1078,7 @@ def _convert_container(
10721078
return _convert_container_to_dataframe(
10731079
container,
10741080
dtype,
1075-
constructor_lib,
1081+
constructor_lib or "pandas",
10761082
minversion,
10771083
column_names,
10781084
categorical_feature_names,
@@ -1082,50 +1088,64 @@ def _convert_container(
10821088
return _convert_container_to_series(
10831089
container,
10841090
dtype,
1085-
constructor_lib,
1091+
constructor_lib or "pandas",
10861092
minversion,
10871093
)
10881094

10891095
if constructor_type == "index":
10901096
return _convert_container_to_index(
10911097
container,
10921098
dtype,
1093-
constructor_lib,
1099+
constructor_lib or "pandas",
10941100
minversion,
10951101
)
10961102

10971103

1098-
def _convert_container_to_array(container, dtype, sparse_container, sparse_format):
1104+
def _convert_container_to_array(
1105+
container,
1106+
dtype,
1107+
constructor_lib,
1108+
minversion,
1109+
sparse_container,
1110+
sparse_format,
1111+
):
10991112
"""Helper for `_convert_container` when `constructor_type` is "array"."""
1100-
if sparse_container is None:
1101-
# Convert to dense numpy array
1113+
if constructor_lib not in ("numpy", "scipy", "pyarrow"):
1114+
raise ValueError(f"{constructor_lib=} is incompatible with array")
1115+
1116+
if constructor_lib == "numpy":
11021117
return np.asarray(container, dtype=dtype)
11031118

1104-
if sp_version < parse_version("1.8") and sparse_container == "array":
1105-
pytest.skip("Sparse arrays require scipy >= 1.8")
1119+
if constructor_lib == "pyarrow":
1120+
pa = pytest.importorskip("pyarrow", minversion=minversion)
1121+
return pa.array(container, type=dtype)
11061122

1107-
if not sp.sparse.issparse(container):
1108-
# For scipy >= 1.13, sparse array constructed from 1d array may be
1109-
# 1d or raise an exception. To avoid this, we make sure that the
1110-
# input container is 2d. For more details, see
1111-
# https://github.com/scipy/scipy/pull/18530#issuecomment-1878005149
1112-
container = np.atleast_2d(container)
1123+
if constructor_lib == "scipy":
1124+
if sp_version < parse_version("1.8") and sparse_container == "array":
1125+
pytest.skip("Sparse arrays require scipy >= 1.8")
11131126

1114-
supported_sparse_containers = ("matrix", "array")
1115-
if sparse_container not in supported_sparse_containers:
1116-
raise ValueError(
1117-
f"Invalid {sparse_container=}; expected one of"
1118-
f" {supported_sparse_containers}"
1119-
)
1127+
if not sp.sparse.issparse(container):
1128+
# For scipy >= 1.13, sparse array constructed from 1d array may be
1129+
# 1d or raise an exception. To avoid this, we make sure that the
1130+
# input container is 2d. For more details, see
1131+
# https://github.com/scipy/scipy/pull/18530#issuecomment-1878005149
1132+
container = np.atleast_2d(container)
11201133

1121-
supported_sparse_formats = ("csr", "csc")
1122-
if sparse_format not in supported_sparse_formats:
1123-
raise ValueError(
1124-
f"Invalid {sparse_format=}; expected one of {supported_sparse_formats}"
1125-
)
1134+
supported_sparse_containers = ("matrix", "array")
1135+
if sparse_container not in supported_sparse_containers:
1136+
raise ValueError(
1137+
f"Invalid {sparse_container=}; expected one of"
1138+
f" {supported_sparse_containers}"
1139+
)
1140+
1141+
supported_sparse_formats = ("csr", "csc")
1142+
if sparse_format not in supported_sparse_formats:
1143+
raise ValueError(
1144+
f"Invalid {sparse_format=}; expected one of {supported_sparse_formats}"
1145+
)
11261146

1127-
sparse_constructor = getattr(sp.sparse, f"{sparse_format}_{sparse_container}")
1128-
return sparse_constructor(container, dtype=dtype)
1147+
sparse_constructor = getattr(sp.sparse, f"{sparse_format}_{sparse_container}")
1148+
return sparse_constructor(container, dtype=dtype)
11291149

11301150

11311151
def _convert_container_to_dataframe(

0 commit comments

Comments
 (0)