Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7499fc1

Browse filesBrowse files
authored
Merge pull request abetlen#126 from Stonelinks/deprecate-example-server
Deprecate example server
2 parents 1971514 + 0fcc25c commit 7499fc1
Copy full SHA for 7499fc1

File tree

Expand file treeCollapse file tree

1 file changed

+19
-244
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+19
-244
lines changed

‎examples/high_level_api/fastapi_server.py

Copy file name to clipboardExpand all lines: examples/high_level_api/fastapi_server.py
+19-244Lines changed: 19 additions & 244 deletions
Original file line numberDiff line numberDiff line change
@@ -4,259 +4,34 @@
44
55
```bash
66
pip install fastapi uvicorn sse-starlette
7-
export MODEL=../models/7B/ggml-model.bin
8-
uvicorn fastapi_server_chat:app --reload
7+
export MODEL=../models/7B/...
98
```
109
11-
Then visit http://localhost:8000/docs to see the interactive API docs.
12-
13-
"""
14-
import os
15-
import json
16-
from typing import List, Optional, Literal, Union, Iterator, Dict
17-
from typing_extensions import TypedDict
18-
19-
import llama_cpp
20-
21-
from fastapi import FastAPI
22-
from fastapi.middleware.cors import CORSMiddleware
23-
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
24-
from sse_starlette.sse import EventSourceResponse
25-
26-
27-
class Settings(BaseSettings):
28-
model: str
29-
n_ctx: int = 2048
30-
n_batch: int = 8
31-
n_threads: int = int(os.cpu_count() / 2) or 1
32-
f16_kv: bool = True
33-
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
34-
embedding: bool = True
35-
last_n_tokens_size: int = 64
36-
37-
38-
app = FastAPI(
39-
title="🦙 llama.cpp Python API",
40-
version="0.0.1",
41-
)
42-
app.add_middleware(
43-
CORSMiddleware,
44-
allow_origins=["*"],
45-
allow_credentials=True,
46-
allow_methods=["*"],
47-
allow_headers=["*"],
48-
)
49-
settings = Settings()
50-
llama = llama_cpp.Llama(
51-
settings.model,
52-
f16_kv=settings.f16_kv,
53-
use_mlock=settings.use_mlock,
54-
embedding=settings.embedding,
55-
n_threads=settings.n_threads,
56-
n_batch=settings.n_batch,
57-
n_ctx=settings.n_ctx,
58-
last_n_tokens_size=settings.last_n_tokens_size,
59-
)
60-
61-
62-
class CreateCompletionRequest(BaseModel):
63-
prompt: str
64-
suffix: Optional[str] = Field(None)
65-
max_tokens: int = 16
66-
temperature: float = 0.8
67-
top_p: float = 0.95
68-
echo: bool = False
69-
stop: List[str] = []
70-
stream: bool = False
71-
72-
# ignored or currently unsupported
73-
model: Optional[str] = Field(None)
74-
n: Optional[int] = 1
75-
logprobs: Optional[int] = Field(None)
76-
presence_penalty: Optional[float] = 0
77-
frequency_penalty: Optional[float] = 0
78-
best_of: Optional[int] = 1
79-
logit_bias: Optional[Dict[str, float]] = Field(None)
80-
user: Optional[str] = Field(None)
81-
82-
# llama.cpp specific parameters
83-
top_k: int = 40
84-
repeat_penalty: float = 1.1
85-
86-
class Config:
87-
schema_extra = {
88-
"example": {
89-
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
90-
"stop": ["\n", "###"],
91-
}
92-
}
93-
94-
95-
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
96-
97-
98-
@app.post(
99-
"/v1/completions",
100-
response_model=CreateCompletionResponse,
101-
)
102-
def create_completion(request: CreateCompletionRequest):
103-
if request.stream:
104-
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
105-
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
106-
return llama(
107-
**request.dict(
108-
exclude={
109-
"model",
110-
"n",
111-
"logprobs",
112-
"frequency_penalty",
113-
"presence_penalty",
114-
"best_of",
115-
"logit_bias",
116-
"user",
117-
}
118-
)
119-
)
120-
121-
122-
class CreateEmbeddingRequest(BaseModel):
123-
model: Optional[str]
124-
input: str
125-
user: Optional[str]
126-
127-
class Config:
128-
schema_extra = {
129-
"example": {
130-
"input": "The food was delicious and the waiter...",
131-
}
132-
}
133-
134-
135-
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
136-
137-
138-
@app.post(
139-
"/v1/embeddings",
140-
response_model=CreateEmbeddingResponse,
141-
)
142-
def create_embedding(request: CreateEmbeddingRequest):
143-
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
144-
145-
146-
class ChatCompletionRequestMessage(BaseModel):
147-
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
148-
content: str
149-
user: Optional[str] = None
150-
151-
152-
class CreateChatCompletionRequest(BaseModel):
153-
model: Optional[str]
154-
messages: List[ChatCompletionRequestMessage]
155-
temperature: float = 0.8
156-
top_p: float = 0.95
157-
stream: bool = False
158-
stop: List[str] = []
159-
max_tokens: int = 128
160-
161-
# ignored or currently unsupported
162-
model: Optional[str] = Field(None)
163-
n: Optional[int] = 1
164-
presence_penalty: Optional[float] = 0
165-
frequency_penalty: Optional[float] = 0
166-
logit_bias: Optional[Dict[str, float]] = Field(None)
167-
user: Optional[str] = Field(None)
168-
169-
# llama.cpp specific parameters
170-
repeat_penalty: float = 1.1
171-
172-
class Config:
173-
schema_extra = {
174-
"example": {
175-
"messages": [
176-
ChatCompletionRequestMessage(
177-
role="system", content="You are a helpful assistant."
178-
),
179-
ChatCompletionRequestMessage(
180-
role="user", content="What is the capital of France?"
181-
),
182-
]
183-
}
184-
}
185-
186-
187-
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
188-
189-
190-
@app.post(
191-
"/v1/chat/completions",
192-
response_model=CreateChatCompletionResponse,
193-
)
194-
async def create_chat_completion(
195-
request: CreateChatCompletionRequest,
196-
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
197-
completion_or_chunks = llama.create_chat_completion(
198-
**request.dict(
199-
exclude={
200-
"model",
201-
"n",
202-
"presence_penalty",
203-
"frequency_penalty",
204-
"logit_bias",
205-
"user",
206-
}
207-
),
208-
)
209-
210-
if request.stream:
211-
212-
async def server_sent_events(
213-
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
214-
):
215-
for chat_chunk in chat_chunks:
216-
yield dict(data=json.dumps(chat_chunk))
217-
yield dict(data="[DONE]")
218-
219-
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
220-
221-
return EventSourceResponse(
222-
server_sent_events(chunks),
223-
)
224-
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
225-
return completion
226-
227-
228-
class ModelData(TypedDict):
229-
id: str
230-
object: Literal["model"]
231-
owned_by: str
232-
permissions: List[str]
10+
Then run:
11+
```
12+
uvicorn llama_cpp.server.app:app --reload
13+
```
23314
15+
or
23416
235-
class ModelList(TypedDict):
236-
object: Literal["list"]
237-
data: List[ModelData]
17+
```
18+
python3 -m llama_cpp.server
19+
```
23820
21+
Then visit http://localhost:8000/docs to see the interactive API docs.
23922
240-
GetModelResponse = create_model_from_typeddict(ModelList)
24123
24+
To actually see the implementation of the server, see llama_cpp/server/app.py
24225
243-
@app.get("/v1/models", response_model=GetModelResponse)
244-
def get_models() -> ModelList:
245-
return {
246-
"object": "list",
247-
"data": [
248-
{
249-
"id": llama.model_path,
250-
"object": "model",
251-
"owned_by": "me",
252-
"permissions": [],
253-
}
254-
],
255-
}
26+
"""
27+
import os
28+
import uvicorn
25629

30+
from llama_cpp.server.app import create_app
25731

25832
if __name__ == "__main__":
259-
import os
260-
import uvicorn
33+
app = create_app()
26134

262-
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=os.getenv("PORT", 8000))
35+
uvicorn.run(
36+
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
37+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.