File tree Expand file tree Collapse file tree 2 files changed +17
-1
lines changed
Filter options
Expand file tree Collapse file tree 2 files changed +17
-1
lines changed
Original file line number Diff line number Diff line change @@ -814,7 +814,7 @@ def decode_batch(n_seq: int):
814
814
815
815
# store embeddings
816
816
for i in range (n_seq ):
817
- embedding : List [float ] = llama_cpp .llama_get_embeddings_ith (
817
+ embedding : List [float ] = llama_cpp .llama_get_embeddings_seq (
818
818
self ._ctx .ctx , i
819
819
)[:n_embd ]
820
820
if normalize :
Original file line number Diff line number Diff line change @@ -1803,6 +1803,22 @@ def llama_get_embeddings_ith(
1803
1803
...
1804
1804
1805
1805
1806
+ # // Get the embeddings for sequence seq_id
1807
+ # // shape: [n_embd] (1-dimensional)
1808
+ # LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1809
+ @ctypes_function (
1810
+ "llama_get_embeddings_seq" ,
1811
+ [llama_context_p_ctypes , ctypes .c_int32 ],
1812
+ ctypes .POINTER (ctypes .c_float ),
1813
+ )
1814
+ def llama_get_embeddings_seq (
1815
+ ctx : llama_context_p , i : Union [ctypes .c_int32 , int ], /
1816
+ ) -> CtypesArray [ctypes .c_float ]:
1817
+ """Get the embeddings for sequence seq_id
1818
+ shape: [n_embd] (1-dimensional)"""
1819
+ ...
1820
+
1821
+
1806
1822
# // Get the embeddings for a sequence id
1807
1823
# // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1808
1824
# // shape: [n_embd] (1-dimensional)
You can’t perform that action at this time.
0 commit comments