Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 583de9f

Browse filesBrowse files
authored
feat(generative_ai): add model evaluation sample (GoogleCloudPlatform#11122)
* feat(generative_ai): add model evaluation sample * set credentials required by pipeline components
1 parent 93429e0 commit 583de9f
Copy full SHA for 583de9f

File tree

Expand file treeCollapse file tree

2 files changed

+93
-0
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+93
-0
lines changed

‎generative_ai/evaluate_model.py

Copy file name to clipboard
+59Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_evaluate_model]
16+
17+
from google.auth import default
18+
import vertexai
19+
from vertexai.preview.language_models import (
20+
EvaluationTextClassificationSpec,
21+
TextGenerationModel,
22+
)
23+
24+
# Set credentials for the pipeline components used in the evaluation task
25+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
26+
27+
28+
def evaluate_model(
29+
project_id: str,
30+
location: str,
31+
) -> object:
32+
"""Evaluate the performance of a generative AI model."""
33+
34+
vertexai.init(project=project_id, location=location, credentials=credentials)
35+
36+
# Create a reference to a generative AI model
37+
model = TextGenerationModel.from_pretrained("text-bison@001")
38+
39+
# Define the evaluation specification for a text classification task
40+
task_spec = EvaluationTextClassificationSpec(
41+
ground_truth_data=[
42+
"gs://cloud-samples-data/ai-platform/generative_ai/llm_classification_bp_input_prompts_with_ground_truth.jsonl"
43+
],
44+
class_names=["nature", "news", "sports", "health", "startups"],
45+
target_column_name="ground_truth",
46+
)
47+
48+
# Evaluate the model
49+
eval_metrics = model.evaluate(task_spec=task_spec)
50+
print(eval_metrics)
51+
52+
return eval_metrics
53+
54+
55+
# [END aiplatform_evaluate_model]
56+
57+
58+
if __name__ == "__main__":
59+
evaluate_model()

‎generative_ai/evaluate_model_test.py

Copy file name to clipboard
+34Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
from google.api_core.exceptions import ResourceExhausted
19+
20+
import evaluate_model
21+
22+
23+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
24+
_LOCATION = "us-central1"
25+
26+
27+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
28+
def test_evaluate_model() -> None:
29+
eval_metrics = evaluate_model.evaluate_model(
30+
_PROJECT_ID,
31+
_LOCATION,
32+
)
33+
34+
assert hasattr(eval_metrics, "auRoc")

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.