Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 470763e

Browse filesBrowse files
committed
Merge branch 'master' of https://github.com/ddbourgin/numpy-ml into master
2 parents 84f65b9 + 7c210a6 commit 470763e
Copy full SHA for 470763e

File tree

Expand file treeCollapse file tree

3 files changed

+168
-18
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+168
-18
lines changed

‎numpy_ml/linear_models/README.md

Copy file name to clipboardExpand all lines: numpy_ml/linear_models/README.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Linear Models
22
The `lm.py` module implements:
33

4-
1. [OLS linear regression](https://en.wikipedia.org/wiki/Ordinary_least_squares) with maximum likelihood parameter estimates via the normal equation.
4+
1. [OLS linear regression](https://en.wikipedia.org/wiki/Ordinary_least_squares) with maximum likelihood parameter estimates via the normal equation. For both (Online and Batch mode)
55
2. [Ridge regression / Tikhonov regularization](https://en.wikipedia.org/wiki/Tikhonov_regularization)
66
with maximum likelihood parameter estimates via the normal equation.
77
2. [Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) with maximum likelihood parameter estimates via gradient descent.

‎numpy_ml/linear_models/lm.py

Copy file name to clipboardExpand all lines: numpy_ml/linear_models/lm.py
+109-17Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,113 @@ def __init__(self, fit_intercept=True):
1111
1212
Notes
1313
-----
14-
Given data matrix *X* and target vector *y*, the maximum-likelihood estimate
15-
for the regression coefficients, :math:`\\beta`, is:
14+
Given data matrix **X** and target vector **y**, the maximum-likelihood
15+
estimate for the regression coefficients, :math:`\beta`, is:
1616
1717
.. math::
1818
19-
\hat{\beta} =
20-
\left(\mathbf{X}^\top \mathbf{X}\right)^{-1} \mathbf{X}^\top \mathbf{y}
19+
\hat{\beta} = \Sigma^{-1} \mathbf{X}^\top \mathbf{y}
20+
21+
where :math:`\Sigma^{-1} = (\mathbf{X}^\top \mathbf{X})^{-1}`.
2122
2223
Parameters
2324
----------
2425
fit_intercept : bool
25-
Whether to fit an additional intercept term in addition to the
26-
model coefficients. Default is True.
26+
Whether to fit an intercept term in addition to the model
27+
coefficients. Default is True.
2728
"""
2829
self.beta = None
30+
self.sigma_inv = None
2931
self.fit_intercept = fit_intercept
3032

33+
self._is_fit = False
34+
35+
def update(self, X, y):
36+
r"""
37+
Incrementally update the least-squares coefficients for a set of new
38+
examples.
39+
40+
Notes
41+
-----
42+
The recursive least-squares algorithm [1]_ [2]_ is used to efficiently
43+
update the regression parameters as new examples become available. For
44+
a single new example :math:`(\mathbf{x}_{t+1}, \mathbf{y}_{t+1})`, the
45+
parameter updates are
46+
47+
.. math::
48+
49+
\beta_{t+1} = \left(
50+
\mathbf{X}_{1:t}^\top \mathbf{X}_{1:t} +
51+
\mathbf{x}_{t+1}\mathbf{x}_{t+1}^\top \right)^{-1}
52+
\mathbf{X}_{1:t}^\top \mathbf{Y}_{1:t} +
53+
\mathbf{x}_{t+1}^\top \mathbf{y}_{t+1}
54+
55+
where :math:`\beta_{t+1}` are the updated regression coefficients,
56+
:math:`\mathbf{X}_{1:t}` and :math:`\mathbf{Y}_{1:t}` are the set of
57+
examples observed from timestep 1 to *t*.
58+
59+
In the single-example case, the RLS algorithm uses the Sherman-Morrison
60+
formula [3]_ to avoid re-inverting the covariance matrix on each new
61+
update. In the multi-example case (i.e., where :math:`\mathbf{X}_{t+1}`
62+
and :math:`\mathbf{y}_{t+1}` are matrices of `N` examples each), we use
63+
the generalized Woodbury matrix identity [4]_ to update the inverse
64+
covariance. This comes at a performance cost, but is still more
65+
performant than doing multiple single-example updates if *N* is large.
66+
67+
References
68+
----------
69+
.. [1] Gauss, C. F. (1821) _Theoria combinationis observationum
70+
erroribus minimis obnoxiae_, Werke, 4. Gottinge
71+
.. [2] https://en.wikipedia.org/wiki/Recursive_least_squares_filter
72+
.. [3] https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula
73+
.. [4] https://en.wikipedia.org/wiki/Woodbury_matrix_identity
74+
75+
Parameters
76+
----------
77+
X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)`
78+
A dataset consisting of `N` examples, each of dimension `M`
79+
y : :py:class:`ndarray <numpy.ndarray>` of shape `(N, K)`
80+
The targets for each of the `N` examples in `X`, where each target
81+
has dimension `K`
82+
"""
83+
if not self._is_fit:
84+
raise RuntimeError("You must call the `fit` method before calling `update`")
85+
86+
X, y = np.atleast_2d(X), np.atleast_2d(y)
87+
88+
X1, Y1 = X.shape[0], y.shape[0]
89+
self._update1D(X, y) if X1 == Y1 == 1 else self._update2D(X, y)
90+
91+
def _update1D(self, x, y):
92+
"""Sherman-Morrison update for a single example"""
93+
beta, S_inv = self.beta, self.sigma_inv
94+
95+
# convert x to a design vector if we're fitting an intercept
96+
if self.fit_intercept:
97+
x = np.c_[1, x]
98+
99+
# update the inverse of the covariance matrix via Sherman-Morrison
100+
S_inv -= (S_inv @ x.T @ x @ S_inv) / (1 + x @ S_inv @ x.T)
101+
102+
# update the model coefficients
103+
beta += S_inv @ x.T @ (y - x @ beta)
104+
105+
def _update2D(self, X, y):
106+
"""Woodbury update for multiple examples"""
107+
beta, S_inv = self.beta, self.sigma_inv
108+
109+
# convert X to a design matrix if we're fitting an intercept
110+
if self.fit_intercept:
111+
X = np.c_[np.ones(X.shape[0]), X]
112+
113+
I = np.eye(X.shape[0])
114+
115+
# update the inverse of the covariance matrix via Woodbury identity
116+
S_inv -= S_inv @ X.T @ np.linalg.pinv(I + X @ S_inv @ X.T) @ X @ S_inv
117+
118+
# update the model coefficients
119+
beta += S_inv @ X.T @ (y - X @ beta)
120+
31121
def fit(self, X, y):
32122
"""
33123
Fit the regression coefficients via maximum likelihood.
@@ -44,8 +134,10 @@ def fit(self, X, y):
44134
if self.fit_intercept:
45135
X = np.c_[np.ones(X.shape[0]), X]
46136

47-
pseudo_inverse = np.linalg.inv(X.T @ X) @ X.T
48-
self.beta = np.dot(pseudo_inverse, y)
137+
self.sigma_inv = np.linalg.pinv(X.T @ X)
138+
self.beta = np.atleast_2d(self.sigma_inv @ X.T @ y)
139+
140+
self._is_fit = True
49141

50142
def predict(self, X):
51143
"""
@@ -166,22 +258,22 @@ def __init__(self, penalty="l2", gamma=0, fit_intercept=True):
166258
\left(
167259
\sum_{i=0}^N y_i \log(\hat{y}_i) +
168260
(1-y_i) \log(1-\hat{y}_i)
169-
\right) - R(\mathbf{b}, \gamma)
261+
\right) - R(\mathbf{b}, \gamma)
170262
\right]
171-
263+
172264
where
173-
265+
174266
.. math::
175-
267+
176268
R(\mathbf{b}, \gamma) = \left\{
177269
\begin{array}{lr}
178270
\frac{\gamma}{2} ||\mathbf{beta}||_2^2 & :\texttt{ penalty = 'l2'}\\
179271
\gamma ||\beta||_1 & :\texttt{ penalty = 'l1'}
180272
\end{array}
181273
\right.
182-
183-
is a regularization penalty, :math:`\gamma` is a regularization weight,
184-
`N` is the number of examples in **y**, and **b** is the vector of model
274+
275+
is a regularization penalty, :math:`\gamma` is a regularization weight,
276+
`N` is the number of examples in **y**, and **b** is the vector of model
185277
coefficients.
186278
187279
Parameters
@@ -251,10 +343,10 @@ def _NLL(self, X, y, y_pred):
251343
\right]
252344
"""
253345
N, M = X.shape
254-
beta, gamma = self.beta, self.gamma
346+
beta, gamma = self.beta, self.gamma
255347
order = 2 if self.penalty == "l2" else 1
256348
norm_beta = np.linalg.norm(beta, ord=order)
257-
349+
258350
nll = -np.log(y_pred[y == 1]).sum() - np.log(1 - y_pred[y == 0]).sum()
259351
penalty = (gamma / 2) * norm_beta ** 2 if order == 2 else gamma * norm_beta
260352
return (penalty + nll) / N
+58Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# flake8: noqa
2+
import numpy as np
3+
4+
from sklearn.linear_model import LinearRegression as LinearRegressionGold
5+
6+
from numpy_ml.linear_models.lm import LinearRegression
7+
from numpy_ml.utils.testing import random_tensor
8+
9+
10+
def test_linear_regression(N=10):
11+
np.random.seed(12345)
12+
N = np.inf if N is None else N
13+
14+
i = 1
15+
while i < N + 1:
16+
train_samples = np.random.randint(2, 30)
17+
update_samples = np.random.randint(1, 30)
18+
n_samples = train_samples + update_samples
19+
20+
# ensure n_feats < train_samples, otherwise multiple solutions are
21+
# possible
22+
n_feats = np.random.randint(1, train_samples)
23+
target_dim = np.random.randint(1, 10)
24+
25+
fit_intercept = np.random.choice([True, False])
26+
27+
X = random_tensor((n_samples, n_feats), standardize=True)
28+
y = random_tensor((n_samples, target_dim), standardize=True)
29+
30+
X_train, X_update = X[:train_samples], X[train_samples:]
31+
y_train, y_update = y[:train_samples], y[train_samples:]
32+
33+
# Fit gold standard model on the entire dataset
34+
lr_gold = LinearRegressionGold(fit_intercept=fit_intercept, normalize=False)
35+
lr_gold.fit(X, y)
36+
37+
# Fit our model on just (X_train, y_train)...
38+
lr = LinearRegression(fit_intercept=fit_intercept)
39+
lr.fit(X_train, y_train)
40+
41+
do_single_sample_update = np.random.choice([True, False])
42+
43+
# ...then update our model on the examples (X_update, y_update)
44+
if do_single_sample_update:
45+
for x_new, y_new in zip(X_update, y_update):
46+
lr.update(x_new, y_new)
47+
else:
48+
lr.update(X_update, y_update)
49+
50+
# check that model predictions match
51+
np.testing.assert_almost_equal(lr.predict(X), lr_gold.predict(X), decimal=5)
52+
53+
# check that model coefficients match
54+
beta = lr.beta.T[:, 1:] if fit_intercept else lr.beta.T
55+
np.testing.assert_almost_equal(beta, lr_gold.coef_, decimal=6)
56+
57+
print("\tPASSED")
58+
i += 1

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.