Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 86b06e0

Browse filesBrowse files
ikawrakowKawrakow
authored andcommitted
iq2_xxs: tune quantization (ggml-org#5320)
We get slightly better PPL, and we cut quantization time in nearly half. The trick is to 1st quantize without forcing points onto the E8-lattice. We can then use a narrower search range around the block scale that we got that way. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 35b93a6 commit 86b06e0
Copy full SHA for 86b06e0

File tree

Expand file treeCollapse file tree

1 file changed

+6
-52
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+6
-52
lines changed

‎ggml-quants.c

Copy file name to clipboardExpand all lines: ggml-quants.c
+6-52Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9048,8 +9048,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
90489048
int8_t L[32];
90499049
int8_t Laux[32];
90509050
float waux[32];
9051-
bool is_on_grid[4];
9052-
bool is_on_grid_aux[4];
90539051
uint8_t block_signs[4];
90549052
uint32_t q2[2*(QK_K/32)];
90559053

@@ -9099,10 +9097,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
90999097
memset(L, 0, 32);
91009098
continue;
91019099
}
9100+
float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
9101+
float eff_max = scale*kMaxQ;
91029102
float best = 0;
9103-
float scale = max/(2*kMaxQ-1);
9104-
for (int is = -9; is <= 9; ++is) {
9105-
float id = (2*kMaxQ-1+is*0.1f)/max;
9103+
for (int is = -6; is <= 6; ++is) {
9104+
float id = (2*kMaxQ-1+is*0.1f)/eff_max;
91069105
float this_scale = 1/id;
91079106
for (int k = 0; k < 4; ++k) {
91089107
for (int i = 0; i < 8; ++i) {
@@ -9112,9 +9111,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91129111
uint16_t u = 0;
91139112
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
91149113
int grid_index = kmap_q2xs[u];
9115-
is_on_grid_aux[k] = true;
91169114
if (grid_index < 0) {
9117-
is_on_grid_aux[k] = false;
91189115
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
91199116
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
91209117
}
@@ -9128,16 +9125,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91289125
}
91299126
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
91309127
scale = sumqx/sumq2; best = scale*sumqx;
9131-
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9132-
for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
9128+
memcpy(L, Laux, 32);
91339129
}
91349130
}
9135-
int n_not_ongrid = 0;
9136-
for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9137-
if (n_not_ongrid > 0 && scale > 0) {
9131+
if (scale > 0) {
91389132
float id = 1/scale;
91399133
for (int k = 0; k < 4; ++k) {
9140-
if (is_on_grid[k]) continue;
91419134
uint16_t u = 0;
91429135
for (int i = 0; i < 8; ++i) {
91439136
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
@@ -9193,49 +9186,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91939186
float d = max_scale/31;
91949187
y[ibl].d = GGML_FP32_TO_FP16(d);
91959188
float id = 1/d;
9196-
float sumqx = 0, sumq2 = 0;
91979189
for (int ib = 0; ib < QK_K/32; ++ib) {
91989190
int l = nearest_int(0.5f*(id*scales[ib]-1));
91999191
l = MAX(0, MIN(15, l));
92009192
q2[2*ib+1] |= ((uint32_t)l << 28);
9201-
const float * xb = xbl + 32*ib;
9202-
const float * qw = quant_weights + QK_K*ibl + 32*ib;
9203-
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9204-
const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
9205-
const float db = d * (1 + 2*l);
9206-
uint32_t u = 0;
9207-
for (int k = 0; k < 4; ++k) {
9208-
const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
9209-
const float * xk = xb + 8*k;
9210-
const float * wk = weight + 8*k;
9211-
const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9212-
float best_mse = 0; int best_index = aux8[k];
9213-
for (int j = 0; j < 8; ++j) {
9214-
float diff = db * grid[j] * signs[j] - xk[j];
9215-
best_mse += wk[j] * diff * diff;
9216-
}
9217-
for (int idx = 0; idx < 256; ++idx) {
9218-
grid = (const uint8_t *)(kgrid_q2xs + idx);
9219-
float mse = 0;
9220-
for (int j = 0; j < 8; ++j) {
9221-
float diff = db * grid[j] * signs[j] - xk[j];
9222-
mse += wk[j] * diff * diff;
9223-
}
9224-
if (mse < best_mse) {
9225-
best_mse = mse; best_index = idx;
9226-
}
9227-
}
9228-
u |= (best_index << 8*k);
9229-
grid = (const uint8_t *)(kgrid_q2xs + best_index);
9230-
//grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9231-
for (int j = 0; j < 8; ++j) {
9232-
float q = db * grid[j] * signs[j];
9233-
sumqx += wk[j] * q * xk[j];
9234-
sumq2 += wk[j] * q * q;
9235-
}
9236-
}
9237-
q2[2*ib] = u;
9238-
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
92399193
}
92409194
memcpy(y[ibl].qs, q2, QK_K/4);
92419195
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.