Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d318cc8

Browse filesBrowse files
committed
fix: Set default pooling_type to mean, check for null pointer.
1 parent dd0ee56 commit d318cc8
Copy full SHA for d318cc8

File tree

Expand file treeCollapse file tree

2 files changed

+8
-3
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+8
-3
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+8-2Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
n_threads: Optional[int] = None,
8080
n_threads_batch: Optional[int] = None,
8181
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
82+
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_MEAN,
8283
rope_freq_base: float = 0.0,
8384
rope_freq_scale: float = 0.0,
8485
yarn_ext_factor: float = -1.0,
@@ -151,6 +152,7 @@ def __init__(
151152
n_threads: Number of threads to use for generation
152153
n_threads_batch: Number of threads to use for batch processing
153154
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
155+
pooling_type: Pooling type, from `enum llama_pooling_type`.
154156
rope_freq_base: RoPE base frequency, 0 = from model
155157
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
156158
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -271,6 +273,7 @@ def __init__(
271273
if rope_scaling_type is not None
272274
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
273275
)
276+
self.context_params.pooling_type = pooling_type
274277
self.context_params.rope_freq_base = (
275278
rope_freq_base if rope_freq_base != 0.0 else 0
276279
)
@@ -814,9 +817,12 @@ def decode_batch(n_seq: int):
814817

815818
# store embeddings
816819
for i in range(n_seq):
817-
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
820+
ptr = llama_cpp.llama_get_embeddings_seq(
818821
self._ctx.ctx, i
819-
)[:n_embd]
822+
)
823+
if not ptr:
824+
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
825+
embedding: List[float] = ptr[:n_embd]
820826
if normalize:
821827
norm = float(np.linalg.norm(embedding))
822828
embedding = [v / norm for v in embedding]

‎llama_cpp/llama_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cpp.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,6 @@ class llama_model_params(ctypes.Structure):
579579
# bool embeddings; // if true, extract embeddings (together with logits)
580580
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
581581

582-
583582
# // Abort callback
584583
# // if it returns true, execution of llama_decode() will be aborted
585584
# // currently works only with CPU execution

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.