@@ -717,10 +717,44 @@ def create_embedding(
717
717
Returns:
718
718
An embedding object.
719
719
"""
720
- assert self ._ctx .ctx is not None
721
720
assert self ._model .model is not None
722
721
model_name : str = model if model is not None else self .model_path
723
722
723
+ # get numeric embeddings
724
+ embeds , total_tokens = self .embed (input , return_count = True )
725
+
726
+ # convert to CreateEmbeddingResponse
727
+ data = [
728
+ {
729
+ "object" : "embedding" ,
730
+ "embedding" : emb ,
731
+ "index" : idx ,
732
+ } for idx , emb in enumerate (embeds )
733
+ ]
734
+
735
+ return {
736
+ "object" : "list" ,
737
+ "data" : data ,
738
+ "model" : model_name ,
739
+ "usage" : {
740
+ "prompt_tokens" : total_tokens ,
741
+ "total_tokens" : total_tokens ,
742
+ },
743
+ }
744
+
745
+ def embed (self , input : str , normalize : bool = True , truncate : bool = True , return_count : bool = False ) -> List [float ]:
746
+ """Embed a string.
747
+
748
+ Args:
749
+ input: The utf-8 encoded string to embed.
750
+
751
+ Returns:
752
+ A list of embeddings
753
+ """
754
+ assert self ._ctx .ctx is not None
755
+ n_embd = self .n_embd ()
756
+ n_ctx = self .n_ctx ()
757
+
724
758
if self .context_params .embedding == False :
725
759
raise RuntimeError (
726
760
"Llama model must be created with embedding=True to call this method"
@@ -734,48 +768,68 @@ def create_embedding(
734
768
else :
735
769
inputs = input
736
770
771
+ def normalize (x ):
772
+ norm = np .linalg .norm (x )
773
+ return [v / norm for v in x ]
774
+
775
+ # reset batch
776
+ self ._batch .reset ()
777
+
778
+ # decode and fetch embeddings
737
779
data : List [Embedding ] = []
780
+ def decode_batch (n_seq ):
781
+ llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
782
+ self ._ctx .decode (self ._batch )
783
+ self ._batch .reset ()
784
+
785
+ # store embeddings
786
+ for i in range (n_seq ):
787
+ embedding = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[:n_embd ]
788
+ if normalize :
789
+ embedding = normalize (embedding )
790
+ data .append (embedding )
791
+
792
+ # init state
738
793
total_tokens = 0
739
- for index , input in enumerate (inputs ):
740
- tokens = self .tokenize (input .encode ("utf-8" ), special = True )
741
- self .reset ()
742
- self .eval (tokens )
794
+ p_batch = 0
795
+ t_batch = 0
796
+
797
+ # accumulate batches and encode
798
+ for text in inputs :
799
+ tokens = self .tokenize (text .encode ("utf-8" ))
800
+ if truncate :
801
+ tokens = tokens [:n_ctx ]
743
802
n_tokens = len (tokens )
744
- total_tokens += n_tokens
745
- embedding = llama_cpp .llama_get_embeddings (self ._ctx .ctx )[
746
- : llama_cpp .llama_n_embd (self ._model .model )
747
- ]
748
803
749
- data .append (
750
- {
751
- "object" : "embedding" ,
752
- "embedding" : embedding ,
753
- "index" : index ,
754
- }
755
- )
756
- if self .verbose :
757
- llama_cpp .llama_print_timings (self ._ctx .ctx )
804
+ # check for overrun
805
+ if n_tokens > n_ctx :
806
+ raise ValueError (
807
+ f"Requested tokens ({ n_tokens } ) exceed context window of { n_ctx } "
808
+ )
758
809
759
- return {
760
- "object" : "list" ,
761
- "data" : data ,
762
- "model" : model_name ,
763
- "usage" : {
764
- "prompt_tokens" : total_tokens ,
765
- "total_tokens" : total_tokens ,
766
- },
767
- }
810
+ # time to eval batch
811
+ if n_tokens + t_batch > self ._n_ctx :
812
+ decode_batch (p_batch )
813
+ total_tokens += t_batch
814
+ p_batch = 0
815
+ t_batch = 0
768
816
769
- def embed (self , input : str ) -> List [float ]:
770
- """Embed a string.
817
+ # add to batch
818
+ self ._batch .add_sequence (tokens , p_batch , False )
819
+ p_batch += 1
820
+ t_batch += n_tokens
771
821
772
- Args:
773
- input: The utf-8 encoded string to embed.
822
+ # hanlde last batch
823
+ decode_batch (p_batch )
824
+ total_tokens += t_batch
774
825
775
- Returns:
776
- A list of embeddings
777
- """
778
- return list (map (float , self .create_embedding (input )["data" ][0 ]["embedding" ]))
826
+ if self .verbose :
827
+ llama_cpp .llama_print_timings (self ._ctx .ctx )
828
+
829
+ if return_count :
830
+ return data , total_tokens
831
+ else :
832
+ return data
779
833
780
834
def _create_completion (
781
835
self ,
0 commit comments