Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit aa33b20

Browse filesBrowse files
jsaied99jas8dz
and
jas8dz
authored
Adding filtering based on metadata using kwargs (#682)
Co-authored-by: jas8dz <jsaied@mail.missouri.edu>
1 parent 3f712a6 commit aa33b20
Copy full SHA for aa33b20

File tree

2 files changed

+14
-2
lines changed
Filter options

2 files changed

+14
-2
lines changed

‎pgml-sdks/python/pgml/examples/question_answering.py

Copy file name to clipboardExpand all lines: pgml-sdks/python/pgml/examples/question_answering.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
start = time()
3535
query = "Who won 20 grammy awards?"
36-
results = collection.vector_search(query, top_k=5)
36+
results = collection.vector_search(query, top_k=5, title="Beyoncé")
3737
_end = time()
3838
console.print("\nResults for '%s'" % (query), style="bold")
3939
console.print(results)

‎pgml-sdks/python/pgml/pgml/collection.py

Copy file name to clipboardExpand all lines: pgml-sdks/python/pgml/pgml/collection.py
+13-1Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ def vector_search(
683683
top_k: int = 5,
684684
model_id: int = 1,
685685
splitter_id: int = 1,
686+
**kwargs: Any,
686687
) -> List[Dict[str, Any]]:
687688
"""
688689
This function performs a vector search on a database using a query and returns the top matching
@@ -702,6 +703,9 @@ def vector_search(
702703
splitter used to split the documents into chunks. It is used to retrieve the embeddings table
703704
associated with the specified splitter, defaults to 1
704705
:type splitter_id: int (optional)
706+
:param kwargs: Additional filtering parameters to be used in the search query. These parameters
707+
are from the metadata of the documents and can be used to filter the search results based on
708+
metadata values.
705709
:return: a list of dictionaries containing search results for a given query. Each dictionary
706710
contains the following keys: "score", "text", and "metadata". The "score" key contains a float
707711
value representing the similarity score between the query and the search result. The "text" key
@@ -749,6 +753,13 @@ def vector_search(
749753
% (model_id, splitter_id, model_id, splitter_id)
750754
)
751755
return []
756+
757+
if kwargs:
758+
metadata_filter = [f"documents.metadata->>'{k}' = '{v}'" if isinstance(v, str) else f"documents.metadata->>'{k}' = {v}" for k, v in kwargs.items()]
759+
metadata_filter = " AND ".join(metadata_filter)
760+
metadata_filter = f"AND {metadata_filter}"
761+
else:
762+
metadata_filter = ""
752763

753764
cte_select_statement = """
754765
WITH query_cte AS (
@@ -764,7 +775,7 @@ def vector_search(
764775
SELECT cte.score, chunks.chunk, documents.metadata
765776
FROM cte
766777
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
767-
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id;
778+
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id {metadata_filter}
768779
""".format(
769780
model=sql.Literal(model).as_string(conn),
770781
query_text=query,
@@ -773,6 +784,7 @@ def vector_search(
773784
top_k=top_k,
774785
chunks_table=self.chunks_table,
775786
documents_table=self.documents_table,
787+
metadata_filter=metadata_filter,
776788
)
777789

778790
search_results = run_select_statement(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.