-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Description
Describe the workflow you want to enable
Naive Bayes classifiers (GaussianNB, MultinomialNB, BernoulliNB, ComplementNB, CategoricalNB) currently do not support the class_weight parameter, while almost all other scikit-learn classifiers do. This creates API inconsistency and requires users to manually compute and pass sample_weight for imbalanced datasets.
Current Workaround
from sklearn.naive_bayes import GaussianNB
from sklearn.utils.class_weight import compute_sample_weight
sample_weight = compute_sample_weight('balanced', y)
gnb = GaussianNB()
gnb.fit(X, y, sample_weight=sample_weight)Proposed Enhancement
gnb = GaussianNB(class_weight='balanced')
gnb.fit(X, y)Describe your proposed solution
Add a class_weight parameter to all Naive Bayes classifiers that internally converts class weights to sample weights before calling the existing fit() logic. This leverages the fact that all NB classifiers already support sample_weight in their fitting methods.
class GaussianNB(_BaseNB):
_parameter_constraints: dict = {
"priors": ["array-like", None],
"var_smoothing": [Interval(Real, 0, None, closed="left")],
"class_weight": ["dict", "balanced", None], # NEW
}
def __init__(self, *, priors=None, var_smoothing=1e-9, class_weight=None):
self.priors = priors
self.var_smoothing = var_smoothing
--> self.class_weight = class_weight Describe alternatives you've considered, if relevant
No response
Additional context
This is a improvement for API consistency. While the workaround exists, having uniform behavior across classifiers reduces confusion and improves discoverability, especially for beginners.