27
27
from ._fastcan import FastCan
28
28
from ._narx_fast import _predict_step , _update_cfd , _update_terms # type: ignore
29
29
from ._refine import refine
30
+ from .utils import mask_missing_values
30
31
31
32
32
33
@validate_params (
@@ -273,14 +274,6 @@ def make_poly_ids(
273
274
return np .delete (ids , const_id , 0 ) # remove the constant featrue
274
275
275
276
276
- def _mask_missing_value (* arr , return_mask = False ):
277
- """Remove missing value for all arrays."""
278
- mask_nomissing = np .all (np .isfinite (np .c_ [arr ]), axis = 1 )
279
- if return_mask :
280
- return mask_nomissing
281
- return tuple ([x [mask_nomissing ] for x in arr ])
282
-
283
-
284
277
def _valiate_time_shift_poly_ids (
285
278
time_shift_ids , poly_ids , n_samples = None , n_features = None , n_outputs = None
286
279
):
@@ -374,7 +367,7 @@ def _validate_feat_delay_ids(
374
367
)
375
368
if (delay_ids_ .min () < - 1 ) or (delay_ids_ .max () >= n_samples ):
376
369
raise ValueError (
377
- "The element x of delay_ids should " f" satisfy -1 <= x < { n_samples } ."
370
+ f "The element x of delay_ids should satisfy -1 <= x < { n_samples } ."
378
371
)
379
372
return feat_ids_ , delay_ids_
380
373
@@ -783,7 +776,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
783
776
time_shift_vars = make_time_shift_features (xy_hstack , time_shift_ids )
784
777
poly_terms = make_poly_features (time_shift_vars , poly_ids )
785
778
# Remove missing values
786
- poly_terms_masked , y_masked , sample_weight_masked = _mask_missing_value (
779
+ poly_terms_masked , y_masked , sample_weight_masked = mask_missing_values (
787
780
poly_terms , y , sample_weight
788
781
)
789
782
coef = np .zeros (n_terms , dtype = float )
@@ -1060,7 +1053,7 @@ def _loss(
1060
1053
output_ids ,
1061
1054
)
1062
1055
1063
- y_masked , y_hat_masked , sample_weight_sqrt_masked = _mask_missing_value (
1056
+ y_masked , y_hat_masked , sample_weight_sqrt_masked = mask_missing_values (
1064
1057
y , y_hat , sample_weight_sqrt
1065
1058
)
1066
1059
@@ -1115,12 +1108,10 @@ def _grad(
1115
1108
grad_delay_ids ,
1116
1109
)
1117
1110
1118
- mask_nomissing = _mask_missing_value (
1119
- y , y_hat , sample_weight_sqrt , return_mask = True
1120
- )
1111
+ mask_valid = mask_missing_values (y , y_hat , sample_weight_sqrt , return_mask = True )
1121
1112
1122
- sample_weight_sqrt_masked = sample_weight_sqrt [mask_nomissing ]
1123
- dydx_masked = dydx [mask_nomissing ]
1113
+ sample_weight_sqrt_masked = sample_weight_sqrt [mask_valid ]
1114
+ dydx_masked = dydx [mask_valid ]
1124
1115
1125
1116
return dydx_masked .sum (axis = 1 ) * sample_weight_sqrt_masked
1126
1117
@@ -1264,7 +1255,7 @@ def _get_term_str(term_feat_ids, term_delay_ids):
1264
1255
else :
1265
1256
term_str += f"*X[k-{ delay_id } ,{ feat_id } ]"
1266
1257
elif feat_id >= narx .n_features_in_ :
1267
- term_str += f"*y_hat[k-{ delay_id } ,{ feat_id - narx .n_features_in_ } ]"
1258
+ term_str += f"*y_hat[k-{ delay_id } ,{ feat_id - narx .n_features_in_ } ]"
1268
1259
return term_str [1 :]
1269
1260
1270
1261
yid_space = 5
@@ -1472,7 +1463,7 @@ def make_narx(
1472
1463
poly_terms = make_poly_features (time_shift_vars , poly_ids_all )
1473
1464
1474
1465
# Remove missing values
1475
- poly_terms_masked , y_masked = _mask_missing_value (poly_terms , y )
1466
+ poly_terms_masked , y_masked = mask_missing_values (poly_terms , y )
1476
1467
1477
1468
selected_poly_ids = []
1478
1469
for i in range (n_outputs ):
0 commit comments