Skip to content

Fix ClassifierChain error message for multiclass-multioutput targets #31797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
process_routing,
)
from .utils.metaestimators import available_if
from .utils.multiclass import check_classification_targets
from .utils.multiclass import check_classification_targets, type_of_target
from .utils.parallel import Parallel, delayed
from .utils.validation import (
_check_method_params,
Expand Down Expand Up @@ -1086,6 +1086,20 @@ def fit(self, X, Y, **fit_params):
"""
_raise_for_params(fit_params, self, "fit")

# Validate input data
X, Y = validate_data(self, X, Y, multi_output=True, accept_sparse=True)

# Check if we have multiclass-multioutput targets, which are not supported
target_type = type_of_target(Y)
if target_type == "multiclass-multioutput":
raise ValueError(
"ClassifierChain does not support multiclass-multioutput "
"targets. ClassifierChain is designed for multilabel "
"classification where each target is binary (0 or 1). "
"Your target has multiple classes per output. "
"Consider using MultiOutputClassifier instead."
)

super().fit(X, Y, **fit_params)
self.classes_ = [estimator.classes_ for estimator in self.estimators_]
return self
Expand Down
48 changes: 48 additions & 0 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,51 @@ def test_base_estimator_deprecation(Estimator):

with pytest.raises(ValueError):
Estimator(base_estimator=estimator, estimator=estimator).fit(X, y)


def test_classifier_chain_multiclass_multioutput_error():
"""Test that ClassifierChain raises clear error for multiclass-multioutput targets."""
# Create multiclass-multioutput data (3 classes per output)
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
y_multiclass_multioutput = np.array([
[0, 1], # First output: class 0, Second output: class 1
[1, 2], # First output: class 1, Second output: class 2
[2, 0], # First output: class 2, Second output: class 0
[0, 2], # First output: class 0, Second output: class 2
[1, 1], # First output: class 1, Second output: class 1
[2, 0], # First output: class 2, Second output: class 0
])

# This should raise a ValueError with clear message
chain = ClassifierChain(LogisticRegression(random_state=42))

expected_msg = (
"ClassifierChain does not support multiclass-multioutput targets. "
"ClassifierChain is designed for multilabel classification where "
"each target is binary \\(0 or 1\\). Your target has multiple classes "
"per output. Consider using MultiOutputClassifier instead."
)

with pytest.raises(ValueError, match=expected_msg):
chain.fit(X, y_multiclass_multioutput)


def test_classifier_chain_multilabel_still_works():
"""Test that ClassifierChain still works correctly with multilabel data."""
# Create proper multilabel data (binary values only)
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y_multilabel = np.array([
[0, 1], # Not label 1, has label 2
[1, 1], # Has label 1, has label 2
[0, 0], # No labels
[1, 0], # Has label 1, not label 2
])

# This should work fine (no error)
chain = ClassifierChain(LogisticRegression(random_state=42))
chain.fit(X, y_multilabel)

# Basic functionality check
predictions = chain.predict(X)
assert predictions.shape == y_multilabel.shape
assert np.all((predictions == 0) | (predictions == 1)) # Should be binary
Loading