Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

WIP: Adding Passive Aggressive learning rates #1259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 5, 2012

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Oct 21, 2012

I have added Passive Aggressive learning rates as described in
http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf

I have tried to make as few changes as possible, so these rates should integrate well with everything else.

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

Hi Rob, could you please update the documentation to mention those learning rates (both in docstrings and narrative documentation of the SGD module)? Don't forget to update the reference section to add a link the crammer06a paper.

Also have you tried to use those in practice? What do they bring? What was your motivation to implement those?

eta = 1.0 / sqnorm(x_data_ptr, x_ind_ptr, xnnz)
eta = min(alpha/loss.dloss(p,y), eta)
elif learning_rate == PA2:
eta = 1.0 / (sqnorm(x_data_ptr, x_ind_ptr, xnnz) + 1.0/2*alpha)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pep8: 1.0 / 2 * alpha which should even be simplified as 0.5 * alpha

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

Let me add the documentation. My motivation is completeness, this is frequently the baseline other SGD classifiers try to beat, and partly since passive-aggressive algorithms seem to give smoother updates of the weights. What would be a good test to show how well this performs?

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

Let me add the documentation.

Thanks, please also extend the existing unittests for the sklearn.linear_model.stochastic_gradient module to check the convergence of those learning rates strategies on toy data.

since passive-aggressive algorithms seem to give smoother updates of the weights. What would be a good test to show how well this performs?

I don't know, you tell me :) Such demonstration would ideally be done as new example or extending an existing example with one of the default datasets.

Maybe you could plot the online regret (cumulated number of misclassifications) using the partial_fit API using the digits dataset of scikit-learn to try to approximately reproduce figure 6 of the paper? You could also include other SGD learning rates on that plot (setting other parameters to their optimal cross validated values).

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

It's a bit of a drag that we can't access the regret information from SGD. As a meta-level discussion, do you get the sense that SGD has been a bit monolithic in structure which will make it harder over time to add SGD learners?

@@ -358,6 +358,9 @@ user via ``eta0`` and ``power_t``, resp.
For a constant learning rate use ``learning_rate='constant'`` and use ``eta0``
to specify the learning rate.

For passive-aggressive learning rate use ``learning_rate='pa'``, ``learning_rate='pa1'``
or ``learning_rate='pa2'``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should explain what passive aggressive learning rate mean in practice here and also give the mathematical formulation of those losses in:

http://scikit-learn.org/dev/modules/sgd.html#id1

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

The monolithic design of the current implementation is primarily motivated by the speed constraints. Every bit of flexibility we add in the main SGD loop has a cost. However we will add a callback based monitoring / checkpointing API in the scikit at some point.

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

Also in writing this explanation I am wondering if I should remove "PA" since PA-1 and PA-2 reduce to it when alpha = 0.

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

+1 for removing PA and explaining how vanilla PA can be obtained from "PA-1" or "PA-2".

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

Thanks for the explanation this is much clearer now. Do you plan to do the online regret plot using the partial_fit method using small chunks (e.g. 10 samples at a time) on the digits data? Please put a WIP: prefix in the pull request title to state that this is not yet ready for merging.

Also more importantly, this PR needs to add unit tests (extending the existing tests for the stochastic_gradient module) for the new learning rate options. Smoke tests are good enough for testing learning rates. We don't want to have long running tests in the test suite.

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

I am considering doing a test-error test on 5, 10, 25, 50, 100 percent of training data. I am worried that doing a partial_fit on small chunks won't work well enough.

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

The regular fit on a large data array with n_iter=1 and partial_fit on the same row-chunked array should yield the same results whatever the size of the chunks.

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

Also, where are the stochastic_gradient tests?

@mblondel
Copy link
Member

I think the idea is nice but I have two concerns:

  • The regularizer in passive-aggressive should be called C, not alpha. The bigger C the less regularized. This is the converse for alpha.
  • Whether or not this theoretically works with other losses than the hinge loss (classification) and epsilon-insensitive loss (regression). I haven't looked into the regret bounds but something may not hold if we change the loss.

