Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ba35aef

Browse filesBrowse files
committed
handle batched embeddings
1 parent 07a7837 commit ba35aef
Copy full SHA for ba35aef

File tree

Expand file treeCollapse file tree

2 files changed

+111
-35
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+111
-35
lines changed

‎llama_cpp/_internals.py

Copy file name to clipboardExpand all lines: llama_cpp/_internals.py
+22Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,14 @@ def __del__(self):
506506
self._llama_batch_free(self.batch)
507507
self.batch = None
508508

509+
def n_tokens(self) -> int:
510+
assert self.batch is not None
511+
return self.batch.n_tokens
512+
513+
def reset(self):
514+
assert self.batch is not None
515+
self.batch.n_tokens = 0
516+
509517
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
510518
assert self.batch is not None
511519
n_tokens = len(batch)
@@ -518,6 +526,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
518526
self.batch.logits[i] = logits_all
519527
self.batch.logits[n_tokens - 1] = True
520528

529+
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
530+
assert self.batch is not None
531+
n_tokens = len(batch)
532+
n_tokens0 = self.batch.n_tokens
533+
self.batch.n_tokens += n_tokens
534+
for i in range(n_tokens):
535+
j = n_tokens0 + i
536+
self.batch.token[j] = batch[i]
537+
self.batch.pos[j] = i
538+
self.batch.seq_id[j][0] = seq_id
539+
self.batch.n_seq_id[j] = 1
540+
self.batch.logits[j] = logits_all
541+
self.batch.logits[n_tokens - 1] = True
542+
521543

522544
class _LlamaTokenDataArray:
523545
def __init__(self, *, n_vocab: int):

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+89-35Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,44 @@ def create_embedding(
717717
Returns:
718718
An embedding object.
719719
"""
720-
assert self._ctx.ctx is not None
721720
assert self._model.model is not None
722721
model_name: str = model if model is not None else self.model_path
723722

723+
# get numeric embeddings
724+
embeds, total_tokens = self.embed(input, return_count=True)
725+
726+
# convert to CreateEmbeddingResponse
727+
data = [
728+
{
729+
"object": "embedding",
730+
"embedding": emb,
731+
"index": idx,
732+
} for idx, emb in enumerate(embeds)
733+
]
734+
735+
return {
736+
"object": "list",
737+
"data": data,
738+
"model": model_name,
739+
"usage": {
740+
"prompt_tokens": total_tokens,
741+
"total_tokens": total_tokens,
742+
},
743+
}
744+
745+
def embed(self, input: str, normalize: bool = True, truncate: bool = True, return_count: bool = False) -> List[float]:
746+
"""Embed a string.
747+
748+
Args:
749+
input: The utf-8 encoded string to embed.
750+
751+
Returns:
752+
A list of embeddings
753+
"""
754+
assert self._ctx.ctx is not None
755+
n_embd = self.n_embd()
756+
n_ctx = self.n_ctx()
757+
724758
if self.context_params.embedding == False:
725759
raise RuntimeError(
726760
"Llama model must be created with embedding=True to call this method"
@@ -734,48 +768,68 @@ def create_embedding(
734768
else:
735769
inputs = input
736770

771+
def normalize(x):
772+
norm = np.linalg.norm(x)
773+
return [v/norm for v in x]
774+
775+
# reset batch
776+
self._batch.reset()
777+
778+
# decode and fetch embeddings
737779
data: List[Embedding] = []
780+
def decode_batch(n_seq):
781+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
782+
self._ctx.decode(self._batch)
783+
self._batch.reset()
784+
785+
# store embeddings
786+
for i in range(n_seq):
787+
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[:n_embd]
788+
if normalize:
789+
embedding = normalize(embedding)
790+
data.append(embedding)
791+
792+
# init state
738793
total_tokens = 0
739-
for index, input in enumerate(inputs):
740-
tokens = self.tokenize(input.encode("utf-8"), special=True)
741-
self.reset()
742-
self.eval(tokens)
794+
p_batch = 0
795+
t_batch = 0
796+
797+
# accumulate batches and encode
798+
for text in inputs:
799+
tokens = self.tokenize(text.encode("utf-8"))
800+
if truncate:
801+
tokens = tokens[:n_ctx]
743802
n_tokens = len(tokens)
744-
total_tokens += n_tokens
745-
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
746-
: llama_cpp.llama_n_embd(self._model.model)
747-
]
748803

749-
data.append(
750-
{
751-
"object": "embedding",
752-
"embedding": embedding,
753-
"index": index,
754-
}
755-
)
756-
if self.verbose:
757-
llama_cpp.llama_print_timings(self._ctx.ctx)
804+
# check for overrun
805+
if n_tokens > n_ctx:
806+
raise ValueError(
807+
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
808+
)
758809

759-
return {
760-
"object": "list",
761-
"data": data,
762-
"model": model_name,
763-
"usage": {
764-
"prompt_tokens": total_tokens,
765-
"total_tokens": total_tokens,
766-
},
767-
}
810+
# time to eval batch
811+
if n_tokens + t_batch > self._n_ctx:
812+
decode_batch(p_batch)
813+
total_tokens += t_batch
814+
p_batch = 0
815+
t_batch = 0
768816

769-
def embed(self, input: str) -> List[float]:
770-
"""Embed a string.
817+
# add to batch
818+
self._batch.add_sequence(tokens, p_batch, False)
819+
p_batch += 1
820+
t_batch += n_tokens
771821

772-
Args:
773-
input: The utf-8 encoded string to embed.
822+
# hanlde last batch
823+
decode_batch(p_batch)
824+
total_tokens += t_batch
774825

775-
Returns:
776-
A list of embeddings
777-
"""
778-
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
826+
if self.verbose:
827+
llama_cpp.llama_print_timings(self._ctx.ctx)
828+
829+
if return_count:
830+
return data, total_tokens
831+
else:
832+
return data
779833

780834
def _create_completion(
781835
self,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.