Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9eafc4c

Browse filesBrowse files
committed
Refactor server to use factory
1 parent dd9ad1c commit 9eafc4c
Copy full SHA for 9eafc4c

File tree

Expand file treeCollapse file tree

3 files changed

+47
-31
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+47
-31
lines changed

‎llama_cpp/server/__main__.py

Copy file name to clipboardExpand all lines: llama_cpp/server/__main__.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import os
2525
import uvicorn
2626

27-
from llama_cpp.server.app import app, init_llama
27+
from llama_cpp.server.app import create_app
2828

2929
if __name__ == "__main__":
30-
init_llama()
30+
app = create_app()
3131

3232
uvicorn.run(
3333
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+30-21Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import json
33
from threading import Lock
44
from typing import List, Optional, Union, Iterator, Dict
5-
from typing_extensions import TypedDict, Literal
5+
from typing_extensions import TypedDict, Literal, Annotated
66

77
import llama_cpp
88

9-
from fastapi import Depends, FastAPI
9+
from fastapi import Depends, FastAPI, APIRouter
1010
from fastapi.middleware.cors import CORSMiddleware
1111
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
1212
from sse_starlette.sse import EventSourceResponse
1313

1414

1515
class Settings(BaseSettings):
16-
model: str = os.environ.get("MODEL", "null")
16+
model: str
1717
n_ctx: int = 2048
1818
n_batch: int = 512
1919
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
2727
vocab_only: bool = False
2828

2929

30-
app = FastAPI(
31-
title="🦙 llama.cpp Python API",
32-
version="0.0.1",
33-
)
34-
app.add_middleware(
35-
CORSMiddleware,
36-
allow_origins=["*"],
37-
allow_credentials=True,
38-
allow_methods=["*"],
39-
allow_headers=["*"],
40-
)
30+
router = APIRouter()
31+
32+
llama: Optional[llama_cpp.Llama] = None
4133

42-
llama: llama_cpp.Llama = None
43-
def init_llama(settings: Settings = None):
34+
35+
def create_app(settings: Optional[Settings] = None):
4436
if settings is None:
4537
settings = Settings()
38+
app = FastAPI(
39+
title="🦙 llama.cpp Python API",
40+
version="0.0.1",
41+
)
42+
app.add_middleware(
43+
CORSMiddleware,
44+
allow_origins=["*"],
45+
allow_credentials=True,
46+
allow_methods=["*"],
47+
allow_headers=["*"],
48+
)
49+
app.include_router(router)
4650
global llama
4751
llama = llama_cpp.Llama(
48-
settings.model,
52+
model_path=settings.model,
4953
f16_kv=settings.f16_kv,
5054
use_mlock=settings.use_mlock,
5155
use_mmap=settings.use_mmap,
@@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
6064
if settings.cache:
6165
cache = llama_cpp.LlamaCache()
6266
llama.set_cache(cache)
67+
return app
68+
6369

6470
llama_lock = Lock()
71+
72+
6573
def get_llama():
6674
with llama_lock:
6775
yield llama
6876

77+
6978
class CreateCompletionRequest(BaseModel):
7079
prompt: Union[str, List[str]]
7180
suffix: Optional[str] = Field(None)
@@ -102,7 +111,7 @@ class Config:
102111
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
103112

104113

105-
@app.post(
114+
@router.post(
106115
"/v1/completions",
107116
response_model=CreateCompletionResponse,
108117
)
@@ -148,7 +157,7 @@ class Config:
148157
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
149158

150159

151-
@app.post(
160+
@router.post(
152161
"/v1/embeddings",
153162
response_model=CreateEmbeddingResponse,
154163
)
@@ -202,7 +211,7 @@ class Config:
202211
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
203212

204213

205-
@app.post(
214+
@router.post(
206215
"/v1/chat/completions",
207216
response_model=CreateChatCompletionResponse,
208217
)
@@ -256,7 +265,7 @@ class ModelList(TypedDict):
256265
GetModelResponse = create_model_from_typeddict(ModelList)
257266

258267

259-
@app.get("/v1/models", response_model=GetModelResponse)
268+
@router.get("/v1/models", response_model=GetModelResponse)
260269
def get_models() -> ModelList:
261270
return {
262271
"object": "list",

‎tests/test_llama.py

Copy file name to clipboardExpand all lines: tests/test_llama.py
+15-8Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
2222
## Set up mock function
2323
def mock_eval(*args, **kwargs):
2424
return 0
25-
25+
2626
def mock_get_logits(*args, **kwargs):
27-
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
27+
return (llama_cpp.c_float * n_vocab)(
28+
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
29+
)
2830

2931
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
3032
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -88,6 +90,7 @@ def mock_sample(*args, **kwargs):
8890
def test_llama_pickle():
8991
import pickle
9092
import tempfile
93+
9194
fp = tempfile.TemporaryFile()
9295
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
9396
pickle.dump(llama, fp)
@@ -101,6 +104,7 @@ def test_llama_pickle():
101104

102105
assert llama.detokenize(llama.tokenize(text)) == text
103106

107+
104108
def test_utf8(monkeypatch):
105109
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
106110
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@@ -110,7 +114,9 @@ def mock_eval(*args, **kwargs):
110114
return 0
111115

112116
def mock_get_logits(*args, **kwargs):
113-
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
117+
return (llama_cpp.c_float * n_vocab)(
118+
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
119+
)
114120

115121
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
116122
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -143,11 +149,12 @@ def mock_sample(*args, **kwargs):
143149

144150
def test_llama_server():
145151
from fastapi.testclient import TestClient
146-
from llama_cpp.server.app import app, init_llama, Settings
147-
s = Settings()
148-
s.model = MODEL
149-
s.vocab_only = True
150-
init_llama(s)
152+
from llama_cpp.server.app import create_app, Settings
153+
154+
settings = Settings()
155+
settings.model = MODEL
156+
settings.vocab_only = True
157+
app = create_app(settings)
151158
client = TestClient(app)
152159
response = client.get("/v1/models")
153160
assert response.json() == {

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.