For those reasons, we may want to expose the passive-aggressive algorithm only in its own estimators (PassiveAggressiveClassifier and PassiveAggressiveRegressor) while keeping the cython code as it is now.

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

@mblondel I feel like C and alpha do very similar things. PA should be defined in terms of alpha.

Giving it its own estimators is fine, especially as I am going to be pushing AROW, AdaGrad, and Averaged SGD/Perceptron soon.

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

In the scikit-learn project: C means the opposite of alpha: the larger C is, the less the model gets regularized while the opposite holds for alpha.

The C parameter is used in the following models: SVC, LinearSVC, LogisticRegression:

http://scikit-learn.org/dev/modules/svm.html#mathematical-formulation

The alpha parameter is used as a constructor parameter for instance in models: SGDClassifier, ElasticNet, LassoLars.

http://scikit-learn.org/dev/modules/linear_model.html#lasso
http://scikit-learn.org/dev/modules/linear_model.html#elastic-net
http://scikit-learn.org/dev/modules/sgd.html#mathematical-formulation

@@ -358,6 +358,23 @@ user via ``eta0`` and ``power_t``, resp.
For a constant learning rate use ``learning_rate='constant'`` and use ``eta0``
to specify the learning rate.

For passive-aggressive algorithms, there is no learning rate, instead
step-size is taken as large as would guarantee the example would have
been correctly classified. In practice, this only works on seperable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separable

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2012

As the alpha of PA does not have the same meaning as in the other regularized models of scikit-learn I think we should give it another name. For instance pa_regularization in the cython code and regularization in the PassiveAggressiveClassifier constructor. WDYT?

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 21, 2012

I think C is probably the right variable name. The larger question is whether it should be using the same cython code as a base.

@mblondel
Copy link
Member

If we create PassiveAggressiveClassifier, there is no reason not to call it C. It's called that way in the paper and it has the same meaning as in SVC.

You can pass both alpha and C to the underlying cython routine (only either one will be used depending on the selected learning rate).

Inheriting from SGDClassifier or another base class could still be nice to get support for one-vs-rest, partial fit and warm-start.

@mblondel
Copy link
Member

Giving it its own estimators is fine, especially as I am going to be pushing AROW, AdaGrad, and Averaged SGD/Perceptron soon.

You may want to ask people's opinion on the mailing-list for those ones. Averaged SGD/Perceptron seems fine to me but the other two may be too recent (only 30 to 40 citations).

PA-2 respectively. Setting ``C`` to 0 gives the vanilla passive-
aggressive algorithm.

..math::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not display right in the docs. I believe that there is a missing white space between '..' and 'math'.

@GaelVaroquaux
Copy link
Member

I have made a bunch of minor comments. I cannot comment on the example yet, as it doesn't run for me.

A few other minor remarks:

  • Could you please add the 2 estimators contributed in doc/modules/classes.rst, as right now the classes entry in the documentation cannot link to the reference documentation.
  • Something that I do not understand: a parameter 'C' has been added. However, as far as I can tell, it is useful only for passive-aggressive objects. It is not documented in the other classes. What I do not understand is why this parameter is given in the constructor of the other classes. I find it confusing. Maybe I am just confused :)

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2012

A couple of remarks:

  • if C is indeed the same as for linear SVC, then the default value should set to 1.0 to be consistent. However I am not yet sure that LinearSVC(C=1.0, loss='l1') converges to the same solution as PassiveAggressiveClassifier(C=1.0, loss='pa1') (and same remark holds for loss='l2' / lost='pa2').
  • if this above equivalence should theoretically be true, then we add some more tests in a new test_passive_aggressive.py to check that assertion and maybe debug the equivalence.
  • if @mblondel's explanation about the PA optimization objective functions is correct (I have not yet found the time to read the PA paper) then it should be put in the documentation
  • there is something that breaks clone (hence cross validation and grid search) in the current state of the implementation:
