13
13
"""
14
14
import os
15
15
import json
16
+ from threading import Lock
16
17
from typing import List , Optional , Literal , Union , Iterator , Dict
17
18
from typing_extensions import TypedDict
18
19
19
20
import llama_cpp
20
21
21
- from fastapi import FastAPI
22
+ from fastapi import Depends , FastAPI
22
23
from fastapi .middleware .cors import CORSMiddleware
23
24
from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
24
25
from sse_starlette .sse import EventSourceResponse
@@ -59,6 +60,13 @@ class Settings(BaseSettings):
59
60
n_ctx = settings .n_ctx ,
60
61
last_n_tokens_size = settings .last_n_tokens_size ,
61
62
)
63
+ llama_lock = Lock ()
64
+
65
+
66
+ def get_llama ():
67
+ with llama_lock :
68
+ yield llama
69
+
62
70
63
71
64
72
class CreateCompletionRequest (BaseModel ):
@@ -101,7 +109,7 @@ class Config:
101
109
"/v1/completions" ,
102
110
response_model = CreateCompletionResponse ,
103
111
)
104
- def create_completion (request : CreateCompletionRequest ):
112
+ def create_completion (request : CreateCompletionRequest , llama : llama_cpp . Llama = Depends ( get_llama ) ):
105
113
if isinstance (request .prompt , list ):
106
114
request .prompt = "" .join (request .prompt )
107
115
@@ -146,7 +154,7 @@ class Config:
146
154
"/v1/embeddings" ,
147
155
response_model = CreateEmbeddingResponse ,
148
156
)
149
- def create_embedding (request : CreateEmbeddingRequest ):
157
+ def create_embedding (request : CreateEmbeddingRequest , llama : llama_cpp . Llama = Depends ( get_llama ) ):
150
158
return llama .create_embedding (** request .dict (exclude = {"model" , "user" }))
151
159
152
160
@@ -200,6 +208,7 @@ class Config:
200
208
)
201
209
def create_chat_completion (
202
210
request : CreateChatCompletionRequest ,
211
+ llama : llama_cpp .Llama = Depends (get_llama ),
203
212
) -> Union [llama_cpp .ChatCompletion , EventSourceResponse ]:
204
213
completion_or_chunks = llama .create_chat_completion (
205
214
** request .dict (
0 commit comments