Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e783f1c

Browse filesBrowse files
committed
feat: make embedding support list of string as input
makes the /v1/embedding route similar to OpenAI api.
1 parent 01a010b commit e783f1c
Copy full SHA for e783f1c

File tree

Expand file treeCollapse file tree

2 files changed

+30
-18
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+30
-18
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+29-17Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,9 @@ def generate(
531531
if tokens_or_none is not None:
532532
tokens.extend(tokens_or_none)
533533

534-
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
534+
def create_embedding(
535+
self, input: Union[str, List[str]], model: Optional[str] = None
536+
) -> Embedding:
535537
"""Embed a string.
536538
537539
Args:
@@ -551,30 +553,40 @@ def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding
551553
if self.verbose:
552554
llama_cpp.llama_reset_timings(self.ctx)
553555

554-
tokens = self.tokenize(input.encode("utf-8"))
555-
self.reset()
556-
self.eval(tokens)
557-
n_tokens = len(tokens)
558-
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
559-
: llama_cpp.llama_n_embd(self.ctx)
560-
]
556+
if isinstance(input, str):
557+
inputs = [input]
558+
else:
559+
inputs = input
561560

562-
if self.verbose:
563-
llama_cpp.llama_print_timings(self.ctx)
561+
data = []
562+
total_tokens = 0
563+
for input in inputs:
564+
tokens = self.tokenize(input.encode("utf-8"))
565+
self.reset()
566+
self.eval(tokens)
567+
n_tokens = len(tokens)
568+
total_tokens += n_tokens
569+
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
570+
: llama_cpp.llama_n_embd(self.ctx)
571+
]
564572

565-
return {
566-
"object": "list",
567-
"data": [
573+
if self.verbose:
574+
llama_cpp.llama_print_timings(self.ctx)
575+
data.append(
568576
{
569577
"object": "embedding",
570578
"embedding": embedding,
571579
"index": 0,
572580
}
573-
],
574-
"model": model_name,
581+
)
582+
583+
return {
584+
"object": "list",
585+
"data": data,
586+
"model": self.model_path,
575587
"usage": {
576-
"prompt_tokens": n_tokens,
577-
"total_tokens": n_tokens,
588+
"prompt_tokens": total_tokens,
589+
"total_tokens": total_tokens,
578590
},
579591
}
580592

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ async def server_sent_events(
275275

276276
class CreateEmbeddingRequest(BaseModel):
277277
model: Optional[str] = model_field
278-
input: str = Field(description="The input to embed.")
278+
input: Union[str, List[str]] = Field(description="The input to embed.")
279279
user: Optional[str]
280280

281281
class Config:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.