@@ -522,7 +522,7 @@ def generate(
522
522
if tokens_or_none is not None :
523
523
tokens .extend (tokens_or_none )
524
524
525
- def create_embedding (self , input : str ) -> Embedding :
525
+ def create_embedding (self , input : str , model : Optional [ str ] = None ) -> Embedding :
526
526
"""Embed a string.
527
527
528
528
Args:
@@ -532,6 +532,7 @@ def create_embedding(self, input: str) -> Embedding:
532
532
An embedding object.
533
533
"""
534
534
assert self .ctx is not None
535
+ _model : str = model if model is not None else self .model_path
535
536
536
537
if self .params .embedding == False :
537
538
raise RuntimeError (
@@ -561,7 +562,7 @@ def create_embedding(self, input: str) -> Embedding:
561
562
"index" : 0 ,
562
563
}
563
564
],
564
- "model" : self . model_path ,
565
+ "model" : _model ,
565
566
"usage" : {
566
567
"prompt_tokens" : n_tokens ,
567
568
"total_tokens" : n_tokens ,
@@ -598,6 +599,7 @@ def _create_completion(
598
599
mirostat_mode : int = 0 ,
599
600
mirostat_tau : float = 5.0 ,
600
601
mirostat_eta : float = 0.1 ,
602
+ model : Optional [str ] = None ,
601
603
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
602
604
assert self .ctx is not None
603
605
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
@@ -610,6 +612,7 @@ def _create_completion(
610
612
text : bytes = b""
611
613
returned_characters : int = 0
612
614
stop = stop if stop is not None else []
615
+ _model : str = model if model is not None else self .model_path
613
616
614
617
if self .verbose :
615
618
llama_cpp .llama_reset_timings (self .ctx )
@@ -708,7 +711,7 @@ def _create_completion(
708
711
"id" : completion_id ,
709
712
"object" : "text_completion" ,
710
713
"created" : created ,
711
- "model" : self . model_path ,
714
+ "model" : _model ,
712
715
"choices" : [
713
716
{
714
717
"text" : text [start :].decode ("utf-8" , errors = "ignore" ),
@@ -737,7 +740,7 @@ def _create_completion(
737
740
"id" : completion_id ,
738
741
"object" : "text_completion" ,
739
742
"created" : created ,
740
- "model" : self . model_path ,
743
+ "model" : _model ,
741
744
"choices" : [
742
745
{
743
746
"text" : text [returned_characters :].decode (
@@ -807,7 +810,7 @@ def _create_completion(
807
810
"id" : completion_id ,
808
811
"object" : "text_completion" ,
809
812
"created" : created ,
810
- "model" : self . model_path ,
813
+ "model" : _model ,
811
814
"choices" : [
812
815
{
813
816
"text" : text_str ,
@@ -842,6 +845,7 @@ def create_completion(
842
845
mirostat_mode : int = 0 ,
843
846
mirostat_tau : float = 5.0 ,
844
847
mirostat_eta : float = 0.1 ,
848
+ model : Optional [str ] = None ,
845
849
) -> Union [Completion , Iterator [CompletionChunk ]]:
846
850
"""Generate text from a prompt.
847
851
@@ -883,6 +887,7 @@ def create_completion(
883
887
mirostat_mode = mirostat_mode ,
884
888
mirostat_tau = mirostat_tau ,
885
889
mirostat_eta = mirostat_eta ,
890
+ model = model ,
886
891
)
887
892
if stream :
888
893
chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -909,6 +914,7 @@ def __call__(
909
914
mirostat_mode : int = 0 ,
910
915
mirostat_tau : float = 5.0 ,
911
916
mirostat_eta : float = 0.1 ,
917
+ model : Optional [str ] = None ,
912
918
) -> Union [Completion , Iterator [CompletionChunk ]]:
913
919
"""Generate text from a prompt.
914
920
@@ -950,6 +956,7 @@ def __call__(
950
956
mirostat_mode = mirostat_mode ,
951
957
mirostat_tau = mirostat_tau ,
952
958
mirostat_eta = mirostat_eta ,
959
+ model = model ,
953
960
)
954
961
955
962
def _convert_text_completion_to_chat (
@@ -1026,6 +1033,7 @@ def create_chat_completion(
1026
1033
mirostat_mode : int = 0 ,
1027
1034
mirostat_tau : float = 5.0 ,
1028
1035
mirostat_eta : float = 0.1 ,
1036
+ model : Optional [str ] = None ,
1029
1037
) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
1030
1038
"""Generate a chat completion from a list of messages.
1031
1039
@@ -1064,6 +1072,7 @@ def create_chat_completion(
1064
1072
mirostat_mode = mirostat_mode ,
1065
1073
mirostat_tau = mirostat_tau ,
1066
1074
mirostat_eta = mirostat_eta ,
1075
+ model = model ,
1067
1076
)
1068
1077
if stream :
1069
1078
chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments