@@ -168,7 +168,11 @@ def test_kmeans_elkan_results(distribution, array_constr, tol, global_random_see
168
168
169
169
km_lloyd = KMeans (n_clusters = 5 , random_state = global_random_seed , n_init = 1 , tol = tol )
170
170
km_elkan = KMeans (
171
- algorithm = "elkan" , n_clusters = 5 , random_state = global_random_seed , n_init = 1 , tol = tol
171
+ algorithm = "elkan" ,
172
+ n_clusters = 5 ,
173
+ random_state = global_random_seed ,
174
+ n_init = 1 ,
175
+ tol = tol ,
172
176
)
173
177
174
178
km_lloyd .fit (X )
@@ -333,7 +337,10 @@ def test_fortran_aligned_data(Estimator, global_random_seed):
333
337
n_clusters = n_clusters , init = centers , n_init = 1 , random_state = global_random_seed
334
338
).fit (X )
335
339
km_f = Estimator (
336
- n_clusters = n_clusters , init = centers_fortran , n_init = 1 , random_state = global_random_seed
340
+ n_clusters = n_clusters ,
341
+ init = centers_fortran ,
342
+ n_init = 1 ,
343
+ random_state = global_random_seed ,
337
344
).fit (X_fortran )
338
345
assert_allclose (km_c .cluster_centers_ , km_f .cluster_centers_ )
339
346
assert_array_equal (km_c .labels_ , km_f .labels_ )
@@ -400,7 +407,9 @@ def test_minibatch_sensible_reassign(global_random_seed):
400
407
# check that identical initial clusters are reassigned
401
408
# also a regression test for when there are more desired reassignments than
402
409
# samples.
403
- zeroed_X , true_labels = make_blobs (n_samples = 100 , centers = 5 , random_state = global_random_seed )
410
+ zeroed_X , true_labels = make_blobs (
411
+ n_samples = 100 , centers = 5 , random_state = global_random_seed
412
+ )
404
413
zeroed_X [::2 , :] = 0
405
414
406
415
km = MiniBatchKMeans (
@@ -626,10 +635,16 @@ def test_kmeans_predict(
626
635
@pytest .mark .parametrize ("Estimator" , [KMeans , MiniBatchKMeans ])
627
636
def test_dense_sparse (Estimator , global_random_seed ):
628
637
# Check that the results are the same for dense and sparse input.
629
- sample_weight = np .random .RandomState (global_random_seed ).random_sample ((n_samples ,))
630
- km_dense = Estimator (n_clusters = n_clusters , random_state = global_random_seed , n_init = 1 )
638
+ sample_weight = np .random .RandomState (global_random_seed ).random_sample (
639
+ (n_samples ,)
640
+ )
641
+ km_dense = Estimator (
642
+ n_clusters = n_clusters , random_state = global_random_seed , n_init = 1
643
+ )
631
644
km_dense .fit (X , sample_weight = sample_weight )
632
- km_sparse = Estimator (n_clusters = n_clusters , random_state = global_random_seed , n_init = 1 )
645
+ km_sparse = Estimator (
646
+ n_clusters = n_clusters , random_state = global_random_seed , n_init = 1
647
+ )
633
648
km_sparse .fit (X_csr , sample_weight = sample_weight )
634
649
635
650
assert_array_equal (km_dense .labels_ , km_sparse .labels_ )
@@ -774,7 +789,8 @@ def test_float_precision(Estimator, data, global_random_seed):
774
789
assert_allclose (inertia [np .float32 ], inertia [np .float64 ], rtol = 1e-4 )
775
790
assert_allclose (Xt [np .float32 ], Xt [np .float64 ], atol = Xt [np .float64 ].max () * 1e-4 )
776
791
assert_allclose (
777
- centers [np .float32 ], centers [np .float64 ], atol = centers [np .float64 ].max () * 1e-4 )
792
+ centers [np .float32 ], centers [np .float64 ], atol = centers [np .float64 ].max () * 1e-4
793
+ )
778
794
assert_array_equal (labels [np .float32 ], labels [np .float64 ])
779
795
780
796
@@ -829,10 +845,14 @@ def test_weighted_vs_repeated(global_random_seed):
829
845
# repetition of the sample. Valid only if init is precomputed, otherwise
830
846
# rng produces different results. Not valid for MinibatchKMeans due to rng
831
847
# to extract minibatches.
832
- sample_weight = np .random .RandomState (global_random_seed ).randint (1 , 5 , size = n_samples )
848
+ sample_weight = np .random .RandomState (global_random_seed ).randint (
849
+ 1 , 5 , size = n_samples
850
+ )
833
851
X_repeat = np .repeat (X , sample_weight , axis = 0 )
834
852
835
- km = KMeans (init = centers , n_init = 1 , n_clusters = n_clusters , random_state = global_random_seed )
853
+ km = KMeans (
854
+ init = centers , n_init = 1 , n_clusters = n_clusters , random_state = global_random_seed
855
+ )
836
856
837
857
km_weighted = clone (km ).fit (X , sample_weight = sample_weight )
838
858
repeated_labels = np .repeat (km_weighted .labels_ , sample_weight )
@@ -908,9 +928,17 @@ def test_result_equal_in_diff_n_threads(Estimator, global_random_seed):
908
928
X = rnd .normal (size = (50 , 10 ))
909
929
910
930
with threadpool_limits (limits = 1 , user_api = "openmp" ):
911
- result_1 = Estimator (n_clusters = n_clusters , random_state = global_random_seed ).fit (X ).labels_
931
+ result_1 = (
932
+ Estimator (n_clusters = n_clusters , random_state = global_random_seed )
933
+ .fit (X )
934
+ .labels_
935
+ )
912
936
with threadpool_limits (limits = 2 , user_api = "openmp" ):
913
- result_2 = Estimator (n_clusters = n_clusters , random_state = global_random_seed ).fit (X ).labels_
937
+ result_2 = (
938
+ Estimator (n_clusters = n_clusters , random_state = global_random_seed )
939
+ .fit (X )
940
+ .labels_
941
+ )
914
942
assert_array_equal (result_1 , result_2 )
915
943
916
944
@@ -1119,7 +1147,9 @@ def test_kmeans_plusplus_wrong_params(param, match):
1119
1147
def test_kmeans_plusplus_output (data , dtype , global_random_seed ):
1120
1148
# Check for the correct number of seeds and all positive values
1121
1149
data = data .astype (dtype )
1122
- centers , indices = kmeans_plusplus (data , n_clusters , random_state = global_random_seed )
1150
+ centers , indices = kmeans_plusplus (
1151
+ data , n_clusters , random_state = global_random_seed
1152
+ )
1123
1153
1124
1154
# Check there are the correct number of indices and that all indices are
1125
1155
# positive and within the number of samples
@@ -1152,7 +1182,9 @@ def test_kmeans_plusplus_dataorder(global_random_seed):
1152
1182
1153
1183
X_fortran = np .asfortranarray (X )
1154
1184
1155
- centers_fortran , _ = kmeans_plusplus (X_fortran , n_clusters , random_state = global_random_seed )
1185
+ centers_fortran , _ = kmeans_plusplus (
1186
+ X_fortran , n_clusters , random_state = global_random_seed
1187
+ )
1156
1188
1157
1189
assert_allclose (centers_c , centers_fortran )
1158
1190
0 commit comments