>>> from sklearn.linear_model import PassiveAggressiveClassifier
>>> from sklearn.base import clone
>>> clone(PassiveAggressiveClassifier())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/ogrisel/coding/scikit-learn/sklearn/base.py", line 47, in clone
    new_object = klass(**new_object_params)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 94, in __init__
    n_jobs=n_jobs)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 384, in __init__
    warm_start=warm_start)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 73, in __init__
    self._validate_params()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 104, in _validate_params
    self._get_learning_rate_type(self.learning_rate)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 139, in _get_learning_rate_type
    "is not supported. " % learning_rate)
ValueError: learning rate hingeis not supported. 
  • some other test_common checks fail too (to run the complete test suite, run make from the top level folder):
======================================================================
ERROR: sklearn.feature_selection.tests.test_selector_mixin.test_transform_linear_model
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/feature_selection/tests/test_selector_mixin.py", line 19, in test_transform_linear_model
    SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, seed=0)):
TypeError: __init__() got an unexpected keyword argument 'seed'

======================================================================
ERROR: sklearn.feature_selection.tests.test_selector_mixin.test_invalid_input
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/feature_selection/tests/test_selector_mixin.py", line 37, in test_invalid_input
    clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, seed=0)
TypeError: __init__() got an unexpected keyword argument 'seed'

======================================================================
ERROR: sklearn.linear_model.tests.test_perceptron.test_perceptron_accuracy
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/tests/test_perceptron.py", line 47, in test_perceptron_accuracy
    clf = Perceptron(n_iter=30, shuffle=False, seed=0)
TypeError: __init__() got an unexpected keyword argument 'seed'

======================================================================
ERROR: sklearn.tests.test_common.test_all_estimators
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/tests/test_common.py", line 73, in test_all_estimators
    clone(e)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/base.py", line 47, in clone
    new_object = klass(**new_object_params)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 94, in __init__
    n_jobs=n_jobs)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 384, in __init__
    warm_start=warm_start)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 73, in __init__
    self._validate_params()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 104, in _validate_params
    self._get_learning_rate_type(self.learning_rate)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 139, in _get_learning_rate_type
    "is not supported. " % learning_rate)
ValueError: learning rate hingeis not supported. 

======================================================================
ERROR: sklearn.tests.test_common.test_estimators_sparse_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/tests/test_common.py", line 94, in test_estimators_sparse_data
    clf = Clf()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 185, in __init__
    n_jobs=n_jobs)
TypeError: __init__() got an unexpected keyword argument 'class_weight'

======================================================================
ERROR: sklearn.tests.test_common.test_transformers
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/tests/test_common.py", line 133, in test_transformers
    trans = Trans()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 185, in __init__
    n_jobs=n_jobs)
TypeError: __init__() got an unexpected keyword argument 'class_weight'

======================================================================
ERROR: sklearn.tests.test_common.test_transformers_sparse_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/tests/test_common.py", line 202, in test_transformers_sparse_data
    trans = Trans()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 185, in __init__
    n_jobs=n_jobs)
TypeError: __init__() got an unexpected keyword argument 'class_weight'

======================================================================
ERROR: sklearn.tests.test_common.test_regressors_int
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/tests/test_common.py", line 436, in test_regressors_int
    reg1 = Reg()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 185, in __init__
    n_jobs=n_jobs)
TypeError: __init__() got an unexpected keyword argument 'class_weight'

======================================================================
ERROR: sklearn.tests.test_common.test_regressors_train
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/nose/case.py", line 197, in runTest
    self.test(*self.arg)
  File "/Users/ogrisel/coding/scikit-learn/sklearn/tests/test_common.py", line 476, in test_regressors_train
    reg = Reg()
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/pa.py", line 185, in __init__
    n_jobs=n_jobs)
TypeError: __init__() got an unexpected keyword argument 'class_weight'

======================================================================
FAIL: Doctest: sklearn.linear_model.stochastic_gradient.SGDClassifier
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/Cellar/python/2.7.3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/doctest.py", line 2201, in runTest
    raise self.failureException(self.format_failure(new.getvalue()))
AssertionError: Failed doctest test for sklearn.linear_model.stochastic_gradient.SGDClassifier
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 240, in SGDClassifier

----------------------------------------------------------------------
File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 345, in sklearn.linear_model.stochastic_gradient.SGDClassifier
Failed example:
    clf.fit(X, Y)
    #doctest: +NORMALIZE_WHITESPACE
Expected:
    SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
            fit_intercept=True, l1_ratio=0.15, learning_rate='optimal',
            loss='hinge', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5,
            rho=None, random_state=0, shuffle=False, verbose=0, warm_start=False)
Got:
    SGDClassifier(C=1.0, alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
           fit_intercept=True, l1_ratio=0.15, learning_rate='optimal',
           loss='hinge', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5,
           random_state=0, rho=None, shuffle=False, verbose=0,
           warm_start=False)

>>  raise self.failureException(self.format_failure(<StringIO.StringIO instance at 0x10fc40cb0>.getvalue()))


======================================================================
FAIL: Doctest: sklearn.linear_model.stochastic_gradient.SGDRegressor
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/Cellar/python/2.7.3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/doctest.py", line 2201, in runTest
    raise self.failureException(self.format_failure(new.getvalue()))
AssertionError: Failed doctest test for sklearn.linear_model.stochastic_gradient.SGDRegressor
  File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 681, in SGDRegressor

----------------------------------------------------------------------
File "/Users/ogrisel/coding/scikit-learn/sklearn/linear_model/stochastic_gradient.py", line 775, in sklearn.linear_model.stochastic_gradient.SGDRegressor
Failed example:
    clf.fit(X, y)
Expected:
    SGDRegressor(alpha=0.0001, epsilon=0.1, eta0=0.01, fit_intercept=True,
           l1_ratio=0.15, learning_rate='invscaling', loss='squared_loss',
           n_iter=5, p=None, penalty='l2', power_t=0.25, rho=None, random_state=0,
           shuffle=False, verbose=0, warm_start=False)
Got:
    SGDRegressor(C=1.0, alpha=0.0001, epsilon=0.1, eta0=0.01, fit_intercept=True,
           l1_ratio=0.15, learning_rate='invscaling', loss='squared_loss',
           n_iter=5, p=None, penalty='l2', power_t=0.25, random_state=0,
           rho=None, shuffle=False, verbose=0, warm_start=False)

>>  raise self.failureException(self.format_failure(<StringIO.StringIO instance at 0x10fc40cb0>.getvalue()))

The class_weight stuff might just be that we need to update the test suite to not check for rebalanced class support for those new estimators if they do not support it (shall this be the case or not?).

The fact that C is now a part of the public API of SGDClassifier and SGDRegressor is a mistake. C should only be exposed for PassiveAggressiveClassifier/Regressor.

Now if this is too much for you to fix @zaxtax then I can take over from here. However I have very few spare bandwidth so don't be surprised if I cannot find the time to fix the remaining issues of this PR before a couple of months :)

However I think this is a very important contrib as it seems that in the few couple of tests I did when checking this PR, PA models were always on par or beating LinearSVC in terms of cross validated score. The fact that it additionally supports the partial_fit API for large scale learning makes it even more appealing.

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2012

@mblondel I'm not sure if PA optimizes a global objective, as it solves an optimization problem based on one instance at each iteration. Setting C in PA and C / n_samples in LinearSVC might work though.

Hum alright, then calling the regularization parameter C might be misleading if it's not the same C as in the SVM literature. Let me do a couple more checks.

@GaelVaroquaux
Copy link
Member

On Sat, Oct 27, 2012 at 10:09:28AM -0700, Olivier Grisel wrote:

Hum alright, then calling the regularization parameter C might be misleading if
it's not the same C as in the SVM literature.

If that's indeed the case, than I would push for not using 'C', and using
a more explicit name.

@zaxtax
Copy link
Contributor Author

zaxtax commented Oct 27, 2012

@ogrisel I am committed to fixing all issues in this commit.

@mblondel
Copy link
Member

@ogrisel It would indeed be nice if you could check that. I would personally be ok with keeping it C in any case as it is a weight with respect to the loss rather than the norm (like alpha) and bigger means less regularized like LinearSVC.

