Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 79ba9ed

Browse filesBrowse files
authored
Merge pull request abetlen#125 from Stonelinks/app-server-module-importable
Make app server module importable
2 parents 755f9fa + efe8e6f commit 79ba9ed
Copy full SHA for 79ba9ed

File tree

Expand file treeCollapse file tree

6 files changed

+401
-270
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+401
-270
lines changed

‎llama_cpp/server/__init__.py

Copy file name to clipboardExpand all lines: llama_cpp/server/__init__.py
Whitespace-only changes.

‎llama_cpp/server/__main__.py

Copy file name to clipboardExpand all lines: llama_cpp/server/__main__.py
+14-268Lines changed: 14 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -5,283 +5,29 @@
55
```bash
66
pip install fastapi uvicorn sse-starlette
77
export MODEL=../models/7B/...
8-
uvicorn fastapi_server_chat:app --reload
98
```
109
11-
Then visit http://localhost:8000/docs to see the interactive API docs.
12-
13-
"""
14-
import os
15-
import json
16-
from threading import Lock
17-
from typing import List, Optional, Literal, Union, Iterator, Dict
18-
from typing_extensions import TypedDict
19-
20-
import llama_cpp
21-
22-
from fastapi import Depends, FastAPI
23-
from fastapi.middleware.cors import CORSMiddleware
24-
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
25-
from sse_starlette.sse import EventSourceResponse
26-
27-
28-
class Settings(BaseSettings):
29-
model: str
30-
n_ctx: int = 2048
31-
n_batch: int = 512
32-
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
33-
f16_kv: bool = True
34-
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
35-
use_mmap: bool = True
36-
embedding: bool = True
37-
last_n_tokens_size: int = 64
38-
logits_all: bool = False
39-
cache: bool = False # WARNING: This is an experimental feature
40-
41-
42-
app = FastAPI(
43-
title="🦙 llama.cpp Python API",
44-
version="0.0.1",
45-
)
46-
app.add_middleware(
47-
CORSMiddleware,
48-
allow_origins=["*"],
49-
allow_credentials=True,
50-
allow_methods=["*"],
51-
allow_headers=["*"],
52-
)
53-
settings = Settings()
54-
llama = llama_cpp.Llama(
55-
settings.model,
56-
f16_kv=settings.f16_kv,
57-
use_mlock=settings.use_mlock,
58-
use_mmap=settings.use_mmap,
59-
embedding=settings.embedding,
60-
logits_all=settings.logits_all,
61-
n_threads=settings.n_threads,
62-
n_batch=settings.n_batch,
63-
n_ctx=settings.n_ctx,
64-
last_n_tokens_size=settings.last_n_tokens_size,
65-
)
66-
if settings.cache:
67-
cache = llama_cpp.LlamaCache()
68-
llama.set_cache(cache)
69-
llama_lock = Lock()
70-
71-
72-
def get_llama():
73-
with llama_lock:
74-
yield llama
75-
76-
77-
class CreateCompletionRequest(BaseModel):
78-
prompt: Union[str, List[str]]
79-
suffix: Optional[str] = Field(None)
80-
max_tokens: int = 16
81-
temperature: float = 0.8
82-
top_p: float = 0.95
83-
echo: bool = False
84-
stop: Optional[List[str]] = []
85-
stream: bool = False
86-
87-
# ignored or currently unsupported
88-
model: Optional[str] = Field(None)
89-
n: Optional[int] = 1
90-
logprobs: Optional[int] = Field(None)
91-
presence_penalty: Optional[float] = 0
92-
frequency_penalty: Optional[float] = 0
93-
best_of: Optional[int] = 1
94-
logit_bias: Optional[Dict[str, float]] = Field(None)
95-
user: Optional[str] = Field(None)
96-
97-
# llama.cpp specific parameters
98-
top_k: int = 40
99-
repeat_penalty: float = 1.1
100-
101-
class Config:
102-
schema_extra = {
103-
"example": {
104-
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
105-
"stop": ["\n", "###"],
106-
}
107-
}
108-
109-
110-
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
111-
112-
113-
@app.post(
114-
"/v1/completions",
115-
response_model=CreateCompletionResponse,
116-
)
117-
def create_completion(
118-
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
119-
):
120-
if isinstance(request.prompt, list):
121-
request.prompt = "".join(request.prompt)
122-
123-
completion_or_chunks = llama(
124-
**request.dict(
125-
exclude={
126-
"model",
127-
"n",
128-
"frequency_penalty",
129-
"presence_penalty",
130-
"best_of",
131-
"logit_bias",
132-
"user",
133-
}
134-
)
135-
)
136-
if request.stream:
137-
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
138-
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
139-
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
140-
return completion
141-
142-
143-
class CreateEmbeddingRequest(BaseModel):
144-
model: Optional[str]
145-
input: str
146-
user: Optional[str]
147-
148-
class Config:
149-
schema_extra = {
150-
"example": {
151-
"input": "The food was delicious and the waiter...",
152-
}
153-
}
154-
155-
156-
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
157-
158-
159-
@app.post(
160-
"/v1/embeddings",
161-
response_model=CreateEmbeddingResponse,
162-
)
163-
def create_embedding(
164-
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
165-
):
166-
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
167-
168-
169-
class ChatCompletionRequestMessage(BaseModel):
170-
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
171-
content: str
172-
user: Optional[str] = None
173-
174-
175-
class CreateChatCompletionRequest(BaseModel):
176-
model: Optional[str]
177-
messages: List[ChatCompletionRequestMessage]
178-
temperature: float = 0.8
179-
top_p: float = 0.95
180-
stream: bool = False
181-
stop: Optional[List[str]] = []
182-
max_tokens: int = 128
183-
184-
# ignored or currently unsupported
185-
model: Optional[str] = Field(None)
186-
n: Optional[int] = 1
187-
presence_penalty: Optional[float] = 0
188-
frequency_penalty: Optional[float] = 0
189-
logit_bias: Optional[Dict[str, float]] = Field(None)
190-
user: Optional[str] = Field(None)
191-
192-
# llama.cpp specific parameters
193-
repeat_penalty: float = 1.1
194-
195-
class Config:
196-
schema_extra = {
197-
"example": {
198-
"messages": [
199-
ChatCompletionRequestMessage(
200-
role="system", content="You are a helpful assistant."
201-
),
202-
ChatCompletionRequestMessage(
203-
role="user", content="What is the capital of France?"
204-
),
205-
]
206-
}
207-
}
208-
209-
210-
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
211-
212-
213-
@app.post(
214-
"/v1/chat/completions",
215-
response_model=CreateChatCompletionResponse,
216-
)
217-
def create_chat_completion(
218-
request: CreateChatCompletionRequest,
219-
llama: llama_cpp.Llama = Depends(get_llama),
220-
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
221-
completion_or_chunks = llama.create_chat_completion(
222-
**request.dict(
223-
exclude={
224-
"model",
225-
"n",
226-
"presence_penalty",
227-
"frequency_penalty",
228-
"logit_bias",
229-
"user",
230-
}
231-
),
232-
)
233-
234-
if request.stream:
235-
236-
async def server_sent_events(
237-
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
238-
):
239-
for chat_chunk in chat_chunks:
240-
yield dict(data=json.dumps(chat_chunk))
241-
yield dict(data="[DONE]")
242-
243-
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
244-
245-
return EventSourceResponse(
246-
server_sent_events(chunks),
247-
)
248-
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
249-
return completion
250-
251-
252-
class ModelData(TypedDict):
253-
id: str
254-
object: Literal["model"]
255-
owned_by: str
256-
permissions: List[str]
257-
258-
259-
class ModelList(TypedDict):
260-
object: Literal["list"]
261-
data: List[ModelData]
10+
Then run:
11+
```
12+
uvicorn llama_cpp.server.app:app --reload
13+
```
26214
15+
or
26316
264-
GetModelResponse = create_model_from_typeddict(ModelList)
17+
```
18+
python3 -m llama_cpp.server
19+
```
26520
21+
Then visit http://localhost:8000/docs to see the interactive API docs.
26622
267-
@app.get("/v1/models", response_model=GetModelResponse)
268-
def get_models() -> ModelList:
269-
return {
270-
"object": "list",
271-
"data": [
272-
{
273-
"id": llama.model_path,
274-
"object": "model",
275-
"owned_by": "me",
276-
"permissions": [],
277-
}
278-
],
279-
}
23+
"""
24+
import os
25+
import uvicorn
28026

27+
from llama_cpp.server.app import app, init_llama
28128

28229
if __name__ == "__main__":
283-
import os
284-
import uvicorn
30+
init_llama()
28531

28632
uvicorn.run(
28733
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.