@@ -386,6 +386,22 @@ def _getTargetClass(self):
386
386
from gcloud .bigtable .row_data import PartialRowsData
387
387
return PartialRowsData
388
388
389
+ def _getDoNothingClass (self ):
390
+ klass = self ._getTargetClass ()
391
+
392
+ class FakePartialRowsData (klass ):
393
+
394
+ def __init__ (self , * args , ** kwargs ):
395
+ super (FakePartialRowsData , self ).__init__ (* args , ** kwargs )
396
+ self ._consumed = []
397
+
398
+ def consume_next (self ):
399
+ value = self ._response_iterator .next ()
400
+ self ._consumed .append (value )
401
+ return value
402
+
403
+ return FakePartialRowsData
404
+
389
405
def _makeOne (self , * args , ** kwargs ):
390
406
return self ._getTargetClass ()(* args , ** kwargs )
391
407
@@ -425,3 +441,84 @@ def test_rows_getter(self):
425
441
partial_rows_data = self ._makeOne (None )
426
442
partial_rows_data ._rows = value = object ()
427
443
self .assertTrue (partial_rows_data .rows is value )
444
+
445
+ def test_cancel (self ):
446
+ response_iterator = _MockCancellableIterator ()
447
+ partial_rows_data = self ._makeOne (response_iterator )
448
+ self .assertEqual (response_iterator .cancel_calls , 0 )
449
+ partial_rows_data .cancel ()
450
+ self .assertEqual (response_iterator .cancel_calls , 1 )
451
+
452
+ def test_consume_next (self ):
453
+ from gcloud .bigtable ._generated import (
454
+ bigtable_service_messages_pb2 as messages_pb2 )
455
+ from gcloud .bigtable .row_data import PartialRowData
456
+
457
+ row_key = b'row-key'
458
+ value_pb = messages_pb2 .ReadRowsResponse (row_key = row_key )
459
+ response_iterator = _MockCancellableIterator (value_pb )
460
+ partial_rows_data = self ._makeOne (response_iterator )
461
+ self .assertEqual (partial_rows_data .rows , {})
462
+ partial_rows_data .consume_next ()
463
+ expected_rows = {row_key : PartialRowData (row_key )}
464
+ self .assertEqual (partial_rows_data .rows , expected_rows )
465
+
466
+ def test_consume_next_row_exists (self ):
467
+ from gcloud .bigtable ._generated import (
468
+ bigtable_service_messages_pb2 as messages_pb2 )
469
+ from gcloud .bigtable .row_data import PartialRowData
470
+
471
+ row_key = b'row-key'
472
+ chunk = messages_pb2 .ReadRowsResponse .Chunk (commit_row = True )
473
+ value_pb = messages_pb2 .ReadRowsResponse (row_key = row_key ,
474
+ chunks = [chunk ])
475
+ response_iterator = _MockCancellableIterator (value_pb )
476
+ partial_rows_data = self ._makeOne (response_iterator )
477
+ existing_values = PartialRowData (row_key )
478
+ partial_rows_data ._rows [row_key ] = existing_values
479
+ self .assertFalse (existing_values .committed )
480
+ partial_rows_data .consume_next ()
481
+ self .assertTrue (existing_values .committed )
482
+ self .assertEqual (existing_values .cells , {})
483
+
484
+ def test_consume_next_empty_iter (self ):
485
+ response_iterator = _MockCancellableIterator ()
486
+ partial_rows_data = self ._makeOne (response_iterator )
487
+ with self .assertRaises (StopIteration ):
488
+ partial_rows_data .consume_next ()
489
+
490
+ def test_consume_all (self ):
491
+ klass = self ._getDoNothingClass ()
492
+
493
+ value1 , value2 , value3 = object (), object (), object ()
494
+ response_iterator = _MockCancellableIterator (value1 , value2 , value3 )
495
+ partial_rows_data = klass (response_iterator )
496
+ self .assertEqual (partial_rows_data ._consumed , [])
497
+ partial_rows_data .consume_all ()
498
+ self .assertEqual (partial_rows_data ._consumed , [value1 , value2 , value3 ])
499
+
500
+ def test_consume_all_with_max_loops (self ):
501
+ klass = self ._getDoNothingClass ()
502
+
503
+ value1 , value2 , value3 = object (), object (), object ()
504
+ response_iterator = _MockCancellableIterator (value1 , value2 , value3 )
505
+ partial_rows_data = klass (response_iterator )
506
+ self .assertEqual (partial_rows_data ._consumed , [])
507
+ partial_rows_data .consume_all (max_loops = 1 )
508
+ self .assertEqual (partial_rows_data ._consumed , [value1 ])
509
+ # Make sure the iterator still has the remaining values.
510
+ self .assertEqual (list (response_iterator .iter_values ), [value2 , value3 ])
511
+
512
+
513
+ class _MockCancellableIterator (object ):
514
+
515
+ cancel_calls = 0
516
+
517
+ def __init__ (self , * values ):
518
+ self .iter_values = iter (values )
519
+
520
+ def cancel (self ):
521
+ self .cancel_calls += 1
522
+
523
+ def next (self ):
524
+ return next (self .iter_values )
0 commit comments