Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9ffcc9e

Browse filesBrowse files
authored
sycl: cleanup oneDNN related code (ggml-org#12097)
1 parent e046430 commit 9ffcc9e
Copy full SHA for 9ffcc9e

File tree

Expand file treeCollapse file tree

5 files changed

+88
-64
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+88
-64
lines changed

‎docs/backend/SYCL.md

Copy file name to clipboardExpand all lines: docs/backend/SYCL.md
+11-2Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,15 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB
237237
cmake --build buildWithCublas --config Release
238238
```
239239

240+
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
241+
242+
```sh
243+
git clone https://github.com/oneapi-src/oneDNN.git
244+
cd oneDNN
245+
cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
246+
cmake --build build-nvidia --config Release
247+
```
248+
240249
- **Adding support to AMD GPUs**
241250

242251
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
@@ -327,10 +336,10 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
327336
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
328337

329338
# Option 1: Use FP32 (recommended for better performance in most cases)
330-
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
339+
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
331340

332341
# Option 2: Use FP16
333-
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON
342+
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
334343

335344
# build all binary
336345
cmake --build build --config Release -j -v

‎ggml/src/ggml-sycl/CMakeLists.txt

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/CMakeLists.txt
+32-12Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,38 @@ ggml_add_backend_library(ggml-sycl
2323
../../include/ggml-sycl.h
2424
)
2525

26+
find_package(DNNL)
27+
set(GGML_SYCL_DNNL 0)
28+
if(DNNL_FOUND)
29+
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
30+
# Assuming oneDNN packaged with oneapi release is used which
31+
# supports only intel target
32+
set(DNNL_GPU_VENDOR "INTEL")
33+
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
34+
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
35+
endif()
36+
endif()
37+
38+
# Verify oneDNN was compiled for the same target as llama
39+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
40+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
41+
set(GGML_SYCL_DNNL 1)
42+
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
43+
foreach(CONFIG ${CONFIGS})
44+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
45+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
46+
endforeach()
47+
else()
48+
message(WARNING
49+
"oneDNN must be compiled for the same target as llama.cpp.
50+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
51+
Disabling oneDNN support.")
52+
endif()
53+
else()
54+
message(STATUS "oneDNN not found, disabling oneDNN support")
55+
endif()
56+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
57+
2658
if (GGML_SYCL_F16)
2759
if (GGML_SYCL_TARGET STREQUAL "AMD")
2860
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
@@ -48,18 +80,6 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
4880
file(GLOB GGML_SOURCES_SYCL "*.cpp")
4981
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
5082

51-
find_package(DNNL)
52-
message("-- DNNL found:" ${DNNL_FOUND})
53-
54-
if (GGML_SYCL_TARGET STREQUAL "INTEL")
55-
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
56-
else()
57-
add_compile_definitions(GGML_SYCL_DNNL=0)
58-
endif()
59-
60-
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
61-
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
62-
endif()
6383

6484
if (WIN32)
6585
find_package(IntelSYCL REQUIRED)

‎ggml/src/ggml-sycl/common.hpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/common.hpp
+27-1Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0;
170170
int get_current_device_id();
171171

172172
inline dpct::err0 ggml_sycl_set_device(const int device) try {
173-
174173
int current_device_id;
175174
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
176175

@@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc {
242241
}
243242
}
244243

244+
T * realloc(size_t size) {
245+
GGML_ASSERT(pool != nullptr);
246+
if (ptr)
247+
pool->free(ptr, actual_size);
248+
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
249+
return ptr;
250+
}
251+
245252
// size is in number of elements
246253
T * alloc(size_t size) {
247254
GGML_ASSERT(pool != nullptr);
@@ -371,10 +378,29 @@ struct ggml_backend_sycl_context {
371378
dnnl::stream stream_dnnl() {
372379
return stream_dnnl(device, 0);
373380
}
381+
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
382+
const dnnl::engine & eng, const queue_ptr q) {
383+
ggml_sycl_pool_alloc<uint8_t> * pool;
384+
auto it = scratchpad_map.find(q);
385+
if (it == scratchpad_map.end()) {
386+
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
387+
pool = scratchpad_map[q].get();
388+
} else {
389+
pool = it->second.get();
390+
}
391+
392+
size_t scratchpad_size = scratchpad_md.get_size();
393+
if (scratchpad_size > pool->actual_size) {
394+
pool->realloc(scratchpad_size);
395+
}
396+
void * mem_ptr = pool->get();
397+
return dnnl::memory(scratchpad_md, eng, mem_ptr);
398+
}
374399
#endif
375400

376401
// pool
377402
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
403+
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
378404

379405
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
380406

‎ggml/src/ggml-sycl/gemm.hpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/gemm.hpp
+12-43Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
#ifndef GGML_SYCL_GEMM_HPP
1414
#define GGML_SYCL_GEMM_HPP
1515

16-
#include <fstream>
17-
#include <iostream>
18-
1916
#include "ggml-sycl.h"
2017

2118
#if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ class DnnlGemmWrapper {
3532
else static_assert(0);
3633
}
3734

38-
static inline void row_gemm(sycl::queue& q, bool a_trans,
39-
bool b_trans, int m, int n, int k,
40-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41-
{
42-
// Get the device associated with the queue
43-
sycl::device dev = q.get_device();
44-
// Get the context associated with the queue
45-
sycl::context ctx = q.get_context();
46-
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47-
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
35+
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
37+
auto stream = ctx.stream_dnnl(q);
38+
auto eng = ctx.engine_dnnl(q);
4839
dnnl::memory::dims a_dims = { m, k };
4940
dnnl::memory::dims b_dims = { k, n };
5041
dnnl::memory::dims c_dims = { m, n };
5142
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
5243
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54-
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
55-
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
56-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57-
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
44+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
5845

59-
// Create the primitive.
60-
auto matmul_prim = dnnl::matmul(matmul_pd);
61-
// Primitive arguments.
62-
std::unordered_map<int, dnnl::memory> matmul_args;
63-
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64-
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65-
matmul_args.insert({ DNNL_ARG_DST, c_mem });
46+
dnnl::primitive_attr primitive_attr;
47+
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
6648

67-
matmul_prim.execute(stream, matmul_args);
68-
}
69-
70-
71-
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72-
bool b_trans, int m, int n, int k,
73-
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74-
{
75-
auto const eng = stream.get_engine();
76-
dnnl::memory::dims a_dims = { m, k };
77-
dnnl::memory::dims b_dims = { k, n };
78-
dnnl::memory::dims c_dims = { m, n };
79-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
8249
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
8350
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
84-
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
51+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
8552
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
8653

87-
// Create the primitive.
54+
auto scratchpad_md = matmul_pd.scratchpad_desc();
55+
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
8856
auto matmul_prim = dnnl::matmul(matmul_pd);
89-
// Primitive arguments.
57+
9058
std::unordered_map<int, dnnl::memory> matmul_args;
9159
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
9260
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
9361
matmul_args.insert({ DNNL_ARG_DST, c_mem });
62+
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
9463

9564
matmul_prim.execute(stream, matmul_args);
9665
}

‎ggml/src/ggml-sycl/ggml-sycl.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-sycl/ggml-sycl.cpp
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,9 +2058,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
20582058
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
20592059
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
20602060
#else
2061-
auto dnnl_stream = ctx.stream_dnnl(stream);
2062-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2063-
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2061+
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2062+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2063+
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
20642064
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
20652065
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
20662066
#endif
@@ -2099,9 +2099,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
20992099
dst_dd_i, ldc)));
21002100
# endif
21012101
#else
2102-
auto dnnl_stream = ctx.stream_dnnl(stream);
2103-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2104-
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2102+
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2103+
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2104+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
21052105
#endif
21062106
}
21072107
GGML_UNUSED(dst);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.