Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 146f9a5

Browse filesBrowse files
decode openai token ids
1 parent 78eea51 commit 146f9a5
Copy full SHA for 146f9a5

File tree

Expand file treeCollapse file tree

1 file changed

+25
-1
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+25
-1
lines changed

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+25-1Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,12 @@ def iterator() -> Iterator[llama_cpp.CompletionChunk]:
664664

665665
class CreateEmbeddingRequest(BaseModel):
666666
model: Optional[str] = model_field
667-
input: Union[str, List[str]] = Field(description="The input to embed.")
667+
input: Union[ # make this accept both string and integers
668+
str,
669+
List[str],
670+
List[int],
671+
List[List[int]],
672+
] = Field(description="The input to embed.")
668673
user: Optional[str] = Field(default=None)
669674

670675
model_config = {
@@ -677,13 +682,32 @@ class CreateEmbeddingRequest(BaseModel):
677682
}
678683
}
679684

685+
import tiktoken
686+
openai_tiktoken_encoding = tiktoken.get_encoding("cl100k_base")
680687

681688
@router.post(
682689
"/v1/embeddings",
683690
)
684691
async def create_embedding(
685692
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
686693
):
694+
# Force input to be in a list
695+
if isinstance(request.input, list):
696+
if isinstance(request.input[0], int): # a single list of int (token id(s))
697+
list_input: List = [request.input]
698+
else:
699+
list_input: List = request.input # list of string or list of list of token id(s)
700+
else:
701+
list_input: List = [request.input] # a single string
702+
703+
# Force input to be a list of str (decode with tiktoken if it's from python langchain's OpenAIEmbedding)
704+
request.input = (
705+
[
706+
openai_tiktoken_encoding.decode(tokenArr) for tokenArr in list_input
707+
]
708+
if isinstance(list_input[0], list) # is a list of array of token id(s)
709+
else list_input # This is a list of strings
710+
) # type: ignore
687711
return await run_in_threadpool(
688712
llama.create_embedding, **request.model_dump(exclude={"user"})
689713
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.