From 5b45e6bd80c0bdfb565184d05bc69b37147b540f Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Mon, 29 Apr 2024 18:21:32 +0000 Subject: [PATCH 1/2] fix: llm palm score tests --- tests/system/load/test_llm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 835b31955e..0bfbb2741c 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -54,7 +54,7 @@ def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_ model_name="text-bison", max_iterations=1 ) - df = llm_fine_tune_df_default_index.dropna() + df = llm_fine_tune_df_default_index.dropna().sample(n=100) X_train = df[["prompt"]] y_train = df[["label"]] model.fit(X_train, y_train) @@ -102,12 +102,10 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): ).to_pandas() score_result_col = score_result.columns.to_list() expected_col = [ - "trial_id", "precision", "recall", - "accuracy", "f1_score", - "log_loss", - "roc_auc", + "label", + "evaluation_status", ] assert all(col in score_result_col for col in expected_col) From 4c808485df4633e0b3d83fe88b6c8adb66a55210 Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Mon, 29 Apr 2024 18:47:56 +0000 Subject: [PATCH 2/2] address comments --- tests/system/load/test_llm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 0bfbb2741c..fd13662275 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -49,6 +49,7 @@ def llm_remote_text_df(session, llm_remote_text_pandas_df): return session.read_pandas(llm_remote_text_pandas_df) +@pytest.mark.flaky(retries=2) def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df): model = bigframes.ml.llm.PaLM2TextGenerator( model_name="text-bison", max_iterations=1 @@ -70,6 +71,7 @@ def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_ # TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept +@pytest.mark.flaky(retries=2) def test_llm_palm_score(llm_fine_tune_df_default_index): model = bigframes.ml.llm.PaLM2TextGenerator(model_name="text-bison") @@ -89,6 +91,7 @@ def test_llm_palm_score(llm_fine_tune_df_default_index): assert all(col in score_result_col for col in expected_col) +@pytest.mark.flaky(retries=2) def test_llm_palm_score_params(llm_fine_tune_df_default_index): model = bigframes.ml.llm.PaLM2TextGenerator( model_name="text-bison", max_iterations=1