diff --git a/README.rst b/README.rst index 96bc81fa3822d..5be5e3fea7355 100644 --- a/README.rst +++ b/README.rst @@ -58,7 +58,7 @@ scikit-learn 0.23 and later require Python 3.6 or newer. Scikit-learn plotting capabilities (i.e., functions start with ``plot_`` and classes end with "Display") require Matplotlib (>= 2.1.1). For running the examples Matplotlib >= 2.1.1 is required. A few examples require -scikit-image >= 0.13, a few examples require pandas >= 0.18.0. +scikit-image >= 0.13, a few examples require pandas >= 0.21.0. User installation ~~~~~~~~~~~~~~~~~ diff --git a/doc/install.rst b/doc/install.rst index 6a2b83605c1a6..31ecfeffc7d44 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -134,7 +134,7 @@ it as ``scikit-learn[alldeps]``. Scikit-learn plotting capabilities (i.e., functions start with "plot\_" and classes end with "Display") require Matplotlib (>= 2.1.1). For running the examples Matplotlib >= 2.1.1 is required. A few examples require -scikit-image >= 0.13, a few examples require pandas >= 0.18.0. +scikit-image >= 0.13, a few examples require pandas >= 0.21.0. .. warning:: diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index deefaaf642c39..3a2f816f98648 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -109,10 +109,10 @@ set of classifiers is created by introducing randomness in the classifier construction. The prediction of the ensemble is given as the averaged prediction of the individual classifiers. -As other classifiers, forest classifiers have to be fitted with two -arrays: a sparse or dense array X of size ``[n_samples, n_features]`` holding the -training samples, and an array Y of size ``[n_samples]`` holding the -target values (class labels) for the training samples:: +As other classifiers, forest classifiers have to be fitted with two arrays: a +sparse or dense array X of size ``[n_samples, n_features]`` holding the +training samples, and an array Y of size ``[n_samples]`` holding the target +values (class labels) for the training samples:: >>> from sklearn.ensemble import RandomForestClassifier >>> X = [[0, 0], [1, 1]] @@ -200,6 +200,9 @@ in bias:: Parameters ---------- +Impactful parameters +.................... + The main parameters to adjust when using these methods is ``n_estimators`` and ``max_features``. The former is the number of trees in the forest. The larger the better, but also the longer it will take to compute. In addition, note that @@ -223,10 +226,50 @@ or out-of-bag samples. This can be enabled by setting ``oob_score=True``. .. note:: - The size of the model with the default parameters is :math:`O( M * N * log (N) )`, - where :math:`M` is the number of trees and :math:`N` is the number of samples. - In order to reduce the size of the model, you can change these parameters: - ``min_samples_split``, ``max_leaf_nodes``, ``max_depth`` and ``min_samples_leaf``. + The size of the model with the default parameters is :math:`O( M * N * log + (N) )`, where :math:`M` is the number of trees and :math:`N` is the number + of samples. In order to reduce the size of the model, you can change these + parameters: ``min_samples_split``, ``max_leaf_nodes``, ``max_depth`` and + ``min_samples_leaf``. + + .. _balanced_bootstrap: + +Learning from imbalanced classes dataset +........................................ + +In some datasets, the number of samples per classes might vary tremendously +(e.g 100 samples for a "majority" class for a single sample in a "minority" +class). Learning from these imbalanced dataset is challenging. The tree +criteria (i.e. gini or entropy) are sensitive to class imbalanced and will +naturally favor the classes with the most samples given during ``fit``. + +The :class:`RandomForestClassifier` provides a parameter `class_weight` with +the option `"balanced_bootstrap"` to alleviate the bias induces by the class +imbalance. This strategy will create a bootstrap subsample for the "minority" +class and draw with replacement the same amount of training instances from the +other classes. Each balanced subsample is given to each tree of the ensemble to +be fitted as proposed in [CLB2004]_. This algorithm is also called balanced +random-forest. + +`class_weight="balanced"` and `class_weight="balanced_subsample"` provide +alternative balancing strategies which are not as efficient in case of large +difference between the class frequencies. + +.. note:: + Be aware that `sample_weight` will be taken into account when setting + `class_weight="balanced_bootstrap"`. Thus, it is recommended to not manually + balanced the dataset using `sample_weight` and use + `class_weight="balanced_bootstrap"` at the same time. + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_plot_learn_from_imbalanced_dataset.py` + +.. topic:: References + + .. [CLB2004] C. Chen, A. Liaw, and L. Breiman, "Using random forest to learn + imbalanced data." University of California, Berkeley + 110.1-12, 24, 2004. Parallelization --------------- diff --git a/doc/themes/scikit-learn-modern/static/css/theme.css b/doc/themes/scikit-learn-modern/static/css/theme.css index a77fb03e36f65..2b80d6fe2b762 100644 --- a/doc/themes/scikit-learn-modern/static/css/theme.css +++ b/doc/themes/scikit-learn-modern/static/css/theme.css @@ -963,6 +963,44 @@ div.sphx-glr-thumbcontainer { } } +/* Pandas dataframe css */ +/* Taken from: https://github.com/spatialaudio/nbsphinx/blob/fb3ba670fc1ba5f54d4c487573dbc1b4ecf7e9ff/src/nbsphinx.py#L587-L619 */ +/* FIXME: to be removed when sphinx-gallery >= 5.0 will be released */ + +table.dataframe { + border: none !important; + border-collapse: collapse; + border-spacing: 0; + border-color: transparent; + color: black; + font-size: 12px; + table-layout: fixed; +} +table.dataframe thead { + border-bottom: 1px solid black; + vertical-align: bottom; +} +table.dataframe tr, +table.dataframe th, +table.dataframe td { + text-align: right; + vertical-align: middle; + padding: 0.5em 0.5em; + line-height: normal; + white-space: normal; + max-width: none; + border: none; +} +table.dataframe th { + font-weight: bold; +} +table.dataframe tbody tr:nth-child(odd) { + background: #f5f5f5; +} +table.dataframe tbody tr:hover { + background: rgba(66, 165, 245, 0.2); +} + /* rellinks */ .sk-btn-rellink { diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index a76fa442db7c4..49accca26465b 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -48,9 +48,24 @@ Changelog :mod:`sklearn.cluster` ...................... +- |Fix| example fix in model XXX. :pr:`xxxx` or :issue:`xxxx` by + :user:`name ` + + +:mod:`sklearn.ensemble` +....................... + - |Enhancement| :class:`cluster.AgglomerativeClustering` has a faster and more more memory efficient implementation of single linkage clustering. :pr:`11514` by :user:`Leland McInnes `. + +- |Efficiency| add the option `class_weight="balanced_bootstrap"` in + :class:`ensemble.RandomForestClassifier`. This option will ensure that each + tree is trained on a subsample with equal number of instances from each + class. This algorithm is known as balanced-random forest. + :pr:`13227` by :user:`Eric Potash `, :user:`Christos Aridas ` + and :user:`Guillaume Lemaitre `. + - |Fix| :class:`cluster.KMeans` with ``algorithm="elkan"`` now converges with ``tol=0`` as with the default ``algorithm="full"``. :pr:`16075` by :user:`Erich Schubert `. diff --git a/examples/plot_learn_from_imbalanced_dataset.py b/examples/plot_learn_from_imbalanced_dataset.py new file mode 100644 index 0000000000000..5c6c38cabb816 --- /dev/null +++ b/examples/plot_learn_from_imbalanced_dataset.py @@ -0,0 +1,283 @@ +""" +============================== +Learn from imbalanced datasets +============================== + +This example illustrates the problem induced by learning on datasets having +imbalanced classes. Subsequently, we compare different approaches alleviating +these negative effects using other estimators. + +""" + +# Authors: Guillaume Lemaitre +# License: MIT + +print(__doc__) + +############################################################################### +# Problem definition +############################################################################### + +from sklearn.datasets import fetch_openml + +df, y = fetch_openml('adult', version=2, as_frame=True, return_X_y=True) +# we are dropping the following features: +# - "fnlwgt": this feature was created while studying the "adult" dataset. +# Thus, we will not use this feature which is not acquired during the survey. +# - "education-num": it is encoding the same information than "education". +# Thus, we are removing one of these 2 features. +df = df.drop(columns=['fnlwgt', 'education-num']) + +############################################################################### +# The "adult" dataset as a class ratio of about 3:1 + +classes_count = y.value_counts() +classes_count + +############################################################################### +# This dataset is only slightly imbalanced. To better highlight the effect of +# learning from an imbalanced dataset, we will increase its ratio to 100:1 + +import numpy as np +import pandas as pd + +rng = np.random.RandomState(0) + +# we define a ratio 100:1 +n_samples_minority_class = classes_count.max() // 100 + +mask_minority_class = y == classes_count.idxmin() +indices_minority_class = np.flatnonzero(mask_minority_class) +indices_minority_class_subsampled = rng.choice( + indices_minority_class, size=n_samples_minority_class, replace=False +) + +# sample the dataframe +df_res = pd.concat([df.loc[~mask_minority_class, :], + df.loc[indices_minority_class_subsampled, :]]) +# sample the target +y_res = pd.concat([y.loc[~mask_minority_class], + y.loc[indices_minority_class_subsampled]]) +y_res.value_counts() + +############################################################################### +# For the rest of the notebook, we will make a single split to get training +# and testing data. Note that in practise, you should always use +# cross-validation to have an estimate of the performance variation. You can +# refer to the following example showing how to use a scikit-learn pipeline +# within a grid-search: +# :ref:`sphx_glr_auto_examples_compose_plot_compare_reduction.py` + +from sklearn.model_selection import train_test_split + +X_train, X_test, y_train, y_test = train_test_split( + df_res, y_res, stratify=y_res, random_state=42 +) + +############################################################################### +# As a baseline, we use a classifier which will always predict the majority +# class independently of the features provided. + +from sklearn.dummy import DummyClassifier + +dummy_clf = DummyClassifier(strategy="most_frequent") +score = dummy_clf.fit(X_train, y_train).score(X_test, y_test) +print(f"Accuracy score of a dummy classifier: {score:.3f}") + +############################################################################## +# Instead of using the accuracy, we can use the balanced accuracy which will +# take into account the balancing issue. + +from sklearn.metrics import balanced_accuracy_score + +y_pred = dummy_clf.predict(X_test) +score = balanced_accuracy_score(y_test, y_pred) +print(f"Balanced accuracy score of a dummy classifier: {score:.3f}") + +############################################################################### +# Strategies to learn from an imbalanced dataset +############################################################################### + +############################################################################### +# We will first define a helper function which will train a given model +# and compute both accuracy and balanced accuracy. The results will be stored +# in a dataframe + + +def evaluate_classifier(clf, clf_name=None): + from sklearn.pipeline import Pipeline + if clf_name is None: + if isinstance(clf, Pipeline): + clf_name = clf[-1].__class__.__name__ + else: + clf_name = clf.__class__.__name__ + acc = clf.fit(X_train, y_train).score(X_test, y_test) + y_pred = clf.predict(X_test) + bal_acc = balanced_accuracy_score(y_test, y_pred) + clf_score = pd.DataFrame( + {"Accuracy": acc, "Balanced accuracy": bal_acc}, + index=[clf_name] + ) + # to avoid passing df_scores and returning it, we make it a global variable + global df_scores + df_scores = pd.concat([df_scores, clf_score], axis=0).round(decimals=3) + + +# Let's define an empty dataframe to store the results +df_scores = pd.DataFrame() + +############################################################################### +# Dummy baseline +# .............. +# +# Before to train a real machine learning model, we put the +# :class:`sklearn.dummy.DummyClassifier`'s results in the dataframe as a +# baseline. + +evaluate_classifier(dummy_clf) +df_scores + +############################################################################### +# Linear classifier baseline +# .......................... +# +# We will create a machine learning pipeline using a +# :class:`sklearn.linear_model.LogisticRegression` classifier. In this regard, +# we will need to one-hot encode the categorical columns and standardized the +# numerical columns before to inject the data into the +# :class:`sklearn.linear_model.LogisticRegression` classifier. +# +# First, we preprocess the data using imputation of the missing values and +# one-hot encoding for the categorical features and a standard scaler for the +# numerical ones. + +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import OneHotEncoder +from sklearn.pipeline import make_pipeline + +num_pipe = make_pipeline( + StandardScaler(), SimpleImputer(strategy="mean", add_indicator=True) +) +cat_pipe = make_pipeline( + SimpleImputer(strategy="constant", fill_value="missing"), + OneHotEncoder(handle_unknown="ignore") +) + +############################################################################### +# Then, we can create a preprocessor which will dispatch the categorical +# columns to the categorical pipeline and the numerical columns to the +# numerical pipeline + +from sklearn.compose import ColumnTransformer +from sklearn.compose import make_column_selector as selector + +preprocessor_linear = ColumnTransformer( + [("num-pipe", num_pipe, selector(dtype_include=np.number)), + ("cat-pipe", cat_pipe, selector(dtype_include="category"))], + n_jobs=2 +) + +############################################################################### +# Finally, we connect our preprocessor with our `LogisticRegression`. We can +# then evaluate our model. + +from sklearn.linear_model import LogisticRegression + +lr_clf = make_pipeline( + preprocessor_linear, LogisticRegression(max_iter=1000) +) +evaluate_classifier(lr_clf) +df_scores + +############################################################################### +# We can see that our linear model is learning slightly better than our dummy +# baseline. However, it is impacted by the class imbalance. +# +# We can verify that something similar is happening with a tree-based model +# such as :class:`sklearn.ensemble.RandomForestClassifier`. Note that we don't +# need to scale the data for the tree-based models. We also use an ordinal +# encoder instead of a one-hot enconder. This encoding is more efficient since +# it will not create extra columns scanned during the search of the best split +# by the trees. + +from sklearn.preprocessing import OrdinalEncoder +from sklearn.ensemble import RandomForestClassifier + +cat_pipe = make_pipeline( + SimpleImputer(strategy="constant", fill_value="missing"), + OrdinalEncoder() +) + +preprocessor_tree = ColumnTransformer( + [("cat-pipe", cat_pipe, selector(dtype_include="category"))], + remainder="passthrough", + n_jobs=2 +) + +rf_clf = make_pipeline( + preprocessor_tree, RandomForestClassifier(random_state=42, n_jobs=2) +) + +evaluate_classifier(rf_clf) +df_scores + +############################################################################### +# The :class:`sklearn.ensemble.RandomForestClassifier` is as well affected by +# the class imbalanced, slightly less than the linear model. Now, we will +# present different approach to improve the performance of these 2 models. +# +# Use `class_weight` +# .................. +# +# Most of the models in `scikit-learn` have a parameter `class_weight`. This +# parameter will affect the computation of the loss in linear model or the +# criterion in the tree-based model to penalize differently a false +# classification from the minority and majority class. We can set +# `class_weight="balanced"` such that the weight applied is inversely +# proportional to the class frequency. We test this parametrization in both +# linear model and tree-based model. + +lr_clf.set_params(logisticregression__class_weight="balanced") +evaluate_classifier( + lr_clf, "LogisticRegression with class weight='balanced'" +) +df_scores + +############################################################################### +# This weighting strategy is particularly efficient for the logistic +# regression. The balanced accuracy increased significantly. + +rf_clf.set_params(randomforestclassifier__class_weight="balanced") +evaluate_classifier( + rf_clf, "RandomForestClassifier with class weight='balanced'" +) +df_scores + +############################################################################### +# However, the same weighting strategy is not efficient with random forest. +# Indeed, the chosen criteria (e.g. entropy) is known to be sensitive to class +# imbalance. + +############################################################################### +# From a random-forest toward a balanced random-forest +# .................................................... +# +# One way to improve the accuracy of the tree-based method is to perform +# some under-sampling such that each tree in the ensemble is learning from a +# balanced set. +# +# The :class:`sklearn.ensemble.RandomForestClassifier` provide an option +# `class_weight="balanced_bootstrap"` such that each tree will learn from a +# bootstrap sample with equal number of samples for each class. This +# algorithm is also known as a balanced random-forest. + +rf_clf.set_params(randomforestclassifier__class_weight="balanced_bootstrap") +evaluate_classifier( + rf_clf, "RandomForestClassifier with class_weight='balanced_bootstrap'" +) +df_scores + +############################################################################### +# We can observe by taking a balanced bootstrap for each tree alleviate the +# overfitting in the random-forest. diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index f8141579d7f4f..c2e81985bdd9b 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -62,6 +62,7 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..utils.fixes import _joblib_parallel_args from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted, _check_sample_weight +from ..utils.validation import column_or_1d __all__ = ["RandomForestClassifier", @@ -112,21 +113,86 @@ def _get_n_samples_bootstrap(n_samples, max_samples): raise TypeError(msg.format(type(max_samples))) -def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap): +def _get_class_distribution(y): + """Compute the class distributions and associated statistics. + + Parameters + ---------- + y : array-like of shape (n_samples,) + Targets. + + Returns + ------- + class_indices : dict + Dictionary where the key is the class name and the value is an array + of the class indices. + + class_counts : dict + Dictionary where the key is the class name and the value is the counts. """ - Private function used to _parallel_build_trees function.""" + try: + y = column_or_1d(y) + except ValueError: + raise NotImplementedError( + "Balanced random-forest not yet implemented for multi-output" + ) - random_instance = check_random_state(random_state) - sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap) + classes = np.unique(y) + class_indices = {klass: np.flatnonzero(y == klass) for klass in classes} + class_counts = {klass: len(indices) + for klass, indices in class_indices.items()} + + return class_indices, class_counts + + +def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap, + balanced_bootstrap, y): + """Generate bootstrap samples. + + Parameters + ---------- + random_state : int, RandomState + Random state used in the different random draw. + n_samples : int + The number of samples in the dataset. + n_samples_bootstrap : int + The maximum number of samples required in the bootstrap sample. + balanced_bootstrap : bool + Whether or not the class counts should be balanced in the bootstrap + y : array-like of shape (n_samples,) or (n_samples, 1) + The array of targets used when a balanced bootstrap is requested. + Returns + ------- + samples_indices : ndarray of shape (n_bootstrap_sample,) + The indices of the bootstrap sample. + """ + rng = check_random_state(random_state) + if balanced_bootstrap: + class_indices, class_counts = _get_class_distribution(y) + n_classes = len(class_counts) + n_samples_per_class = min(min(class_counts.values()), + n_samples_bootstrap // n_classes) + sample_indices = np.empty( + n_classes * n_samples_per_class, dtype=class_indices[0].dtype + ) + for i, indices in enumerate(class_indices.values()): + sample_indices[i * n_samples_per_class: + (i + 1) * n_samples_per_class] = rng.choice( + indices, size=n_samples_per_class, replace=True + ) + else: + sample_indices = rng.randint(0, n_samples, n_samples_bootstrap) return sample_indices -def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap): +def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap, + balanced_bootstrap, y): """ Private function used to forest._set_oob_score function.""" - sample_indices = _generate_sample_indices(random_state, n_samples, - n_samples_bootstrap) + sample_indices = _generate_sample_indices( + random_state, n_samples, n_samples_bootstrap, balanced_bootstrap, y + ) sample_counts = np.bincount(sample_indices, minlength=n_samples) unsampled_mask = sample_counts == 0 indices_range = np.arange(n_samples) @@ -137,7 +203,7 @@ def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap): def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, verbose=0, class_weight=None, - n_samples_bootstrap=None): + n_samples_bootstrap=None, balanced_bootstrap=False): """ Private function used to fit a single tree in parallel.""" if verbose > 1: @@ -150,8 +216,11 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, else: curr_sample_weight = sample_weight.copy() - indices = _generate_sample_indices(tree.random_state, n_samples, - n_samples_bootstrap) + indices = _generate_sample_indices( + tree.random_state, n_samples, n_samples_bootstrap, + balanced_bootstrap, y + ) + sample_counts = np.bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts @@ -374,6 +443,11 @@ def fit(self, X, y, sample_weight=None): random_state=random_state) for i in range(n_more_estimators)] + if self.class_weight == 'balanced_bootstrap': + self._balanced_bootstrap = True + else: + self._balanced_bootstrap = False + # Parallel loop: we prefer the threading backend as the Cython code # for fitting the trees is internally releasing the Python GIL # making threading more efficient than multiprocessing in @@ -385,7 +459,8 @@ def fit(self, X, y, sample_weight=None): delayed(_parallel_build_trees)( t, self, X, y, sample_weight, i, len(trees), verbose=self.verbose, class_weight=self.class_weight, - n_samples_bootstrap=n_samples_bootstrap) + n_samples_bootstrap=n_samples_bootstrap, + balanced_bootstrap=self._balanced_bootstrap) for i, t in enumerate(trees)) # Collect newly grown trees @@ -522,7 +597,8 @@ def _set_oob_score(self, X, y): for estimator in self.estimators_: unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_samples_bootstrap) + estimator.random_state, n_samples, n_samples_bootstrap, + self._balanced_bootstrap, y) p_estimator = estimator.predict_proba(X[unsampled_indices, :], check_input=False) @@ -554,6 +630,10 @@ def _set_oob_score(self, X, y): def _validate_y_class_weight(self, y): check_classification_targets(y) + class_weight = self.class_weight + if self.class_weight == 'balanced_bootstrap': + class_weight = None + y = np.copy(y) expanded_class_weight = None @@ -572,22 +652,25 @@ def _validate_y_class_weight(self, y): y = y_store_unique_indices if self.class_weight is not None: - valid_presets = ('balanced', 'balanced_subsample') + valid_presets = ('balanced', + 'balanced_subsample', + 'balanced_bootstrap') if isinstance(self.class_weight, str): if self.class_weight not in valid_presets: raise ValueError('Valid presets for class_weight include ' - '"balanced" and "balanced_subsample".' + '"balanced", "balanced_subsample", ' + 'and "balanced_bootstrap". ' 'Given "%s".' - % self.class_weight) + % class_weight) if self.warm_start: - warn('class_weight presets "balanced" or ' - '"balanced_subsample" are ' + warn('class_weight presets "balanced", ' + '"balanced_subsample", or "balanced_bootstrap" are ' 'not recommended for warm_start if the fitted data ' 'differs from the full dataset. In order to use ' - '"balanced" weights, use compute_class_weight ' - '("balanced", classes, y). In place of y you can use ' - 'a large enough sample of the full training set ' - 'target to properly estimate the class frequency ' + '"balanced" weights, use compute_class_weight(' + '"balanced", classes, y). In place of y you can use a' + 'large enough sample of the full training set target ' + 'to properly estimate the class frequency ' 'distributions. Pass the resulting weights as the ' 'class_weight parameter.') @@ -597,8 +680,9 @@ def _validate_y_class_weight(self, y): class_weight = "balanced" else: class_weight = self.class_weight - expanded_class_weight = compute_sample_weight(class_weight, - y_original) + if self.class_weight != 'balanced_bootstrap': + expanded_class_weight = compute_sample_weight(class_weight, + y_original) return y, expanded_class_weight @@ -815,7 +899,8 @@ def _set_oob_score(self, X, y): for estimator in self.estimators_: unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_samples_bootstrap) + estimator.random_state, n_samples, n_samples_bootstrap, + self._balanced_bootstrap, y) p_estimator = estimator.predict( X[unsampled_indices, :], check_input=False) @@ -990,8 +1075,8 @@ class RandomForestClassifier(ForestClassifier): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. - class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \ - default=None + class_weight : {"balanced", "balanced_subsample", "balanced_bootstrap"}, \ + dict or list of dicts, default=None Weights associated with classes in the form ``{class_label: weight}``. If not given, all classes are supposed to have weight one. For multi-output problems, a list of dicts can be provided in the same @@ -1011,11 +1096,19 @@ class RandomForestClassifier(ForestClassifier): weights are computed based on the bootstrap sample for every tree grown. + The "balanced_bootstrap" triggers the Balanced Random Forest [2]_. + Instead of down-weighting majority class(es) it undersamples them. + In this case multi-output is not supported. You can find more + information in the :ref:`User Guide `. + For multi-output, the weights of each column of y will be multiplied. Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + .. versionchanged:: 0.23 + The option `"balanced_bootstrap"` was added in 0.23 + ccp_alpha : non-negative float, default=0.0 Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than @@ -1114,6 +1207,9 @@ class labels (multi-output problem). .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. + .. [2] Chen, C., Liaw, A., Breiman, L. (2004) "Using Random Forest to + Learn Imbalanced Data", Tech. Rep. 666, 2004 + See Also -------- DecisionTreeClassifier, ExtraTreesClassifier @@ -1600,8 +1696,8 @@ class ExtraTreesClassifier(ForestClassifier): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. - class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \ - default=None + class_weight : {"balanced", "balanced_subsample", "balanced_subsample"}, \ + dict or list of dicts, default=None Weights associated with classes in the form ``{class_label: weight}``. If not given, all classes are supposed to have weight one. For multi-output problems, a list of dicts can be provided in the same @@ -1626,6 +1722,9 @@ class ExtraTreesClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + .. versionchanged:: 0.23 + The option `"balanced_bootstrap"` was added in 0.23 + ccp_alpha : non-negative float, default=0.0 Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 9164751bac256..82cb78974c14c 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -39,12 +39,15 @@ from sklearn import datasets from sklearn.decomposition import TruncatedSVD from sklearn.datasets import make_classification +from sklearn.datasets import make_multilabel_classification from sklearn.ensemble import ExtraTreesClassifier from sklearn.ensemble import ExtraTreesRegressor from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestRegressor from sklearn.ensemble import RandomTreesEmbedding +from sklearn.metrics import balanced_accuracy_score from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import train_test_split from sklearn.svm import LinearSVC from sklearn.utils.validation import check_random_state from sklearn.utils.fixes import comb @@ -1307,6 +1310,70 @@ def test_forest_degenerate_feature_importances(): np.zeros(10, dtype=np.float64)) +def test_forest_balanced_bootstrap_not_implemented(): + # check that an error is raised for unknown target + X, y = make_multilabel_classification(random_state=0) + with pytest.raises(NotImplementedError, match="Balanced random-forest"): + RandomForestClassifier(class_weight="balanced_bootstrap").fit(X, y) + + +def test_forest_class_weight_balanced_bootstrap(): + # check the implementation of balanced random forest + X, y = datasets.make_classification( + n_samples=1000, weights=[0.95, 0.05], random_state=0 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=0 + ) + + rf = RandomForestClassifier(n_estimators=10, random_state=0) + brf = RandomForestClassifier( + n_estimators=10, random_state=0, class_weight="balanced_bootstrap" + ) + + rf.fit(X_train, y_train) + brf.fit(X_train, y_train) + + # The random-forest will be affected by the imbalanced classes while the + # balanced random-forest will be able to fit a proper model + balanced_acc_rf = balanced_accuracy_score(y_test, rf.predict(X_test)) + balanced_acc_brf = balanced_accuracy_score(y_test, brf.predict(X_test)) + assert balanced_acc_brf > balanced_acc_rf + + # the balanced random-forest will be trained with much less samples than + # the random-forest. Thus, the number of nodes will be much less. + for brf_est, rf_est in zip(brf.estimators_, rf.estimators_): + assert brf_est.tree_.node_count < rf_est.tree_.node_count + + +def test_forest_balanced_bootstrap_max_samples(): + # check that we take the minimum between max_samples and the minimum + # class_counts + X, y = datasets.make_classification( + n_samples=1000, weights=[0.8, 0.2], random_state=0 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=0 + ) + + brf = RandomForestClassifier( + n_estimators=10, random_state=0, class_weight="balanced_bootstrap" + ) + brf_max_samples = RandomForestClassifier( + n_estimators=10, random_state=0, class_weight="balanced_bootstrap", + max_samples=60 + ) + + brf.fit(X_train, y_train) + brf_max_samples.fit(X_train, y_train) + + # the random forest with max_samples will have less training samples and + # therefore less number of nodes + for brf_est_max_samples, brf_est in zip(brf_max_samples.estimators_, + brf.estimators_): + assert brf_est_max_samples.tree_.node_count < brf_est.tree_.node_count + + @pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS) @pytest.mark.parametrize( 'max_samples, exc_type, exc_msg',