Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a88a460

Browse filesBrowse files
Merge pull request #81 from MatthewSZhang/missing-values
FEAT add mask_missing_values in utils
2 parents c939d8c + 4006c59 commit a88a460
Copy full SHA for a88a460

File tree

Expand file treeCollapse file tree

5 files changed

+426
-400
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+426
-400
lines changed

‎fastcan/narx.py

Copy file name to clipboardExpand all lines: fastcan/narx.py
+9-18Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ._fastcan import FastCan
2828
from ._narx_fast import _predict_step, _update_cfd, _update_terms # type: ignore
2929
from ._refine import refine
30+
from .utils import mask_missing_values
3031

3132

3233
@validate_params(
@@ -273,14 +274,6 @@ def make_poly_ids(
273274
return np.delete(ids, const_id, 0) # remove the constant featrue
274275

275276

276-
def _mask_missing_value(*arr, return_mask=False):
277-
"""Remove missing value for all arrays."""
278-
mask_nomissing = np.all(np.isfinite(np.c_[arr]), axis=1)
279-
if return_mask:
280-
return mask_nomissing
281-
return tuple([x[mask_nomissing] for x in arr])
282-
283-
284277
def _valiate_time_shift_poly_ids(
285278
time_shift_ids, poly_ids, n_samples=None, n_features=None, n_outputs=None
286279
):
@@ -374,7 +367,7 @@ def _validate_feat_delay_ids(
374367
)
375368
if (delay_ids_.min() < -1) or (delay_ids_.max() >= n_samples):
376369
raise ValueError(
377-
"The element x of delay_ids should " f"satisfy -1 <= x < {n_samples}."
370+
f"The element x of delay_ids should satisfy -1 <= x < {n_samples}."
378371
)
379372
return feat_ids_, delay_ids_
380373

@@ -783,7 +776,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
783776
time_shift_vars = make_time_shift_features(xy_hstack, time_shift_ids)
784777
poly_terms = make_poly_features(time_shift_vars, poly_ids)
785778
# Remove missing values
786-
poly_terms_masked, y_masked, sample_weight_masked = _mask_missing_value(
779+
poly_terms_masked, y_masked, sample_weight_masked = mask_missing_values(
787780
poly_terms, y, sample_weight
788781
)
789782
coef = np.zeros(n_terms, dtype=float)
@@ -1060,7 +1053,7 @@ def _loss(
10601053
output_ids,
10611054
)
10621055

1063-
y_masked, y_hat_masked, sample_weight_sqrt_masked = _mask_missing_value(
1056+
y_masked, y_hat_masked, sample_weight_sqrt_masked = mask_missing_values(
10641057
y, y_hat, sample_weight_sqrt
10651058
)
10661059

@@ -1115,12 +1108,10 @@ def _grad(
11151108
grad_delay_ids,
11161109
)
11171110

1118-
mask_nomissing = _mask_missing_value(
1119-
y, y_hat, sample_weight_sqrt, return_mask=True
1120-
)
1111+
mask_valid = mask_missing_values(y, y_hat, sample_weight_sqrt, return_mask=True)
11211112

1122-
sample_weight_sqrt_masked = sample_weight_sqrt[mask_nomissing]
1123-
dydx_masked = dydx[mask_nomissing]
1113+
sample_weight_sqrt_masked = sample_weight_sqrt[mask_valid]
1114+
dydx_masked = dydx[mask_valid]
11241115

11251116
return dydx_masked.sum(axis=1) * sample_weight_sqrt_masked
11261117

@@ -1264,7 +1255,7 @@ def _get_term_str(term_feat_ids, term_delay_ids):
12641255
else:
12651256
term_str += f"*X[k-{delay_id},{feat_id}]"
12661257
elif feat_id >= narx.n_features_in_:
1267-
term_str += f"*y_hat[k-{delay_id},{feat_id-narx.n_features_in_}]"
1258+
term_str += f"*y_hat[k-{delay_id},{feat_id - narx.n_features_in_}]"
12681259
return term_str[1:]
12691260

12701261
yid_space = 5
@@ -1472,7 +1463,7 @@ def make_narx(
14721463
poly_terms = make_poly_features(time_shift_vars, poly_ids_all)
14731464

14741465
# Remove missing values
1475-
poly_terms_masked, y_masked = _mask_missing_value(poly_terms, y)
1466+
poly_terms_masked, y_masked = mask_missing_values(poly_terms, y)
14761467

14771468
selected_poly_ids = []
14781469
for i in range(n_outputs):

‎fastcan/utils.py

Copy file name to clipboardExpand all lines: fastcan/utils.py
+50-1Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
from sklearn.cross_decomposition import CCA
10-
from sklearn.utils import check_X_y
10+
from sklearn.utils import _safe_indexing, check_consistent_length, check_X_y
1111
from sklearn.utils._param_validation import Interval, validate_params
1212

1313

@@ -120,3 +120,52 @@ def ols(X, y, t=1):
120120
if not mask[j]:
121121
w[:, j] = w[:, j] - w[:, d] * (w[:, d] @ w[:, j])
122122
w[:, j] /= np.linalg.norm(w[:, j], axis=0)
123+
124+
125+
@validate_params(
126+
{
127+
"return_mask": ["boolean"],
128+
},
129+
prefer_skip_nested_validation=True,
130+
)
131+
def mask_missing_values(*arrays, return_mask=False):
132+
"""Remove missing values for all arrays.
133+
134+
Parameters
135+
----------
136+
*arrays : sequence of array-like of shape (n_samples,) or \
137+
(n_samples, n_outputs)
138+
Arrays with consistent first dimension.
139+
140+
return_mask : bool, default=False
141+
If True, return a mask of valid values.
142+
If False, return the arrays with missing values removed.
143+
144+
Returns
145+
-------
146+
mask_valid : ndarray of shape (n_samples,)
147+
Mask of valid values.
148+
149+
masked_arrays : sequence of array-like of shape (n_samples,) or \
150+
(n_samples, n_outputs)
151+
Arrays with missing values removed.
152+
The order of the arrays is the same as the input arrays.
153+
154+
Examples
155+
--------
156+
>>> import numpy as np
157+
>>> from fastcan.utils import mask_missing_values
158+
>>> a = [[1, 2], [3, np.nan], [5, 6]]
159+
>>> b = [1, 2, 3]
160+
>>> mask_missing_values(a, b)
161+
[[[1, 2], [5, 6]], [1, 3]]
162+
>>> mask_missing_values(a, b, return_mask=True)
163+
array([ True, False, True])
164+
"""
165+
if len(arrays) == 0:
166+
return None
167+
check_consistent_length(*arrays)
168+
mask_valid = np.all(np.isfinite(np.c_[arrays]), axis=1)
169+
if return_mask:
170+
return mask_valid
171+
return [_safe_indexing(x, mask_valid) for x in arrays]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.