Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a335292

Browse filesBrowse files
committed
Add model_alias option to override model_path in completions. Closes abetlen#39
1 parent 214589e commit a335292
Copy full SHA for a335292

File tree

Expand file treeCollapse file tree

2 files changed

+34
-9
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+34
-9
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+14-5Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def generate(
522522
if tokens_or_none is not None:
523523
tokens.extend(tokens_or_none)
524524

525-
def create_embedding(self, input: str) -> Embedding:
525+
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
526526
"""Embed a string.
527527
528528
Args:
@@ -532,6 +532,7 @@ def create_embedding(self, input: str) -> Embedding:
532532
An embedding object.
533533
"""
534534
assert self.ctx is not None
535+
_model: str = model if model is not None else self.model_path
535536

536537
if self.params.embedding == False:
537538
raise RuntimeError(
@@ -561,7 +562,7 @@ def create_embedding(self, input: str) -> Embedding:
561562
"index": 0,
562563
}
563564
],
564-
"model": self.model_path,
565+
"model": _model,
565566
"usage": {
566567
"prompt_tokens": n_tokens,
567568
"total_tokens": n_tokens,
@@ -598,6 +599,7 @@ def _create_completion(
598599
mirostat_mode: int = 0,
599600
mirostat_tau: float = 5.0,
600601
mirostat_eta: float = 0.1,
602+
model: Optional[str] = None,
601603
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
602604
assert self.ctx is not None
603605
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -610,6 +612,7 @@ def _create_completion(
610612
text: bytes = b""
611613
returned_characters: int = 0
612614
stop = stop if stop is not None else []
615+
_model: str = model if model is not None else self.model_path
613616

614617
if self.verbose:
615618
llama_cpp.llama_reset_timings(self.ctx)
@@ -708,7 +711,7 @@ def _create_completion(
708711
"id": completion_id,
709712
"object": "text_completion",
710713
"created": created,
711-
"model": self.model_path,
714+
"model": _model,
712715
"choices": [
713716
{
714717
"text": text[start:].decode("utf-8", errors="ignore"),
@@ -737,7 +740,7 @@ def _create_completion(
737740
"id": completion_id,
738741
"object": "text_completion",
739742
"created": created,
740-
"model": self.model_path,
743+
"model": _model,
741744
"choices": [
742745
{
743746
"text": text[returned_characters:].decode(
@@ -807,7 +810,7 @@ def _create_completion(
807810
"id": completion_id,
808811
"object": "text_completion",
809812
"created": created,
810-
"model": self.model_path,
813+
"model": _model,
811814
"choices": [
812815
{
813816
"text": text_str,
@@ -842,6 +845,7 @@ def create_completion(
842845
mirostat_mode: int = 0,
843846
mirostat_tau: float = 5.0,
844847
mirostat_eta: float = 0.1,
848+
model: Optional[str] = None,
845849
) -> Union[Completion, Iterator[CompletionChunk]]:
846850
"""Generate text from a prompt.
847851
@@ -883,6 +887,7 @@ def create_completion(
883887
mirostat_mode=mirostat_mode,
884888
mirostat_tau=mirostat_tau,
885889
mirostat_eta=mirostat_eta,
890+
model=model,
886891
)
887892
if stream:
888893
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -909,6 +914,7 @@ def __call__(
909914
mirostat_mode: int = 0,
910915
mirostat_tau: float = 5.0,
911916
mirostat_eta: float = 0.1,
917+
model: Optional[str] = None,
912918
) -> Union[Completion, Iterator[CompletionChunk]]:
913919
"""Generate text from a prompt.
914920
@@ -950,6 +956,7 @@ def __call__(
950956
mirostat_mode=mirostat_mode,
951957
mirostat_tau=mirostat_tau,
952958
mirostat_eta=mirostat_eta,
959+
model=model,
953960
)
954961

955962
def _convert_text_completion_to_chat(
@@ -1026,6 +1033,7 @@ def create_chat_completion(
10261033
mirostat_mode: int = 0,
10271034
mirostat_tau: float = 5.0,
10281035
mirostat_eta: float = 0.1,
1036+
model: Optional[str] = None,
10291037
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
10301038
"""Generate a chat completion from a list of messages.
10311039
@@ -1064,6 +1072,7 @@ def create_chat_completion(
10641072
mirostat_mode=mirostat_mode,
10651073
mirostat_tau=mirostat_tau,
10661074
mirostat_eta=mirostat_eta,
1075+
model=model,
10671076
)
10681077
if stream:
10691078
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+20-4Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class Settings(BaseSettings):
1616
model: str = Field(
1717
description="The path to the model to use for generating completions."
1818
)
19+
model_alias: Optional[str] = Field(
20+
default=None,
21+
description="The alias of the model to use for generating completions.",
22+
)
1923
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
2024
n_gpu_layers: int = Field(
2125
default=0,
@@ -64,6 +68,7 @@ class Settings(BaseSettings):
6468

6569
router = APIRouter()
6670

71+
settings: Optional[Settings] = None
6772
llama: Optional[llama_cpp.Llama] = None
6873

6974

@@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
101106
if settings.cache:
102107
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
103108
llama.set_cache(cache)
109+
110+
def set_settings(_settings: Settings):
111+
global settings
112+
settings = _settings
113+
114+
set_settings(settings)
104115
return app
105116

106117

@@ -112,6 +123,10 @@ def get_llama():
112123
yield llama
113124

114125

126+
def get_settings():
127+
yield settings
128+
129+
115130
model_field = Field(description="The model to use for generating completions.")
116131

117132
max_tokens_field = Field(
@@ -236,7 +251,6 @@ def create_completion(
236251
completion_or_chunks = llama(
237252
**request.dict(
238253
exclude={
239-
"model",
240254
"n",
241255
"best_of",
242256
"logit_bias",
@@ -274,7 +288,7 @@ class Config:
274288
def create_embedding(
275289
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
276290
):
277-
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
291+
return llama.create_embedding(**request.dict(exclude={"user"}))
278292

279293

280294
class ChatCompletionRequestMessage(BaseModel):
@@ -335,7 +349,6 @@ def create_chat_completion(
335349
completion_or_chunks = llama.create_chat_completion(
336350
**request.dict(
337351
exclude={
338-
"model",
339352
"n",
340353
"logit_bias",
341354
"user",
@@ -378,13 +391,16 @@ class ModelList(TypedDict):
378391

379392
@router.get("/v1/models", response_model=GetModelResponse)
380393
def get_models(
394+
settings: Settings = Depends(get_settings),
381395
llama: llama_cpp.Llama = Depends(get_llama),
382396
) -> ModelList:
383397
return {
384398
"object": "list",
385399
"data": [
386400
{
387-
"id": llama.model_path,
401+
"id": settings.model_alias
402+
if settings.model_alias is not None
403+
else llama.model_path,
388404
"object": "model",
389405
"owned_by": "me",
390406
"permissions": [],

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.