Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ba932df

Browse filesBrowse files
authored
ggml : fix quantized cpy op (ggml-org#12310)
* ggml : fix quantized cpy op ggml-ci * tests : add cpy tests for all types ggml-ci * tests : add BF16 copy tests ggml-ci * tests : fix loop for same-type copy ggml-ci * tests : add option to permute the dst tensor ggml-ci
1 parent fac63a3 commit ba932df
Copy full SHA for ba932df

File tree

Expand file treeCollapse file tree

2 files changed

+61
-37
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+61
-37
lines changed

‎ggml/src/ggml-cpu/ggml-cpu.c

Copy file name to clipboardExpand all lines: ggml/src/ggml-cpu/ggml-cpu.c
+31-27Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont(
31103110
const int ith = params->ith; // thread index
31113111
const int nth = params->nth; // number of threads
31123112

3113-
// parallelize by elements
3114-
const int ne = ggml_nelements(dst);
3115-
const int dr = (ne + nth - 1) / nth;
3116-
const int ie0 = dr * ith;
3117-
const int ie1 = MIN(ie0 + dr, ne);
3113+
// parallelize by blocks
3114+
const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
3115+
const int dr = (nk + nth - 1) / nth;
3116+
const int k0 = dr * ith;
3117+
const int k1 = MIN(k0 + dr, nk);
31183118

3119-
if (ie0 < ie1) {
3119+
if (k0 < k1) {
31203120
memcpy(
3121-
((char *) dst->data + ie0*nb0),
3122-
((char *) src0->data + ie0*nb0),
3123-
(ie1 - ie0) * nb0);
3121+
((char *) dst->data + k0*nb0),
3122+
((char *) src0->data + k0*nb0),
3123+
(k1 - k0) * nb0);
31243124
}
31253125
}
31263126

@@ -4055,7 +4055,6 @@ static void ggml_compute_forward_dup_f32(
40554055
static void ggml_compute_forward_dup_bytes(
40564056
const struct ggml_compute_params * params,
40574057
struct ggml_tensor * dst) {
4058-
40594058
const struct ggml_tensor * src0 = dst->src[0];
40604059

40614060
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
@@ -4069,10 +4068,10 @@ static void ggml_compute_forward_dup_bytes(
40694068
}
40704069

40714070
const size_t type_size = ggml_type_size(src0->type);
4071+
40724072
const int ith = params->ith; // thread index
40734073
const int nth = params->nth; // number of threads
40744074

4075-
40764075
// parallelize by rows
40774076
const int nr = ne01;
40784077
// number of rows per thread
@@ -4082,10 +4081,10 @@ static void ggml_compute_forward_dup_bytes(
40824081
const int ir1 = MIN(ir0 + dr, nr);
40834082

40844083
if (src0->type == dst->type &&
4085-
ne00 == ne0 &&
4084+
ggml_are_same_shape(src0, dst) &&
40864085
nb00 == type_size && nb0 == type_size) {
40874086
// copy by rows
4088-
const size_t rs = ne00 * type_size;
4087+
const size_t rs = ggml_row_size(src0->type, ne00);
40894088
for (int64_t i03 = 0; i03 < ne03; i03++) {
40904089
for (int64_t i02 = 0; i02 < ne02; i02++) {
40914090
for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -4140,17 +4139,20 @@ static void ggml_compute_forward_dup_bytes(
41404139
}
41414140

41424141
// dst counters
4143-
4144-
int64_t i10 = 0;
4142+
int64_t k10 = 0;
41454143
int64_t i11 = 0;
41464144
int64_t i12 = 0;
41474145
int64_t i13 = 0;
41484146

4147+
// number of blocks in a row
4148+
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
4149+
const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
4150+
41494151
for (int64_t i03 = 0; i03 < ne03; i03++) {
41504152
for (int64_t i02 = 0; i02 < ne02; i02++) {
4151-
i10 += ne00 * ir0;
4152-
while (i10 >= ne0) {
4153-
i10 -= ne0;
4153+
k10 += nk00 * ir0;
4154+
while (k10 >= nk0) {
4155+
k10 -= nk0;
41544156
if (++i11 == ne1) {
41554157
i11 = 0;
41564158
if (++i12 == ne2) {
@@ -4162,14 +4164,14 @@ static void ggml_compute_forward_dup_bytes(
41624164
}
41634165
}
41644166
for (int64_t i01 = ir0; i01 < ir1; i01++) {
4165-
for (int64_t i00 = 0; i00 < ne00; i00++) {
4166-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4167-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4167+
for (int64_t k00 = 0; k00 < nk00; k00++) {
4168+
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4169+
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
41684170

41694171
memcpy(dst_ptr, src0_ptr, type_size);
41704172

4171-
if (++i10 == ne0) {
4172-
i10 = 0;
4173+
if (++k10 == nk0) {
4174+
k10 = 0;
41734175
if (++i11 == ne1) {
41744176
i11 = 0;
41754177
if (++i12 == ne2) {
@@ -4182,9 +4184,9 @@ static void ggml_compute_forward_dup_bytes(
41824184
}
41834185
}
41844186
}
4185-
i10 += ne00 * (ne01 - ir1);
4186-
while (i10 >= ne0) {
4187-
i10 -= ne0;
4187+
k10 += nk00 * (ne01 - ir1);
4188+
while (k10 >= nk0) {
4189+
k10 -= nk0;
41884190
if (++i11 == ne1) {
41894191
i11 = 0;
41904192
if (++i12 == ne2) {
@@ -14308,7 +14310,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1430814310
}
1430914311

1431014312
// extra_buffer op?
14311-
if (ggml_cpu_extra_compute_forward(params, tensor)) return;
14313+
if (ggml_cpu_extra_compute_forward(params, tensor)) {
14314+
return;
14315+
}
1431214316

1431314317
switch (tensor->op) {
1431414318
case GGML_OP_DUP:

‎tests/test-backend-ops.cpp

Copy file name to clipboardExpand all lines: tests/test-backend-ops.cpp
+30-10Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,11 +1463,13 @@ struct test_cpy : public test_case {
14631463
const ggml_type type_src;
14641464
const ggml_type type_dst;
14651465
const std::array<int64_t, 4> ne;
1466-
const std::array<int64_t, 4> permute;
1466+
const std::array<int64_t, 4> permute_src;
1467+
const std::array<int64_t, 4> permute_dst;
14671468
bool _src_use_permute;
1469+
bool _dst_use_permute;
14681470

14691471
std::string vars() override {
1470-
return VARS_TO_STR4(type_src, type_dst, ne, permute);
1472+
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
14711473
}
14721474

14731475
double max_nmse_err() override {
@@ -1480,23 +1482,30 @@ struct test_cpy : public test_case {
14801482

14811483
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
14821484
std::array<int64_t, 4> ne = {10, 10, 10, 1},
1483-
std::array<int64_t, 4> permute = {0, 0, 0, 0})
1484-
: type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
1485-
_src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
1485+
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
1486+
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
1487+
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
1488+
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
1489+
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
14861490

14871491
ggml_tensor * build_graph(ggml_context * ctx) override {
14881492
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
14891493
ggml_set_param(ctx, src);
14901494
ggml_set_name(src, "src");
14911495

14921496
if (_src_use_permute) {
1493-
src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
1497+
src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
14941498
ggml_set_name(src, "src_permuted");
14951499
}
14961500

1497-
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
1501+
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
14981502
ggml_set_name(dst, "dst");
14991503

1504+
if (_dst_use_permute) {
1505+
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
1506+
ggml_set_name(dst, "dst_permuted");
1507+
}
1508+
15001509
ggml_tensor * out = ggml_cpy(ctx, src, dst);
15011510
ggml_set_name(out, "out");
15021511

@@ -4004,14 +4013,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40044013
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
40054014
}
40064015

4007-
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4016+
// same-type copy
4017+
for (ggml_type type : all_types) {
4018+
const auto nk = ggml_blck_size(type);
4019+
4020+
for (int k = 1; k < 4; ++k) {
4021+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
4022+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
4023+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
4024+
}
4025+
}
4026+
4027+
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
40084028
for (ggml_type type_dst : all_types) {
40094029
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
40104030
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
40114031
}
40124032
}
4013-
for (ggml_type type_dst : {GGML_TYPE_F32}) {
4014-
for (ggml_type type_src : all_types) {
4033+
for (ggml_type type_src : all_types) {
4034+
for (ggml_type type_dst : {GGML_TYPE_F32}) {
40154035
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
40164036
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
40174037
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.