Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit 95763ff

Browse filesBrowse files
perf: Avoid requery for some result downsample methods (#2219)
Co-authored-by: Chelsea Lin <chelsealin@google.com>
1 parent 0396278 commit 95763ff
Copy full SHA for 95763ff

6 files changed

+64-59Lines changed: 64 additions & 59 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎bigframes/core/blocks.py‎

Copy file name to clipboardExpand all lines: bigframes/core/blocks.py
+23-42Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -818,49 +818,30 @@ def _materialize_local(
818818
total_rows = result_batches.approx_total_rows
819819
# Remove downsampling config from subsequent invocations, as otherwise could result in many
820820
# iterations if downsampling undershoots
821-
return self._downsample(
822-
total_rows=total_rows,
823-
sampling_method=sample_config.sampling_method,
824-
fraction=fraction,
825-
random_state=sample_config.random_state,
826-
)._materialize_local(
827-
MaterializationOptions(ordered=materialize_options.ordered)
828-
)
829-
else:
830-
df = result_batches.to_pandas()
831-
df = self._copy_index_to_pandas(df)
832-
df.set_axis(self.column_labels, axis=1, copy=False)
833-
return df, execute_result.query_job
834-
835-
def _downsample(
836-
self, total_rows: int, sampling_method: str, fraction: float, random_state
837-
) -> Block:
838-
# either selecting fraction or number of rows
839-
if sampling_method == _HEAD:
840-
filtered_block = self.slice(stop=int(total_rows * fraction))
841-
return filtered_block
842-
elif (sampling_method == _UNIFORM) and (random_state is None):
843-
filtered_expr = self.expr._uniform_sampling(fraction)
844-
block = Block(
845-
filtered_expr,
846-
index_columns=self.index_columns,
847-
column_labels=self.column_labels,
848-
index_labels=self.index.names,
849-
)
850-
return block
851-
elif sampling_method == _UNIFORM:
852-
block = self.split(
853-
fracs=(fraction,),
854-
random_state=random_state,
855-
sort=False,
856-
)[0]
857-
return block
821+
if sample_config.sampling_method == "head":
822+
# Just truncates the result iterator without a follow-up query
823+
raw_df = result_batches.to_pandas(limit=int(total_rows * fraction))
824+
elif (
825+
sample_config.sampling_method == "uniform"
826+
and sample_config.random_state is None
827+
):
828+
# Pushes sample into result without new query
829+
sampled_batches = execute_result.batches(sample_rate=fraction)
830+
raw_df = sampled_batches.to_pandas()
831+
else: # uniform sample with random state requires a full follow-up query
832+
down_sampled_block = self.split(
833+
fracs=(fraction,),
834+
random_state=sample_config.random_state,
835+
sort=False,
836+
)[0]
837+
return down_sampled_block._materialize_local(
838+
MaterializationOptions(ordered=materialize_options.ordered)
839+
)
858840
else:
859-
# This part should never be called, just in case.
860-
raise NotImplementedError(
861-
f"The downsampling method {sampling_method} is not implemented, "
862-
f"please choose from {','.join(_SAMPLING_METHODS)}."
863-
)
841+
raw_df = result_batches.to_pandas()
842+
df = self._copy_index_to_pandas(raw_df)
843+
df.set_axis(self.column_labels, axis=1, copy=False)
844+
return df, execute_result.query_job
864845

865846
def split(
866847
self,
Collapse file

‎bigframes/core/bq_data.py‎

Copy file name to clipboardExpand all lines: bigframes/core/bq_data.py
+12-1Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,22 @@ def get_arrow_batches(
186186
columns: Sequence[str],
187187
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
188188
project_id: str,
189+
sample_rate: Optional[float] = None,
189190
) -> ReadResult:
190191
table_mod_options = {}
191192
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}
193+
194+
predicates = []
192195
if data.sql_predicate:
193-
read_options_dict["row_restriction"] = data.sql_predicate
196+
predicates.append(data.sql_predicate)
197+
if sample_rate is not None:
198+
assert isinstance(sample_rate, float)
199+
predicates.append(f"RAND() < {sample_rate}")
200+
201+
if predicates:
202+
full_predicates = " AND ".join(f"( {pred} )" for pred in predicates)
203+
read_options_dict["row_restriction"] = full_predicates
204+
194205
read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict)
195206

196207
if data.at_time:
Collapse file

‎bigframes/core/local_data.py‎

Copy file name to clipboardExpand all lines: bigframes/core/local_data.py
+10-1Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import uuid
2626

2727
import geopandas # type: ignore
28+
import numpy
2829
import numpy as np
2930
import pandas as pd
3031
import pyarrow as pa
@@ -124,13 +125,21 @@ def to_arrow(
124125
geo_format: Literal["wkb", "wkt"] = "wkt",
125126
duration_type: Literal["int", "duration"] = "duration",
126127
json_type: Literal["string"] = "string",
128+
sample_rate: Optional[float] = None,
127129
max_chunksize: Optional[int] = None,
128130
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
129131
if geo_format != "wkt":
130132
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
131133
assert json_type == "string"
132134

133-
batches = self.data.to_batches(max_chunksize=max_chunksize)
135+
data = self.data
136+
137+
# This exists for symmetry with remote sources, but sampling local data like this shouldn't really happen
138+
if sample_rate is not None:
139+
to_take = numpy.random.rand(data.num_rows) < sample_rate
140+
data = data.filter(to_take)
141+
142+
batches = data.to_batches(max_chunksize=max_chunksize)
134143
schema = self.data.schema
135144
if duration_type == "int":
136145
schema = _schema_durations_to_ints(schema)
Collapse file

‎bigframes/session/executor.py‎

Copy file name to clipboardExpand all lines: bigframes/session/executor.py
+15-11Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
8888

8989
yield batch
9090

91-
def to_arrow_table(self) -> pyarrow.Table:
91+
def to_arrow_table(self, limit: Optional[int] = None) -> pyarrow.Table:
9292
# Need to provide schema if no result rows, as arrow can't infer
9393
# If ther are rows, it is safest to infer schema from batches.
9494
# Any discrepencies between predicted schema and actual schema will produce errors.
@@ -97,18 +97,21 @@ def to_arrow_table(self) -> pyarrow.Table:
9797
peek_value = list(peek_it)
9898
# TODO: Enforce our internal schema on the table for consistency
9999
if len(peek_value) > 0:
100-
return pyarrow.Table.from_batches(
101-
itertools.chain(peek_value, batches), # reconstruct
102-
)
100+
batches = itertools.chain(peek_value, batches) # reconstruct
101+
if limit:
102+
batches = pyarrow_utils.truncate_pyarrow_iterable(
103+
batches, max_results=limit
104+
)
105+
return pyarrow.Table.from_batches(batches)
103106
else:
104107
try:
105108
return self._schema.to_pyarrow().empty_table()
106109
except pa.ArrowNotImplementedError:
107110
# Bug with some pyarrow versions, empty_table only supports base storage types, not extension types.
108111
return self._schema.to_pyarrow(use_storage_types=True).empty_table()
109112

110-
def to_pandas(self) -> pd.DataFrame:
111-
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self._schema)
113+
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
114+
return io_pandas.arrow_to_pandas(self.to_arrow_table(limit=limit), self._schema)
112115

113116
def to_pandas_batches(
114117
self, page_size: Optional[int] = None, max_results: Optional[int] = None
@@ -158,7 +161,7 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
158161
...
159162

160163
@abc.abstractmethod
161-
def batches(self) -> ResultsIterator:
164+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
162165
...
163166

164167
@property
@@ -200,9 +203,9 @@ def execution_metadata(self) -> ExecutionMetadata:
200203
def schema(self) -> bigframes.core.schema.ArraySchema:
201204
return self._data.schema
202205

203-
def batches(self) -> ResultsIterator:
206+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
204207
return ResultsIterator(
205-
iter(self._data.to_arrow()[1]),
208+
iter(self._data.to_arrow(sample_rate=sample_rate)[1]),
206209
self.schema,
207210
self._data.metadata.row_count,
208211
self._data.metadata.total_bytes,
@@ -226,7 +229,7 @@ def execution_metadata(self) -> ExecutionMetadata:
226229
def schema(self) -> bigframes.core.schema.ArraySchema:
227230
return self._schema
228231

229-
def batches(self) -> ResultsIterator:
232+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
230233
return ResultsIterator(iter([]), self.schema, 0, 0)
231234

232235

@@ -260,12 +263,13 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
260263
source_ids = [selection[0] for selection in self._selected_fields]
261264
return self._data.schema.select(source_ids).rename(dict(self._selected_fields))
262265

263-
def batches(self) -> ResultsIterator:
266+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
264267
read_batches = bq_data.get_arrow_batches(
265268
self._data,
266269
[x[0] for x in self._selected_fields],
267270
self._storage_client,
268271
self._project_id,
272+
sample_rate=sample_rate,
269273
)
270274
arrow_batches: Iterator[pa.RecordBatch] = map(
271275
functools.partial(
Collapse file

‎tests/system/small/test_anywidget.py‎

Copy file name to clipboardExpand all lines: tests/system/small/test_anywidget.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def execution_metadata(self) -> ExecutionMetadata:
165165
def schema(self) -> Any:
166166
return schema
167167

168-
def batches(self) -> ResultsIterator:
168+
def batches(self, sample_rate=None) -> ResultsIterator:
169169
return ResultsIterator(
170170
arrow_batches_val,
171171
self.schema,
Collapse file

‎tests/system/small/test_dataframe.py‎

Copy file name to clipboardExpand all lines: tests/system/small/test_dataframe.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4524,7 +4524,7 @@ def test_df_kurt(scalars_dfs):
45244524
"n_default",
45254525
],
45264526
)
4527-
def test_sample(scalars_dfs, frac, n, random_state):
4527+
def test_df_to_pandas_sample(scalars_dfs, frac, n, random_state):
45284528
scalars_df, _ = scalars_dfs
45294529
df = scalars_df.sample(frac=frac, n=n, random_state=random_state)
45304530
bf_result = df.to_pandas()
@@ -4535,15 +4535,15 @@ def test_sample(scalars_dfs, frac, n, random_state):
45354535
assert bf_result.shape[1] == scalars_df.shape[1]
45364536

45374537

4538-
def test_sample_determinism(penguins_df_default_index):
4538+
def test_df_to_pandas_sample_determinism(penguins_df_default_index):
45394539
df = penguins_df_default_index.sample(n=100, random_state=12345).head(15)
45404540
bf_result = df.to_pandas()
45414541
bf_result2 = df.to_pandas()
45424542

45434543
pandas.testing.assert_frame_equal(bf_result, bf_result2)
45444544

45454545

4546-
def test_sample_raises_value_error(scalars_dfs):
4546+
def test_df_to_pandas_sample_raises_value_error(scalars_dfs):
45474547
scalars_df, _ = scalars_dfs
45484548
with pytest.raises(
45494549
ValueError, match="Only one of 'n' or 'frac' parameter can be specified."

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.