@@ -79,6 +79,7 @@ def __init__(
79
79
n_threads : Optional [int ] = None ,
80
80
n_threads_batch : Optional [int ] = None ,
81
81
rope_scaling_type : Optional [int ] = llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ,
82
+ pooling_type : int = llama_cpp .LLAMA_POOLING_TYPE_MEAN ,
82
83
rope_freq_base : float = 0.0 ,
83
84
rope_freq_scale : float = 0.0 ,
84
85
yarn_ext_factor : float = - 1.0 ,
@@ -151,6 +152,7 @@ def __init__(
151
152
n_threads: Number of threads to use for generation
152
153
n_threads_batch: Number of threads to use for batch processing
153
154
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
155
+ pooling_type: Pooling type, from `enum llama_pooling_type`.
154
156
rope_freq_base: RoPE base frequency, 0 = from model
155
157
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
156
158
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -271,6 +273,7 @@ def __init__(
271
273
if rope_scaling_type is not None
272
274
else llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
273
275
)
276
+ self .context_params .pooling_type = pooling_type
274
277
self .context_params .rope_freq_base = (
275
278
rope_freq_base if rope_freq_base != 0.0 else 0
276
279
)
@@ -814,9 +817,12 @@ def decode_batch(n_seq: int):
814
817
815
818
# store embeddings
816
819
for i in range (n_seq ):
817
- embedding : List [ float ] = llama_cpp .llama_get_embeddings_seq (
820
+ ptr = llama_cpp .llama_get_embeddings_seq (
818
821
self ._ctx .ctx , i
819
- )[:n_embd ]
822
+ )
823
+ if not ptr :
824
+ raise RuntimeError ("Failed to get embeddings from sequence pooling type is not set" )
825
+ embedding : List [float ] = ptr [:n_embd ]
820
826
if normalize :
821
827
norm = float (np .linalg .norm (embedding ))
822
828
embedding = [v / norm for v in embedding ]
0 commit comments