Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 93dc56a

Browse filesBrowse files
committed
Update llama.cpp
1 parent 87a6e57 commit 93dc56a
Copy full SHA for 93dc56a

File tree

Expand file treeCollapse file tree

3 files changed

+30
-12
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+30
-12
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __init__(
293293
self.context_params.logits_all = (
294294
logits_all if draft_model is None else True
295295
) # Must be set to True for speculative decoding
296-
self.context_params.embedding = embedding
296+
self.context_params.embeddings = embedding # TODO: Rename to embeddings
297297
self.context_params.offload_kqv = offload_kqv
298298

299299
# Sampling Params
@@ -787,7 +787,7 @@ def embed(
787787
n_embd = self.n_embd()
788788
n_batch = self.n_batch
789789

790-
if self.context_params.embedding == False:
790+
if self.context_params.embeddings == False:
791791
raise RuntimeError(
792792
"Llama model must be created with embedding=True to call this method"
793793
)
@@ -1725,7 +1725,7 @@ def __getstate__(self):
17251725
yarn_beta_slow=self.context_params.yarn_beta_slow,
17261726
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
17271727
logits_all=self.context_params.logits_all,
1728-
embedding=self.context_params.embedding,
1728+
embedding=self.context_params.embeddings,
17291729
# Sampling Params
17301730
last_n_tokens_size=self.last_n_tokens_size,
17311731
# LoRA Params

‎llama_cpp/llama_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cpp.py
+26-8Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class llama_token_data_array(ctypes.Structure):
399399
# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
400400
# // - pos : the positions of the respective token in the sequence
401401
# // - seq_id : the sequence to which the respective token belongs
402-
# // - logits : if zero, the logits for the respective token will not be output
402+
# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
403403
# //
404404
# typedef struct llama_batch {
405405
# int32_t n_tokens;
@@ -409,7 +409,7 @@ class llama_token_data_array(ctypes.Structure):
409409
# llama_pos * pos;
410410
# int32_t * n_seq_id;
411411
# llama_seq_id ** seq_id;
412-
# int8_t * logits;
412+
# int8_t * logits; // TODO: rename this to "output"
413413

414414

415415
# // NOTE: helpers for smooth API transition - can be deprecated in the future
@@ -572,7 +572,7 @@ class llama_model_params(ctypes.Structure):
572572

573573
# // Keep the booleans together to avoid misalignment during copy-by-value.
574574
# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
575-
# bool embedding; // embedding mode only
575+
# bool embeddings; // if true, extract embeddings (together with logits)
576576
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
577577

578578
# // Abort callback
@@ -605,7 +605,7 @@ class llama_context_params(ctypes.Structure):
605605
type_k (int): data type for K cache
606606
type_v (int): data type for V cache
607607
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
608-
embedding (bool): embedding mode only
608+
embeddings (bool): if true, extract embeddings (together with logits)
609609
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
610610
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
611611
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
@@ -632,7 +632,7 @@ class llama_context_params(ctypes.Structure):
632632
("type_k", ctypes.c_int),
633633
("type_v", ctypes.c_int),
634634
("logits_all", ctypes.c_bool),
635-
("embedding", ctypes.c_bool),
635+
("embeddings", ctypes.c_bool),
636636
("offload_kqv", ctypes.c_bool),
637637
("abort_callback", ggml_abort_callback),
638638
("abort_callback_data", ctypes.c_void_p),
@@ -1774,8 +1774,8 @@ def llama_get_logits_ith(
17741774
...
17751775

17761776

1777-
# Get the embeddings for the input
1778-
# shape: [n_embd] (1-dimensional)
1777+
# // Get all output token embeddings
1778+
# // shape: [n_tokens*n_embd] (1-dimensional)
17791779
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
17801780
@ctypes_function(
17811781
"llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
@@ -1786,8 +1786,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
17861786
...
17871787

17881788

1789-
# // Get the embeddings for the ith sequence
1789+
# // Get the embeddings for the ith token
17901790
# // llama_get_embeddings(ctx) + i*n_embd
1791+
# // shape: [n_embd] (1-dimensional)
17911792
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
17921793
@ctypes_function(
17931794
"llama_get_embeddings_ith",
@@ -1802,6 +1803,23 @@ def llama_get_embeddings_ith(
18021803
...
18031804

18041805

1806+
# // Get the embeddings for a sequence id
1807+
# // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1808+
# // shape: [n_embd] (1-dimensional)
1809+
# LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1810+
@ctypes_function(
1811+
"llama_get_embeddings_seq",
1812+
[llama_context_p_ctypes, llama_seq_id],
1813+
ctypes.POINTER(ctypes.c_float),
1814+
)
1815+
def llama_get_embeddings_seq(
1816+
ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /
1817+
) -> CtypesArray[ctypes.c_float]:
1818+
"""Get the embeddings for a sequence id
1819+
Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1820+
shape: [n_embd] (1-dimensional)"""
1821+
...
1822+
18051823
# //
18061824
# // Vocab
18071825
# //

‎vendor/llama.cpp

Copy file name to clipboard

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.