@@ -664,7 +664,12 @@ def iterator() -> Iterator[llama_cpp.CompletionChunk]:
664
664
665
665
class CreateEmbeddingRequest (BaseModel ):
666
666
model : Optional [str ] = model_field
667
- input : Union [str , List [str ]] = Field (description = "The input to embed." )
667
+ input : Union [ # make this accept both string and integers
668
+ str ,
669
+ List [str ],
670
+ List [int ],
671
+ List [List [int ]],
672
+ ] = Field (description = "The input to embed." )
668
673
user : Optional [str ] = Field (default = None )
669
674
670
675
model_config = {
@@ -677,13 +682,32 @@ class CreateEmbeddingRequest(BaseModel):
677
682
}
678
683
}
679
684
685
+ import tiktoken
686
+ openai_tiktoken_encoding = tiktoken .get_encoding ("cl100k_base" )
680
687
681
688
@router .post (
682
689
"/v1/embeddings" ,
683
690
)
684
691
async def create_embedding (
685
692
request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
686
693
):
694
+ # Force input to be in a list
695
+ if isinstance (request .input , list ):
696
+ if isinstance (request .input [0 ], int ): # a single list of int (token id(s))
697
+ list_input : List = [request .input ]
698
+ else :
699
+ list_input : List = request .input # list of string or list of list of token id(s)
700
+ else :
701
+ list_input : List = [request .input ] # a single string
702
+
703
+ # Force input to be a list of str (decode with tiktoken if it's from python langchain's OpenAIEmbedding)
704
+ request .input = (
705
+ [
706
+ openai_tiktoken_encoding .decode (tokenArr ) for tokenArr in list_input
707
+ ]
708
+ if isinstance (list_input [0 ], list ) # is a list of array of token id(s)
709
+ else list_input # This is a list of strings
710
+ ) # type: ignore
687
711
return await run_in_threadpool (
688
712
llama .create_embedding , ** request .model_dump (exclude = {"user" })
689
713
)
0 commit comments