Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f300d43

Browse filesBrowse files
committed
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
2 parents c336f78 + d7a6791 commit f300d43
Copy full SHA for f300d43

File tree

Expand file treeCollapse file tree

2 files changed

+123
-34
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+123
-34
lines changed

‎llama_cpp/_internals.py

Copy file name to clipboardExpand all lines: llama_cpp/_internals.py
+22Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,14 @@ def __del__(self):
510510
self._llama_batch_free(self.batch)
511511
self.batch = None
512512

513+
def n_tokens(self) -> int:
514+
assert self.batch is not None
515+
return self.batch.n_tokens
516+
517+
def reset(self):
518+
assert self.batch is not None
519+
self.batch.n_tokens = 0
520+
513521
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
514522
assert self.batch is not None
515523
n_tokens = len(batch)
@@ -522,6 +530,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
522530
self.batch.logits[i] = logits_all
523531
self.batch.logits[n_tokens - 1] = True
524532

533+
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
534+
assert self.batch is not None
535+
n_tokens = len(batch)
536+
n_tokens0 = self.batch.n_tokens
537+
self.batch.n_tokens += n_tokens
538+
for i in range(n_tokens):
539+
j = n_tokens0 + i
540+
self.batch.token[j] = batch[i]
541+
self.batch.pos[j] = i
542+
self.batch.seq_id[j][0] = seq_id
543+
self.batch.n_seq_id[j] = 1
544+
self.batch.logits[j] = logits_all
545+
self.batch.logits[n_tokens - 1] = True
546+
525547

526548
class _LlamaTokenDataArray:
527549
def __init__(self, *, n_vocab: int):

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+101-34Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,53 @@ def create_embedding(
717717
Returns:
718718
An embedding object.
719719
"""
720-
assert self._ctx.ctx is not None
721720
assert self._model.model is not None
722721
model_name: str = model if model is not None else self.model_path
723722

723+
# get numeric embeddings
724+
embeds: List[List[float]]
725+
total_tokens: int
726+
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
727+
728+
# convert to CreateEmbeddingResponse
729+
data: List[Embedding] = [
730+
{
731+
"object": "embedding",
732+
"embedding": emb,
733+
"index": idx,
734+
}
735+
for idx, emb in enumerate(embeds)
736+
]
737+
738+
return {
739+
"object": "list",
740+
"data": data,
741+
"model": model_name,
742+
"usage": {
743+
"prompt_tokens": total_tokens,
744+
"total_tokens": total_tokens,
745+
},
746+
}
747+
748+
def embed(
749+
self,
750+
input: Union[str, List[str]],
751+
normalize: bool = True,
752+
truncate: bool = True,
753+
return_count: bool = False,
754+
):
755+
"""Embed a string.
756+
757+
Args:
758+
input: The utf-8 encoded string to embed.
759+
760+
Returns:
761+
A list of embeddings
762+
"""
763+
assert self._ctx.ctx is not None
764+
n_embd = self.n_embd()
765+
n_ctx = self.n_ctx()
766+
724767
if self.context_params.embedding == False:
725768
raise RuntimeError(
726769
"Llama model must be created with embedding=True to call this method"
@@ -734,48 +777,72 @@ def create_embedding(
734777
else:
735778
inputs = input
736779

737-
data: List[Embedding] = []
780+
# reset batch
781+
self._batch.reset()
782+
783+
# decode and fetch embeddings
784+
data: List[List[float]] = []
785+
def decode_batch(sizes: List[int]):
786+
assert self._ctx.ctx is not None
787+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
788+
self._ctx.decode(self._batch)
789+
self._batch.reset()
790+
791+
# store embeddings
792+
for i, s in enumerate(sizes):
793+
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
794+
:n_embd
795+
]
796+
norm = np.linalg.norm(embedding) if normalize else s
797+
embedding: List[float] = [v / float(norm) for v in embedding]
798+
data.append(embedding)
799+
800+
# init state
738801
total_tokens = 0
739-
for index, input in enumerate(inputs):
740-
tokens = self.tokenize(input.encode("utf-8"), special=True)
741-
self.reset()
742-
self.eval(tokens)
802+
t_batch = 0
803+
s_sizes: List[int] = []
804+
805+
# accumulate batches and encode
806+
for text in inputs:
807+
tokens = self.tokenize(text.encode("utf-8"))
808+
if truncate:
809+
tokens = tokens[:n_ctx]
810+
743811
n_tokens = len(tokens)
744812
total_tokens += n_tokens
745-
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
746-
: llama_cpp.llama_n_embd(self._model.model)
747-
]
748813

749-
data.append(
750-
{
751-
"object": "embedding",
752-
"embedding": embedding,
753-
"index": index,
754-
}
755-
)
814+
# check for overrun
815+
if n_tokens > n_ctx:
816+
raise ValueError(
817+
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
818+
)
819+
820+
# time to eval batch
821+
if t_batch + n_tokens > self._n_ctx:
822+
decode_batch(s_sizes)
823+
t_batch = 0
824+
s_sizes = []
825+
826+
# add to batch
827+
self._batch.add_sequence(tokens, len(s_sizes), False)
828+
t_batch += n_tokens
829+
s_sizes.append(n_tokens)
830+
831+
# hanlde last batch
832+
decode_batch(s_sizes)
833+
756834
if self.verbose:
757835
llama_cpp.llama_print_timings(self._ctx.ctx)
758836

759-
return {
760-
"object": "list",
761-
"data": data,
762-
"model": model_name,
763-
"usage": {
764-
"prompt_tokens": total_tokens,
765-
"total_tokens": total_tokens,
766-
},
767-
}
768-
769-
def embed(self, input: str) -> List[float]:
770-
"""Embed a string.
837+
output = data[0] if isinstance(input, str) else data
771838

772-
Args:
773-
input: The utf-8 encoded string to embed.
839+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
840+
self.reset()
774841

775-
Returns:
776-
A list of embeddings
777-
"""
778-
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
842+
if return_count:
843+
return output, total_tokens
844+
else:
845+
return output
779846

780847
def _create_completion(
781848
self,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.