Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5dc9dd7

Browse filesBrowse files
RefractAISSslarenggerganov
authored
llama : add Command R Plus support (#6491)
* Add Command R Plus GGUF * Add Command R Plus GGUF * Loading works up to LayerNorm2D * Export new tensors in 1D so they are not quantized. * Fix embedding layer based on Noeda's example * Whitespace * Add line * Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda) * dranger003: Fix block index overflow in CUDA dequantizing. * Reverted blocked multiplication code as it still has issues and could affect other Llama arches * export norms as f32 * fix overflow issues during quant and other cleanup * Type convention Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * dranger003: Fix more int overflow during quant. --------- Co-authored-by: S <seast@Ss-Mac-Studio.local> Co-authored-by: S <s@example.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent e11a899 commit 5dc9dd7
Copy full SHA for 5dc9dd7
Expand file treeCollapse file tree

16 files changed

+366
-326
lines changed

‎convert-hf-to-gguf.py

Copy file name to clipboardExpand all lines: convert-hf-to-gguf.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def write_tensors(self):
160160
data = data.astype(np.float32)
161161

162162
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
163-
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
163+
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
164164
data = data.astype(np.float32)
165165

166166
# if f16 desired, convert any float32 2-dim weight tensors to float16

‎ggml-cuda.cu

Copy file name to clipboardExpand all lines: ggml-cuda.cu
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12251225

12261226
// the main device has a larger memory buffer to hold the results from all GPUs
12271227
// ldc == nrows of the matrix that cuBLAS writes into
1228-
int ldc = id == ctx.device ? ne0 : row_diff;
1228+
int64_t ldc = id == ctx.device ? ne0 : row_diff;
12291229

12301230
const int compute_capability = ggml_cuda_info().devices[id].cc;
12311231

@@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
13771377
const int64_t ne0 = dst->ne[0];
13781378
const int64_t ne1 = dst->ne[1];
13791379

1380-
const int nb2 = dst->nb[2];
1381-
const int nb3 = dst->nb[3];
1380+
const int64_t nb2 = dst->nb[2];
1381+
const int64_t nb3 = dst->nb[3];
13821382

13831383
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
13841384
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));

‎ggml-cuda/common.cuh

Copy file name to clipboardExpand all lines: ggml-cuda/common.cuh
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
394394
// TODO: move to ggml-common.h
395395
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
396396

397-
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
397+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
398398

399399

400400
//////////////////////

‎ggml-cuda/convert.cu

Copy file name to clipboardExpand all lines: ggml-cuda/convert.cu
+37-37Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
#define CUDA_Q8_0_NE_ALIGN 2048
55

66
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
7-
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
8-
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
7+
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8+
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
99

1010
if (i >= k) {
1111
return;
1212
}
1313

14-
const int ib = i/qk; // block index
14+
const int64_t ib = i/qk; // block index
1515
const int iqs = (i%qk)/qr; // quant index
1616
const int iybs = i - i%qk; // y block start index
1717
const int y_offset = qr == 1 ? 1 : qk/2;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2525
}
2626

2727
template <bool need_check>
28-
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
28+
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
2929
#if __CUDA_ARCH__ >= CC_PASCAL
3030
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
3131

@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
6868
template<typename dst_t>
6969
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
7070

71-
const int i = blockIdx.x;
71+
const int64_t i = blockIdx.x;
7272

7373
// assume 32 threads
7474
const int tid = threadIdx.x;
7575
const int il = tid/8;
7676
const int ir = tid%8;
77-
const int ib = 8*i + ir;
77+
const int64_t ib = 8*i + ir;
7878
if (ib >= nb32) {
7979
return;
8080
}
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
9696
template<typename dst_t>
9797
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
9898

99-
const int i = blockIdx.x;
99+
const int64_t i = blockIdx.x;
100100

101101
// assume 32 threads
102102
const int tid = threadIdx.x;
103103
const int il = tid/8;
104104
const int ir = tid%8;
105-
const int ib = 8*i + ir;
105+
const int64_t ib = 8*i + ir;
106106
if (ib >= nb32) {
107107
return;
108108
}
@@ -313,14 +313,14 @@ template<typename dst_t>
313313
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
314314
const block_q6_K * x = (const block_q6_K *) vx;
315315

316-
const int i = blockIdx.x;
316+
const int64_t i = blockIdx.x;
317317
#if QK_K == 256
318318

319319
// assume 64 threads - this is very slightly better than the one below
320-
const int tid = threadIdx.x;
321-
const int ip = tid/32; // ip is 0 or 1
322-
const int il = tid - 32*ip; // 0...32
323-
const int is = 8*ip + il/16;
320+
const int64_t tid = threadIdx.x;
321+
const int64_t ip = tid/32; // ip is 0 or 1
322+
const int64_t il = tid - 32*ip; // 0...32
323+
const int64_t is = 8*ip + il/16;
324324

325325
dst_t * y = yy + i*QK_K + 128*ip + il;
326326

@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
337337
#else
338338

339339
// assume 32 threads
340-
const int tid = threadIdx.x;
341-
const int ip = tid/16; // 0 or 1
342-
const int il = tid - 16*ip; // 0...15
340+
const int64_t tid = threadIdx.x;
341+
const int64_t ip = tid/16; // 0 or 1
342+
const int64_t il = tid - 16*ip; // 0...15
343343

344344
dst_t * y = yy + i*QK_K + 16*ip + il;
345345

