Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

feat: Add fine tuning fit() for Palm2TextGenerator #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions 40 bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,46 @@ def create_model(

return self._create_model_with_sql(session=session, sql=sql)

def create_llm_remote_model(
self,
X_train: bpd.DataFrame,
y_train: bpd.DataFrame,
connection_name: str,
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
) -> BqmlModel:
"""Create a session-temporary BQML model with the CREATE OR REPLACE MODEL statement

Args:
X_train: features columns for training
y_train: labels columns for training
options: a dict of options to configure the model. Generates a BQML OPTIONS
clause
connection_name:
a BQ connection to talk with Vertex AI, of the format <PROJECT_NUMBER>.<REGION>.<CONNECTION_NAME>. https://cloud.google.com/bigquery/docs/create-cloud-resource-connection

Returns: a BqmlModel, wrapping a trained model in BigQuery
"""
options = dict(options)
# Cache dataframes to make sure base table is not a snapshot
# cached dataframe creates a full copy, never uses snapshot
input_data = X_train._cached(force=True).join(
y_train._cached(force=True), how="outer"
)
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})

session = X_train._session

model_ref = self._create_model_ref(session._anonymous_dataset)

sql = self._model_creation_sql_generator.create_llm_remote_model(
source_df=input_data,
model_ref=model_ref,
options=options,
connection_name=connection_name,
)

return self._create_model_with_sql(session=session, sql=sql)

def create_time_series_model(
self,
X_train: bpd.DataFrame,
Expand Down
71 changes: 70 additions & 1 deletion 71 bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd

_BQML_PARAMS_MAPPING = {
"max_iterations": "maxIterations",
}

_TEXT_GENERATOR_BISON_ENDPOINT = "text-bison"
_TEXT_GENERATOR_BISON_32K_ENDPOINT = "text-bison-32k"
_TEXT_GENERATOR_ENDPOINTS = (
Expand Down Expand Up @@ -62,6 +66,8 @@ class PaLM2TextGenerator(base.BaseEstimator):
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
permission if the connection isn't fully setup.
max_iterations (Optional[int], Default to 300):
The number of steps to run when performing supervised tuning.
"""

def __init__(
Expand All @@ -70,9 +76,11 @@ def __init__(
model_name: Literal["text-bison", "text-bison-32k"] = "text-bison",
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
max_iterations: int = 300,
):
self.model_name = model_name
self.session = session or bpd.get_global_session()
self.max_iterations = max_iterations
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
Expand Down Expand Up @@ -132,12 +140,73 @@ def _from_bq(
model_connection = model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

# Get the optional params
kwargs: dict = {}
last_fitting = model.training_runs[-1]["trainingOptions"]

dummy_text_generator = cls()
for bf_param, _ in dummy_text_generator.__dict__.items():
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
if bqml_param in last_fitting:
# Convert types
if bf_param in ["max_iterations"]:
kwargs[bf_param] = int(last_fitting[bqml_param])

text_generator_model = cls(
session=session, model_name=model_endpoint, connection_name=model_connection
**kwargs,
session=session,
model_name=model_endpoint,
connection_name=model_connection,
)
text_generator_model._bqml_model = core.BqmlModel(session, model)
return text_generator_model

@property
def _bqml_options(self) -> dict:
"""The model options as they will be set for BQML"""
options = {
"max_iterations": self.max_iterations,
"data_split_method": "NO_SPLIT",
}
return options

def fit(
self,
X: Union[bpd.DataFrame, bpd.Series],
y: Union[bpd.DataFrame, bpd.Series],
) -> PaLM2TextGenerator:
"""Fine tune PaLM2TextGenerator model.

.. note::

This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
and might have limited support. For more information, see the launch stage descriptions
(https://cloud.google.com/products#product-launch-stages).

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
DataFrame of shape (n_samples, n_features). Training data.
y (bigframes.dataframe.DataFrame or bigframes.series.Series:
Training labels.

Returns:
PaLM2TextGenerator: Fitted Estimator.
"""
X, y = utils.convert_to_dataframe(X, y)

options = self._bqml_options
options["endpoint"] = self.model_name + "@001"
ashleyxuu marked this conversation as resolved.
Show resolved Hide resolved
options["prompt_col"] = X.columns.tolist()[0]

self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
X,
y,
options=options,
connection_name=self.connection_name,
)
return self

def predict(
self,
X: Union[bpd.DataFrame, bpd.Series],
Expand Down
17 changes: 17 additions & 0 deletions 17 bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,23 @@ def create_model(
parts.append(f"AS {source_sql}")
return "\n".join(parts)

def create_llm_remote_model(
ashleyxuu marked this conversation as resolved.
Show resolved Hide resolved
self,
source_df: bpd.DataFrame,
connection_name: str,
model_ref: google.cloud.bigquery.ModelReference,
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
) -> str:
"""Encode the CREATE OR REPLACE MODEL statement for BQML"""
source_sql = source_df.sql

parts = [f"CREATE OR REPLACE MODEL {self._model_id_sql(model_ref)}"]
parts.append(self.connection(connection_name))
if options:
parts.append(self.options(**options))
parts.append(f"AS {source_sql}")
return "\n".join(parts)

def create_remote_model(
self,
connection_name: str,
Expand Down
68 changes: 68 additions & 0 deletions 68 tests/system/load/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 Google LLC
ashleyxuu marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
import pytest

import bigframes.ml.llm


@pytest.fixture(scope="session")
def llm_fine_tune_df_default_index(
session: bigframes.Session,
) -> bigframes.dataframe.DataFrame:
sql = """
SELECT
CONCAT("Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: ", text) as prompt,
CAST(label AS STRING) as label
FROM `llm_tuning.emotion_classification_train`
"""
return session.read_gbq(sql)


@pytest.fixture(scope="session")
def llm_remote_text_pandas_df():
"""Additional data matching the penguins dataset, with a new index"""
return pd.DataFrame(
{
"prompt": [
"Please do sentiment analysis on the following text and only output a number from 0 to 5where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: i feel beautifully emotional knowing that these women of whom i knew just a handful were holding me and my baba on our journey",
"Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: i was feeling a little vain when i did this one",
"Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: a father of children killed in an accident",
],
}
)


def test_llm_palm_configure_fit(
llm_fine_tune_df_default_index, llm_remote_text_pandas_df
):
model = bigframes.ml.llm.PaLM2TextGenerator(
model_name="text-bison", max_iterations=1
)

df = llm_fine_tune_df_default_index.dropna()
X_train = df[["prompt"]]
y_train = df[["label"]]
model.fit(X_train, y_train)

assert model is not None

df = model.predict(llm_remote_text_pandas_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() == 1)

# TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept
2 changes: 1 addition & 1 deletion 2 tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google LLC
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
23 changes: 23 additions & 0 deletions 23 tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,29 @@ def test_create_model_transform_correct(
)


def test_create_llm_remote_model_correct(
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_creation_sql_generator.create_llm_remote_model(
source_df=mock_df,
connection_name="my_project.us.my_connection",
model_ref=bigquery.ModelReference.from_string(
"test-proj._anonXYZ.create_remote_model"
),
options={"option_key1": "option_value1", "option_key2": 2},
)
assert (
sql
== """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_remote_model`
REMOTE WITH CONNECTION `my_project.us.my_connection`
OPTIONS(
option_key1="option_value1",
option_key2=2)
AS input_X_y_sql"""
)


def test_create_remote_model_correct(
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
):
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.