@@ -195,6 +195,9 @@ class OneHotEncoder(_BaseEncoder):
195
195
- None : retain all features (the default).
196
196
- 'first' : drop the first category in each feature. If only one
197
197
category is present, the feature will be dropped entirely.
198
+ - 'if_binary' : drop the first category in each feature with two
199
+ categories. Features with 1 or more than 2 categories are
200
+ left intact.
198
201
- array : ``drop[i]`` is the category in feature ``X[:, i]`` that
199
202
should be dropped.
200
203
@@ -222,8 +225,12 @@ class OneHotEncoder(_BaseEncoder):
222
225
223
226
drop_idx_ : array of shape (n_features,)
224
227
``drop_idx_[i]`` is the index in ``categories_[i]`` of the category to
225
- be dropped for each feature. None if all the transformed features will
226
- be retained.
228
+ be dropped for each feature.
229
+ ``drop_idx_[i] = -1`` if no category is to be dropped from the feature
230
+ with index ``i``, e.g. when `drop='if_binary'` and the feature isn't
231
+ binary
232
+
233
+ ``drop_idx_ = None`` if all the transformed features will be retained.
227
234
228
235
See Also
229
236
--------
@@ -293,15 +300,28 @@ def _validate_keywords(self):
293
300
def _compute_drop_idx (self ):
294
301
if self .drop is None :
295
302
return None
296
- elif (isinstance (self .drop , str ) and self .drop == 'first' ):
297
- return np .zeros (len (self .categories_ ), dtype = np .int_ )
298
- elif not isinstance (self .drop , str ):
303
+ elif isinstance (self .drop , str ):
304
+ if self .drop == 'first' :
305
+ return np .zeros (len (self .categories_ ), dtype = np .int_ )
306
+ elif self .drop == 'if_binary' :
307
+ return np .array ([0 if len (cats ) == 2 else - 1
308
+ for cats in self .categories_ ], dtype = np .int_ )
309
+ else :
310
+ msg = (
311
+ "Wrong input for parameter `drop`. Expected "
312
+ "'first', 'if_binary', None or array of objects, got {}"
313
+ )
314
+ raise ValueError (msg .format (type (self .drop )))
315
+
316
+ else :
299
317
try :
300
318
self .drop = np .asarray (self .drop , dtype = object )
301
319
droplen = len (self .drop )
302
320
except (ValueError , TypeError ):
303
- msg = ("Wrong input for parameter `drop`. Expected "
304
- "'first', None or array of objects, got {}" )
321
+ msg = (
322
+ "Wrong input for parameter `drop`. Expected "
323
+ "'first', 'if_binary', None or array of objects, got {}"
324
+ )
305
325
raise ValueError (msg .format (type (self .drop )))
306
326
if droplen != len (self .categories_ ):
307
327
msg = ("`drop` should have length equal to the number "
@@ -321,10 +341,6 @@ def _compute_drop_idx(self):
321
341
return np .array ([np .where (cat_list == val )[0 ][0 ]
322
342
for (val , cat_list ) in
323
343
zip (self .drop , self .categories_ )], dtype = np .int_ )
324
- else :
325
- msg = ("Wrong input for parameter `drop`. Expected "
326
- "'first', None or array of objects, got {}" )
327
- raise ValueError (msg .format (type (self .drop )))
328
344
329
345
def fit (self , X , y = None ):
330
346
"""
@@ -392,15 +408,25 @@ def transform(self, X):
392
408
n_samples , n_features = X_int .shape
393
409
394
410
if self .drop is not None :
395
- to_drop = self .drop_idx_ .reshape (1 , - 1 )
396
-
411
+ to_drop = self .drop_idx_ .copy ()
397
412
# We remove all the dropped categories from mask, and decrement all
398
413
# categories that occur after them to avoid an empty column.
399
-
400
414
keep_cells = X_int != to_drop
401
- X_mask &= keep_cells
415
+ n_values = []
416
+ for i , cats in enumerate (self .categories_ ):
417
+ n_cats = len (cats )
418
+
419
+ # drop='if_binary' but feature isn't binary
420
+ if to_drop [i ] == - 1 :
421
+ # set to cardinality to not drop from X_int
422
+ to_drop [i ] = n_cats
423
+ n_values .append (n_cats )
424
+ else : # dropped
425
+ n_values .append (n_cats - 1 )
426
+
427
+ to_drop = to_drop .reshape (1 , - 1 )
402
428
X_int [X_int > to_drop ] -= 1
403
- n_values = [ len ( cats ) - 1 for cats in self . categories_ ]
429
+ X_mask &= keep_cells
404
430
else :
405
431
n_values = [len (cats ) for cats in self .categories_ ]
406
432
@@ -447,6 +473,10 @@ def inverse_transform(self, X):
447
473
if self .drop is None :
448
474
n_transformed_features = sum (len (cats )
449
475
for cats in self .categories_ )
476
+ elif isinstance (self .drop , str ) and self .drop == 'if_binary' :
477
+ n_transformed_features = sum (1 if len (cats ) == 2
478
+ else len (cats )
479
+ for cats in self .categories_ )
450
480
else :
451
481
n_transformed_features = sum (len (cats ) - 1
452
482
for cats in self .categories_ )
0 commit comments