From 8bc1cca3f09ce09f80049864e39380447c963006 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 11 Nov 2024 21:59:44 +0000 Subject: [PATCH 01/18] feat: create arima_plus_predict_attribution method --- bigframes/ml/core.py | 4 +++ bigframes/ml/forecasting.py | 38 +++++++++++++++++++++++ bigframes/ml/sql.py | 6 ++++ tests/system/small/ml/test_forecasting.py | 29 +++++++++++++++++ 4 files changed, 77 insertions(+) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 4bc61c5015..af5d5b70db 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -172,6 +172,10 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options) return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index() + def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: + sql = self._model_manipulation_sql_generator.ml_explain_forecast(struct_options=options) + return self._session.read_gbq(sql, index_col="time_series_timestamp").reset_index() + def evaluate(self, input_data: Optional[bpd.DataFrame] = None): sql = self._model_manipulation_sql_generator.ml_evaluate( input_data.sql if (input_data is not None) else None diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index d27801caa3..818b18a39a 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -252,6 +252,44 @@ def predict( return self._bqml_model.forecast( options={"horizon": horizon, "confidence_level": confidence_level} ) + + def predict_attribution( + self, X=None, *, horizon: int = 3, confidence_level: float = 0.95 + ) -> bpd.DataFrame: + """Forecast time series at future horizon. + + .. note:: + + Output matches that of the BigQuery ML.FORECAST function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-forecast + + Args: + X (default None): + ignored, to be compatible with other APIs. + horizon (int, default: 3): + an int value that specifies the number of time points to forecast. + The default value is 3, and the maximum value is 1000. + confidence_level (float, default 0.95): + A float value that specifies percentage of the future values that fall in the prediction interval. + The valid input range is [0.0, 1.0). + + Returns: + bigframes.dataframe.DataFrame: The predicted DataFrames. Which + contains 2 columns: "forecast_timestamp" and "forecast_value". + """ + if horizon < 1 or horizon > 1000: + raise ValueError(f"horizon must be [1, 1000], but is {horizon}.") + if confidence_level < 0.0 or confidence_level >= 1.0: + raise ValueError( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ) + + if not self._bqml_model: + raise RuntimeError("A model must be fitted before predict") + + return self._bqml_model.explain_forecast( + options={"horizon": horizon, "confidence_level": confidence_level} + ) @property def coef_( diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index b7d550ac63..02ebd97ef1 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -309,6 +309,12 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: struct_options_sql = self.struct_options(**struct_options) return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()}, {struct_options_sql})""" + + def ml_explain_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: + """Encode ML.EXPLAIN_FORECAST for BQML""" + struct_options_sql = self.struct_options(**struct_options) + return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()}, + {struct_options_sql})""" def ml_generate_text( self, source_sql: str, struct_options: Mapping[str, Union[int, float]] diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 7fef189550..7ea16396c0 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -64,6 +64,35 @@ def test_arima_plus_predict_default( check_index_type=False, ) +def test_arima_plus_predict_attribution_default( + time_series_arima_plus_model: forecasting.ARIMAPlus, +): + utc = pytz.utc + predictions = time_series_arima_plus_model.predict_attribution().to_pandas() + assert predictions.shape == (3, 8) + result = predictions[["forecast_timestamp", "forecast_value"]] + expected = pd.DataFrame( + { + "forecast_timestamp": [ + datetime(2017, 8, 2, tzinfo=utc), + datetime(2017, 8, 3, tzinfo=utc), + datetime(2017, 8, 4, tzinfo=utc), + ], + "forecast_value": [2724.472284, 2593.368389, 2353.613034], + } + ) + expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype()) + expected["forecast_timestamp"] = expected["forecast_timestamp"].astype( + pd.ArrowDtype(pa.timestamp("us", tz="UTC")) + ) + + pd.testing.assert_frame_equal( + result, + expected, + rtol=0.1, + check_index_type=False, + ) + def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus): utc = pytz.utc From f6dd455893e9bcd9612ba6e31c81de6166b72c7e Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 12 Nov 2024 21:56:55 +0000 Subject: [PATCH 02/18] tmp: debug notes for time_series_arima_plus_model.predict_attribution --- notebooks/debug.ipynb | 1093 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1093 insertions(+) create mode 100644 notebooks/debug.ipynb diff --git a/notebooks/debug.ipynb b/notebooks/debug.ipynb new file mode 100644 index 0000000000..88d5557e66 --- /dev/null +++ b/notebooks/debug.ipynb @@ -0,0 +1,1093 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating ARIMAPlus forcasting model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## load_test_data_tables in tests/system/conftest.py" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[SchemaField('parsed_date', 'TIMESTAMP', 'NULLABLE', None, None, (), None), SchemaField('total_visits', 'INTEGER', 'NULLABLE', None, None, (), None)]\n", + "/usr/local/google/home/chelsealin/src/bigframes/tests/data/time_series.jsonl\n" + ] + }, + { + "data": { + "text/plain": [ + "LoadJob" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import google.cloud.bigquery as bigquery\n", + "\n", + "DATA_DIR=\"/usr/local/google/home/chelsealin/src/bigframes/tests/data/\"\n", + "schema_filename=DATA_DIR + \"time_series_schema.json\"\n", + "data_filename=DATA_DIR + \"time_series.jsonl\"\n", + "\n", + "time_series_table_id='bigframes-dev.chelsealin.time_series_0'\n", + "\n", + "client = bigquery.Client(project='bigframes-dev')\n", + "\n", + "job_config = bigquery.LoadJobConfig()\n", + "job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON\n", + "job_config.schema = tuple(\n", + " client.schema_from_json(schema_filename)\n", + ")\n", + "print(job_config.schema)\n", + "job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE\n", + "\n", + "with open(data_filename, \"rb\") as input_file:\n", + " print(data_filename)\n", + " job = client.load_table_from_file(\n", + " input_file,\n", + " time_series_table_id,\n", + " job_config=job_config,\n", + " )\n", + "job.result()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## time_series_arima_plus_model in tests/system/conftest.py: " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import bigframes.pandas as bpd\n", + "\n", + "time_series_arima_plus_model_name = \"bigframes-dev.chelsealin.time_series_arima_plus_0\"\n", + "sql = f\"\"\"\n", + "CREATE OR REPLACE MODEL `{time_series_arima_plus_model_name}`\n", + "OPTIONS (\n", + " model_type='ARIMA_PLUS',\n", + " time_series_timestamp_col = 'parsed_date',\n", + " time_series_data_col = 'total_visits'\n", + ") AS SELECT\n", + " *\n", + "FROM `{time_series_table_id}`\"\"\"\n", + "\n", + "client.query(sql).result()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "time_series_arima_plus_model = bpd.read_gbq_model(time_series_arima_plus_model_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## test_arima_plus_predict_attribution_default" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Query job f525fa55-228c-45b6-8ea4-eb4700513538 is RUNNING. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job d8e97c97-e3e3-47ff-8c40-53f3eb100069 is DONE. 43.7 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 6fb13978-b613-417b-b79b-86a8c100036d is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forecast_timestampforecast_valuestandard_errorconfidence_levelprediction_interval_lower_boundprediction_interval_upper_boundconfidence_interval_lower_boundconfidence_interval_upper_bound
02017-08-02 00:00:00+00:002727.693349190.6147780.952354.7635253100.6231742354.7635253100.623174
12017-08-03 00:00:00+00:002595.290749222.0933380.952160.7744413029.8070562160.7744413029.807056
22017-08-04 00:00:00+00:002370.86767255.6459420.951870.7070972871.0282421870.7070972871.028242
\n", + "

3 rows × 8 columns

