From 97d92594ae400e3ad6aafade862380ff1bf48e58 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Tue, 3 Dec 2024 00:16:38 +0000 Subject: [PATCH 01/12] docs(bigquery): update minor parts in base.py --- bigframes/ml/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 5662e54d6d..26918f67cd 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -33,7 +33,7 @@ class BaseEstimator(bigframes_vendored.sklearn.base.BaseEstimator, abc.ABC): """ - A BigQuery DataFrames machine learning component following the SKLearn API + A BigQuery DataFrames machine learning component follows SKLearn API design Ref: https://bit.ly/3NyhKjN The estimator is the fundamental abstraction for all learning components. This includes learning From c9318d0212f4b357806e6be94a9bc04820c03e23 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Tue, 3 Dec 2024 00:20:51 +0000 Subject: [PATCH 02/12] docs(bigquery): update minor changes for bigframes/ml/base.py --- bigframes/ml/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 26918f67cd..4058647adb 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -33,7 +33,7 @@ class BaseEstimator(bigframes_vendored.sklearn.base.BaseEstimator, abc.ABC): """ - A BigQuery DataFrames machine learning component follows SKLearn API + A BigQuery DataFrames machine learning component follows sklearn API design Ref: https://bit.ly/3NyhKjN The estimator is the fundamental abstraction for all learning components. This includes learning From e9f28f48b4758ccbce1d059a769912279c364bfa Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Wed, 11 Dec 2024 18:29:19 +0000 Subject: [PATCH 03/12] feat: Update lUpdate GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 \n Bug: b/381936588 and b/344891364 --- bigframes/ml/llm.py | 20 +++++++++++++------- tests/system/load/test_llm.py | 13 ++++++++++--- tests/system/small/ml/test_llm.py | 13 ++++++++++--- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 427e99583d..a9c024236b 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -874,7 +874,7 @@ def fit( X: utils.ArrayType, y: utils.ArrayType, ) -> GeminiTextGenerator: - """Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now. + """Fine tune GeminiTextGenerator model. Only support "gemini-pro" and "gemini-1.5" models for now. .. note:: @@ -892,8 +892,11 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - if self._bqml_model.model_name.startswith("gemini-1.5"): - raise NotImplementedError("Fit is not supported for gemini-1.5 model.") + supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] + if self.model_name not in supported_models: + raise NotImplementedError( + "Score is not supported models other than gemini-pro or gemini-1.5 model." + ) X, y = utils.batch_convert_to_dataframe(X, y) @@ -1009,7 +1012,7 @@ def score( "text_generation", "classification", "summarization", "question_answering" ] = "text_generation", ) -> bpd.DataFrame: - """Calculate evaluation metrics of the model. Only "gemini-pro" model is supported for now. + """Calculate evaluation metrics of the model. Only "gemini-pro" and "gemini-1.5" models are supported for now. .. note:: @@ -1041,9 +1044,12 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - # TODO(ashleyxu): Support gemini-1.5 when the rollout is ready. b/344891364. - if self._bqml_model.model_name.startswith("gemini-1.5"): - raise NotImplementedError("Score is not supported for gemini-1.5 model.") + # Support gemini-1.5 and gemini-pro + supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] + if self.model_name not in supported_models: + raise NotImplementedError( + "Score is not supported models other than gemini-pro or gemini-1.5 model." + ) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 9ef60bae0b..7e15583421 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -38,12 +38,19 @@ def llm_remote_text_df(session, llm_remote_text_pandas_df): return session.read_pandas(llm_remote_text_pandas_df) -@pytest.mark.flaky(retries=2) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-002", + "gemini-1.5-flash-002", + ), +) def test_llm_gemini_configure_fit( - session, llm_fine_tune_df_default_index, llm_remote_text_df + session, model_name, llm_fine_tune_df_default_index, llm_remote_text_df ): model = llm.GeminiTextGenerator( - session=session, model_name="gemini-pro", max_iterations=1 + session=session, model_name=model_name, max_iterations=1 ) X_train = llm_fine_tune_df_default_index[["prompt"]] diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 4bc1bd63be..18538ac404 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -413,9 +413,16 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): ) -@pytest.mark.flaky(retries=2) -def test_llm_gemini_pro_score(llm_fine_tune_df_default_index): - model = llm.GeminiTextGenerator(model_name="gemini-pro") +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-002", + "gemini-1.5-flash-002", + ), +) +def test_llm_gemini_pro_score(model_name, llm_fine_tune_df_default_index): + model = llm.GeminiTextGenerator(model_name=model_name) # Check score to ensure the model was fitted score_result = model.score( From 5d2a807186d6e65f1cd0bcdbee2ecc9770f63762 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Wed, 11 Dec 2024 18:46:06 +0000 Subject: [PATCH 04/12] feat: Update lUpdate GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 \n Bug: b/381936588 and b/344891364 --- bigframes/ml/llm.py | 1 + tests/system/load/test_llm.py | 1 - tests/system/small/ml/test_llm.py | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index a9c024236b..5b705b4fee 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -892,6 +892,7 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ + # Support gemini-1.5 and gemini-pro supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] if self.model_name not in supported_models: raise NotImplementedError( diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 7e15583421..45dd1667a6 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -76,7 +76,6 @@ def test_llm_gemini_configure_fit( ], index=3, ) - # TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept @pytest.mark.flaky(retries=2) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 18538ac404..86d970f56b 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -413,6 +413,7 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): ) +# test score() function for "gemini-pro" and "gemini-1.5" model @pytest.mark.parametrize( "model_name", ( @@ -421,7 +422,7 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): "gemini-1.5-flash-002", ), ) -def test_llm_gemini_pro_score(model_name, llm_fine_tune_df_default_index): +def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name): model = llm.GeminiTextGenerator(model_name=model_name) # Check score to ensure the model was fitted From 8be16d3bc1194d421323414f73722d8050591bcc Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Wed, 11 Dec 2024 23:46:34 +0000 Subject: [PATCH 05/12] update testcase and docs for better clarification --- bigframes/ml/llm.py | 11 ++++++----- tests/system/small/ml/test_llm.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 103d747807..0632a85d83 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -890,7 +890,8 @@ def fit( X: utils.ArrayType, y: utils.ArrayType, ) -> GeminiTextGenerator: - """Fine tune GeminiTextGenerator model. Only support "gemini-pro" and "gemini-1.5" models for now. + """Fine tune GeminiTextGenerator model. Only support "gemini-pro", "gemini-1.5-pro-002", + "gemini-1.5-flash-002" models for now. .. note:: @@ -908,11 +909,11 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - # Support gemini-1.5 and gemini-pro supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] if self.model_name not in supported_models: raise NotImplementedError( - "Score is not supported models other than gemini-pro or gemini-1.5 model." + "Score is not supported for models other than gemini-pro, \ + gemini-1.5-pro-002, or gemini-1.5-flash-002 model." ) X, y = utils.batch_convert_to_dataframe(X, y) @@ -1061,11 +1062,11 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - # Support gemini-1.5 and gemini-pro supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] if self.model_name not in supported_models: raise NotImplementedError( - "Score is not supported models other than gemini-pro or gemini-1.5 model." + "Score is not supported models other than gemini-pro \ + , gemini-1.5-pro-002, or gemini-1.5-flash-2 model." ) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 86d970f56b..a0813f276c 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -413,7 +413,6 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): ) -# test score() function for "gemini-pro" and "gemini-1.5" model @pytest.mark.parametrize( "model_name", ( @@ -443,9 +442,16 @@ def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name): ) -@pytest.mark.flaky(retries=2) -def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index): - model = llm.GeminiTextGenerator(model_name="gemini-pro") +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-002", + "gemini-1.5-flash-002", + ), +) +def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index, model_name): + model = llm.GeminiTextGenerator(model_name=model_name) # Check score to ensure the model was fitted score_result = model.score( From 8a39a0202de11d959f1a948e836749faeb99b5cf Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Thu, 12 Dec 2024 01:30:49 +0000 Subject: [PATCH 06/12] update endpoint to corresponding endpoint for fine tuning. --- bigframes/ml/llm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 0632a85d83..6043470a25 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -919,7 +919,9 @@ def fit( X, y = utils.batch_convert_to_dataframe(X, y) options = self._bqml_options - options["endpoint"] = "gemini-1.0-pro-002" + options["endpoint"] = ( + "gemini-1.0-pro-002" if self.model_name == "gemini-pro" else self.model_name + ) options["prompt_col"] = X.columns.tolist()[0] self._bqml_model = self._bqml_model_factory.create_llm_remote_model( From 9de2c0eb928dc9422de87dbecf94ab22741d6aef Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Tue, 3 Dec 2024 00:16:38 +0000 Subject: [PATCH 07/12] docs(bigquery): update minor parts in base.py --- bigframes/ml/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 4058647adb..26918f67cd 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -33,7 +33,7 @@ class BaseEstimator(bigframes_vendored.sklearn.base.BaseEstimator, abc.ABC): """ - A BigQuery DataFrames machine learning component follows sklearn API + A BigQuery DataFrames machine learning component follows SKLearn API design Ref: https://bit.ly/3NyhKjN The estimator is the fundamental abstraction for all learning components. This includes learning From ed001b82875ee15e8fc95002e35f693ee2eff3a4 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Thu, 12 Dec 2024 20:23:51 +0000 Subject: [PATCH 08/12] fix syntax issue --- bigframes/ml/llm.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 6043470a25..f1631b8333 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -79,6 +79,16 @@ _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, _GEMINI_2_FLASH_EXP_ENDPOINT, ) +_GEMINI_FINE_TUNE_ENDPOINTS = ( + _GEMINI_PRO_ENDPOINT, + _GEMINI_1P5_PRO_002_ENDPOINT, + _GEMINI_1P5_FLASH_002_ENDPOINT, +) +_GEMINI_SCORE_ENDPOINTS = ( + _GEMINI_PRO_ENDPOINT, + _GEMINI_1P5_PRO_002_ENDPOINT, + _GEMINI_1P5_FLASH_002_ENDPOINT, +) _CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" _CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku" @@ -909,10 +919,9 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] - if self.model_name not in supported_models: + if self.model_name not in _GEMINI_FINE_TUNE_ENDPOINTS: raise NotImplementedError( - "Score is not supported for models other than gemini-pro, \ + "fit() only supports gemini-pro, \ gemini-1.5-pro-002, or gemini-1.5-flash-002 model." ) @@ -1032,7 +1041,7 @@ def score( "text_generation", "classification", "summarization", "question_answering" ] = "text_generation", ) -> bpd.DataFrame: - """Calculate evaluation metrics of the model. Only "gemini-pro" and "gemini-1.5" models are supported for now. + """Calculate evaluation metrics of the model. Only support "gemini-pro" and "gemini-1.5-pro-002", and "gemini-1.5-flash-002". .. note:: @@ -1064,11 +1073,10 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] - if self.model_name not in supported_models: + if self.model_name not in _GEMINI_SCORE_ENDPOINTS: raise NotImplementedError( - "Score is not supported models other than gemini-pro \ - , gemini-1.5-pro-002, or gemini-1.5-flash-2 model." + "score() only supports gemini-pro \ + , gemini-1.5-pro-002, and gemini-1.5-flash-2 model." ) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) From 9928f105a8d382ca1efa894b28e28c862fca239a Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Thu, 12 Dec 2024 21:15:47 +0000 Subject: [PATCH 09/12] Revert "docs(bigquery): update minor parts in base.py" This reverts commit 9de2c0eb928dc9422de87dbecf94ab22741d6aef. --- bigframes/ml/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 26918f67cd..4058647adb 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -33,7 +33,7 @@ class BaseEstimator(bigframes_vendored.sklearn.base.BaseEstimator, abc.ABC): """ - A BigQuery DataFrames machine learning component follows SKLearn API + A BigQuery DataFrames machine learning component follows sklearn API design Ref: https://bit.ly/3NyhKjN The estimator is the fundamental abstraction for all learning components. This includes learning From 241ae7300b373e8f71f5f5c295e9684caf16bbde Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Fri, 13 Dec 2024 19:02:20 +0000 Subject: [PATCH 10/12] merge gemini_fine_tune_endpoints and gemini_score_endpoints together, since they are identical --- bigframes/ml/llm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f1631b8333..9b7228fe83 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -79,12 +79,7 @@ _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, _GEMINI_2_FLASH_EXP_ENDPOINT, ) -_GEMINI_FINE_TUNE_ENDPOINTS = ( - _GEMINI_PRO_ENDPOINT, - _GEMINI_1P5_PRO_002_ENDPOINT, - _GEMINI_1P5_FLASH_002_ENDPOINT, -) -_GEMINI_SCORE_ENDPOINTS = ( +_GEMINI_FINE_TUNE_SCORE_ENDPOINTS = ( _GEMINI_PRO_ENDPOINT, _GEMINI_1P5_PRO_002_ENDPOINT, _GEMINI_1P5_FLASH_002_ENDPOINT, @@ -919,7 +914,7 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - if self.model_name not in _GEMINI_FINE_TUNE_ENDPOINTS: + if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: raise NotImplementedError( "fit() only supports gemini-pro, \ gemini-1.5-pro-002, or gemini-1.5-flash-002 model." @@ -1073,7 +1068,7 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - if self.model_name not in _GEMINI_SCORE_ENDPOINTS: + if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: raise NotImplementedError( "score() only supports gemini-pro \ , gemini-1.5-pro-002, and gemini-1.5-flash-2 model." From 205e1738d6d21edfd9ad8f753fd659c248b54475 Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Fri, 13 Dec 2024 19:24:42 +0000 Subject: [PATCH 11/12] merge genimi_fine_tune_endpoints and genimi_score_endpoints, since they are identical --- bigframes/ml/llm.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 9b7228fe83..e27461c266 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -895,8 +895,7 @@ def fit( X: utils.ArrayType, y: utils.ArrayType, ) -> GeminiTextGenerator: - """Fine tune GeminiTextGenerator model. Only support "gemini-pro", "gemini-1.5-pro-002", - "gemini-1.5-flash-002" models for now. + """Fine tune GeminiTextGenerator model. Only support "gemini-pro" and "gemini-1.5" model for now. .. note:: @@ -914,10 +913,11 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: + # Support gemini-1.5 and gemini-pro + supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] + if self.model_name not in supported_models: raise NotImplementedError( - "fit() only supports gemini-pro, \ - gemini-1.5-pro-002, or gemini-1.5-flash-002 model." + "Score is not supported models other than gemini-pro or gemini-1.5 model." ) X, y = utils.batch_convert_to_dataframe(X, y) @@ -1036,7 +1036,7 @@ def score( "text_generation", "classification", "summarization", "question_answering" ] = "text_generation", ) -> bpd.DataFrame: - """Calculate evaluation metrics of the model. Only support "gemini-pro" and "gemini-1.5-pro-002", and "gemini-1.5-flash-002". + """Calculate evaluation metrics of the model. Only "gemini-pro" and "gemini-1.5" model is supported for now. .. note:: @@ -1068,10 +1068,11 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: + # Support gemini-1.5 and gemini-pro + supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] + if self.model_name not in supported_models: raise NotImplementedError( - "score() only supports gemini-pro \ - , gemini-1.5-pro-002, and gemini-1.5-flash-2 model." + "Score is not supported models other than gemini-pro or gemini-1.5 model." ) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) From 6a44e7b6c7cbd7542be12865019dafab105c1f1e Mon Sep 17 00:00:00 2001 From: Shuowei Li Date: Fri, 13 Dec 2024 19:28:02 +0000 Subject: [PATCH 12/12] Revert "merge genimi_fine_tune_endpoints and genimi_score_endpoints, since they are identical" This reverts commit 205e1738d6d21edfd9ad8f753fd659c248b54475. --- bigframes/ml/llm.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index e27461c266..9b7228fe83 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -895,7 +895,8 @@ def fit( X: utils.ArrayType, y: utils.ArrayType, ) -> GeminiTextGenerator: - """Fine tune GeminiTextGenerator model. Only support "gemini-pro" and "gemini-1.5" model for now. + """Fine tune GeminiTextGenerator model. Only support "gemini-pro", "gemini-1.5-pro-002", + "gemini-1.5-flash-002" models for now. .. note:: @@ -913,11 +914,10 @@ def fit( Returns: GeminiTextGenerator: Fitted estimator. """ - # Support gemini-1.5 and gemini-pro - supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] - if self.model_name not in supported_models: + if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: raise NotImplementedError( - "Score is not supported models other than gemini-pro or gemini-1.5 model." + "fit() only supports gemini-pro, \ + gemini-1.5-pro-002, or gemini-1.5-flash-002 model." ) X, y = utils.batch_convert_to_dataframe(X, y) @@ -1036,7 +1036,7 @@ def score( "text_generation", "classification", "summarization", "question_answering" ] = "text_generation", ) -> bpd.DataFrame: - """Calculate evaluation metrics of the model. Only "gemini-pro" and "gemini-1.5" model is supported for now. + """Calculate evaluation metrics of the model. Only support "gemini-pro" and "gemini-1.5-pro-002", and "gemini-1.5-flash-002". .. note:: @@ -1068,11 +1068,10 @@ def score( if not self._bqml_model: raise RuntimeError("A model must be fitted before score") - # Support gemini-1.5 and gemini-pro - supported_models = ["gemini-pro", "gemini-1.5-pro-002", "gemini-1.5-flash-002"] - if self.model_name not in supported_models: + if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: raise NotImplementedError( - "Score is not supported models other than gemini-pro or gemini-1.5 model." + "score() only supports gemini-pro \ + , gemini-1.5-pro-002, and gemini-1.5-flash-2 model." ) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session)