Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c1cc67d

Browse filesBrowse files
avidaleDavid Dalelorentzenchr
authored
FEA Add QuantileRegressor estimator (#9978)
Co-authored-by: David Dale <ddale@yandex-team.ru> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
1 parent 88be3c1 commit c1cc67d
Copy full SHA for c1cc67d

File tree

7 files changed

+729
-0
lines changed
Filter options

7 files changed

+729
-0
lines changed

‎doc/modules/classes.rst

Copy file name to clipboardExpand all lines: doc/modules/classes.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ Any estimator using the Huber loss would also be robust to outliers, e.g.
839839
:template: class.rst
840840

841841
linear_model.HuberRegressor
842+
linear_model.QuantileRegressor
842843
linear_model.RANSACRegressor
843844
linear_model.TheilSenRegressor
844845

‎doc/modules/linear_model.rst

Copy file name to clipboardExpand all lines: doc/modules/linear_model.rst
+77Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,83 @@ Note that this estimator is different from the R implementation of Robust Regres
14231423
squares implementation with weights given to each sample on the basis of how much the residual is
14241424
greater than a certain threshold.
14251425

1426+
.. _quantile_regression:
1427+
1428+
Quantile Regression
1429+
===================
1430+
1431+
Quantile regression estimates the median or other quantiles of :math:`y`
1432+
conditional on :math:`X`, while ordinary least squares (OLS) estimates the
1433+
conditional mean.
1434+
1435+
As a linear model, the :class:`QuantileRegressor` gives linear predictions
1436+
:math:`\hat{y}(w, X) = Xw` for the :math:`q`-th quantile, :math:`q \in (0, 1)`.
1437+
The weights or coefficients :math:`w` are then found by the following
1438+
minimization problem:
1439+
1440+
.. math::
1441+
\min_{w} {\frac{1}{n_{\text{samples}}}
1442+
\sum_i PB_q(y_i - X_i w) + \alpha ||w||_1}.
1443+
1444+
This consists of the pinball loss (also known as linear loss),
1445+
see also :class:`~sklearn.metrics.mean_pinball_loss`,
1446+
1447+
.. math::
1448+
PB_q(t) = q \max(t, 0) + (1 - q) \max(-t, 0) =
1449+
\begin{cases}
1450+
q t, & t > 0, \\
1451+
0, & t = 0, \\
1452+
(1-q) t, & t < 0
1453+
\end{cases}
1454+
1455+
and the L1 penalty controlled by parameter ``alpha``, similar to
1456+
:class:`Lasso`.
1457+
1458+
As the pinball loss is only linear in the residuals, quantile regression is
1459+
much more robust to outliers than squared error based estimation of the mean.
1460+
Somewhat in between is the :class:`HuberRegressor`.
1461+
1462+
Quantile regression may be useful if one is interested in predicting an
1463+
interval instead of point prediction. Sometimes, prediction intervals are
1464+
calculated based on the assumption that prediction error is distributed
1465+
normally with zero mean and constant variance. Quantile regression provides
1466+
sensible prediction intervals even for errors with non-constant (but
1467+
predictable) variance or non-normal distribution.
1468+
1469+
.. figure:: /auto_examples/linear_model/images/sphx_glr_plot_quantile_regression_001.png
1470+
:target: ../auto_examples/linear_model/plot_quantile_regression.html
1471+
:align: center
1472+
:scale: 50%
1473+
1474+
Based on minimizing the pinball loss, conditional quantiles can also be
1475+
estimated by models other than linear models. For example,
1476+
:class:`~sklearn.ensemble.GradientBoostingRegressor` can predict conditional
1477+
quantiles if its parameter ``loss`` is set to ``"quantile"`` and parameter
1478+
``alpha`` is set to the quantile that should be predicted. See the example in
1479+
:ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_quantile.py`.
1480+
1481+
Most implementations of quantile regression are based on linear programming
1482+
problem. The current implementation is based on
1483+
:func:`scipy.optimize.linprog`.
1484+
1485+
.. topic:: Examples:
1486+
1487+
* :ref:`sphx_glr_auto_examples_linear_model_plot_quantile_regression.py`
1488+
1489+
.. topic:: References:
1490+
1491+
* Koenker, R., & Bassett Jr, G. (1978). `Regression quantiles.
1492+
<https://gib.people.uic.edu/RQ.pdf>`_
1493+
Econometrica: journal of the Econometric Society, 33-50.
1494+
1495+
* Portnoy, S., & Koenker, R. (1997). The Gaussian hare and the Laplacian
1496+
tortoise: computability of squared-error versus absolute-error estimators.
1497+
Statistical Science, 12, 279-300. https://doi.org/10.1214/ss/1030037960
1498+
1499+
* Koenker, R. (2005). Quantile Regression.
1500+
Cambridge University Press. https://doi.org/10.1017/CBO9780511754098
1501+
1502+
14261503
.. _polynomial_regression:
14271504

14281505
Polynomial regression: extending linear models with basis functions

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ Changelog
282282
:mod:`sklearn.linear_model`
283283
...........................
284284

285+
- |Feature| Added :class:`linear_model.QuantileRegressor` which implements
286+
linear quantile regression with L1 penalty.
287+
:pr:`9978` by :user:`David Dale <avidale>` and
288+
:user:`Christian Lorentzen <lorentzenchr>`.
289+
285290
- |Feature| The new :class:`linear_model.SGDOneClassSVM` provides an SGD
286291
implementation of the linear One-Class SVM. Combined with kernel
287292
approximation techniques, this implementation approximates the solution of
+110Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
===================
3+
Quantile regression
4+
===================
5+
This example illustrates how quantile regression can predict non-trivial
6+
conditional quantiles.
7+
8+
The left figure shows the case when the error distribution is normal,
9+
but has non-constant variance, i.e. with heteroscedasticity.
10+
11+
The right figure shows an example of an asymmetric error distribution,
12+
namely the Pareto distribution.
13+
"""
14+
print(__doc__)
15+
# Authors: David Dale <dale.david@mail.ru>
16+
# Christian Lorentzen <lorentzen.ch@gmail.com>
17+
# License: BSD 3 clause
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
21+
from sklearn.linear_model import QuantileRegressor, LinearRegression
22+
from sklearn.metrics import mean_absolute_error, mean_squared_error
23+
from sklearn.model_selection import cross_val_score
24+
25+
26+
def plot_points_highlighted(x, y, model_low, model_high, ax):
27+
"""Plot points with highlighting."""
28+
mask = y <= model_low.predict(X)
29+
ax.scatter(x[mask], y[mask], c="k", marker="x")
30+
mask = y > model_high.predict(X)
31+
ax.scatter(x[mask], y[mask], c="k", marker="x")
32+
mask = (y > model_low.predict(X)) & (y <= model_high.predict(X))
33+
ax.scatter(x[mask], y[mask], c="k")
34+
35+
36+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
37+
38+
rng = np.random.RandomState(42)
39+
x = np.linspace(0, 10, 100)
40+
X = x[:, np.newaxis]
41+
y = 10 + 0.5 * x + rng.normal(loc=0, scale=0.5 + 0.5 * x, size=x.shape[0])
42+
y_mean = 10 + 0.5 * x
43+
ax1.plot(x, y_mean, "k--")
44+
45+
quantiles = [0.05, 0.5, 0.95]
46+
models = []
47+
for quantile in quantiles:
48+
qr = QuantileRegressor(quantile=quantile, alpha=0)
49+
qr.fit(X, y)
50+
ax1.plot(x, qr.predict(X))
51+
models.append(qr)
52+
53+
plot_points_highlighted(x, y, models[0], models[2], ax1)
54+
ax1.set_xlabel("x")
55+
ax1.set_ylabel("y")
56+
ax1.set_title("Quantiles of heteroscedastic Normal distributed target")
57+
ax1.legend(["true mean"] + quantiles)
58+
59+
60+
a = 5
61+
y = 10 + 0.5 * x + 10 * (rng.pareto(a, size=x.shape[0]) - 1 / (a - 1))
62+
ax2.plot(x, y_mean, "k--")
63+
64+
models = []
65+
for quantile in quantiles:
66+
qr = QuantileRegressor(quantile=quantile, alpha=0)
67+
qr.fit(X, y)
68+
ax2.plot([0, 10], qr.predict([[0], [10]]))
69+
models.append(qr)
70+
71+
plot_points_highlighted(x, y, models[0], models[2], ax2)
72+
ax2.set_xlabel("x")
73+
ax2.set_ylabel("y")
74+
ax2.set_title("Quantiles of asymmetric Pareto distributed target")
75+
ax2.legend(["true mean"] + quantiles, loc="lower right")
76+
ax2.yaxis.set_tick_params(labelbottom=True)
77+
78+
plt.show()
79+
80+
# %%
81+
# Note that both targets have the same mean value, indicated by the dashed
82+
# black line. As the Normal distribution is symmetric, mean and median are
83+
# identical and the predicted 0.5 quantile almost hits the true mean.
84+
# In the Pareto case, the difference between predicted median and true mean
85+
# is evident. We also marked the points below the 0.05 and above 0.95
86+
# predicted quantiles by small crosses. You might count them and consider
87+
# that we have 100 samples in total.
88+
#
89+
# The second part of the example shows that LinearRegression minimizes MSE
90+
# in order to predict the mean, while QuantileRegressor with `quantile=0.5`
91+
# minimizes MAE in order to predict the median. Both do their own job well.
92+
93+
models = [LinearRegression(), QuantileRegressor(alpha=0)]
94+
names = ["OLS", "Quantile"]
95+
96+
print("# In-sample performance")
97+
for model_name, model in zip(names, models):
98+
print(model_name + ":")
99+
model.fit(X, y)
100+
mae = mean_absolute_error(model.predict(X), y)
101+
rmse = np.sqrt(mean_squared_error(model.predict(X), y))
102+
print(f"MAE = {mae:.4} RMSE = {rmse:.4}")
103+
print("\n# Cross-validated performance")
104+
for model_name, model in zip(names, models):
105+
print(model_name + ":")
106+
mae = -cross_val_score(model, X, y, cv=3,
107+
scoring="neg_mean_absolute_error").mean()
108+
rmse = np.sqrt(-cross_val_score(model, X, y, cv=3,
109+
scoring="neg_mean_squared_error").mean())
110+
print(f"MAE = {mae:.4} RMSE = {rmse:.4}")

‎sklearn/linear_model/__init__.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/__init__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ._passive_aggressive import PassiveAggressiveRegressor
2929
from ._perceptron import Perceptron
3030

31+
from ._quantile import QuantileRegressor
3132
from ._ransac import RANSACRegressor
3233
from ._theil_sen import TheilSenRegressor
3334

@@ -59,6 +60,7 @@
5960
'PassiveAggressiveClassifier',
6061
'PassiveAggressiveRegressor',
6162
'Perceptron',
63+
'QuantileRegressor',
6264
'Ridge',
6365
'RidgeCV',
6466
'RidgeClassifier',

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.