\n", + "
[3 rows x 8 columns in total]" + ], + "text/plain": [ + " forecast_timestamp forecast_value standard_error \\\n", + "0 2017-08-02 00:00:00+00:00 2727.693349 190.614778 \n", + "1 2017-08-03 00:00:00+00:00 2595.290749 222.093338 \n", + "2 2017-08-04 00:00:00+00:00 2370.86767 255.645942 \n", + "\n", + " confidence_level prediction_interval_lower_bound \\\n", + "0 0.95 2354.763525 \n", + "1 0.95 2160.774441 \n", + "2 0.95 1870.707097 \n", + "\n", + " prediction_interval_upper_bound confidence_interval_lower_bound \\\n", + "0 3100.623174 2354.763525 \n", + "1 3029.807056 2160.774441 \n", + "2 2871.028242 1870.707097 \n", + "\n", + " confidence_interval_upper_bound \n", + "0 3100.623174 \n", + "1 3029.807056 \n", + "2 2871.028242 \n", + "\n", + "[3 rows x 8 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "time_series_arima_plus_model.predict()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Query job e9f35c3d-5b45-41c9-99c8-4fdde8e90517 is RUNNING. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job a0c929b2-f4a0-4f53-a34e-771a99cf3661 is DONE. 82.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ffa9464d-9d5d-4df0-9b9d-dd6c5a1902dd is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 0eb80464-157c-4dd1-ad46-18603bffda45 is DONE. 32.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
time_series_timestamptime_series_typetime_series_datatime_series_adjusted_datastandard_errorconfidence_levelprediction_interval_lower_boundprediction_interval_upper_boundtrendseasonal_period_yearlyseasonal_period_quarterlyseasonal_period_monthlyseasonal_period_weeklyseasonal_period_dailyholiday_effectspikes_and_dipsstep_changesresidual
02016-08-01 00:00:00+00:00history1711.0505.716474190.614736<NA><NA><NA>0.0<NA><NA><NA>169.611938<NA><NA><NA>1205.283526336.104536
12016-08-02 00:00:00+00:00history2140.0625.750155190.614736<NA><NA><NA>338.716882<NA><NA><NA>287.033273<NA><NA><NA>1205.283526308.966319
22016-08-03 00:00:00+00:00history2890.0995.111101190.614736<NA><NA><NA>549.970223<NA><NA><NA>445.140878<NA><NA><NA>1205.283526689.605373
32016-08-04 00:00:00+00:00history3161.01408.363927190.614736<NA><NA><NA>1005.271573<NA><NA><NA>403.092354<NA><NA><NA>1205.283526547.352547
42016-08-05 00:00:00+00:00history2702.01381.96532190.614736<NA><NA><NA>1236.276965<NA><NA><NA>145.688355<NA><NA><NA>1205.283526114.751155
52016-08-06 00:00:00+00:00history1663.0349.426733190.614736<NA><NA><NA>1100.924343<NA><NA><NA>-751.49761<NA><NA><NA>1205.283526108.289741
62016-08-07 00:00:00+00:00history1622.0432.654477190.614736<NA><NA><NA>1134.891238<NA><NA><NA>-702.236761<NA><NA><NA>1205.283526-15.938002
72016-08-08 00:00:00+00:00history2815.01414.879653190.614736<NA><NA><NA>1239.098627<NA><NA><NA>175.781026<NA><NA><NA>1205.283526194.836821
82016-08-09 00:00:00+00:00history2851.01622.001182190.614736<NA><NA><NA>1328.328933<NA><NA><NA>293.672249<NA><NA><NA>1205.28352623.715292
92016-08-10 00:00:00+00:00history2757.01718.174551190.614736<NA><NA><NA>1273.916936<NA><NA><NA>444.257615<NA><NA><NA>1205.283526-166.458077
102016-08-11 00:00:00+00:00history2667.01537.591365190.614736<NA><NA><NA>1152.629705<NA><NA><NA>384.96166<NA><NA><NA>1205.283526-75.874891
112016-08-12 00:00:00+00:00history2619.01294.642238190.614736<NA><NA><NA>1156.721596<NA><NA><NA>137.920642<NA><NA><NA>1205.283526119.074236
122016-08-13 00:00:00+00:00history1596.0543.975173190.614736<NA><NA><NA>1290.728466<NA><NA><NA>-746.753292<NA><NA><NA>1205.283526-153.258699
132016-08-14 00:00:00+00:00history1801.0488.712986190.614736<NA><NA><NA>1187.781282<NA><NA><NA>-699.068296<NA><NA><NA>1205.283526107.003488
142016-08-15 00:00:00+00:00history3043.01375.529186190.614736<NA><NA><NA>1174.170307<NA><NA><NA>201.358878<NA><NA><NA>1205.283526462.187288
152016-08-16 00:00:00+00:00history2873.01833.311376190.614736<NA><NA><NA>1520.329586<NA><NA><NA>312.98179<NA><NA><NA>1205.283526-165.594902
162016-08-17 00:00:00+00:00history2799.01876.156796190.614736<NA><NA><NA>1441.156251<NA><NA><NA>435.000545<NA><NA><NA>1205.283526-282.440321
172016-08-18 00:00:00+00:00history2725.01412.776022190.614736<NA><NA><NA>1073.311017<NA><NA><NA>339.465005<NA><NA><NA>1205.283526106.940453
182016-08-19 00:00:00+00:00history2379.01328.455858190.614736<NA><NA><NA>1207.351708<NA><NA><NA>121.104149<NA><NA><NA>1205.283526-154.739384
192016-08-20 00:00:00+00:00history1664.0604.913116190.614736<NA><NA><NA>1344.614057<NA><NA><NA>-739.700941<NA><NA><NA>1205.283526-146.196642
202016-08-21 00:00:00+00:00history1730.0453.019443190.614736<NA><NA><NA>1140.422374<NA><NA><NA>-687.402931<NA><NA><NA>1205.28352671.697031
212016-08-22 00:00:00+00:00history2584.01341.060308190.614736<NA><NA><NA>1096.801611<NA><NA><NA>244.258697<NA><NA><NA>1205.28352637.656166
222016-08-23 00:00:00+00:00history2754.01623.867189190.614736<NA><NA><NA>1268.964143<NA><NA><NA>354.903046<NA><NA><NA>1205.283526-75.150715
232016-08-24 00:00:00+00:00history2627.01681.35638190.614736<NA><NA><NA>1242.488663<NA><NA><NA>438.867718<NA><NA><NA>1205.283526-259.639906
242016-08-25 00:00:00+00:00history2539.01241.619428190.614736<NA><NA><NA>973.603865<NA><NA><NA>268.015563<NA><NA><NA>1205.28352692.097046
\n", + "

25 rows × 18 columns

\n", + "
[369 rows x 18 columns in total]" + ], + "text/plain": [ + " time_series_timestamp time_series_type time_series_data \\\n", + "0 2016-08-01 00:00:00+00:00 history 1711.0 \n", + "1 2016-08-02 00:00:00+00:00 history 2140.0 \n", + "2 2016-08-03 00:00:00+00:00 history 2890.0 \n", + "3 2016-08-04 00:00:00+00:00 history 3161.0 \n", + "4 2016-08-05 00:00:00+00:00 history 2702.0 \n", + "5 2016-08-06 00:00:00+00:00 history 1663.0 \n", + "6 2016-08-07 00:00:00+00:00 history 1622.0 \n", + "7 2016-08-08 00:00:00+00:00 history 2815.0 \n", + "8 2016-08-09 00:00:00+00:00 history 2851.0 \n", + "9 2016-08-10 00:00:00+00:00 history 2757.0 \n", + "10 2016-08-11 00:00:00+00:00 history 2667.0 \n", + "11 2016-08-12 00:00:00+00:00 history 2619.0 \n", + "12 2016-08-13 00:00:00+00:00 history 1596.0 \n", + "13 2016-08-14 00:00:00+00:00 history 1801.0 \n", + "14 2016-08-15 00:00:00+00:00 history 3043.0 \n", + "15 2016-08-16 00:00:00+00:00 history 2873.0 \n", + "16 2016-08-17 00:00:00+00:00 history 2799.0 \n", + "17 2016-08-18 00:00:00+00:00 history 2725.0 \n", + "18 2016-08-19 00:00:00+00:00 history 2379.0 \n", + "19 2016-08-20 00:00:00+00:00 history 1664.0 \n", + "20 2016-08-21 00:00:00+00:00 history 1730.0 \n", + "21 2016-08-22 00:00:00+00:00 history 2584.0 \n", + "22 2016-08-23 00:00:00+00:00 history 2754.0 \n", + "23 2016-08-24 00:00:00+00:00 history 2627.0 \n", + "24 2016-08-25 00:00:00+00:00 history 2539.0 \n", + "\n", + " time_series_adjusted_data standard_error confidence_level \\\n", + "0 505.716474 190.614736 \n", + "1 625.750155 190.614736 \n", + "2 995.111101 190.614736 \n", + "3 1408.363927 190.614736 \n", + "4 1381.96532 190.614736 \n", + "5 349.426733 190.614736 \n", + "6 432.654477 190.614736 \n", + "7 1414.879653 190.614736 \n", + "8 1622.001182 190.614736 \n", + "9 1718.174551 190.614736 \n", + "10 1537.591365 190.614736 \n", + "11 1294.642238 190.614736 \n", + "12 543.975173 190.614736 \n", + "13 488.712986 190.614736 \n", + "14 1375.529186 190.614736 \n", + "15 1833.311376 190.614736 \n", + "16 1876.156796 190.614736 \n", + "17 1412.776022 190.614736 \n", + "18 1328.455858 190.614736 \n", + "19 604.913116 190.614736 \n", + "20 453.019443 190.614736 \n", + "21 1341.060308 190.614736 \n", + "22 1623.867189 190.614736 \n", + "23 1681.35638 190.614736 \n", + "24 1241.619428 190.614736 \n", + "\n", + " prediction_interval_lower_bound prediction_interval_upper_bound \\\n", + "0 \n", + "1 \n", + "2 \n", + "3 \n", + "4 \n", + "5 \n", + "6 \n", + "7 \n", + "8 \n", + "9 \n", + "10 \n", + "11 \n", + "12 \n", + "13 \n", + "14 \n", + "15 \n", + "16 \n", + "17 \n", + "18 \n", + "19 \n", + "20 \n", + "21 \n", + "22 \n", + "23 \n", + "24 \n", + "\n", + " trend seasonal_period_yearly seasonal_period_quarterly \\\n", + "0 0.0 \n", + "1 338.716882 \n", + "2 549.970223 \n", + "3 1005.271573 \n", + "4 1236.276965 \n", + "5 1100.924343 \n", + "6 1134.891238 \n", + "7 1239.098627 \n", + "8 1328.328933 \n", + "9 1273.916936 \n", + "10 1152.629705 \n", + "11 1156.721596 \n", + "12 1290.728466 \n", + "13 1187.781282 \n", + "14 1174.170307 \n", + "15 1520.329586 \n", + "16 1441.156251 \n", + "17 1073.311017 \n", + "18 1207.351708 \n", + "19 1344.614057 \n", + "20 1140.422374 \n", + "21 1096.801611 \n", + "22 1268.964143 \n", + "23 1242.488663 \n", + "24 973.603865 \n", + "\n", + " seasonal_period_monthly seasonal_period_weekly seasonal_period_daily \\\n", + "0 169.611938 \n", + "1 287.033273 \n", + "2 445.140878 \n", + "3 403.092354 \n", + "4 145.688355 \n", + "5 -751.49761 \n", + "6 -702.236761 \n", + "7 175.781026 \n", + "8 293.672249 \n", + "9 444.257615 \n", + "10 384.96166 \n", + "11 137.920642 \n", + "12 -746.753292 \n", + "13 -699.068296 \n", + "14 201.358878 \n", + "15 312.98179 \n", + "16 435.000545 \n", + "17 339.465005 \n", + "18 121.104149 \n", + "19 -739.700941 \n", + "20 -687.402931 \n", + "21 244.258697 \n", + "22 354.903046 \n", + "23 438.867718 \n", + "24 268.015563 \n", + "\n", + " holiday_effect spikes_and_dips step_changes residual \n", + "0 1205.283526 336.104536 \n", + "1 1205.283526 308.966319 \n", + "2 1205.283526 689.605373 \n", + "3 1205.283526 547.352547 \n", + "4 1205.283526 114.751155 \n", + "5 1205.283526 108.289741 \n", + "6 1205.283526 -15.938002 \n", + "7 1205.283526 194.836821 \n", + "8 1205.283526 23.715292 \n", + "9 1205.283526 -166.458077 \n", + "10 1205.283526 -75.874891 \n", + "11 1205.283526 119.074236 \n", + "12 1205.283526 -153.258699 \n", + "13 1205.283526 107.003488 \n", + "14 1205.283526 462.187288 \n", + "15 1205.283526 -165.594902 \n", + "16 1205.283526 -282.440321 \n", + "17 1205.283526 106.940453 \n", + "18 1205.283526 -154.739384 \n", + "19 1205.283526 -146.196642 \n", + "20 1205.283526 71.697031 \n", + "21 1205.283526 37.656166 \n", + "22 1205.283526 -75.150715 \n", + "23 1205.283526 -259.639906 \n", + "24 1205.283526 92.097046 \n", + "...\n", + "\n", + "[369 rows x 18 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "time_series_arima_plus_model.predict_attribution()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From b8ec20df2eeefdb66f036e0ff6fac42048b8591a Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 18 Nov 2024 22:34:37 +0000 Subject: [PATCH 03/18] update test_arima_plus_predict_explain_default test and create test_arima_plus_predict_explain_params test --- bigframes/ml/forecasting.py | 11 +++--- tests/system/small/ml/test_forecasting.py | 46 +++++++++++++++++++---- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 818b18a39a..faa8603a15 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -253,15 +253,15 @@ def predict( options={"horizon": horizon, "confidence_level": confidence_level} ) - def predict_attribution( + def predict_explain( self, X=None, *, horizon: int = 3, confidence_level: float = 0.95 ) -> bpd.DataFrame: - """Forecast time series at future horizon. + """Explain Forecast time series at future horizon. .. note:: - Output matches that of the BigQuery ML.FORECAST function. - See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-forecast + Output matches that of the BigQuery ML.EXPLAIN_FORECAST function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-forecast Args: X (default None): @@ -274,8 +274,7 @@ def predict_attribution( The valid input range is [0.0, 1.0). Returns: - bigframes.dataframe.DataFrame: The predicted DataFrames. Which - contains 2 columns: "forecast_timestamp" and "forecast_value". + bigframes.dataframe.DataFrame: The predicted DataFrames. """ if horizon < 1 or horizon > 1000: raise ValueError(f"horizon must be [1, 1000], but is {horizon}.") diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 7ea16396c0..5cdd5230cf 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -64,12 +64,44 @@ def test_arima_plus_predict_default( check_index_type=False, ) -def test_arima_plus_predict_attribution_default( +def test_arima_plus_predict_explain_default( time_series_arima_plus_model: forecasting.ARIMAPlus, ): utc = pytz.utc - predictions = time_series_arima_plus_model.predict_attribution().to_pandas() - assert predictions.shape == (3, 8) + predictions = time_series_arima_plus_model.predict_explain().to_pandas() + assert predictions.shape[0] == 369 + predictions = predictions[predictions["time_series_type"] == "forecast"].reset_index(drop=True) + assert predictions.shape[0] == 3 + result = predictions[["time_series_timestamp", "time_series_data"]] + expected = pd.DataFrame( + { + "time_series_timestamp": [ + datetime(2017, 8, 2, tzinfo=utc), + datetime(2017, 8, 3, tzinfo=utc), + datetime(2017, 8, 4, tzinfo=utc), + ], + "time_series_data": [2727.693349, 2595.290749, 2370.86767], + } + ) + expected["time_series_data"] = expected["time_series_data"].astype(pd.Float64Dtype()) + expected["time_series_timestamp"] = expected["time_series_timestamp"].astype( + pd.ArrowDtype(pa.timestamp("us", tz="UTC")) + ) + + pd.testing.assert_frame_equal( + result, + expected, + rtol=0.1, + check_index_type=False, + ) + + +def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus): + utc = pytz.utc + predictions = time_series_arima_plus_model.predict( + horizon=4, confidence_level=0.9 + ).to_pandas() + assert predictions.shape == (4, 8) result = predictions[["forecast_timestamp", "forecast_value"]] expected = pd.DataFrame( { @@ -77,8 +109,9 @@ def test_arima_plus_predict_attribution_default( datetime(2017, 8, 2, tzinfo=utc), datetime(2017, 8, 3, tzinfo=utc), datetime(2017, 8, 4, tzinfo=utc), + datetime(2017, 8, 5, tzinfo=utc), ], - "forecast_value": [2724.472284, 2593.368389, 2353.613034], + "forecast_value": [2724.472284, 2593.368389, 2353.613034, 1781.623071], } ) expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype()) @@ -93,10 +126,9 @@ def test_arima_plus_predict_attribution_default( check_index_type=False, ) - -def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus): +def test_arima_plus_predict_explain_params(time_series_arima_plus_model: forecasting.ARIMAPlus): utc = pytz.utc - predictions = time_series_arima_plus_model.predict( + predictions = time_series_arima_plus_model.predict_explain( horizon=4, confidence_level=0.9 ).to_pandas() assert predictions.shape == (4, 8) From 8056c92efb8e86012fa45a48baf983f2fc93bd41 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 18 Nov 2024 22:40:57 +0000 Subject: [PATCH 04/18] Merge branch 'ml-predict-explain' of github.com:googleapis/python-bigquery-dataframes into ml-predict-explain --- notebooks/debug.ipynb | 1093 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1093 insertions(+) create mode 100644 notebooks/debug.ipynb diff --git a/notebooks/debug.ipynb b/notebooks/debug.ipynb new file mode 100644 index 0000000000..88d5557e66 --- /dev/null +++ b/notebooks/debug.ipynb @@ -0,0 +1,1093 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating ARIMAPlus forcasting model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## load_test_data_tables in tests/system/conftest.py" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[SchemaField('parsed_date', 'TIMESTAMP', 'NULLABLE', None, None, (), None), SchemaField('total_visits', 'INTEGER', 'NULLABLE', None, None, (), None)]\n", + "/usr/local/google/home/chelsealin/src/bigframes/tests/data/time_series.jsonl\n" + ] + }, + { + "data": { + "text/plain": [ + "LoadJob" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import google.cloud.bigquery as bigquery\n", + "\n", + "DATA_DIR=\"/usr/local/google/home/chelsealin/src/bigframes/tests/data/\"\n", + "schema_filename=DATA_DIR + \"time_series_schema.json\"\n", + "data_filename=DATA_DIR + \"time_series.jsonl\"\n", + "\n", + "time_series_table_id='bigframes-dev.chelsealin.time_series_0'\n", + "\n", + "client = bigquery.Client(project='bigframes-dev')\n", + "\n", + "job_config = bigquery.LoadJobConfig()\n", + "job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON\n", + "job_config.schema = tuple(\n", + " client.schema_from_json(schema_filename)\n", + ")\n", + "print(job_config.schema)\n", + "job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE\n", + "\n", + "with open(data_filename, \"rb\") as input_file:\n", + " print(data_filename)\n", + " job = client.load_table_from_file(\n", + " input_file,\n", + " time_series_table_id,\n", + " job_config=job_config,\n", + " )\n", + "job.result()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## time_series_arima_plus_model in tests/system/conftest.py: " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import bigframes.pandas as bpd\n", + "\n", + "time_series_arima_plus_model_name = \"bigframes-dev.chelsealin.time_series_arima_plus_0\"\n", + "sql = f\"\"\"\n", + "CREATE OR REPLACE MODEL `{time_series_arima_plus_model_name}`\n", + "OPTIONS (\n", + " model_type='ARIMA_PLUS',\n", + " time_series_timestamp_col = 'parsed_date',\n", + " time_series_data_col = 'total_visits'\n", + ") AS SELECT\n", + " *\n", + "FROM `{time_series_table_id}`\"\"\"\n", + "\n", + "client.query(sql).result()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "time_series_arima_plus_model = bpd.read_gbq_model(time_series_arima_plus_model_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## test_arima_plus_predict_attribution_default" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Query job f525fa55-228c-45b6-8ea4-eb4700513538 is RUNNING. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job d8e97c97-e3e3-47ff-8c40-53f3eb100069 is DONE. 43.7 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 6fb13978-b613-417b-b79b-86a8c100036d is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
forecast_timestampforecast_valuestandard_errorconfidence_levelprediction_interval_lower_boundprediction_interval_upper_boundconfidence_interval_lower_boundconfidence_interval_upper_bound
02017-08-02 00:00:00+00:002727.693349190.6147780.952354.7635253100.6231742354.7635253100.623174
12017-08-03 00:00:00+00:002595.290749222.0933380.952160.7744413029.8070562160.7744413029.807056
22017-08-04 00:00:00+00:002370.86767255.6459420.951870.7070972871.0282421870.7070972871.028242
\n", + "

3 rows × 8 columns

\n", + "
[3 rows x 8 columns in total]" + ], + "text/plain": [ + " forecast_timestamp forecast_value standard_error \\\n", + "0 2017-08-02 00:00:00+00:00 2727.693349 190.614778 \n", + "1 2017-08-03 00:00:00+00:00 2595.290749 222.093338 \n", + "2 2017-08-04 00:00:00+00:00 2370.86767 255.645942 \n", + "\n", + " confidence_level prediction_interval_lower_bound \\\n", + "0 0.95 2354.763525 \n", + "1 0.95 2160.774441 \n", + "2 0.95 1870.707097 \n", + "\n", + " prediction_interval_upper_bound confidence_interval_lower_bound \\\n", + "0 3100.623174 2354.763525 \n", + "1 3029.807056 2160.774441 \n", + "2 2871.028242 1870.707097 \n", + "\n", + " confidence_interval_upper_bound \n", + "0 3100.623174 \n", + "1 3029.807056 \n", + "2 2871.028242 \n", + "\n", + "[3 rows x 8 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "time_series_arima_plus_model.predict()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Query job e9f35c3d-5b45-41c9-99c8-4fdde8e90517 is RUNNING. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job a0c929b2-f4a0-4f53-a34e-771a99cf3661 is DONE. 82.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ffa9464d-9d5d-4df0-9b9d-dd6c5a1902dd is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 0eb80464-157c-4dd1-ad46-18603bffda45 is DONE. 32.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
time_series_timestamptime_series_typetime_series_datatime_series_adjusted_datastandard_errorconfidence_levelprediction_interval_lower_boundprediction_interval_upper_boundtrendseasonal_period_yearlyseasonal_period_quarterlyseasonal_period_monthlyseasonal_period_weeklyseasonal_period_dailyholiday_effectspikes_and_dipsstep_changesresidual
02016-08-01 00:00:00+00:00history1711.0505.716474190.614736<NA><NA><NA>0.0<NA><NA><NA>169.611938<NA><NA><NA>1205.283526336.104536
12016-08-02 00:00:00+00:00history2140.0625.750155190.614736<NA><NA><NA>338.716882<NA><NA><NA>287.033273<NA><NA><NA>1205.283526308.966319
22016-08-03 00:00:00+00:00history2890.0995.111101190.614736<NA><NA><NA>549.970223<NA><NA><NA>445.140878<NA><NA><NA>1205.283526689.605373
32016-08-04 00:00:00+00:00history3161.01408.363927190.614736<NA><NA><NA>1005.271573<NA><NA><NA>403.092354<NA><NA><NA>1205.283526547.352547
42016-08-05 00:00:00+00:00history2702.01381.96532190.614736<NA><NA><NA>1236.276965<NA><NA><NA>145.688355<NA><NA><NA>1205.283526114.751155
52016-08-06 00:00:00+00:00history1663.0349.426733190.614736<NA><NA><NA>1100.924343<NA><NA><NA>-751.49761<NA><NA><NA>1205.283526108.289741
62016-08-07 00:00:00+00:00history1622.0432.654477190.614736<NA><NA><NA>1134.891238<NA><NA><NA>-702.236761<NA><NA><NA>1205.283526-15.938002
72016-08-08 00:00:00+00:00history2815.01414.879653190.614736<NA><NA><NA>1239.098627<NA><NA><NA>175.781026<NA><NA><NA>1205.283526194.836821
82016-08-09 00:00:00+00:00history2851.01622.001182190.614736<NA><NA><NA>1328.328933<NA><NA><NA>293.672249<NA><NA><NA>1205.28352623.715292
92016-08-10 00:00:00+00:00history2757.01718.174551190.614736<NA><NA><NA>1273.916936<NA><NA><NA>444.257615<NA><NA><NA>1205.283526-166.458077
102016-08-11 00:00:00+00:00history2667.01537.591365190.614736<NA><NA><NA>1152.629705<NA><NA><NA>384.96166<NA><NA><NA>1205.283526-75.874891
112016-08-12 00:00:00+00:00history2619.01294.642238190.614736<NA><NA><NA>1156.721596<NA><NA><NA>137.920642<NA><NA><NA>1205.283526119.074236
122016-08-13 00:00:00+00:00history1596.0543.975173190.614736<NA><NA><NA>1290.728466<NA><NA><NA>-746.753292<NA><NA><NA>1205.283526-153.258699
132016-08-14 00:00:00+00:00history1801.0488.712986190.614736<NA><NA><NA>1187.781282<NA><NA><NA>-699.068296<NA><NA><NA>1205.283526107.003488
142016-08-15 00:00:00+00:00history3043.01375.529186190.614736<NA><NA><NA>1174.170307<NA><NA><NA>201.358878<NA><NA><NA>1205.283526462.187288
152016-08-16 00:00:00+00:00history2873.01833.311376190.614736<NA><NA><NA>1520.329586<NA><NA><NA>312.98179<NA><NA><NA>1205.283526-165.594902
162016-08-17 00:00:00+00:00history2799.01876.156796190.614736<NA><NA><NA>1441.156251<NA><NA><NA>435.000545<NA><NA><NA>1205.283526-282.440321
172016-08-18 00:00:00+00:00history2725.01412.776022190.614736<NA><NA><NA>1073.311017<NA><NA><NA>339.465005<NA><NA><NA>1205.283526106.940453
182016-08-19 00:00:00+00:00history2379.01328.455858190.614736<NA><NA><NA>1207.351708<NA><NA><NA>121.104149<NA><NA><NA>1205.283526-154.739384
192016-08-20 00:00:00+00:00history1664.0604.913116190.614736<NA><NA><NA>1344.614057<NA><NA><NA>-739.700941<NA><NA><NA>1205.283526-146.196642
202016-08-21 00:00:00+00:00history1730.0453.019443190.614736<NA><NA><NA>1140.422374<NA><NA><NA>-687.402931<NA><NA><NA>1205.28352671.697031
212016-08-22 00:00:00+00:00history2584.01341.060308190.614736<NA><NA><NA>1096.801611<NA><NA><NA>244.258697<NA><NA><NA>1205.28352637.656166
222016-08-23 00:00:00+00:00history2754.01623.867189190.614736<NA><NA><NA>1268.964143<NA><NA><NA>354.903046<NA><NA><NA>1205.283526-75.150715
232016-08-24 00:00:00+00:00history2627.01681.35638190.614736<NA><NA><NA>1242.488663<NA><NA><NA>438.867718<NA><NA><NA>1205.283526-259.639906
242016-08-25 00:00:00+00:00history2539.01241.619428190.614736<NA><NA><NA>973.603865<NA><NA><NA>268.015563<NA><NA><NA>1205.28352692.097046
\n", + "

25 rows × 18 columns

\n", + "
[369 rows x 18 columns in total]" + ], + "text/plain": [ + " time_series_timestamp time_series_type time_series_data \\\n", + "0 2016-08-01 00:00:00+00:00 history 1711.0 \n", + "1 2016-08-02 00:00:00+00:00 history 2140.0 \n", + "2 2016-08-03 00:00:00+00:00 history 2890.0 \n", + "3 2016-08-04 00:00:00+00:00 history 3161.0 \n", + "4 2016-08-05 00:00:00+00:00 history 2702.0 \n", + "5 2016-08-06 00:00:00+00:00 history 1663.0 \n", + "6 2016-08-07 00:00:00+00:00 history 1622.0 \n", + "7 2016-08-08 00:00:00+00:00 history 2815.0 \n", + "8 2016-08-09 00:00:00+00:00 history 2851.0 \n", + "9 2016-08-10 00:00:00+00:00 history 2757.0 \n", + "10 2016-08-11 00:00:00+00:00 history 2667.0 \n", + "11 2016-08-12 00:00:00+00:00 history 2619.0 \n", + "12 2016-08-13 00:00:00+00:00 history 1596.0 \n", + "13 2016-08-14 00:00:00+00:00 history 1801.0 \n", + "14 2016-08-15 00:00:00+00:00 history 3043.0 \n", + "15 2016-08-16 00:00:00+00:00 history 2873.0 \n", + "16 2016-08-17 00:00:00+00:00 history 2799.0 \n", + "17 2016-08-18 00:00:00+00:00 history 2725.0 \n", + "18 2016-08-19 00:00:00+00:00 history 2379.0 \n", + "19 2016-08-20 00:00:00+00:00 history 1664.0 \n", + "20 2016-08-21 00:00:00+00:00 history 1730.0 \n", + "21 2016-08-22 00:00:00+00:00 history 2584.0 \n", + "22 2016-08-23 00:00:00+00:00 history 2754.0 \n", + "23 2016-08-24 00:00:00+00:00 history 2627.0 \n", + "24 2016-08-25 00:00:00+00:00 history 2539.0 \n", + "\n", + " time_series_adjusted_data standard_error confidence_level \\\n", + "0 505.716474 190.614736 \n", + "1 625.750155 190.614736 \n", + "2 995.111101 190.614736 \n", + "3 1408.363927 190.614736 \n", + "4 1381.96532 190.614736 \n", + "5 349.426733 190.614736 \n", + "6 432.654477 190.614736 \n", + "7 1414.879653 190.614736 \n", + "8 1622.001182 190.614736 \n", + "9 1718.174551 190.614736 \n", + "10 1537.591365 190.614736 \n", + "11 1294.642238 190.614736 \n", + "12 543.975173 190.614736 \n", + "13 488.712986 190.614736 \n", + "14 1375.529186 190.614736 \n", + "15 1833.311376 190.614736 \n", + "16 1876.156796 190.614736 \n", + "17 1412.776022 190.614736 \n", + "18 1328.455858 190.614736 \n", + "19 604.913116 190.614736 \n", + "20 453.019443 190.614736 \n", + "21 1341.060308 190.614736 \n", + "22 1623.867189 190.614736 \n", + "23 1681.35638 190.614736 \n", + "24 1241.619428 190.614736 \n", + "\n", + " prediction_interval_lower_bound prediction_interval_upper_bound \\\n", + "0 \n", + "1 \n", + "2 \n", + "3 \n", + "4 \n", + "5 \n", + "6 \n", + "7 \n", + "8 \n", + "9 \n", + "10 \n", + "11 \n", + "12 \n", + "13 \n", + "14 \n", + "15 \n", + "16 \n", + "17 \n", + "18 \n", + "19 \n", + "20 \n", + "21 \n", + "22 \n", + "23 \n", + "24 \n", + "\n", + " trend seasonal_period_yearly seasonal_period_quarterly \\\n", + "0 0.0 \n", + "1 338.716882 \n", + "2 549.970223 \n", + "3 1005.271573 \n", + "4 1236.276965 \n", + "5 1100.924343 \n", + "6 1134.891238 \n", + "7 1239.098627 \n", + "8 1328.328933 \n", + "9 1273.916936 \n", + "10 1152.629705 \n", + "11 1156.721596 \n", + "12 1290.728466 \n", + "13 1187.781282 \n", + "14 1174.170307 \n", + "15 1520.329586 \n", + "16 1441.156251 \n", + "17 1073.311017 \n", + "18 1207.351708 \n", + "19 1344.614057 \n", + "20 1140.422374 \n", + "21 1096.801611 \n", + "22 1268.964143 \n", + "23 1242.488663 \n", + "24 973.603865 \n", + "\n", + " seasonal_period_monthly seasonal_period_weekly seasonal_period_daily \\\n", + "0 169.611938 \n", + "1 287.033273 \n", + "2 445.140878 \n", + "3 403.092354 \n", + "4 145.688355 \n", + "5 -751.49761 \n", + "6 -702.236761 \n", + "7 175.781026 \n", + "8 293.672249 \n", + "9 444.257615 \n", + "10 384.96166 \n", + "11 137.920642 \n", + "12 -746.753292 \n", + "13 -699.068296 \n", + "14 201.358878 \n", + "15 312.98179 \n", + "16 435.000545 \n", + "17 339.465005 \n", + "18 121.104149 \n", + "19 -739.700941 \n", + "20 -687.402931 \n", + "21 244.258697 \n", + "22 354.903046 \n", + "23 438.867718 \n", + "24 268.015563 \n", + "\n", + " holiday_effect spikes_and_dips step_changes residual \n", + "0 1205.283526 336.104536 \n", + "1 1205.283526 308.966319 \n", + "2 1205.283526 689.605373 \n", + "3 1205.283526 547.352547 \n", + "4 1205.283526 114.751155 \n", + "5 1205.283526 108.289741 \n", + "6 1205.283526 -15.938002 \n", + "7 1205.283526 194.836821 \n", + "8 1205.283526 23.715292 \n", + "9 1205.283526 -166.458077 \n", + "10 1205.283526 -75.874891 \n", + "11 1205.283526 119.074236 \n", + "12 1205.283526 -153.258699 \n", + "13 1205.283526 107.003488 \n", + "14 1205.283526 462.187288 \n", + "15 1205.283526 -165.594902 \n", + "16 1205.283526 -282.440321 \n", + "17 1205.283526 106.940453 \n", + "18 1205.283526 -154.739384 \n", + "19 1205.283526 -146.196642 \n", + "20 1205.283526 71.697031 \n", + "21 1205.283526 37.656166 \n", + "22 1205.283526 -75.150715 \n", + "23 1205.283526 -259.639906 \n", + "24 1205.283526 92.097046 \n", + "...\n", + "\n", + "[369 rows x 18 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "time_series_arima_plus_model.predict_attribution()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From a161b338fa995a1dd6ea459e5507ad3e262db287 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 18 Nov 2024 22:44:50 +0000 Subject: [PATCH 05/18] update test_arima_plus_predict_explain_params test --- tests/system/small/ml/test_forecasting.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 5cdd5230cf..00016c189a 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -132,20 +132,20 @@ def test_arima_plus_predict_explain_params(time_series_arima_plus_model: forecas horizon=4, confidence_level=0.9 ).to_pandas() assert predictions.shape == (4, 8) - result = predictions[["forecast_timestamp", "forecast_value"]] + result = predictions[["time_series_timestamp", "time_series_data"]] expected = pd.DataFrame( { - "forecast_timestamp": [ + "time_series_timestamp": [ datetime(2017, 8, 2, tzinfo=utc), datetime(2017, 8, 3, tzinfo=utc), datetime(2017, 8, 4, tzinfo=utc), datetime(2017, 8, 5, tzinfo=utc), ], - "forecast_value": [2724.472284, 2593.368389, 2353.613034, 1781.623071], + "time_series_data": [2724.472284, 2593.368389, 2353.613034, 1781.623071], } ) - expected["forecast_value"] = expected["forecast_value"].astype(pd.Float64Dtype()) - expected["forecast_timestamp"] = expected["forecast_timestamp"].astype( + expected["time_series_data"] = expected["time_series_data"].astype(pd.Float64Dtype()) + expected["time_series_timestamp"] = expected["time_series_timestamp"].astype( pd.ArrowDtype(pa.timestamp("us", tz="UTC")) ) From 347c3c4c32e2b17fc3f82fb1507bc71265441cd9 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 26 Nov 2024 18:07:32 +0000 Subject: [PATCH 06/18] Revert "tmp: debug notes for time_series_arima_plus_model.predict_attribution" This reverts commit f6dd455893e9bcd9612ba6e31c81de6166b72c7e. --- notebooks/debug.ipynb | 1093 ----------------------------------------- 1 file changed, 1093 deletions(-) delete mode 100644 notebooks/debug.ipynb diff --git a/notebooks/debug.ipynb b/notebooks/debug.ipynb deleted file mode 100644 index 88d5557e66..0000000000 --- a/notebooks/debug.ipynb +++ /dev/null @@ -1,1093 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Creating ARIMAPlus forcasting model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## load_test_data_tables in tests/system/conftest.py" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[SchemaField('parsed_date', 'TIMESTAMP', 'NULLABLE', None, None, (), None), SchemaField('total_visits', 'INTEGER', 'NULLABLE', None, None, (), None)]\n", - "/usr/local/google/home/chelsealin/src/bigframes/tests/data/time_series.jsonl\n" - ] - }, - { - "data": { - "text/plain": [ - "LoadJob" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import google.cloud.bigquery as bigquery\n", - "\n", - "DATA_DIR=\"/usr/local/google/home/chelsealin/src/bigframes/tests/data/\"\n", - "schema_filename=DATA_DIR + \"time_series_schema.json\"\n", - "data_filename=DATA_DIR + \"time_series.jsonl\"\n", - "\n", - "time_series_table_id='bigframes-dev.chelsealin.time_series_0'\n", - "\n", - "client = bigquery.Client(project='bigframes-dev')\n", - "\n", - "job_config = bigquery.LoadJobConfig()\n", - "job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON\n", - "job_config.schema = tuple(\n", - " client.schema_from_json(schema_filename)\n", - ")\n", - "print(job_config.schema)\n", - "job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE\n", - "\n", - "with open(data_filename, \"rb\") as input_file:\n", - " print(data_filename)\n", - " job = client.load_table_from_file(\n", - " input_file,\n", - " time_series_table_id,\n", - " job_config=job_config,\n", - " )\n", - "job.result()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## time_series_arima_plus_model in tests/system/conftest.py: " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import bigframes.pandas as bpd\n", - "\n", - "time_series_arima_plus_model_name = \"bigframes-dev.chelsealin.time_series_arima_plus_0\"\n", - "sql = f\"\"\"\n", - "CREATE OR REPLACE MODEL `{time_series_arima_plus_model_name}`\n", - "OPTIONS (\n", - " model_type='ARIMA_PLUS',\n", - " time_series_timestamp_col = 'parsed_date',\n", - " time_series_data_col = 'total_visits'\n", - ") AS SELECT\n", - " *\n", - "FROM `{time_series_table_id}`\"\"\"\n", - "\n", - "client.query(sql).result()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "time_series_arima_plus_model = bpd.read_gbq_model(time_series_arima_plus_model_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## test_arima_plus_predict_attribution_default" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "Query job f525fa55-228c-45b6-8ea4-eb4700513538 is RUNNING. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Query job d8e97c97-e3e3-47ff-8c40-53f3eb100069 is DONE. 43.7 kB processed. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Query job 6fb13978-b613-417b-b79b-86a8c100036d is DONE. 0 Bytes processed. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
forecast_timestampforecast_valuestandard_errorconfidence_levelprediction_interval_lower_boundprediction_interval_upper_boundconfidence_interval_lower_boundconfidence_interval_upper_bound
02017-08-02 00:00:00+00:002727.693349190.6147780.952354.7635253100.6231742354.7635253100.623174
12017-08-03 00:00:00+00:002595.290749222.0933380.952160.7744413029.8070562160.7744413029.807056
22017-08-04 00:00:00+00:002370.86767255.6459420.951870.7070972871.0282421870.7070972871.028242
\n", - "

3 rows × 8 columns

\n", - "
[3 rows x 8 columns in total]" - ], - "text/plain": [ - " forecast_timestamp forecast_value standard_error \\\n", - "0 2017-08-02 00:00:00+00:00 2727.693349 190.614778 \n", - "1 2017-08-03 00:00:00+00:00 2595.290749 222.093338 \n", - "2 2017-08-04 00:00:00+00:00 2370.86767 255.645942 \n", - "\n", - " confidence_level prediction_interval_lower_bound \\\n", - "0 0.95 2354.763525 \n", - "1 0.95 2160.774441 \n", - "2 0.95 1870.707097 \n", - "\n", - " prediction_interval_upper_bound confidence_interval_lower_bound \\\n", - "0 3100.623174 2354.763525 \n", - "1 3029.807056 2160.774441 \n", - "2 2871.028242 1870.707097 \n", - "\n", - " confidence_interval_upper_bound \n", - "0 3100.623174 \n", - "1 3029.807056 \n", - "2 2871.028242 \n", - "\n", - "[3 rows x 8 columns]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "time_series_arima_plus_model.predict()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "Query job e9f35c3d-5b45-41c9-99c8-4fdde8e90517 is RUNNING. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Query job a0c929b2-f4a0-4f53-a34e-771a99cf3661 is DONE. 82.4 kB processed. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Query job ffa9464d-9d5d-4df0-9b9d-dd6c5a1902dd is DONE. 0 Bytes processed. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Query job 0eb80464-157c-4dd1-ad46-18603bffda45 is DONE. 32.9 kB processed. Open Job" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
time_series_timestamptime_series_typetime_series_datatime_series_adjusted_datastandard_errorconfidence_levelprediction_interval_lower_boundprediction_interval_upper_boundtrendseasonal_period_yearlyseasonal_period_quarterlyseasonal_period_monthlyseasonal_period_weeklyseasonal_period_dailyholiday_effectspikes_and_dipsstep_changesresidual
02016-08-01 00:00:00+00:00history1711.0505.716474190.614736<NA><NA><NA>0.0<NA><NA><NA>169.611938<NA><NA><NA>1205.283526336.104536
12016-08-02 00:00:00+00:00history2140.0625.750155190.614736<NA><NA><NA>338.716882<NA><NA><NA>287.033273<NA><NA><NA>1205.283526308.966319
22016-08-03 00:00:00+00:00history2890.0995.111101190.614736<NA><NA><NA>549.970223<NA><NA><NA>445.140878<NA><NA><NA>1205.283526689.605373
32016-08-04 00:00:00+00:00history3161.01408.363927190.614736<NA><NA><NA>1005.271573<NA><NA><NA>403.092354<NA><NA><NA>1205.283526547.352547
42016-08-05 00:00:00+00:00history2702.01381.96532190.614736<NA><NA><NA>1236.276965<NA><NA><NA>145.688355<NA><NA><NA>1205.283526114.751155
52016-08-06 00:00:00+00:00history1663.0349.426733190.614736<NA><NA><NA>1100.924343<NA><NA><NA>-751.49761<NA><NA><NA>1205.283526108.289741
62016-08-07 00:00:00+00:00history1622.0432.654477190.614736<NA><NA><NA>1134.891238<NA><NA><NA>-702.236761<NA><NA><NA>1205.283526-15.938002
72016-08-08 00:00:00+00:00history2815.01414.879653190.614736<NA><NA><NA>1239.098627<NA><NA><NA>175.781026<NA><NA><NA>1205.283526194.836821
82016-08-09 00:00:00+00:00history2851.01622.001182190.614736<NA><NA><NA>1328.328933<NA><NA><NA>293.672249<NA><NA><NA>1205.28352623.715292
92016-08-10 00:00:00+00:00history2757.01718.174551190.614736<NA><NA><NA>1273.916936<NA><NA><NA>444.257615<NA><NA><NA>1205.283526-166.458077
102016-08-11 00:00:00+00:00history2667.01537.591365190.614736<NA><NA><NA>1152.629705<NA><NA><NA>384.96166<NA><NA><NA>1205.283526-75.874891
112016-08-12 00:00:00+00:00history2619.01294.642238190.614736<NA><NA><NA>1156.721596<NA><NA><NA>137.920642<NA><NA><NA>1205.283526119.074236
122016-08-13 00:00:00+00:00history1596.0543.975173190.614736<NA><NA><NA>1290.728466<NA><NA><NA>-746.753292<NA><NA><NA>1205.283526-153.258699
132016-08-14 00:00:00+00:00history1801.0488.712986190.614736<NA><NA><NA>1187.781282<NA><NA><NA>-699.068296<NA><NA><NA>1205.283526107.003488
142016-08-15 00:00:00+00:00history3043.01375.529186190.614736<NA><NA><NA>1174.170307<NA><NA><NA>201.358878<NA><NA><NA>1205.283526462.187288
152016-08-16 00:00:00+00:00history2873.01833.311376190.614736<NA><NA><NA>1520.329586<NA><NA><NA>312.98179<NA><NA><NA>1205.283526-165.594902
162016-08-17 00:00:00+00:00history2799.01876.156796190.614736<NA><NA><NA>1441.156251<NA><NA><NA>435.000545<NA><NA><NA>1205.283526-282.440321
172016-08-18 00:00:00+00:00history2725.01412.776022190.614736<NA><NA><NA>1073.311017<NA><NA><NA>339.465005<NA><NA><NA>1205.283526106.940453
182016-08-19 00:00:00+00:00history2379.01328.455858190.614736<NA><NA><NA>1207.351708<NA><NA><NA>121.104149<NA><NA><NA>1205.283526-154.739384
192016-08-20 00:00:00+00:00history1664.0604.913116190.614736<NA><NA><NA>1344.614057<NA><NA><NA>-739.700941<NA><NA><NA>1205.283526-146.196642
202016-08-21 00:00:00+00:00history1730.0453.019443190.614736<NA><NA><NA>1140.422374<NA><NA><NA>-687.402931<NA><NA><NA>1205.28352671.697031
212016-08-22 00:00:00+00:00history2584.01341.060308190.614736<NA><NA><NA>1096.801611<NA><NA><NA>244.258697<NA><NA><NA>1205.28352637.656166
222016-08-23 00:00:00+00:00history2754.01623.867189190.614736<NA><NA><NA>1268.964143<NA><NA><NA>354.903046<NA><NA><NA>1205.283526-75.150715
232016-08-24 00:00:00+00:00history2627.01681.35638190.614736<NA><NA><NA>1242.488663<NA><NA><NA>438.867718<NA><NA><NA>1205.283526-259.639906
242016-08-25 00:00:00+00:00history2539.01241.619428190.614736<NA><NA><NA>973.603865<NA><NA><NA>268.015563<NA><NA><NA>1205.28352692.097046
\n", - "

25 rows × 18 columns

\n", - "
[369 rows x 18 columns in total]" - ], - "text/plain": [ - " time_series_timestamp time_series_type time_series_data \\\n", - "0 2016-08-01 00:00:00+00:00 history 1711.0 \n", - "1 2016-08-02 00:00:00+00:00 history 2140.0 \n", - "2 2016-08-03 00:00:00+00:00 history 2890.0 \n", - "3 2016-08-04 00:00:00+00:00 history 3161.0 \n", - "4 2016-08-05 00:00:00+00:00 history 2702.0 \n", - "5 2016-08-06 00:00:00+00:00 history 1663.0 \n", - "6 2016-08-07 00:00:00+00:00 history 1622.0 \n", - "7 2016-08-08 00:00:00+00:00 history 2815.0 \n", - "8 2016-08-09 00:00:00+00:00 history 2851.0 \n", - "9 2016-08-10 00:00:00+00:00 history 2757.0 \n", - "10 2016-08-11 00:00:00+00:00 history 2667.0 \n", - "11 2016-08-12 00:00:00+00:00 history 2619.0 \n", - "12 2016-08-13 00:00:00+00:00 history 1596.0 \n", - "13 2016-08-14 00:00:00+00:00 history 1801.0 \n", - "14 2016-08-15 00:00:00+00:00 history 3043.0 \n", - "15 2016-08-16 00:00:00+00:00 history 2873.0 \n", - "16 2016-08-17 00:00:00+00:00 history 2799.0 \n", - "17 2016-08-18 00:00:00+00:00 history 2725.0 \n", - "18 2016-08-19 00:00:00+00:00 history 2379.0 \n", - "19 2016-08-20 00:00:00+00:00 history 1664.0 \n", - "20 2016-08-21 00:00:00+00:00 history 1730.0 \n", - "21 2016-08-22 00:00:00+00:00 history 2584.0 \n", - "22 2016-08-23 00:00:00+00:00 history 2754.0 \n", - "23 2016-08-24 00:00:00+00:00 history 2627.0 \n", - "24 2016-08-25 00:00:00+00:00 history 2539.0 \n", - "\n", - " time_series_adjusted_data standard_error confidence_level \\\n", - "0 505.716474 190.614736 \n", - "1 625.750155 190.614736 \n", - "2 995.111101 190.614736 \n", - "3 1408.363927 190.614736 \n", - "4 1381.96532 190.614736 \n", - "5 349.426733 190.614736 \n", - "6 432.654477 190.614736 \n", - "7 1414.879653 190.614736 \n", - "8 1622.001182 190.614736 \n", - "9 1718.174551 190.614736 \n", - "10 1537.591365 190.614736 \n", - "11 1294.642238 190.614736 \n", - "12 543.975173 190.614736 \n", - "13 488.712986 190.614736 \n", - "14 1375.529186 190.614736 \n", - "15 1833.311376 190.614736 \n", - "16 1876.156796 190.614736 \n", - "17 1412.776022 190.614736 \n", - "18 1328.455858 190.614736 \n", - "19 604.913116 190.614736 \n", - "20 453.019443 190.614736 \n", - "21 1341.060308 190.614736 \n", - "22 1623.867189 190.614736 \n", - "23 1681.35638 190.614736 \n", - "24 1241.619428 190.614736 \n", - "\n", - " prediction_interval_lower_bound prediction_interval_upper_bound \\\n", - "0 \n", - "1 \n", - "2 \n", - "3 \n", - "4 \n", - "5 \n", - "6 \n", - "7 \n", - "8 \n", - "9 \n", - "10 \n", - "11 \n", - "12 \n", - "13 \n", - "14 \n", - "15 \n", - "16 \n", - "17 \n", - "18 \n", - "19 \n", - "20 \n", - "21 \n", - "22 \n", - "23 \n", - "24 \n", - "\n", - " trend seasonal_period_yearly seasonal_period_quarterly \\\n", - "0 0.0 \n", - "1 338.716882 \n", - "2 549.970223 \n", - "3 1005.271573 \n", - "4 1236.276965 \n", - "5 1100.924343 \n", - "6 1134.891238 \n", - "7 1239.098627 \n", - "8 1328.328933 \n", - "9 1273.916936 \n", - "10 1152.629705 \n", - "11 1156.721596 \n", - "12 1290.728466 \n", - "13 1187.781282 \n", - "14 1174.170307 \n", - "15 1520.329586 \n", - "16 1441.156251 \n", - "17 1073.311017 \n", - "18 1207.351708 \n", - "19 1344.614057 \n", - "20 1140.422374 \n", - "21 1096.801611 \n", - "22 1268.964143 \n", - "23 1242.488663 \n", - "24 973.603865 \n", - "\n", - " seasonal_period_monthly seasonal_period_weekly seasonal_period_daily \\\n", - "0 169.611938 \n", - "1 287.033273 \n", - "2 445.140878 \n", - "3 403.092354 \n", - "4 145.688355 \n", - "5 -751.49761 \n", - "6 -702.236761 \n", - "7 175.781026 \n", - "8 293.672249 \n", - "9 444.257615 \n", - "10 384.96166 \n", - "11 137.920642 \n", - "12 -746.753292 \n", - "13 -699.068296 \n", - "14 201.358878 \n", - "15 312.98179 \n", - "16 435.000545 \n", - "17 339.465005 \n", - "18 121.104149 \n", - "19 -739.700941 \n", - "20 -687.402931 \n", - "21 244.258697 \n", - "22 354.903046 \n", - "23 438.867718 \n", - "24 268.015563 \n", - "\n", - " holiday_effect spikes_and_dips step_changes residual \n", - "0 1205.283526 336.104536 \n", - "1 1205.283526 308.966319 \n", - "2 1205.283526 689.605373 \n", - "3 1205.283526 547.352547 \n", - "4 1205.283526 114.751155 \n", - "5 1205.283526 108.289741 \n", - "6 1205.283526 -15.938002 \n", - "7 1205.283526 194.836821 \n", - "8 1205.283526 23.715292 \n", - "9 1205.283526 -166.458077 \n", - "10 1205.283526 -75.874891 \n", - "11 1205.283526 119.074236 \n", - "12 1205.283526 -153.258699 \n", - "13 1205.283526 107.003488 \n", - "14 1205.283526 462.187288 \n", - "15 1205.283526 -165.594902 \n", - "16 1205.283526 -282.440321 \n", - "17 1205.283526 106.940453 \n", - "18 1205.283526 -154.739384 \n", - "19 1205.283526 -146.196642 \n", - "20 1205.283526 71.697031 \n", - "21 1205.283526 37.656166 \n", - "22 1205.283526 -75.150715 \n", - "23 1205.283526 -259.639906 \n", - "24 1205.283526 92.097046 \n", - "...\n", - "\n", - "[369 rows x 18 columns]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "time_series_arima_plus_model.predict_attribution()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.1" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 54175c0c76e62bb1a2be0071bc78100c025aac6c Mon Sep 17 00:00:00 2001 From: Daniela Date: Tue, 26 Nov 2024 20:12:29 +0000 Subject: [PATCH 07/18] format and lint --- bigframes/ml/core.py | 8 ++++++-- bigframes/ml/forecasting.py | 2 +- bigframes/ml/sql.py | 6 ++++-- tests/system/small/ml/test_forecasting.py | 18 ++++++++++++++---- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index c73877c558..810fb1f7bd 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -173,8 +173,12 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index() def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: - sql = self._model_manipulation_sql_generator.ml_explain_forecast(struct_options=options) - return self._session.read_gbq(sql, index_col="time_series_timestamp").reset_index() + sql = self._model_manipulation_sql_generator.ml_explain_forecast( + struct_options=options + ) + return self._session.read_gbq( + sql, index_col="time_series_timestamp" + ).reset_index() def evaluate(self, input_data: Optional[bpd.DataFrame] = None): sql = self._model_manipulation_sql_generator.ml_evaluate( diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index faa8603a15..205ac1db38 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -252,7 +252,7 @@ def predict( return self._bqml_model.forecast( options={"horizon": horizon, "confidence_level": confidence_level} ) - + def predict_explain( self, X=None, *, horizon: int = 3, confidence_level: float = 0.95 ) -> bpd.DataFrame: diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 02ebd97ef1..1ef43d9ce5 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -309,8 +309,10 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: struct_options_sql = self.struct_options(**struct_options) return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()}, {struct_options_sql})""" - - def ml_explain_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: + + def ml_explain_forecast( + self, struct_options: Mapping[str, Union[int, float]] + ) -> str: """Encode ML.EXPLAIN_FORECAST for BQML""" struct_options_sql = self.struct_options(**struct_options) return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()}, diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 00016c189a..3aeb9d5598 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -64,13 +64,16 @@ def test_arima_plus_predict_default( check_index_type=False, ) + def test_arima_plus_predict_explain_default( time_series_arima_plus_model: forecasting.ARIMAPlus, ): utc = pytz.utc predictions = time_series_arima_plus_model.predict_explain().to_pandas() assert predictions.shape[0] == 369 - predictions = predictions[predictions["time_series_type"] == "forecast"].reset_index(drop=True) + predictions = predictions[ + predictions["time_series_type"] == "forecast" + ].reset_index(drop=True) assert predictions.shape[0] == 3 result = predictions[["time_series_timestamp", "time_series_data"]] expected = pd.DataFrame( @@ -83,7 +86,9 @@ def test_arima_plus_predict_explain_default( "time_series_data": [2727.693349, 2595.290749, 2370.86767], } ) - expected["time_series_data"] = expected["time_series_data"].astype(pd.Float64Dtype()) + expected["time_series_data"] = expected["time_series_data"].astype( + pd.Float64Dtype() + ) expected["time_series_timestamp"] = expected["time_series_timestamp"].astype( pd.ArrowDtype(pa.timestamp("us", tz="UTC")) ) @@ -126,7 +131,10 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI check_index_type=False, ) -def test_arima_plus_predict_explain_params(time_series_arima_plus_model: forecasting.ARIMAPlus): + +def test_arima_plus_predict_explain_params( + time_series_arima_plus_model: forecasting.ARIMAPlus, +): utc = pytz.utc predictions = time_series_arima_plus_model.predict_explain( horizon=4, confidence_level=0.9 @@ -144,7 +152,9 @@ def test_arima_plus_predict_explain_params(time_series_arima_plus_model: forecas "time_series_data": [2724.472284, 2593.368389, 2353.613034, 1781.623071], } ) - expected["time_series_data"] = expected["time_series_data"].astype(pd.Float64Dtype()) + expected["time_series_data"] = expected["time_series_data"].astype( + pd.Float64Dtype() + ) expected["time_series_timestamp"] = expected["time_series_timestamp"].astype( pd.ArrowDtype(pa.timestamp("us", tz="UTC")) ) From 448e63af616e31efa269baaecdfd5044565d11fd Mon Sep 17 00:00:00 2001 From: rey-esp Date: Mon, 2 Dec 2024 09:56:15 -0600 Subject: [PATCH 08/18] Update bigframes/ml/forecasting.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tim Sweña (Swast) --- bigframes/ml/forecasting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index c754c1ea80..6079e0ea22 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -276,8 +276,8 @@ def predict_explain( Returns: bigframes.dataframe.DataFrame: The predicted DataFrames. """ - if horizon < 1 or horizon > 1000: - raise ValueError(f"horizon must be [1, 1000], but is {horizon}.") + if horizon < 1: + raise ValueError(f"horizon must be at least 1, but is {horizon}.") if confidence_level < 0.0 or confidence_level >= 1.0: raise ValueError( f"confidence_level must be [0.0, 1.0), but is {confidence_level}." From 8347c4b9bffca0d2131f15d5864169cf29a450a0 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 2 Dec 2024 21:23:26 +0000 Subject: [PATCH 09/18] update predict explain params test --- tests/system/small/ml/test_forecasting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 3aeb9d5598..68eec767c6 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -139,7 +139,7 @@ def test_arima_plus_predict_explain_params( predictions = time_series_arima_plus_model.predict_explain( horizon=4, confidence_level=0.9 ).to_pandas() - assert predictions.shape == (4, 8) + assert predictions.shape == (370, 17) result = predictions[["time_series_timestamp", "time_series_data"]] expected = pd.DataFrame( { From 1fe2d37051954aca0a260699af445597f02365ea Mon Sep 17 00:00:00 2001 From: Daniela Date: Tue, 3 Dec 2024 20:03:10 +0000 Subject: [PATCH 10/18] update test --- tests/system/small/ml/test_forecasting.py | 43 ++++++++++------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 68eec767c6..b692cfc3c5 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -139,32 +139,25 @@ def test_arima_plus_predict_explain_params( predictions = time_series_arima_plus_model.predict_explain( horizon=4, confidence_level=0.9 ).to_pandas() - assert predictions.shape == (370, 17) - result = predictions[["time_series_timestamp", "time_series_data"]] - expected = pd.DataFrame( - { - "time_series_timestamp": [ - datetime(2017, 8, 2, tzinfo=utc), - datetime(2017, 8, 3, tzinfo=utc), - datetime(2017, 8, 4, tzinfo=utc), - datetime(2017, 8, 5, tzinfo=utc), - ], - "time_series_data": [2724.472284, 2593.368389, 2353.613034, 1781.623071], + assert predictions.shape[0] >= 1 + prediction_columns = set(predictions.columns) + expected_columns = { + 'time_series_timestamp', + 'time_series_type', + 'time_series_data', + 'time_series_adjusted_data', + 'standard_error', + 'confidence_level', + 'prediction_interval_lower_bound', + 'trend', + 'seasonal_period_yearly', + 'seasonal_period_quarterly', + 'seasonal_period_monthly', + 'seasonal_period_weekly', + 'seasonal_period_daily', + 'holiday_effect', } - ) - expected["time_series_data"] = expected["time_series_data"].astype( - pd.Float64Dtype() - ) - expected["time_series_timestamp"] = expected["time_series_timestamp"].astype( - pd.ArrowDtype(pa.timestamp("us", tz="UTC")) - ) - - pd.testing.assert_frame_equal( - result, - expected, - rtol=0.1, - check_index_type=False, - ) + assert expected_columns <= prediction_columns def test_arima_plus_detect_anomalies( From 48c81ed7f94da9fd5a6998745224f7afdbe99d91 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 3 Dec 2024 20:09:34 +0000 Subject: [PATCH 11/18] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/system/small/ml/test_forecasting.py | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index b692cfc3c5..711b9f2177 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -142,21 +142,21 @@ def test_arima_plus_predict_explain_params( assert predictions.shape[0] >= 1 prediction_columns = set(predictions.columns) expected_columns = { - 'time_series_timestamp', - 'time_series_type', - 'time_series_data', - 'time_series_adjusted_data', - 'standard_error', - 'confidence_level', - 'prediction_interval_lower_bound', - 'trend', - 'seasonal_period_yearly', - 'seasonal_period_quarterly', - 'seasonal_period_monthly', - 'seasonal_period_weekly', - 'seasonal_period_daily', - 'holiday_effect', - } + "time_series_timestamp", + "time_series_type", + "time_series_data", + "time_series_adjusted_data", + "standard_error", + "confidence_level", + "prediction_interval_lower_bound", + "trend", + "seasonal_period_yearly", + "seasonal_period_quarterly", + "seasonal_period_monthly", + "seasonal_period_weekly", + "seasonal_period_daily", + "holiday_effect", + } assert expected_columns <= prediction_columns From 706a1aeb4590351dc58667630b2e4c63ed294d81 Mon Sep 17 00:00:00 2001 From: Daniela Date: Wed, 4 Dec 2024 23:09:46 +0000 Subject: [PATCH 12/18] add unit test file - bare bones --- tests/unit/ml/test_forecasting.py | 70 +++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/unit/ml/test_forecasting.py diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py new file mode 100644 index 0000000000..b45b8a6473 --- /dev/null +++ b/tests/unit/ml/test_forecasting.py @@ -0,0 +1,70 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from unittest import mock + +# from google.cloud import bigquery +import pytest + +from bigframes.ml import forecasting + +# import bigframes.pandas as bpd +# from . import resources + + +def test_predict_explain_confidence_level(): + confidence_level = 0.9 + + forecasting.ARIMAPlus.predict_explain( + horizon=4, confidence_level=confidence_level + ) + +def test_predict_explain_low_confidence_level(): + confidence_level = -0.5 + + with pytest.raises( + ValueError, + match=f"confidence_level must be 0.0 and 1.0, but is {confidence_level}.", + ): + forecasting.ARIMAPlus.predict_explain( + horizon=4, confidence_level=confidence_level + ) + +def test_predict_high_explain_confidence_level(): + confidence_level = 2.1 + + with pytest.raises( + ValueError, + match=f"confidence_level must be 0.0 and 1.0, but is {confidence_level}.", + ): + forecasting.ARIMAPlus.predict_explain( + horizon=4, confidence_level=confidence_level + ) + + +def test_predict_explain_horizon(): + horizon = 1 + + forecasting.ARIMAPlus.predict_explain( + horizon=horizon, confidence_level=0.9 + ) + +def test_predict_explain_low_horizon(): + horizon = 0.5 + + with pytest.raises( + ValueError, match=f"horizon must be at least 1, but is {horizon}." + ): + forecasting.ARIMAPlus.predict_explain( + horizon=horizon, confidence_level=0.9 + ) From 79a535923b70dde088af9c590ad8e81e372590f3 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Wed, 4 Dec 2024 23:12:30 +0000 Subject: [PATCH 13/18] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/unit/ml/test_forecasting.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py index b45b8a6473..c840306fdf 100644 --- a/tests/unit/ml/test_forecasting.py +++ b/tests/unit/ml/test_forecasting.py @@ -25,9 +25,8 @@ def test_predict_explain_confidence_level(): confidence_level = 0.9 - forecasting.ARIMAPlus.predict_explain( - horizon=4, confidence_level=confidence_level - ) + forecasting.ARIMAPlus.predict_explain(horizon=4, confidence_level=confidence_level) + def test_predict_explain_low_confidence_level(): confidence_level = -0.5 @@ -40,6 +39,7 @@ def test_predict_explain_low_confidence_level(): horizon=4, confidence_level=confidence_level ) + def test_predict_high_explain_confidence_level(): confidence_level = 2.1 @@ -55,9 +55,8 @@ def test_predict_high_explain_confidence_level(): def test_predict_explain_horizon(): horizon = 1 - forecasting.ARIMAPlus.predict_explain( - horizon=horizon, confidence_level=0.9 - ) + forecasting.ARIMAPlus.predict_explain(horizon=horizon, confidence_level=0.9) + def test_predict_explain_low_horizon(): horizon = 0.5 @@ -65,6 +64,4 @@ def test_predict_explain_low_horizon(): with pytest.raises( ValueError, match=f"horizon must be at least 1, but is {horizon}." ): - forecasting.ARIMAPlus.predict_explain( - horizon=horizon, confidence_level=0.9 - ) + forecasting.ARIMAPlus.predict_explain(horizon=horizon, confidence_level=0.9) From 3befd2ec1f3b78707eb83b0cf88ed32911a50bc1 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 9 Dec 2024 20:53:33 +0000 Subject: [PATCH 14/18] fixed tests --- tests/unit/ml/test_forecasting.py | 35 ++++++++++++------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py index c840306fdf..5d56f32694 100644 --- a/tests/unit/ml/test_forecasting.py +++ b/tests/unit/ml/test_forecasting.py @@ -11,31 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# from unittest import mock -# from google.cloud import bigquery +import re + import pytest from bigframes.ml import forecasting -# import bigframes.pandas as bpd -# from . import resources - - -def test_predict_explain_confidence_level(): - confidence_level = 0.9 - - forecasting.ARIMAPlus.predict_explain(horizon=4, confidence_level=confidence_level) - def test_predict_explain_low_confidence_level(): confidence_level = -0.5 + model = forecasting.ARIMAPlus() + with pytest.raises( ValueError, - match=f"confidence_level must be 0.0 and 1.0, but is {confidence_level}.", + match=re.escape(f"confidence_level must be [0.0, 1.0), but is {confidence_level}."), ): - forecasting.ARIMAPlus.predict_explain( + model.predict_explain( horizon=4, confidence_level=confidence_level ) @@ -43,25 +36,23 @@ def test_predict_explain_low_confidence_level(): def test_predict_high_explain_confidence_level(): confidence_level = 2.1 + model = forecasting.ARIMAPlus() + with pytest.raises( ValueError, - match=f"confidence_level must be 0.0 and 1.0, but is {confidence_level}.", + match=re.escape(f"confidence_level must be [0.0, 1.0), but is {confidence_level}."), ): - forecasting.ARIMAPlus.predict_explain( + model.predict_explain( horizon=4, confidence_level=confidence_level ) -def test_predict_explain_horizon(): - horizon = 1 - - forecasting.ARIMAPlus.predict_explain(horizon=horizon, confidence_level=0.9) - - def test_predict_explain_low_horizon(): horizon = 0.5 + model = forecasting.ARIMAPlus() + with pytest.raises( ValueError, match=f"horizon must be at least 1, but is {horizon}." ): - forecasting.ARIMAPlus.predict_explain(horizon=horizon, confidence_level=0.9) + model.predict_explain(horizon=horizon, confidence_level=0.9) From 3fbcb64db66238dbac96de094dc50c3f95651c5f Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Mon, 9 Dec 2024 20:55:53 +0000 Subject: [PATCH 15/18] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/unit/ml/test_forecasting.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py index 5d56f32694..b35fc7bf9b 100644 --- a/tests/unit/ml/test_forecasting.py +++ b/tests/unit/ml/test_forecasting.py @@ -26,11 +26,11 @@ def test_predict_explain_low_confidence_level(): with pytest.raises( ValueError, - match=re.escape(f"confidence_level must be [0.0, 1.0), but is {confidence_level}."), + match=re.escape( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ), ): - model.predict_explain( - horizon=4, confidence_level=confidence_level - ) + model.predict_explain(horizon=4, confidence_level=confidence_level) def test_predict_high_explain_confidence_level(): @@ -40,11 +40,11 @@ def test_predict_high_explain_confidence_level(): with pytest.raises( ValueError, - match=re.escape(f"confidence_level must be [0.0, 1.0), but is {confidence_level}."), + match=re.escape( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ), ): - model.predict_explain( - horizon=4, confidence_level=confidence_level - ) + model.predict_explain(horizon=4, confidence_level=confidence_level) def test_predict_explain_low_horizon(): From 04d1fb48959fe8b3d26964634ab795005f546ca7 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 9 Dec 2024 21:26:03 +0000 Subject: [PATCH 16/18] lint --- tests/unit/ml/test_forecasting.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py index 5d56f32694..b35fc7bf9b 100644 --- a/tests/unit/ml/test_forecasting.py +++ b/tests/unit/ml/test_forecasting.py @@ -26,11 +26,11 @@ def test_predict_explain_low_confidence_level(): with pytest.raises( ValueError, - match=re.escape(f"confidence_level must be [0.0, 1.0), but is {confidence_level}."), + match=re.escape( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ), ): - model.predict_explain( - horizon=4, confidence_level=confidence_level - ) + model.predict_explain(horizon=4, confidence_level=confidence_level) def test_predict_high_explain_confidence_level(): @@ -40,11 +40,11 @@ def test_predict_high_explain_confidence_level(): with pytest.raises( ValueError, - match=re.escape(f"confidence_level must be [0.0, 1.0), but is {confidence_level}."), + match=re.escape( + f"confidence_level must be [0.0, 1.0), but is {confidence_level}." + ), ): - model.predict_explain( - horizon=4, confidence_level=confidence_level - ) + model.predict_explain(horizon=4, confidence_level=confidence_level) def test_predict_explain_low_horizon(): From 6bfb1d3f0c6de71943531492cf2da21f27f59427 Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 9 Dec 2024 21:35:48 +0000 Subject: [PATCH 17/18] lint --- tests/system/small/ml/test_forecasting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index 711b9f2177..1b3a650388 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -135,7 +135,6 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI def test_arima_plus_predict_explain_params( time_series_arima_plus_model: forecasting.ARIMAPlus, ): - utc = pytz.utc predictions = time_series_arima_plus_model.predict_explain( horizon=4, confidence_level=0.9 ).to_pandas() From e2eb29d388d2c3d6c5cd04db588e64ce17fd29cf Mon Sep 17 00:00:00 2001 From: Daniela Date: Mon, 9 Dec 2024 22:13:42 +0000 Subject: [PATCH 18/18] fix test: float -> int --- tests/unit/ml/test_forecasting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/ml/test_forecasting.py b/tests/unit/ml/test_forecasting.py index b35fc7bf9b..3bbf4c777e 100644 --- a/tests/unit/ml/test_forecasting.py +++ b/tests/unit/ml/test_forecasting.py @@ -48,7 +48,7 @@ def test_predict_high_explain_confidence_level(): def test_predict_explain_low_horizon(): - horizon = 0.5 + horizon = -1 model = forecasting.ARIMAPlus()