@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
571571
#endif
572572

573573
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
574-
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
574+
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
575575
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
576576
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
577577
}
578578

579-
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
579+
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
580580
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
581581
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
582582
const bool need_check = false;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
588588
}
589589

590590
template<typename dst_t>
591-
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
591+
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
592592
const int nb = k / QK_K;
593593
#if QK_K == 256
594594
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
598598
}
599599

600600
template<typename dst_t>
601-
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
601+
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
602602
const int nb = k / QK_K;
603603
#if QK_K == 256
604604
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
608608
}
609609

610610
template<typename dst_t>
611-
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
611+
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
612612
const int nb32 = k / 32;
613613
const int nb = (k + 255) / 256;
614614
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
615615
}
616616

617617
template<typename dst_t>
618-
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
618+
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
619619
const int nb32 = k / 32;
620620
const int nb = (k + 255) / 256;
621621
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
622622
}
623623

624624
template<typename dst_t>
625-
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
625+
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
626626
const int nb = k / QK_K;
627627
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
628628
}
629629

630630
template<typename dst_t>
631-
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
631+
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
632632
const int nb = k / QK_K;
633633
#if QK_K == 256
634634
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
638638
}
639639

640640
template<typename dst_t>
641-
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
641+
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
642642
const int nb = k / QK_K;
643643
#if QK_K == 256
644644
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
648648
}
649649

650650
template<typename dst_t>
651-
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
651+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
652652
const int nb = k / QK_K;
653653
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
654654
}
655655

656656
template<typename dst_t>
657-
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
657+
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
658658
const int nb = k / QK_K;
659659
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
660660
}
661661

662662
template<typename dst_t>
663-
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
663+
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
664664
const int nb = k / QK_K;
665665
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
666666
}
667667

668668
template<typename dst_t>
669-
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
669+
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
670670
const int nb = k / QK_K;
671671
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
672672
}
673673

674674
template<typename dst_t>
675-
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
675+
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
676676
const int nb = k / QK_K;
677677
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
678678
}
679679

680680
template<typename dst_t>
681-
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
681+
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
682682
const int nb = k / QK_K;
683683
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
684684
}
685685

686686
template<typename dst_t>
687-
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
687+
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
688688
const int nb = (k + QK_K - 1) / QK_K;
689689
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
690690
}
691691

692692
template<typename dst_t>
693-
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
693+
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
694694
const int nb = k / QK_K;
695695
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
696696
}
697697

698698
template<typename dst_t>
699-
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
699+
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
700700
const int nb = (k + QK_K - 1) / QK_K;
701701
#if QK_K == 64
702702
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
706706
}
707707

708708
template <typename src_t, typename dst_t>
709-
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
710-
const int i = blockDim.x*blockIdx.x + threadIdx.x;
709+
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
710+
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
711711

712712
if (i >= k) {
713713
return;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
719719
}
720720

721721
template <typename src_t, typename dst_t>
722-
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
722+
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
723723
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
724724
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
725725
}

‎ggml-cuda/convert.cuh

Copy file name to clipboardExpand all lines: ggml-cuda/convert.cuh
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
44

55
template<typename T>
6-
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
6+
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
77

88
typedef to_t_cuda_t<float> to_fp32_cuda_t;
99
typedef to_t_cuda_t<half> to_fp16_cuda_t;

‎ggml-cuda/dequantize.cuh

Copy file name to clipboardExpand all lines: ggml-cuda/dequantize.cuh
+5-5Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "common.cuh"
22

3-
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
3+
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
44
const block_q4_0 * x = (const block_q4_0 *) vx;
55

66
const dfloat d = x[ib].d;
@@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
1919
#endif // GGML_CUDA_F16
2020
}
2121

22-
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
22+
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
2323
const block_q4_1 * x = (const block_q4_1 *) vx;
2424

2525
const dfloat d = __low2half(x[ib].dm);
@@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
3939
#endif // GGML_CUDA_F16
4040
}
4141

42-
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
42+
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
4343
const block_q5_0 * x = (const block_q5_0 *) vx;
4444

4545
const dfloat d = x[ib].d;
@@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
6262
#endif // GGML_CUDA_F16
6363
}
6464

65-
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
65+
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
6666
const block_q5_1 * x = (const block_q5_1 *) vx;
6767

6868
const dfloat d = __low2half(x[ib].dm);
@@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
8686
#endif // GGML_CUDA_F16
8787
}
8888

89-
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
89+
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
9090
const block_q8_0 * x = (const block_q8_0 *) vx;
9191

9292
const dfloat d = x[ib].d;

‎ggml-cuda/dmmv.cu

Copy file name to clipboardExpand all lines: ggml-cuda/dmmv.cu
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
565565
}
566566
}
567567

568-
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
568+
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
569569
const half * x = (const half *) vx;
570570

571571
// automatic half -> float type cast if dfloat == float
@@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
577577
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
578578
// qk = quantized weights per x block
579579
// qr = number of quantized weights per data value in x block
580-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
580+
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
581581

582582
if (row >= nrows) {
583583
return;
@@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
598598

599599
for (int i = 0; i < ncols; i += iter_stride) {
600600
const int col = i + vals_per_iter*tid;
601-
const int ib = (row*ncols + col)/qk; // x block index
601+
const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
602602
const int iqs = (col%qk)/qr; // x quant index
603603
const int iybs = col - col%qk; // y block start index
604604

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.