Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit af9f64c

Browse filesBrowse files
committed
sycl: cleanup oneDNN related code
1 parent 9f7add1 commit af9f64c
Copy full SHA for af9f64c

File tree

Expand file treeCollapse file tree

5 files changed

+80
-68
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+80
-68
lines changed

‎docs/backend/SYCL.md

Copy file name to clipboardExpand all lines: docs/backend/SYCL.md
+11-2Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB
227227
cmake --build buildWithCublas --config Release
228228
```
229229

230+
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
231+
232+
```sh
233+
git clone https://github.com/oneapi-src/oneDNN.git
234+
cd oneDNN
235+
cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
236+
cmake --build build-nvidia --config Release
237+
```
238+
230239
- **Adding support to AMD GPUs**
231240

232241
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
@@ -317,10 +326,10 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
317326
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
318327

319328
# Option 1: Use FP32 (recommended for better performance in most cases)
320-
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
329+
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
321330

322331
# Option 2: Use FP16
323-
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON
332+
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
324333

325334
# build all binary
326335
cmake --build build --config Release -j -v

‎ggml/src/ggml-sycl/CMakeLists.txt

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/CMakeLists.txt
+20-12Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@ ggml_add_backend_library(ggml-sycl
2121
../../include/ggml-sycl.h
2222
)
2323

24+
find_package(DNNL)
25+
set(GGML_SYCL_DNNL 0)
26+
if(DNNL_FOUND)
27+
get_target_property(CONFIG DNNL::dnnl IMPORTED_CONFIGURATIONS)
28+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
29+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
30+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
31+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
32+
set(GGML_SYCL_DNNL 1)
33+
else()
34+
message(WARNING
35+
"oneDNN must be compiled for the same target as llama.cpp.
36+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
37+
Disabling oneDNN support.")
38+
endif()
39+
else()
40+
message(STATUS "oneDNN not found, disabling oneDNN support")
41+
endif()
42+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
43+
2444
if (GGML_SYCL_F16)
2545
if (GGML_SYCL_TARGET STREQUAL "AMD")
2646
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
@@ -46,18 +66,6 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
4666
file(GLOB GGML_SOURCES_SYCL "*.cpp")
4767
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
4868

49-
find_package(DNNL)
50-
message("-- DNNL found:" ${DNNL_FOUND})
51-
52-
if (GGML_SYCL_TARGET STREQUAL "INTEL")
53-
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
54-
else()
55-
add_compile_definitions(GGML_SYCL_DNNL=0)
56-
endif()
57-
58-
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
59-
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
60-
endif()
6169

6270
if (WIN32)
6371
find_package(IntelSYCL REQUIRED)

‎ggml/src/ggml-sycl/common.hpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/common.hpp
+29-3Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,8 @@ static size_t g_scratch_offset = 0;
163163
int get_current_device_id();
164164

165165
inline dpct::err0 ggml_sycl_set_device(const int device) try {
166-
167-
int current_device_id;
168-
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
166+
int current_device_id;
167+
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
169168

170169
// GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d,
171170
// current_device_id=%d\n", device, current_device);
@@ -229,6 +228,14 @@ struct ggml_sycl_pool_alloc {
229228
}
230229
}
231230

231+
T * realloc(size_t size) {
232+
GGML_ASSERT(pool != nullptr);
233+
if (ptr)
234+
pool->free(ptr, actual_size);
235+
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
236+
return ptr;
237+
}
238+
232239
// size is in number of elements
233240
T * alloc(size_t size) {
234241
GGML_ASSERT(pool != nullptr);
@@ -328,10 +335,29 @@ struct ggml_backend_sycl_context {
328335
dnnl::stream stream_dnnl() {
329336
return stream_dnnl(device, 0);
330337
}
338+
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
339+
const dnnl::engine & eng, const queue_ptr q) {
340+
ggml_sycl_pool_alloc<uint8_t> * pool;
341+
auto it = scratchpad_map.find(q);
342+
if (it == scratchpad_map.end()) {
343+
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
344+
pool = scratchpad_map[q].get();
345+
} else {
346+
pool = it->second.get();
347+
}
348+
349+
size_t scratchpad_size = scratchpad_md.get_size();
350+
if (scratchpad_size > pool->actual_size) {
351+
pool->realloc(scratchpad_size);
352+
}
353+
void * mem_ptr = pool->get();
354+
return dnnl::memory(scratchpad_md, eng, mem_ptr);
355+
}
331356
#endif
332357

333358
// pool
334359
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
360+
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
335361

336362
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
337363

‎ggml/src/ggml-sycl/gemm.hpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/gemm.hpp
+14-45Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
#ifndef GGML_SYCL_GEMM_HPP
1414
#define GGML_SYCL_GEMM_HPP
1515

16-
#include <fstream>
17-
#include <iostream>
18-
1916
#include "ggml-sycl.h"
2017

2118
#if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ class DnnlGemmWrapper {
3532
else static_assert(0);
3633
}
3734

38-
static inline void row_gemm(sycl::queue& q, bool a_trans,
39-
bool b_trans, int m, int n, int k,
40-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41-
{
42-
// Get the device associated with the queue
43-
sycl::device dev = q.get_device();
44-
// Get the context associated with the queue
45-
sycl::context ctx = q.get_context();
46-
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47-
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
35+
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
37+
auto stream = ctx.stream_dnnl(q);
38+
auto eng = ctx.engine_dnnl(q);
4839
dnnl::memory::dims a_dims = { m, k };
4940
dnnl::memory::dims b_dims = { k, n };
5041
dnnl::memory::dims c_dims = { m, n };
5142
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
5243
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54-
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
55-
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
56-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57-
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
44+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
5845

59-
// Create the primitive.
60-
auto matmul_prim = dnnl::matmul(matmul_pd);
61-
// Primitive arguments.
62-
std::unordered_map<int, dnnl::memory> matmul_args;
63-
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64-
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65-
matmul_args.insert({ DNNL_ARG_DST, c_mem });
46+
dnnl::primitive_attr primitive_attr;
47+
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
6648

67-
matmul_prim.execute(stream, matmul_args);
68-
}
69-
70-
71-
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72-
bool b_trans, int m, int n, int k,
73-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74-
{
75-
auto const eng = stream.get_engine();
76-
dnnl::memory::dims a_dims = { m, k };
77-
dnnl::memory::dims b_dims = { k, n };
78-
dnnl::memory::dims c_dims = { m, n };
79-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
82-
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
83-
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
84-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
49+
auto a_mem = dnnl::memory(a_in_md, eng, (void *) a);
50+
auto b_mem = dnnl::memory(b_in_md, eng, (void *) b);
51+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
8552
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
8653

87-
// Create the primitive.
54+
auto scratchpad_md = matmul_pd.scratchpad_desc();
55+
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
8856
auto matmul_prim = dnnl::matmul(matmul_pd);
89-
// Primitive arguments.
57+
9058
std::unordered_map<int, dnnl::memory> matmul_args;
9159
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
9260
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
9361
matmul_args.insert({ DNNL_ARG_DST, c_mem });
62+
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
9463

9564
matmul_prim.execute(stream, matmul_args);
9665
}

‎ggml/src/ggml-sycl/ggml-sycl.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/ggml-sycl.cpp
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,9 +2629,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
26292629
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
26302630
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
26312631
#else
2632-
auto dnnl_stream = ctx.stream_dnnl(stream);
2633-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2634-
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2632+
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2633+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2634+
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
26352635
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
26362636
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
26372637
#endif
@@ -2670,9 +2670,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
26702670
dst_dd_i, ldc)));
26712671
# endif
26722672
#else
2673-
auto dnnl_stream = ctx.stream_dnnl(stream);
2674-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2675-
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2673+
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2674+
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2675+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
26762676
#endif
26772677
}
26782678
GGML_UNUSED(dst);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.