Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8b3beb6

Browse filesBrowse files
yinghsienwucopybara-github
authored andcommitted
feat: Add support for user-configurable 1P embedding models and quota for RAG
PiperOrigin-RevId: 642414350
1 parent cf8bc3d commit 8b3beb6
Copy full SHA for 8b3beb6

6 files changed

+223-3Lines changed: 223 additions & 3 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎tests/unit/vertex_rag/test_rag_constants.py‎

Copy file name to clipboardExpand all lines: tests/unit/vertex_rag/test_rag_constants.py
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
from vertexai.preview.rag.utils.resources import (
19+
EmbeddingModelConfig,
1920
RagCorpus,
2021
RagFile,
2122
RagResource,
@@ -49,10 +50,19 @@
4950
display_name=TEST_CORPUS_DISPLAY_NAME,
5051
description=TEST_CORPUS_DISCRIPTION,
5152
)
53+
TEST_GAPIC_RAG_CORPUS.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
54+
"projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format(
55+
TEST_PROJECT, TEST_REGION
56+
)
57+
)
58+
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
59+
publisher_model="publishers/google/models/textembedding-gecko",
60+
)
5261
TEST_RAG_CORPUS = RagCorpus(
5362
name=TEST_RAG_CORPUS_RESOURCE_NAME,
5463
display_name=TEST_CORPUS_DISPLAY_NAME,
5564
description=TEST_CORPUS_DISCRIPTION,
65+
embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG,
5666
)
5767
TEST_PAGE_TOKEN = "test-page-token"
5868

@@ -114,6 +124,8 @@
114124
chunk_overlap=TEST_CHUNK_OVERLAP,
115125
)
116126
)
127+
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800
128+
117129
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.google_drive_source.resource_ids = [
118130
GoogleDriveSource.ResourceId(
119131
resource_id=TEST_DRIVE_FILE_ID,
Collapse file

‎tests/unit/vertex_rag/test_rag_data.py‎

Copy file name to clipboardExpand all lines: tests/unit/vertex_rag/test_rag_data.py
+45-1Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vertexai.preview import rag
2020
from vertexai.preview.rag.utils._gapic_utils import (
2121
prepare_import_files_request,
22+
set_embedding_model_config,
2223
)
2324
from google.cloud.aiplatform_v1beta1 import (
2425
VertexRagDataServiceAsyncClient,
@@ -171,7 +172,10 @@ def teardown_method(self):
171172

172173
@pytest.mark.usefixtures("create_rag_corpus_mock")
173174
def test_create_corpus_success(self):
174-
rag_corpus = rag.create_corpus(display_name=tc.TEST_CORPUS_DISPLAY_NAME)
175+
rag_corpus = rag.create_corpus(
176+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
177+
embedding_model_config=tc.TEST_EMBEDDING_MODEL_CONFIG,
178+
)
175179

176180
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)
177181

@@ -391,6 +395,7 @@ def test_prepare_import_files_request_drive_files(self):
391395
paths=paths,
392396
chunk_size=tc.TEST_CHUNK_SIZE,
393397
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
398+
max_embedding_requests_per_min=800,
394399
)
395400
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FILE)
396401

@@ -415,3 +420,42 @@ def test_prepare_import_files_request_invalid_path(self):
415420
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
416421
)
417422
e.match("path must be a Google Cloud Storage uri or a Google Drive url")
423+
424+
def test_set_embedding_model_config_set_both_error(self):
425+
embedding_model_config = rag.EmbeddingModelConfig(
426+
publisher_model="whatever",
427+
endpoint="whatever",
428+
)
429+
with pytest.raises(ValueError) as e:
430+
set_embedding_model_config(
431+
embedding_model_config,
432+
tc.TEST_GAPIC_RAG_CORPUS,
433+
)
434+
e.match("publisher_model and endpoint cannot be set at the same time")
435+
436+
def test_set_embedding_model_config_not_set_error(self):
437+
embedding_model_config = rag.EmbeddingModelConfig()
438+
with pytest.raises(ValueError) as e:
439+
set_embedding_model_config(
440+
embedding_model_config,
441+
tc.TEST_GAPIC_RAG_CORPUS,
442+
)
443+
e.match("At least one of publisher_model and endpoint must be set")
444+
445+
def test_set_embedding_model_config_wrong_publisher_model_format_error(self):
446+
embedding_model_config = rag.EmbeddingModelConfig(publisher_model="whatever")
447+
with pytest.raises(ValueError) as e:
448+
set_embedding_model_config(
449+
embedding_model_config,
450+
tc.TEST_GAPIC_RAG_CORPUS,
451+
)
452+
e.match("publisher_model must be of the format ")
453+
454+
def test_set_embedding_model_config_wrong_endpoint_format_error(self):
455+
embedding_model_config = rag.EmbeddingModelConfig(endpoint="whatever")
456+
with pytest.raises(ValueError) as e:
457+
set_embedding_model_config(
458+
embedding_model_config,
459+
tc.TEST_GAPIC_RAG_CORPUS,
460+
)
461+
e.match("endpoint must be of the format ")
Collapse file

