Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-trivial classes_ in LogisticRegression #6346

Open
wants to merge 5 commits into
base: branch-25.04
Choose a base branch
from

Conversation

jcrist
Copy link
Member

@jcrist jcrist commented Feb 20, 2025

Scikit-Learn's LogisticRegression is a bit unique in that it natively supports complex labels (e.g. raw strings/categories/non-monotonically increasing ints), rather than requiring that the labels are pre-encoded. It does this by internally using a LabelEncoder during fit, then converting the predicted labels back to the original label dtype in predict.

This PR adds support for this in cuml, improving compatibility with sklearn. This required:

  • A small improvement to LabelEncoder to support all the input types that cuml natively supports.
  • Addition of a LabelEncoder in LogisticRegression.fit
  • Changing how we store classes_. Previously this used the descriptor functionality to support different container types. However, CumlArray/cupy/numba don't support non-numeric types, so we can't use that to store the classes anymore. We now always store classes_ as a numpy array. I think this is fine - the size of classes is small - and also makes us a bit more compatible with sklearn since we can better ensure our dtypes match theirs.
  • Changing predict to convert the numeric output back into the original classes. This was complicated since CumlArray/cupy/numba don't support non-numeric types, which means that most of our existing output_type machinery fails for these cases. I hacked something in that I think is sufficient, but it's definitely a hack.
  • Addition of a new test to check things work properly across dtypes and container types.

This is an alternative to #6328. The fix here doesn't add any additional state, and I believe the test cases added here provide better coverage of the behavior we're trying to ensure.

@jcrist jcrist requested a review from a team as a code owner February 20, 2025 19:32
@jcrist jcrist requested review from csadorf and vyasr February 20, 2025 19:32
@github-actions github-actions bot added the Cython / Python Cython or Python issue label Feb 20, 2025
if is_numeric:
if (self.classes_ == np.arange(nclasses)).all():
# Fast path for common case of monotonically increasing numeric classes
out = indices.to_output("cupy", output_dtype=self.classes_.dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that any code working prior to this change will take this path.

@jcrist jcrist added improvement Improvement / enhancement to an existing function cuml-cpu non-breaking Non-breaking change labels Feb 20, 2025
@jcrist jcrist force-pushed the logreg-complex-classes branch from f36aaf7 to 8eb40ba Compare February 20, 2025 19:35
Copy link
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my initial review, this looks overall very sound. It's a bit of hack to work around the CumlArray limitation of course, but that was expected and we can of course address that moving forward.

I have not had the chance yet to investigate whether our tests fully cover cases where we call LinearRegression.predict() internally and whether we maintain behavior in that case. Since we are replicating the api decorator in predict() there is a small chance that we are not covering all edge cases. That is not the only reason I'm not approving just yet.

@jcrist
Copy link
Member Author

jcrist commented Feb 20, 2025

Looks like there's some test failures that I missed fixing locally. Most look pretty straightforward - if we want to still try and get this in pre-patch I can work on resolving these tomorrow.

Copy link
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, would just add a small test to prevent any regression for issue https://github.com/rapidsai/cuml-accel/issues/94

@jcrist
Copy link
Member Author

jcrist commented Feb 21, 2025

That case should already sufficiently tested by test_logistic_regression_complex_classes added here (the int32 and float32 cases to be specific)

@csadorf
Copy link
Contributor

csadorf commented Feb 21, 2025

Looks like there's some test failures that I missed fixing locally. Most look pretty straightforward - if we want to still try and get this in pre-patch I can work on resolving these tomorrow.

@jcrist and I just had a brief offline chat and agreed that it's worth trying to address the issues before code freeze.

Previously assumptions were made that prevented supporting all the
possible input types `cuml` normally supports.

`LabelEncoder` should probably be fixed to play nicely with cuml's
output type handling, but that issue is beyond the requirements of this
PR.
Scikit-learn's `LogisticRegression` contains support for non-trivial
classes (those that would typically require encoding before processing).

This PR adds support for that in both `fit` and `predict`. This is
complicated by `CumlArray`/`cupy`/`numba` not supporting non-numeric
types, which means we need to special case the output handling in
`predict`. It's gross, but functional.
@jcrist jcrist force-pushed the logreg-complex-classes branch from 8eb40ba to 17c4d07 Compare February 21, 2025 22:08
@jcrist
Copy link
Member Author

jcrist commented Feb 21, 2025

I think I've fixed all the test failures, but who knows. Also added a fix for categorical y support, which was another sklearn compatibility bug we had.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuml-cpu Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants