Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 75aa035

Browse filesBrowse files
rprkhglemaitrejeremiedbb
authored
MAINT Parameter Validation for covariance.graphical_lasso (#25053)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 6305fa4 commit 75aa035
Copy full SHA for 75aa035

File tree

2 files changed

+11
-1
lines changed
Filter options

2 files changed

+11
-1
lines changed

‎sklearn/covariance/_graph_lasso.py

Copy file name to clipboardExpand all lines: sklearn/covariance/_graph_lasso.py
+10-1Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from ..utils.parallel import delayed, Parallel
2626
from ..utils._param_validation import Interval, StrOptions
27+
from ..utils._param_validation import validate_params
2728

2829
# mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast'
2930
from ..linear_model import _cd_fast as cd_fast # type: ignore
@@ -211,6 +212,14 @@ def alpha_max(emp_cov):
211212
return np.max(np.abs(A))
212213

213214

215+
@validate_params(
216+
{
217+
"emp_cov": ["array-like"],
218+
"cov_init": ["array-like", None],
219+
"return_costs": ["boolean"],
220+
"return_n_iter": ["boolean"],
221+
}
222+
)
214223
def graphical_lasso(
215224
emp_cov,
216225
alpha,
@@ -234,7 +243,7 @@ def graphical_lasso(
234243
235244
Parameters
236245
----------
237-
emp_cov : ndarray of shape (n_features, n_features)
246+
emp_cov : array-like of shape (n_features, n_features)
238247
Empirical covariance from which to compute the covariance estimate.
239248
240249
alpha : float

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def test_function_param_validation(func_module):
280280
("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"),
281281
("sklearn.cluster.mean_shift", "sklearn.cluster.MeanShift"),
282282
("sklearn.cluster.spectral_clustering", "sklearn.cluster.SpectralClustering"),
283+
("sklearn.covariance.graphical_lasso", "sklearn.covariance.GraphicalLasso"),
283284
("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"),
284285
("sklearn.covariance.oas", "sklearn.covariance.OAS"),
285286
("sklearn.decomposition.dict_learning", "sklearn.decomposition.DictionaryLearning"),

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.