@@ -717,10 +717,53 @@ def create_embedding(
717
717
Returns:
718
718
An embedding object.
719
719
"""
720
- assert self ._ctx .ctx is not None
721
720
assert self ._model .model is not None
722
721
model_name : str = model if model is not None else self .model_path
723
722
723
+ # get numeric embeddings
724
+ embeds : List [List [float ]]
725
+ total_tokens : int
726
+ embeds , total_tokens = self .embed (input , return_count = True ) # type: ignore
727
+
728
+ # convert to CreateEmbeddingResponse
729
+ data : List [Embedding ] = [
730
+ {
731
+ "object" : "embedding" ,
732
+ "embedding" : emb ,
733
+ "index" : idx ,
734
+ }
735
+ for idx , emb in enumerate (embeds )
736
+ ]
737
+
738
+ return {
739
+ "object" : "list" ,
740
+ "data" : data ,
741
+ "model" : model_name ,
742
+ "usage" : {
743
+ "prompt_tokens" : total_tokens ,
744
+ "total_tokens" : total_tokens ,
745
+ },
746
+ }
747
+
748
+ def embed (
749
+ self ,
750
+ input : Union [str , List [str ]],
751
+ normalize : bool = True ,
752
+ truncate : bool = True ,
753
+ return_count : bool = False ,
754
+ ):
755
+ """Embed a string.
756
+
757
+ Args:
758
+ input: The utf-8 encoded string to embed.
759
+
760
+ Returns:
761
+ A list of embeddings
762
+ """
763
+ assert self ._ctx .ctx is not None
764
+ n_embd = self .n_embd ()
765
+ n_ctx = self .n_ctx ()
766
+
724
767
if self .context_params .embedding == False :
725
768
raise RuntimeError (
726
769
"Llama model must be created with embedding=True to call this method"
@@ -734,48 +777,72 @@ def create_embedding(
734
777
else :
735
778
inputs = input
736
779
737
- data : List [Embedding ] = []
780
+ # reset batch
781
+ self ._batch .reset ()
782
+
783
+ # decode and fetch embeddings
784
+ data : List [List [float ]] = []
785
+ def decode_batch (sizes : List [int ]):
786
+ assert self ._ctx .ctx is not None
787
+ llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
788
+ self ._ctx .decode (self ._batch )
789
+ self ._batch .reset ()
790
+
791
+ # store embeddings
792
+ for i , s in enumerate (sizes ):
793
+ embedding = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
794
+ :n_embd
795
+ ]
796
+ norm = np .linalg .norm (embedding ) if normalize else s
797
+ embedding : List [float ] = [v / float (norm ) for v in embedding ]
798
+ data .append (embedding )
799
+
800
+ # init state
738
801
total_tokens = 0
739
- for index , input in enumerate (inputs ):
740
- tokens = self .tokenize (input .encode ("utf-8" ), special = True )
741
- self .reset ()
742
- self .eval (tokens )
802
+ t_batch = 0
803
+ s_sizes : List [int ] = []
804
+
805
+ # accumulate batches and encode
806
+ for text in inputs :
807
+ tokens = self .tokenize (text .encode ("utf-8" ))
808
+ if truncate :
809
+ tokens = tokens [:n_ctx ]
810
+
743
811
n_tokens = len (tokens )
744
812
total_tokens += n_tokens
745
- embedding = llama_cpp .llama_get_embeddings (self ._ctx .ctx )[
746
- : llama_cpp .llama_n_embd (self ._model .model )
747
- ]
748
813
749
- data .append (
750
- {
751
- "object" : "embedding" ,
752
- "embedding" : embedding ,
753
- "index" : index ,
754
- }
755
- )
814
+ # check for overrun
815
+ if n_tokens > n_ctx :
816
+ raise ValueError (
817
+ f"Requested tokens ({ n_tokens } ) exceed context window of { n_ctx } "
818
+ )
819
+
820
+ # time to eval batch
821
+ if t_batch + n_tokens > self ._n_ctx :
822
+ decode_batch (s_sizes )
823
+ t_batch = 0
824
+ s_sizes = []
825
+
826
+ # add to batch
827
+ self ._batch .add_sequence (tokens , len (s_sizes ), False )
828
+ t_batch += n_tokens
829
+ s_sizes .append (n_tokens )
830
+
831
+ # hanlde last batch
832
+ decode_batch (s_sizes )
833
+
756
834
if self .verbose :
757
835
llama_cpp .llama_print_timings (self ._ctx .ctx )
758
836
759
- return {
760
- "object" : "list" ,
761
- "data" : data ,
762
- "model" : model_name ,
763
- "usage" : {
764
- "prompt_tokens" : total_tokens ,
765
- "total_tokens" : total_tokens ,
766
- },
767
- }
768
-
769
- def embed (self , input : str ) -> List [float ]:
770
- """Embed a string.
837
+ output = data [0 ] if isinstance (input , str ) else data
771
838
772
- Args:
773
- input: The utf-8 encoded string to embed.
839
+ llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx )
840
+ self . reset ()
774
841
775
- Returns :
776
- A list of embeddings
777
- """
778
- return list ( map ( float , self . create_embedding ( input )[ "data" ][ 0 ][ "embedding" ]))
842
+ if return_count :
843
+ return output , total_tokens
844
+ else :
845
+ return output
779
846
780
847
def _create_completion (
781
848
self ,
0 commit comments