@@ -531,7 +531,9 @@ def generate(
531
531
if tokens_or_none is not None :
532
532
tokens .extend (tokens_or_none )
533
533
534
- def create_embedding (self , input : str , model : Optional [str ] = None ) -> Embedding :
534
+ def create_embedding (
535
+ self , input : Union [str , List [str ]], model : Optional [str ] = None
536
+ ) -> Embedding :
535
537
"""Embed a string.
536
538
537
539
Args:
@@ -551,30 +553,40 @@ def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding
551
553
if self .verbose :
552
554
llama_cpp .llama_reset_timings (self .ctx )
553
555
554
- tokens = self .tokenize (input .encode ("utf-8" ))
555
- self .reset ()
556
- self .eval (tokens )
557
- n_tokens = len (tokens )
558
- embedding = llama_cpp .llama_get_embeddings (self .ctx )[
559
- : llama_cpp .llama_n_embd (self .ctx )
560
- ]
556
+ if isinstance (input , str ):
557
+ inputs = [input ]
558
+ else :
559
+ inputs = input
561
560
562
- if self .verbose :
563
- llama_cpp .llama_print_timings (self .ctx )
561
+ data = []
562
+ total_tokens = 0
563
+ for input in inputs :
564
+ tokens = self .tokenize (input .encode ("utf-8" ))
565
+ self .reset ()
566
+ self .eval (tokens )
567
+ n_tokens = len (tokens )
568
+ total_tokens += n_tokens
569
+ embedding = llama_cpp .llama_get_embeddings (self .ctx )[
570
+ : llama_cpp .llama_n_embd (self .ctx )
571
+ ]
564
572
565
- return {
566
- "object" : "list" ,
567
- " data" : [
573
+ if self . verbose :
574
+ llama_cpp . llama_print_timings ( self . ctx )
575
+ data . append (
568
576
{
569
577
"object" : "embedding" ,
570
578
"embedding" : embedding ,
571
579
"index" : 0 ,
572
580
}
573
- ],
574
- "model" : model_name ,
581
+ )
582
+
583
+ return {
584
+ "object" : "list" ,
585
+ "data" : data ,
586
+ "model" : self .model_path ,
575
587
"usage" : {
576
- "prompt_tokens" : n_tokens ,
577
- "total_tokens" : n_tokens ,
588
+ "prompt_tokens" : total_tokens ,
589
+ "total_tokens" : total_tokens ,
578
590
},
579
591
}
580
592
0 commit comments