@mblondel
Copy link
Member

Apparently, SGDRegressor and Lasso (coordinate descent) have the same issue. They both use alpha but the scale is different: SGD minizes 1 / (2n) \sum (w . x_i - y_i) ^2 + alpha * penalty while CD minimizes 1 / 2 \sum (w . x_i - y_i) ^2 + alpha * penalty.

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2012

@mblondel @ogrisel It would indeed be nice if you could check that. I would personally be ok with keeping it C in any case as it is a weight with respect to the loss rather than the norm (like alpha) and bigger means less regularized like LinearSVC.

I tried using the digits dataset and LinearSVC & PA do not seem to converge to the same coef_ vector although the scores are similar. I guess that's kind of expected as classification losses based on the hinge loss are not strictly convex and the (random) ordering of the samples has an impact on the PA optimization.

However, for loss = 'l1' / loss='pa1', setting C_pa=C_svc / X_train.shape[0] with low values of C_svc (e.g. C_svc = 0.0000001) to emphasize the impact of the regularizer, I get exactly the same scores for the two models when I vary the random_state of the train_test_split. So L1-SVC and PA-I are indeed equivalent using the above parameterization.

For L2-SVC and PA-II the equivalent parameterization seems to be different: I need to set C_pa=1. / C_svc with low values of C_svc to observe the same kind of score correlations when I vary the train / test split distribution randomly.

So this might be a bug of the PA-II regularization.

Note that the digits dataset might not be noisy enough to have the regularization play an important role. Maybe we should find a better dataset to test those equivalences. If you are interested:

import numpy as np
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.svm import LinearSVC
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
digits = load_digits()

Then for L2 / PA-II:

rng = np.random.RandomState(None)
seed = rng.randint(1000)
C = 0.0000001
X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, random_state=seed)
linear_svm = LinearSVC(C=C, loss='l2').fit(X_train, y_train)
linear_pa = PassiveAggressiveClassifier(C=1. / C, loss='pa2', random_state=seed, n_iter=5).fit(X_train, y_train)
print linear_svm.score(X_test, y_test), linear_pa.score(X_test, y_test)

and for L1 / PA-I:

rng = np.random.RandomState(None)
seed = rng.randint(1000)
C = 0.0000001
X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, random_state=seed)
linear_svm = LinearSVC(C=C, loss='l1').fit(X_train, y_train)
linear_pa = PassiveAggressiveClassifier(C=C / X_train.shape[0], loss='pa1', random_state=seed, n_iter=5).fit(X_train, y_train)
print linear_svm.score(X_test, y_test), linear_pa.score(X_test, y_test)

@mblondel
Copy link
Member

mblondel commented Nov 2, 2012

@zaxtax I will work on merging this PR over the week-end or at the beginning of next week. Could you push your latest local changes if you have any?

@ogrisel I will double-check the PA-II implementation but the different behavior is not totally unexpected.

@zaxtax
Copy link
Contributor Author

zaxtax commented Nov 2, 2012

@mblondel Will do.

@mblondel mblondel merged commit f67e0d2 into scikit-learn:master Nov 5, 2012
@mblondel
Copy link
Member

mblondel commented Nov 5, 2012

Merged.

Turns out the implementation was still incorrect (along with all the issues at the API level)...

@amueller
Copy link
Member

amueller commented Nov 5, 2012

Err did you fix it yet?

@amueller
Copy link
Member

amueller commented Nov 5, 2012

Or rather: are you planning to work on it?

@mblondel
Copy link
Member

mblondel commented Nov 5, 2012

Everything is good now (except that I spent 5 hours on this PR T_T)

@GaelVaroquaux
Copy link
Member

except that I spent 5 hours on this PR T_T

As often. It is my experience that quite often I pick up a PR that looks
finished (according to the author) and a close look in the code reveals
work to be done. I then decide to do it rather than bother the original
author; and eventually realize that I had grossely underestimated the
amount of time to spend on the PR. If we could find a way to avoid this
scenario, it would probably greatly decrease the number of PRs waiting to
be merged.

