@@ -67,6 +67,12 @@ def _load_shared_library(lib_base_name):
67
67
_lib = _load_shared_library (_lib_base_name )
68
68
69
69
# C types
70
+ LLAMA_FILE_VERSION = ctypes .c_int (1 )
71
+ LLAMA_FILE_MAGIC = b"ggjt"
72
+ LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
73
+ LLAMA_SESSION_MAGIC = b"ggsn"
74
+ LLAMA_SESSION_VERSION = ctypes .c_int (0 )
75
+
70
76
llama_context_p = c_void_p
71
77
72
78
@@ -77,13 +83,24 @@ def _load_shared_library(lib_base_name):
77
83
class llama_token_data (Structure ):
78
84
_fields_ = [
79
85
("id" , llama_token ), # token id
86
+ ("logit" , c_float ), # log-odds of the token
80
87
("p" , c_float ), # probability of the token
81
- ("plog" , c_float ), # log probability of the token
82
88
]
83
89
84
90
85
91
llama_token_data_p = POINTER (llama_token_data )
86
92
93
+
94
+ class llama_token_data_array (Structure ):
95
+ _fields_ = [
96
+ ("data" , llama_token_data_p ),
97
+ ("size" , c_size_t ),
98
+ ("sorted" , c_bool ),
99
+ ]
100
+
101
+
102
+ llama_token_data_array_p = POINTER (llama_token_data_array )
103
+
87
104
llama_progress_callback = ctypes .CFUNCTYPE (None , c_float , c_void_p )
88
105
89
106
@@ -118,7 +135,7 @@ class llama_context_params(Structure):
118
135
4
119
136
) # tok_embeddings.weight and output.weight are F16
120
137
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes .c_int (5 ) # except 1d tensors
121
- LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes .c_int (6 ) # except 1d tensors
138
+ # LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
122
139
LLAMA_FTYPE_MOSTYL_Q8_0 = ctypes .c_int (7 ) # except 1d tensors
123
140
LLAMA_FTYPE_MOSTYL_Q5_0 = ctypes .c_int (8 ) # except 1d tensors
124
141
LLAMA_FTYPE_MOSTYL_Q5_1 = ctypes .c_int (9 ) # except 1d tensors
@@ -401,31 +418,214 @@ def llama_token_eos() -> llama_token:
401
418
_lib .llama_token_eos .restype = llama_token
402
419
403
420
404
- # TODO: improve the last_n_tokens interface ?
405
- def llama_sample_top_p_top_k (
421
+ def llama_token_nl () -> llama_token :
422
+ return _lib .llama_token_nl ()
423
+
424
+
425
+ _lib .llama_token_nl .argtypes = []
426
+ _lib .llama_token_nl .restype = llama_token
427
+
428
+
429
+ # Sampling functions
430
+ def llama_sample_repetition_penalty (
431
+ ctx : llama_context_p ,
432
+ candidates ,
433
+ last_tokens_data ,
434
+ last_tokens_size : c_int ,
435
+ penalty : c_float ,
436
+ ) -> llama_token :
437
+ return _lib .llama_sample_repetition_penalty (
438
+ ctx , candidates , last_tokens_data , last_tokens_size , penalty
439
+ )
440
+
441
+
442
+ _lib .llama_sample_repetition_penalty .argtypes = [
443
+ llama_context_p ,
444
+ llama_token_data_array_p ,
445
+ llama_token_p ,
446
+ c_int ,
447
+ c_float ,
448
+ ]
449
+ _lib .llama_sample_repetition_penalty .restype = llama_token
450
+
451
+
452
+ # LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
453
+ def llama_sample_frequency_and_presence_penalties (
406
454
ctx : llama_context_p ,
407
- last_n_tokens_data , # type: Array[llama_token]
408
- last_n_tokens_size : c_int ,
409
- top_k : c_int ,
410
- top_p : c_float ,
411
- temp : c_float ,
412
- repeat_penalty : c_float ,
455
+ candidates ,
456
+ last_tokens_data ,
457
+ last_tokens_size : c_int ,
458
+ alpha_frequency : c_float ,
459
+ alpha_presence : c_float ,
413
460
) -> llama_token :
414
- return _lib .llama_sample_top_p_top_k (
415
- ctx , last_n_tokens_data , last_n_tokens_size , top_k , top_p , temp , repeat_penalty
461
+ return _lib .llama_sample_frequency_and_presence_penalties (
462
+ ctx ,
463
+ candidates ,
464
+ last_tokens_data ,
465
+ last_tokens_size ,
466
+ alpha_frequency ,
467
+ alpha_presence ,
416
468
)
417
469
418
470
419
- _lib .llama_sample_top_p_top_k .argtypes = [
471
+ _lib .llama_sample_frequency_and_presence_penalties .argtypes = [
420
472
llama_context_p ,
473
+ llama_token_data_array_p ,
421
474
llama_token_p ,
422
475
c_int ,
476
+ c_float ,
477
+ c_float ,
478
+ ]
479
+ _lib .llama_sample_frequency_and_presence_penalties .restype = llama_token
480
+
481
+
482
+ # LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
483
+ def llama_sample_softmax (ctx : llama_context_p , candidates ) -> llama_token :
484
+ return _lib .llama_sample_softmax (ctx , candidates )
485
+
486
+
487
+ _lib .llama_sample_softmax .argtypes = [
488
+ llama_context_p ,
489
+ llama_token_data_array_p ,
490
+ ]
491
+ _lib .llama_sample_softmax .restype = llama_token
492
+
493
+
494
+ # LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
495
+ def llama_sample_top_k (
496
+ ctx : llama_context_p , candidates , k : c_int , min_keep : c_int
497
+ ) -> llama_token :
498
+ return _lib .llama_sample_top_k (ctx , candidates , k , min_keep )
499
+
500
+
501
+ _lib .llama_sample_top_k .argtypes = [
502
+ llama_context_p ,
503
+ llama_token_data_array_p ,
504
+ c_int ,
505
+ c_int ,
506
+ ]
507
+
508
+
509
+ # LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
510
+ def llama_sample_top_p (
511
+ ctx : llama_context_p , candidates , p : c_float , min_keep : c_int
512
+ ) -> llama_token :
513
+ return _lib .llama_sample_top_p (ctx , candidates , p , min_keep )
514
+
515
+
516
+ _lib .llama_sample_top_p .argtypes = [
517
+ llama_context_p ,
518
+ llama_token_data_array_p ,
519
+ c_float ,
520
+ c_int ,
521
+ ]
522
+ _lib .llama_sample_top_p .restype = llama_token
523
+
524
+
525
+ # LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
526
+ def llama_sample_tail_free (
527
+ ctx : llama_context_p , candidates , z : c_float , min_keep : c_int
528
+ ) -> llama_token :
529
+ return _lib .llama_sample_tail_free (ctx , candidates , z , min_keep )
530
+
531
+
532
+ _lib .llama_sample_tail_free .argtypes = [
533
+ llama_context_p ,
534
+ llama_token_data_array_p ,
535
+ c_float ,
536
+ c_int ,
537
+ ]
538
+ _lib .llama_sample_tail_free .restype = llama_token
539
+
540
+
541
+ # LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
542
+ def llama_sample_typical (
543
+ ctx : llama_context_p , candidates , p : c_float , min_keep : c_int
544
+ ) -> llama_token :
545
+ return _lib .llama_sample_typical (ctx , candidates , p , min_keep )
546
+
547
+
548
+ _lib .llama_sample_typical .argtypes = [
549
+ llama_context_p ,
550
+ llama_token_data_array_p ,
551
+ c_float ,
423
552
c_int ,
553
+ ]
554
+ _lib .llama_sample_typical .restype = llama_token
555
+
556
+
557
+ # LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
558
+ def llama_sample_temperature (
559
+ ctx : llama_context_p , candidates , temp : c_float
560
+ ) -> llama_token :
561
+ return _lib .llama_sample_temperature (ctx , candidates , temp )
562
+
563
+
564
+ _lib .llama_sample_temperature .argtypes = [
565
+ llama_context_p ,
566
+ llama_token_data_array_p ,
424
567
c_float ,
568
+ ]
569
+ _lib .llama_sample_temperature .restype = llama_token
570
+
571
+
572
+ # LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
573
+ def llama_sample_token_mirostat (
574
+ ctx : llama_context_p , candidates , tau : c_float , eta : c_float , m : c_int , mu
575
+ ) -> llama_token :
576
+ return _lib .llama_sample_token_mirostat (ctx , candidates , tau , eta , m , mu )
577
+
578
+
579
+ _lib .llama_sample_token_mirostat .argtypes = [
580
+ llama_context_p ,
581
+ llama_token_data_array_p ,
582
+ c_float ,
583
+ c_float ,
584
+ c_int ,
585
+ POINTER (c_float ),
586
+ ]
587
+ _lib .llama_sample_token_mirostat .restype = llama_token
588
+
589
+
590
+ # LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
591
+ def llama_sample_token_mirostat_v2 (
592
+ ctx : llama_context_p , candidates , tau : c_float , eta : c_float , mu
593
+ ) -> llama_token :
594
+ return _lib .llama_sample_token_mirostat_v2 (ctx , candidates , tau , eta , mu )
595
+
596
+
597
+ _lib .llama_sample_token_mirostat_v2 .argtypes = [
598
+ llama_context_p ,
599
+ llama_token_data_array_p ,
425
600
c_float ,
426
601
c_float ,
602
+ POINTER (c_float ),
603
+ ]
604
+ _lib .llama_sample_token_mirostat_v2 .restype = llama_token
605
+
606
+
607
+ # LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
608
+ def llama_sample_token_greedy (ctx : llama_context_p , candidates ) -> llama_token :
609
+ return _lib .llama_sample_token_greedy (ctx , candidates )
610
+
611
+
612
+ _lib .llama_sample_token_greedy .argtypes = [
613
+ llama_context_p ,
614
+ llama_token_data_array_p ,
615
+ ]
616
+ _lib .llama_sample_token_greedy .restype = llama_token
617
+
618
+
619
+ # LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
620
+ def llama_sample_token (ctx : llama_context_p , candidates ) -> llama_token :
621
+ return _lib .llama_sample_token (ctx , candidates )
622
+
623
+
624
+ _lib .llama_sample_token .argtypes = [
625
+ llama_context_p ,
626
+ llama_token_data_array_p ,
427
627
]
428
- _lib .llama_sample_top_p_top_k .restype = llama_token
628
+ _lib .llama_sample_token .restype = llama_token
429
629
430
630
431
631
# Performance information
0 commit comments