‎vertexai/preview/rag/__init__.py‎

Copy file name to clipboardExpand all lines: vertexai/preview/rag/__init__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
VertexRagStore,
3838
)
3939
from vertexai.preview.rag.utils.resources import (
40+
EmbeddingModelConfig,
4041
RagResource,
4142
)
4243

@@ -53,6 +54,7 @@
5354
"list_files",
5455
"delete_file",
5556
"retrieval_query",
57+
"EmbeddingModelConfig",
5658
"Retrieval",
5759
"VertexRagStore",
5860
"RagResource",
Collapse file

‎vertexai/preview/rag/rag_data.py‎

Copy file name to clipboardExpand all lines: vertexai/preview/rag/rag_data.py
+33-1Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@
4343
_gapic_utils,
4444
)
4545
from vertexai.preview.rag.utils.resources import (
46+
EmbeddingModelConfig,
4647
RagCorpus,
4748
RagFile,
4849
)
4950

5051

5152
def create_corpus(
52-
display_name: Optional[str] = None, description: Optional[str] = None
53+
display_name: Optional[str] = None,
54+
description: Optional[str] = None,
55+
embedding_model_config: Optional[EmbeddingModelConfig] = None,
5356
) -> RagCorpus:
5457
"""Creates a new RagCorpus resource.
5558
@@ -69,6 +72,7 @@ def create_corpus(
6972
the RagCorpus. The name can be up to 128 characters long and can
7073
consist of any UTF-8 characters.
7174
description: The description of the RagCorpus.
75+
embedding_model_config: The embedding model config.
7276
Returns:
7377
RagCorpus.
7478
Raises:
@@ -80,6 +84,12 @@ def create_corpus(
8084
parent = initializer.global_config.common_location_path(project=None, location=None)
8185

8286
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
87+
if embedding_model_config:
88+
rag_corpus = _gapic_utils.set_embedding_model_config(
89+
embedding_model_config,
90+
rag_corpus,
91+
)
92+
8393
request = CreateRagCorpusRequest(
8494
parent=parent,
8595
rag_corpus=rag_corpus,
@@ -264,6 +274,7 @@ def import_files(
264274
chunk_size: int = 1024,
265275
chunk_overlap: int = 200,
266276
timeout: int = 600,
277+
max_embedding_requests_per_min: int = 1000,
267278
) -> ImportRagFilesResponse:
268279
"""
269280
Import files to an existing RagCorpus, wait until completion.
@@ -299,6 +310,15 @@ def import_files(
299310
"https://drive.google.com/corp/drive/folders/...").
300311
chunk_size: The size of the chunks.
301312
chunk_overlap: The overlap between chunks.
313+
max_embedding_requests_per_min:
314+
Optional. The max number of queries per
315+
minute that this job is allowed to make to the
316+
embedding model specified on the corpus. This
317+
value is specific to this job and not shared
318+
across other import jobs. Consult the Quotas
319+
page on the project to set an appropriate value
320+
here. If unspecified, a default value of 1,000
321+
QPM would be used.
302322
timeout: Default is 600 seconds.
303323
Returns:
304324
ImportRagFilesResponse.
@@ -309,6 +329,7 @@ def import_files(
309329
paths=paths,
310330
chunk_size=chunk_size,
311331
chunk_overlap=chunk_overlap,
332+
max_embedding_requests_per_min=max_embedding_requests_per_min,
312333
)
313334
client = _gapic_utils.create_rag_data_service_client()
314335
try:
@@ -324,6 +345,7 @@ async def import_files_async(
324345
paths: Sequence[str],
325346
chunk_size: int = 1024,
326347
chunk_overlap: int = 200,
348+
max_embedding_requests_per_min: int = 1000,
327349
) -> operation_async.AsyncOperation:
328350
"""
329351
Import files to an existing RagCorpus asynchronously.
@@ -361,6 +383,15 @@ async def import_files_async(
361383
"https://drive.google.com/corp/drive/folders/...").
362384
chunk_size: The size of the chunks.
363385
chunk_overlap: The overlap between chunks.
386+
max_embedding_requests_per_min:
387+
Optional. The max number of queries per
388+
minute that this job is allowed to make to the
389+
embedding model specified on the corpus. This
390+
value is specific to this job and not shared
391+
across other import jobs. Consult the Quotas
392+
page on the project to set an appropriate value
393+
here. If unspecified, a default value of 1,000
394+
QPM would be used.
364395
Returns:
365396
operation_async.AsyncOperation.
366397
"""
@@ -370,6 +401,7 @@ async def import_files_async(
370401
paths=paths,
371402
chunk_size=chunk_size,
372403
chunk_overlap=chunk_overlap,
404+
max_embedding_requests_per_min=max_embedding_requests_per_min,
373405
)
374406
async_client = _gapic_utils.create_rag_data_service_async_client()
375407
try:
Collapse file

‎vertexai/preview/rag/utils/_gapic_utils.py‎

Copy file name to clipboardExpand all lines: vertexai/preview/rag/utils/_gapic_utils.py
+98-1Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import re
1818
from typing import Any, Dict, Sequence, Union
1919
from google.cloud.aiplatform_v1beta1 import (
20+
RagEmbeddingModelConfig,
2021
GoogleDriveSource,
2122
ImportRagFilesConfig,
2223
ImportRagFilesRequest,
@@ -31,6 +32,7 @@
3132
VertexRagClientWithOverride,
3233
)
3334
from vertexai.preview.rag.utils.resources import (
35+
EmbeddingModelConfig,
3436
RagCorpus,
3537
RagFile,
3638
)
@@ -57,12 +59,43 @@ def create_rag_service_client():
5759
)
5860

5961

62+
def convert_gapic_to_embedding_model_config(
63+
gapic_embedding_model_config: RagEmbeddingModelConfig,
64+
) -> EmbeddingModelConfig:
65+
"""Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig."""
66+
embedding_model_config = EmbeddingModelConfig()
67+
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
68+
publisher_model = re.match(
69+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
70+
path,
71+
)
72+
endpoint = re.match(
73+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
74+
path,
75+
)
76+
if publisher_model:
77+
embedding_model_config.publisher_model = path
78+
if endpoint:
79+
embedding_model_config.endpoint = path
80+
embedding_model_config.model = (
81+
gapic_embedding_model_config.vertex_prediction_endpoint.model
82+
)
83+
embedding_model_config.model_version_id = (
84+
gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id
85+
)
86+
87+
return embedding_model_config
88+
89+
6090
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
6191
""" "Convert GapicRagCorpus to RagCorpus."""
6292
rag_corpus = RagCorpus(
6393
name=gapic_rag_corpus.name,
6494
display_name=gapic_rag_corpus.display_name,
6595
description=gapic_rag_corpus.description,
96+
embedding_model_config=convert_gapic_to_embedding_model_config(
97+
gapic_rag_corpus.rag_embedding_model_config
98+
),
6699
)
67100
return rag_corpus
68101

@@ -124,6 +157,7 @@ def prepare_import_files_request(
124157
paths: Sequence[str],
125158
chunk_size: int = 1024,
126159
chunk_overlap: int = 200,
160+
max_embedding_requests_per_min: int = 1000,
127161
) -> ImportRagFilesRequest:
128162
if len(corpus_name.split("/")) != 6:
129163
raise ValueError(
@@ -135,7 +169,8 @@ def prepare_import_files_request(
135169
chunk_overlap=chunk_overlap,
136170
)
137171
import_rag_files_config = ImportRagFilesConfig(
138-
rag_file_chunking_config=rag_file_chunking_config
172+
rag_file_chunking_config=rag_file_chunking_config,
173+
max_embedding_requests_per_min=max_embedding_requests_per_min,
139174
)
140175

141176
uris = []
@@ -204,3 +239,65 @@ def get_file_name(
204239
raise ValueError(
205240
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
206241
)
242+
243+
244+
def set_embedding_model_config(
245+
embedding_model_config: EmbeddingModelConfig,
246+
rag_corpus: GapicRagCorpus,
247+
) -> GapicRagCorpus:
248+
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
249+
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
250+
if (
251+
not embedding_model_config.publisher_model
252+
and not embedding_model_config.endpoint
253+
):
254+
raise ValueError("At least one of publisher_model and endpoint must be set.")
255+
parent = initializer.global_config.common_location_path(project=None, location=None)
256+
257+
if embedding_model_config.publisher_model:
258+
publisher_model = embedding_model_config.publisher_model
259+
full_resource_name = re.match(
260+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
261+
publisher_model,
262+
)
263+
resource_name = re.match(
264+
r"^publishers/google/models/(?P<model_id>.+?)$",
265+
publisher_model,
266+
)
267+
if full_resource_name:
268+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
269+
publisher_model
270+
)
271+
elif resource_name:
272+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
273+
parent + "/" + publisher_model
274+
)
275+
else:
276+
raise ValueError(
277+
"publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
278+
)
279+
280+
if embedding_model_config.endpoint:
281+
endpoint = embedding_model_config.endpoint
282+
full_resource_name = re.match(
283+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
284+
endpoint,
285+
)
286+
resource_name = re.match(
287+
r"^endpoints/(?P<endpoint>.+?)$",
288+
endpoint,
289+
)
290+
if full_resource_name:
291+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
292+
endpoint
293+
)
294+
elif resource_name:
295+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
296+
parent + "/" + endpoint
297+
)
298+
else:
299+
raise ValueError(
300+
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
301+
)
302+
303+
return rag_corpus

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.