diff --git a/tests/system/small/ml/test_core.py b/tests/system/small/ml/test_core.py index 95719ea0db..6b852e87af 100644 --- a/tests/system/small/ml/test_core.py +++ b/tests/system/small/ml/test_core.py @@ -14,7 +14,6 @@ from datetime import datetime import typing -from unittest import TestCase import pandas as pd import pyarrow as pa @@ -24,7 +23,7 @@ import bigframes import bigframes.features from bigframes.ml import core -import tests.system.utils +from tests.system import utils def test_model_eval( @@ -212,7 +211,7 @@ def test_pca_model_principal_components(penguins_bqml_pca_model: core.BqmlModel) .reset_index(drop=True) ) - tests.system.utils.assert_pandas_df_equal_pca_components( + utils.assert_pandas_df_equal_pca_components( result, expected, check_exact=False, @@ -234,7 +233,7 @@ def test_pca_model_principal_component_info(penguins_bqml_pca_model: core.BqmlMo "cumulative_explained_variance_ratio": [0.469357, 0.651283, 0.812383], }, ) - tests.system.utils.assert_pandas_df_equal( + utils.assert_pandas_df_equal( result, expected, check_exact=False, @@ -349,18 +348,9 @@ def test_model_generate_text( llm_text_df, options=options ).to_pandas() - TestCase().assertSequenceEqual(df.shape, (3, 4)) - TestCase().assertSequenceEqual( - [ - "ml_generate_text_llm_result", - "ml_generate_text_rai_result", - "ml_generate_text_status", - "prompt", - ], - df.columns.to_list(), + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False ) - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) def test_model_forecast(time_series_bqml_arima_plus_model: core.BqmlModel): diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 2f1a16f23c..43e756019d 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -74,10 +74,9 @@ def test_create_text_generator_model_default_session( llm_text_df = bpd.read_pandas(llm_text_pandas_df) df = model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.flaky(retries=2) @@ -104,10 +103,9 @@ def test_create_text_generator_32k_model_default_session( llm_text_df = bpd.read_pandas(llm_text_pandas_df) df = model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.flaky(retries=2) @@ -131,10 +129,9 @@ def test_create_text_generator_model_default_connection( ) df = model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) # Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough. @@ -143,10 +140,9 @@ def test_text_generator_predict_default_params_success( palm2_text_generator_model, llm_text_df ): df = palm2_text_generator_model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.flaky(retries=2) @@ -154,10 +150,9 @@ def test_text_generator_predict_series_default_params_success( palm2_text_generator_model, llm_text_df ): df = palm2_text_generator_model.predict(llm_text_df["prompt"]).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.flaky(retries=2) @@ -166,10 +161,9 @@ def test_text_generator_predict_arbitrary_col_label_success( ): llm_text_df = llm_text_df.rename(columns={"prompt": "arbitrary"}) df = palm2_text_generator_model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.flaky(retries=2) @@ -179,10 +173,9 @@ def test_text_generator_predict_with_params_success( df = palm2_text_generator_model.predict( llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5 ).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) def test_create_embedding_generator_model( @@ -379,10 +372,9 @@ def test_gemini_text_generator_predict_default_params_success( model_name=model_name, connection_name=bq_connection, session=session ) df = gemini_text_generator_model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.parametrize( @@ -399,10 +391,9 @@ def test_gemini_text_generator_predict_with_params_success( df = gemini_text_generator_model.predict( llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5 ).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.parametrize( @@ -444,10 +435,9 @@ def test_claude3_text_generator_predict_default_params_success( model_name=model_name, connection_name=bq_connection, session=session ) df = claude3_text_generator_model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 3) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.parametrize( @@ -466,10 +456,9 @@ def test_claude3_text_generator_predict_with_params_success( df = claude3_text_generator_model.predict( llm_text_df, max_output_tokens=100, top_k=20, top_p=0.5 ).to_pandas() - assert df.shape == (3, 3) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False + ) @pytest.mark.flaky(retries=2) diff --git a/tests/system/utils.py b/tests/system/utils.py index e9054d04c9..26e3e97e24 100644 --- a/tests/system/utils.py +++ b/tests/system/utils.py @@ -45,6 +45,11 @@ "log_loss", "roc_auc", ] +ML_GENERATE_TEXT_OUTPUT = [ + "ml_generate_text_llm_result", + "ml_generate_text_status", + "prompt", +] def skip_legacy_pandas(test):