@amueller
Copy link
Member

amueller commented Nov 5, 2012

Wow thanks a lot for putting all this work in to merge the PR! This is a great feature to have.
Also thanks a lot to @zaxtax for baring with us and our nit-picking and api-discussions :)

@amueller
Copy link
Member

amueller commented Nov 5, 2012

@GaelVaroquaux how about Travis?
Passing tests is obviously not sufficient but necessary. And having broken tests will definitely alert the authors!
Also missing test coverage maybe (+pep8 + pyflakes)? These things could be easily checked and would help at least a bit.

@ogrisel
Copy link
Member

ogrisel commented Nov 5, 2012

I am not sure that the class_weight reweighting of SGDClassifier is correct for PassiveAggressiveClassifier. Any opinion on this? Maybe we should remove the class_weight parameter for PassiveAggressiveClassifier unless someone has a theoretical proof or an empirical justification for it to stay.

Also, AFAIK the narrative documentation lacks the mathematical formulation of the objective functions being optimized by this family of models.

Next time @mblondel it would be better to open a new PR from your branch to master to allow others to follow and review the merge work.

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux how about Travis?

Do we have the time budget necessary? If so, it would be great.

Also missing test coverage maybe (+pep8 + pyflakes)? These things could be
easily checked and would help at least a bit.

Yes, its often a lot of simple things.

@ogrisel
Copy link
Member

ogrisel commented Nov 5, 2012

Would be great to setup travis for scikit-learn but we need to find a way no to install scipy from source as part of the travis build: it means finding a way to configure the travis VM to use a binary package for numpy / scipy (e.g. the ubuntu package if the travis VM is ubuntu distrib).

@mblondel
Copy link
Member

mblondel commented Nov 5, 2012

I am not sure that the class_weight reweighting of SGDClassifier is correct for PassiveAggressiveClassifier. Any opinion on this? Maybe we should remove the class_weight parameter for PassiveAggressiveClassifier unless someone has a theoretical proof or an empirical justification for it to stay.

My intuition is that it's ok (it downscales the step size) but you're right that there's no mathematical grounding for it. I'm ok with removing it but we need to remove sample_weight too then.

BTW, the paper has a section on cost sensitive learning but it's focused on multiclass classification.

Also, AFAIK the narrative documentation lacks the mathematical formulation of the objective functions being optimized by this family of models.

I removed it on purpose because I thought it didn't bring anything. You can add it back if you want.

Next time @mblondel it would be better to open a new PR from your branch to master to allow others to follow and review the merge work.

I should have but I have other stuff I want to do this week. At least now it is robustly tested (I implemented a correctness test like I did for the Perceptron). Let's do the final polishing in master.

@amueller
Copy link
Member

amueller commented Nov 5, 2012

Am 05.11.2012 16:33, schrieb Olivier Grisel:

Would be great to setup travis for scikit-learn but we need to find a
way no to install scipy from source as part of the travis build: it
means finding a way to configure the travis VM to use a binary package
for numpy / scipy (e.g. the ubuntu package if the travis VM is ubuntu
distrib).

scikit-image has a solution, I think. Forgot what it was, though.

@ogrisel
Copy link
Member

ogrisel commented Nov 5, 2012

My intuition is that it's ok (it downscales the step size) but you're right that there's no mathematical grounding for it. I'm ok with removing it but we need to remove sample_weight too then.
BTW, the paper has a section on cost sensitive learning but it's focused on multiclass classification.

+0 for for removing stuff that has no theoretical grounding nor empirical validation that it works as expected.

Also, AFAIK the narrative documentation lacks the mathematical formulation of the objective functions being optimized by this family of models.
I removed it on purpose because I thought it didn't bring anything. You can add it back if you want.

Yes, also please include the equivalence with LinearSVC objective functions as done in your earlier comment.

@zaxtax
Copy link
Contributor Author

zaxtax commented Nov 5, 2012

@amueller thanks, next time I hope to have enough of the work needed done ahead of time to not burden you guys

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.