1414import anyio
1515from anyio .streams .memory import MemoryObjectSendStream
1616from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
17- from fastapi import Depends , FastAPI , APIRouter , Request , Response
17+ from fastapi import Depends , FastAPI , APIRouter , Request , Response , HTTPException , status
1818from fastapi .middleware import Middleware
1919from fastapi .middleware .cors import CORSMiddleware
2020from fastapi .responses import JSONResponse
2121from fastapi .routing import APIRoute
22+ from fastapi .security import HTTPBearer
2223from pydantic import BaseModel , Field
2324from pydantic_settings import BaseSettings
2425from sse_starlette .sse import EventSourceResponse
@@ -163,6 +164,10 @@ class Settings(BaseSettings):
163164 default = True ,
164165 description = "Whether to interrupt requests when a new request is received." ,
165166 )
167+ api_key : Optional [str ] = Field (
168+ default = None ,
169+ description = "API key for authentication. If set all requests need to be authenticated."
170+ )
166171
167172
168173class ErrorResponse (TypedDict ):
@@ -314,6 +319,9 @@ async def custom_route_handler(request: Request) -> Response:
314319 elapsed_time_ms = int ((time .perf_counter () - start_sec ) * 1000 )
315320 response .headers ["openai-processing-ms" ] = f"{ elapsed_time_ms } "
316321 return response
322+ except HTTPException as unauthorized :
323+ # api key check failed
324+ raise unauthorized
317325 except Exception as exc :
318326 json_body = await request .json ()
319327 try :
@@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids(
658666 return to_bias
659667
660668
669+ # Setup Bearer authentication scheme
670+ bearer_scheme = HTTPBearer (auto_error = False )
671+
672+
673+ async def authenticate (settings : Settings = Depends (get_settings ), authorization : Optional [str ] = Depends (bearer_scheme )):
674+ # Skip API key check if it's not set in settings
675+ if settings .api_key is None :
676+ return True
677+
678+ # check bearer credentials against the api_key
679+ if authorization and authorization .credentials == settings .api_key :
680+ # api key is valid
681+ return authorization .credentials
682+
683+ # raise http error 401
684+ raise HTTPException (
685+ status_code = status .HTTP_401_UNAUTHORIZED ,
686+ detail = "Invalid API key" ,
687+ )
688+
689+
661690@router .post (
662691 "/v1/completions" ,
663692 summary = "Completion"
@@ -667,6 +696,7 @@ async def create_completion(
667696 request : Request ,
668697 body : CreateCompletionRequest ,
669698 llama : llama_cpp .Llama = Depends (get_llama ),
699+ authenticated : str = Depends (authenticate ),
670700) -> llama_cpp .Completion :
671701 if isinstance (body .prompt , list ):
672702 assert len (body .prompt ) <= 1
@@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel):
740770 summary = "Embedding"
741771)
742772async def create_embedding (
743- request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
773+ request : CreateEmbeddingRequest ,
774+ llama : llama_cpp .Llama = Depends (get_llama ),
775+ authenticated : str = Depends (authenticate ),
744776):
745777 return await run_in_threadpool (
746778 llama .create_embedding , ** request .model_dump (exclude = {"user" })
@@ -834,6 +866,7 @@ async def create_chat_completion(
834866 body : CreateChatCompletionRequest ,
835867 llama : llama_cpp .Llama = Depends (get_llama ),
836868 settings : Settings = Depends (get_settings ),
869+ authenticated : str = Depends (authenticate ),
837870) -> llama_cpp .ChatCompletion :
838871 exclude = {
839872 "n" ,
@@ -895,6 +928,7 @@ class ModelList(TypedDict):
895928@router .get ("/v1/models" , summary = "Models" )
896929async def get_models (
897930 settings : Settings = Depends (get_settings ),
931+ authenticated : str = Depends (authenticate ),
898932) -> ModelList :
899933 assert llama is not None
900934 return {
0 commit comments