diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 48c4af7a37..092c8ab82f 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3397,6 +3397,17 @@ def _set_block(self, block: blocks.Block): def _get_block(self) -> blocks.Block: return self._block + def cache(self): + """ + Materializes the DataFrame to a temporary table. + + Useful if the dataframe will be used multiple times, as this will avoid recomputating the shared intermediate value. + + Returns: + DataFrame: Self + """ + return self._cached(force=True) + def _cached(self, *, force: bool = False) -> DataFrame: """Materialize dataframe to a temporary table. No-op if the dataframe represents a trivial transformation of an existing materialization. diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 12c881c19a..7b4638157e 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -83,7 +83,7 @@ def distance( """ assert len(x.columns) == 1 and len(y.columns) == 1 - input_data = x._cached().join(y._cached(), how="outer") + input_data = x.cache().join(y.cache(), how="outer") x_column_id, y_column_id = x._block.value_columns[0], y._block.value_columns[0] return self._apply_sql( @@ -310,11 +310,9 @@ def create_model( # Cache dataframes to make sure base table is not a snapshot # cached dataframe creates a full copy, never uses snapshot if y_train is None: - input_data = X_train._cached(force=True) + input_data = X_train.cache() else: - input_data = X_train._cached(force=True).join( - y_train._cached(force=True), how="outer" - ) + input_data = X_train.cache().join(y_train.cache(), how="outer") options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()}) session = X_train._session @@ -354,9 +352,7 @@ def create_llm_remote_model( options = dict(options) # Cache dataframes to make sure base table is not a snapshot # cached dataframe creates a full copy, never uses snapshot - input_data = X_train._cached(force=True).join( - y_train._cached(force=True), how="outer" - ) + input_data = X_train.cache().join(y_train.cache(), how="outer") options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()}) session = X_train._session @@ -389,9 +385,7 @@ def create_time_series_model( options = dict(options) # Cache dataframes to make sure base table is not a snapshot # cached dataframe creates a full copy, never uses snapshot - input_data = X_train._cached(force=True).join( - y_train._cached(force=True), how="outer" - ) + input_data = X_train.cache().join(y_train.cache(), how="outer") options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]}) options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]}) diff --git a/bigframes/series.py b/bigframes/series.py index 5184d4bf1d..3986d38445 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -1682,6 +1682,17 @@ def _slice( ), ) + def cache(self): + """ + Materializes the Series to a temporary table. + + Useful if the series will be used multiple times, as this will avoid recomputating the shared intermediate value. + + Returns: + Series: Self + """ + return self._cached(force=True) + def _cached(self, *, force: bool = True) -> Series: self._set_block(self._block.cached(force=force)) return self diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 2a4b53403d..b428207314 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4204,7 +4204,7 @@ def test_df_cached(scalars_df_index): ) df = df[df["rowindex_2"] % 2 == 0] - df_cached_copy = df._cached() + df_cached_copy = df.cache() pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas()) diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index bcb220b107..48fb7011ea 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -63,7 +63,7 @@ def bqml_model_factory(mocker: pytest_mock.MockerFixture): def mock_y(): mock_y = mock.create_autospec(spec=bpd.DataFrame) mock_y.columns = pd.Index(["input_column_label"]) - mock_y._cached.return_value = mock_y + mock_y.cache.return_value = mock_y return mock_y @@ -83,7 +83,7 @@ def mock_X(mock_y, mock_session): ["index_column_id"], ["index_column_label"], ) - mock_X._cached.return_value = mock_X + mock_X.cache.return_value = mock_X return mock_X