Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method (TMAC) #13206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
Loading
from

Conversation

QingtaoLi1
Copy link

@QingtaoLi1 QingtaoLi1 commented Apr 30, 2025

This is a re-submitted PR of #10181. We re-factor all the codes to meet the requirements and follow the constructive discussions with @slaren in the previous one.

Different from #10181, we integrate all the LUT codes under ggml/, thus no longer need any 3rdparty dependencies, and the CMakeLists.txt changes are minor.

Instead of a single new data type INT_N, this time we introduce a series of TMAC_* data types to avoid external meta information loading. New data types include one for Bitnet-like models (1 tensor with 1 scale value), and several for GPTQ models (group quantized, e.g. w2g64). We have listed some common GPTQ dtypes, and it's easy to extend to more bits and group sizes. Following existing data types, _0 means no zero points, and _1 means having zero points.

image

How to Use It

Since there is no 3rdparty dependencies, the build/run pipeline is quite similar to the existing one.

# Some examples. Currently, we support GPTQ models with "desc_act=False".
huggingface-cli download 1bitLLM/bitnet_b1_58-3B --local-dir %MODEL_DIR%
huggingface-cli download jakiAJK/DeepSeek-R1-Distill-Qwen-7B_GPTQ-int4 --local-dir %MODEL_DIR%
huggingface-cli download kaitchup/Qwen3-1.7B-autoround-4bit-gptq --local-dir %MODEL_DIR%
huggingface-cli download ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g64-GPTQ --local-dir %MODEL_DIR%
huggingface-cli download jakiAJK/microsoft-phi-4_GPTQ-int4 --local-dir %MODEL_DIR%

#Convert hf model to gguf: for GPTQ models and Bitnet
python convert_hf_to_gguf.py %MODEL_DIR% --outtype auto --outfile %MODEL_OUTPUT_PATH%.gguf --enable-t-mac
python convert_hf_to_gguf.py %MODEL_DIR% --outtype tmac_bn_0 --outfile %MODEL_OUTPUT_PATH%.gguf --enable-t-mac

#Build on Windows
cd build
cmake .. -DGGML_TMAC=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -T ClangCL
cmake --build . --target llama-cli llama-bench llama-quantize llama-perplexity --config Release

#Run on Windows
.\bin\Release\llama-cli.exe -m %MODEL_OUTPUT_PATH%.gguf -no-cnv -n 128 -p "Hello, world!" -ngl 0 -c 2048 -t 4

For other devices, e.g. Apple, the script is similar or the same.

Speed

TBD

Model size

TBD

Perplexity

TBD

Note

There is an option desc_act in GPTQ model config: True means the weight columns are re-ordered when quantization, while False means the weights are in the original order. We only support desc_act=False now. The other one is likely to be supported but needs quite a few engineer efforts. We think it better to open another PR after this one if it's required.

TODO list

[x] Adapt and test Bitnet.
[x] Adapt and test Q4_0 and TQ types.
[x] Support BF16 model conversion.
[x] Support F16 scales and zero points.
[ ] Fix T-MAC gguf quantization of embed/output_weights.
[ ] Support kernel tuning with threadpool. This will probably break the current build encapsulation of targets llama/ggml/ggml-cpu/ggml-base.

[WIP] Fit llama.cpp build visibility. Runtime error-free. Wrong outputs.

[WIP] Remove some deprecated codes.

[Fix] ggml_tmac_transform_tensor should use *data as the original data.
And gather code logics in ggml_tmac_can_mul_mat.

Change tuning profile time back to 5000ms.

Hard code bits/groupsize/sym. GPTQ Llama correct.

Unify quantization_config loading.
@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Apr 30, 2025
@zhouwg
Copy link
Contributor

zhouwg commented May 1, 2025

today I suddenly found this PR(because I left Github on 07/18/2024 and back to Github on 01/29/2025 and missed a lot of wonderful/standout PRs) and found that impressive paper accordingly and I'm reading the paper again and again. I'm working on implementation of int8-based mulmat on Qualcomm Hexagon NPU currently. might-be this standout approach from MSRA can be used in ggml-hexagon(a specified llama.cpp backend for Qualcomm Hexagon NPU).

@zhouwg
Copy link
Contributor

zhouwg commented May 2, 2025

after reading your team's outstanding paper again, I tried to dig into source code in this PR, but I can't build the source code:

zhouwg:$ make
[  1%] Building C object ggml/src/CMakeFiles/ggml-base.dir/ggml.c.o
[  1%] Building C object ggml/src/CMakeFiles/ggml-base.dir/ggml-alloc.c.o
[  2%] Building CXX object ggml/src/CMakeFiles/ggml-base.dir/ggml-backend.cpp.o
[  2%] Building CXX object ggml/src/CMakeFiles/ggml-base.dir/ggml-opt.cpp.o
[  2%] Building CXX object ggml/src/CMakeFiles/ggml-base.dir/ggml-threading.cpp.o
[  3%] Building C object ggml/src/CMakeFiles/ggml-base.dir/ggml-quants.c.o
[  3%] Building CXX object ggml/src/CMakeFiles/ggml-base.dir/gguf.cpp.o
[  4%] Linking CXX shared library ../../bin/libggml-base.so
[  4%] Built target ggml-base
[  4%] Building C object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/ggml-cpu.c.o
[  5%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/ggml-cpu.cpp.o
[  5%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/ggml-cpu-aarch64.cpp.o
[  6%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/ggml-cpu-hbm.cpp.o
[  6%] Building C object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/ggml-cpu-quants.c.o
[  6%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/ggml-cpu-traits.cpp.o
[  7%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/amx/amx.cpp.o
[  7%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/amx/mmq.cpp.o
[  8%] Building CXX object ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/tmac/tmac.cpp.o
In file included from /home/zhouwg/kantvai/tmac_llama.cpp/ggml/src/ggml-cpu/tmac/tmac.cpp:8:
/home/zhouwg/kantvai/tmac_llama.cpp/ggml/src/ggml-cpu/tmac/lut_mul_mat.h:22:62: error: declaration of ‘std::unordered_map<std::__cxx11::basic_string<char>, tmac_tensor_extra*> ggml::cpu::tmac::tensor_traits::tmac_tensor_extra’ changes meaning of ‘tmac_tensor_extra’ [-fpermissive]
   22 |         std::unordered_map<std::string, tmac_tensor_extra *> tmac_tensor_extra;
      |                                                              ^~~~~~~~~~~~~~~~~
/home/zhouwg/kantvai/tmac_llama.cpp/ggml/src/ggml-cpu/tmac/lut_mul_mat.h:12:8: note: ‘tmac_tensor_extra’ declared here as ‘struct tmac_tensor_extra’
   12 | struct tmac_tensor_extra {
      |        ^~~~~~~~~~~~~~~~~
make[2]: *** [ggml/src/CMakeFiles/ggml-cpu.dir/build.make:188: ggml/src/CMakeFiles/ggml-cpu.dir/ggml-cpu/tmac/tmac.cpp.o] Error 1
make[1]: *** [CMakeFiles/Makefile2:1751: ggml/src/CMakeFiles/ggml-cpu.dir/all] Error 2
make: *** [Makefile:146: all] Error 2
zhouwg:$ git branch
* 202504_tmac
  dc/matmul
  master

@Zant12
Copy link

Zant12 commented May 2, 2025

make may be unsupported? try cmake on linux.

cd tmac_llama.cpp
mkdir build && cd build
cmake .. -DGGML_TMAC=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLAMA_CURL=OFF
cmake --build . --target llama-cli llama-bench llama-quantize llama-perplexity --config Release

@zhouwg
Copy link
Contributor

zhouwg commented May 3, 2025

it works in my Linux machine. thanks so much!

@QingtaoLi1
Copy link
Author

Hi @slaren , could you take a quick look at this new pull request and see if it's basically all right?

ggml/src/kompute Outdated

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that the kompute Module is added here again and the kompute module not moved out of ggml-kompute?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inforithmics Nope, it's added back by mistake. Have removed it.

@Zijie-Tian
Copy link

Zijie-Tian commented May 14, 2025

Unexpectedly SLOW performance on Apple M4 MAX for Llama-3-8b-EfficientQAT-w2g128-GPTQ compared to AGX Orin.

I use following command to run your code on AGX and M4MAX

./build-arm64/bin/llama-cli -m /gguf/Llama-3-8b-EfficientQAT-w2g128-GPTQ-GGUF/llama-3-8b-w2g128.gguf  -p Hi -n 1 -ngl 0

The performance on the Apple M4 MAX is considerably slower than on the NVIDIA AGX Orin.

On M4 MAX

load_tensors: loading model tensors, this can take a while... (mmap = true)
Tuned kernel config: M=4096, N=1, K=4096, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 9.0884 ms
Tuned kernel config: M=4096, N=1, K=4096, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 9.0201 ms
Tuned kernel config: M=4096, N=1, K=4096, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 8.9925 ms
Tuned kernel config: M=4096, N=1, K=4096, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 9.0045 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 2.2312 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 2.2180 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 2.2191 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 2.2125 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 30.9397 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 30.7267 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 30.6523 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 30.5683 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 30.9330 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 30.8253 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 30.7018 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 30.6082 ms
load_tensors: offloading 0 repeating layers to GPU
load_tensors: offloaded 0/33 layers to GPU
load_tensors:         TMAC model buffer size =  4085.02 MiB
.....................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 500000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M4 Max
ggml_metal_init: picking default device: Apple M4 Max
ggml_metal_load_library: using embedded metal library
ggml_metal_init: GPU name:   Apple M4 Max
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction   = true
ggml_metal_init: simdgroup matrix mul. = true
ggml_metal_init: has residency sets    = true
ggml_metal_init: has bfloat            = true
ggml_metal_init: use bfloat            = false
ggml_metal_init: hasUnifiedMemory      = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 103079.22 MB
ggml_metal_init: skipping kernel_get_rows_bf16                     (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32                   (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_1row              (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_l4                (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_bf16                  (not supported)
ggml_metal_init: skipping kernel_mul_mv_id_bf16_f32                (not supported)
ggml_metal_init: skipping kernel_mul_mm_bf16_f32                   (not supported)
ggml_metal_init: skipping kernel_mul_mm_id_bf16_f16                (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h64           (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h80           (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h96           (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h112          (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h128          (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h192          (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_hk192_hv128   (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_h256          (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_bf16_hk576_hv512   (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_vec_bf16_h96       (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_vec_bf16_h128      (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_vec_bf16_h192      (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_vec_bf16_hk192_hv128 (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_vec_bf16_h256      (not supported)
ggml_metal_init: skipping kernel_flash_attn_ext_vec_bf16_hk576_hv512 (not supported)
ggml_metal_init: skipping kernel_cpy_f32_bf16                      (not supported)
ggml_metal_init: skipping kernel_cpy_bf16_f32                      (not supported)
ggml_metal_init: skipping kernel_cpy_bf16_bf16                     (not supported)
llama_context:        CPU  output buffer size =     0.49 MiB
llama_kv_cache_unified: kv_size = 4096, type_k = 'f16', type_v = 'f16', n_layer = 32, can_shift = 1, padding = 256
llama_kv_cache_unified:        CPU KV buffer size =   512.00 MiB
llama_kv_cache_unified: KV self size  =  512.00 MiB, K (f16):  256.00 MiB, V (f16):  256.00 MiB
llama_context:        CPU compute buffer size =   258.50 MiB
llama_context: graph nodes  = 967
llama_context: graph splits = 2 (with bs=512), 1 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 12

system_info: n_threads = 12 (n_threads_batch = 12) / 16 | Metal : EMBED_LIBRARY = 1 | CPU : ARM_FMA = 1 | FP16_VA = 1 | LLAMAFILE = 1 | ACCELERATE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 3940678951
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 1, n_keep = 0

Hi,

llama_perf_sampler_print:    sampling time =       0.17 ms /     2 runs   (    0.08 ms per token, 11834.32 tokens per second)
llama_perf_context_print:        load time =  230953.18 ms
llama_perf_context_print: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =     820.90 ms /     1 runs   (  820.90 ms per token,     1.22 tokens per second)
llama_perf_context_print:       total time =     828.30 ms /     2 tokens
ggml_metal_free: deallocating

On AGX Orin 64 GB:

load_tensors: loading model tensors, this can take a while... (mmap = true)
Tuned kernel config: M=4096, N=1, K=4096, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 0.5989 ms
Tuned kernel config: M=4096, N=1, K=4096, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 0.5925 ms
Tuned kernel config: M=4096, N=1, K=4096, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 0.5883 ms
Tuned kernel config: M=4096, N=1, K=4096, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 0.5929 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 0.1605 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64  TIME: 0.1592 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 0.1579 ms
Tuned kernel config: M=1024, N=1, K=4096, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 0.1590 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 2.0651 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 2.0359 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 2.0216 ms
Tuned kernel config: M=14336, N=1, K=4096, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 2.0263 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=256, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 2.0720 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=512, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64         TIME: 2.0503 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=1024, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 2.0335 ms
Tuned kernel config: M=4096, N=1, K=14336, bm=2048, n=8, kfactor=16, bits=2, g=4, ngroups_per_elem=2, q_group_size=128, act_group_size=64        TIME: 2.0375 ms
load_tensors:         TMAC model buffer size =  4085.02 MiB
.....................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: freq_base     = 500000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.49 MiB
llama_kv_cache_unified: kv_size = 4096, type_k = 'f16', type_v = 'f16', n_layer = 32, can_shift = 1, padding = 32
llama_kv_cache_unified:        CPU KV buffer size =   512.00 MiB
llama_kv_cache_unified: KV self size  =  512.00 MiB, K (f16):  256.00 MiB, V (f16):  256.00 MiB
llama_context:        CPU compute buffer size =   296.01 MiB
llama_context: graph nodes  = 1094
llama_context: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 12

system_info: n_threads = 12 (n_threads_batch = 12) / 12 | CPU : NEON = 1 | ARM_FMA = 1 | FP16_VA = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 3584698764
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 128, n_keep = 0

Hello, world, I’m the guy with the best idea for the world and I just happen to be the guy in charge of the world. And I’m not even going to take a break because I don’t even have to think about that because I’m so brilliant.
I’m not sure where this is going to go but I’ll let you know if it ever gets anywhere.
You have the ability to change the world. To make a difference. To make a difference in your life. To make a difference in the lives of others. To make a difference in the world. To make a difference in the world you live in. To make a

llama_perf_sampler_print:    sampling time =      18.01 ms /   131 runs   (    0.14 ms per token,  7273.74 tokens per second)
llama_perf_context_print:        load time =  319866.72 ms
llama_perf_context_print: prompt eval time =     205.26 ms /     3 tokens (   68.42 ms per token,    14.62 tokens per second)
llama_perf_context_print:        eval time =   12163.65 ms /   127 runs   (   95.78 ms per token,    10.44 tokens per second)
llama_perf_context_print:       total time =   12423.95 ms /   130 tokens

This is contrary to my expectation, as the M4 MAX's CPU is generally considered to be 3-5 times faster than the AGX Orin's CPU.

@slaren
Copy link
Member

slaren commented May 14, 2025

Hi @QingtaoLi1, sorry for the delay. Before going into more depth, can you explain briefly what was the motivation of implementing this as an extra buffer type in the CPU backend?

The reason I ask is because we would need the T-MAC quantization types to be integrated into ggml seamlessly along with the current quantization types. If the way quantization types are currently represented in ggml is not compatible with T-MAC, we would need to investigate in which ways it could be adapted to support it. The core functions such as ggml_nbytes, ggml_type_size, ggml_block_size, ggml_row_size, etc, can be changed if necessary, but should still work with all types.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented May 15, 2025

Unexpectedly SLOW performance on Apple M4 MAX for Llama-3-8b-EfficientQAT-w2g128-GPTQ compared to AGX Orin.

@Zijie-Tian It may be some bug. We will test on Apple devices later.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented May 15, 2025

@slaren Okay!

can you explain briefly what was the motivation of implementing this as an extra buffer type in the CPU backend?

For the newly-added data types, we need to re-order the weight layout for efficient LUT computation, as well as fitting the float-type of scales. In the previous PR, you mentioned #10446 (amx) as an example to pack the weights. So I imitate amx's implementation to add the new buffer type.
It may be a viable option to do it in the convert_hf_to_gguf script. However, the conversion should also be done to other types such as Q4_0 and TQ types that we can support because we need the specific weight layout.

If the way quantization types are currently represented in ggml is not compatible with T-MAC, we would need to investigate in which ways it could be adapted to support it.

We have studied the existing ggml types. Our conclusion is that i-quant cannot be supported by current T-MAC method because i-quant is a vector-quant method, while T-MAC is for scalar-quant.
For k-quant and some simple quant types, the key problem is that we need group_size >= 32 to achieve good efficiency, while most existing types are using (innermost) group_size=16.

The core functions such as ggml_nbytes, ggml_type_size, ggml_block_size, ggml_row_size, etc, can be changed if necessary, but should still work with all types.

Here we have two main differences.

  1. For types except for TMAC_BN_0, we don't store the weights as "block", but store all the weights of a tensor continuously, then all scales and zero points. This will not affect size computing or nbytes computing, but may confuse functions related to nb, such as is_coutinuous.
  2. For TMAC_BN_0, one tensor has only one scale, which will affect the existing logics about block even for size computing. So far I put a patch in ggml_tmac_get_nbytes in lut_mul_mat.cpp.

In our first implementation, this change was indeed put in ggml functions like ggml_nbytes. The reason I moved it out is the compiling encapsulation. ggml is lower than ggml-cpu, so if we want to implement in ggml.c, some T-MAC related logics should move, too. I'm not sure whether it's proper or not.

@QingtaoLi1
Copy link
Author

@Zijie-Tian I've tested the w2g128 model on M2 Ultra, and seems it works well now.

@slaren
Copy link
Member

slaren commented May 22, 2025

I have been thinking about this, and I think it would be ok to add new tensor types that do not conform exactly to the ggml structure of fixed size blocks organized sequentially in the tensor data stream. To do this however, we would need to add some safety checks to ensure that these types are not used in a unsupported way, for example, by forbidding creating views of tensors of these types. Functions like ggml_nbytes can be modified to handle these types as a special case.

About the use of extra buffer types, this is intended for cases where the standard layout of a type can be reorganized to perform better on the current hardware. If these types only have one layout, there is no need to use an extra buffer type, and the code should be integrated into the CPU backend normally. ggml_type_traits_cpu could be extended to define a custom matrix multiplication function for these types. Extra buffer types add complexity, increase load time and prevent using mmap, so they should not be used unless strictly necessary. All the processing and memory allocations that are currently being done in ggml_tmac_transform_tensor in response to a call to ggml_backend_tmac_buffer_set_tensor should be removed, the tensor data should be usable in the same format as it is stored on the model files. If necessary the conversion scripts should be modified to allow this.

On the last point, I do not think that we would want to add types that cannot be created using the tools in ggml and llama.cpp. I would expect the quantization code to be integrated into ggml as well, and the types supported by tools such as llama-quantize. This would also be necessary to support these types in other backends and enable the tests from test-backend-ops.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented May 26, 2025

@slaren Thanks for your reply! Let me respond to some of your concerns.

I would expect the quantization code to be integrated into ggml as well, and the types supported by tools such as llama-quantize.

=============== Update ================
@slaren Sorry, I think I made a mistake here. The GPTQ format model can theoretically be converted and quantized by tools. I can think of one difference that GPTQ model does not store a weight in one single ".weight" tensor, but in three or four tensors ".qweight", ".scales", "qzeros" and an optional ".g_idx". This will break the per-tensor loop in llama-quantize and needs major modifications.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented May 26, 2025

we would need to add some safety checks to ensure that these types are not used in a unsupported way, for example, by forbidding creating views of tensors of these types. Functions like ggml_nbytes can be modified to handle these types as a special case.

That's great! We can try to move the codes and add some of the safety checks. We may need your reviewers to check if there is anything missed.

About the use of extra buffer types, this is intended for cases where the standard layout of a type can be reorganized to perform better on the current hardware.

I think this is our case. We want to adjust some parameters like tile size to achieve better performance on different devices and thread numbers. The weight layout will change accordingly.

Again, it has at least these trade-off choices. The controdictory factors are: complexity of data types, whether to convert weight layout in advance, running speed.

  1. Fix the tile size with slower running speed, estimated 10%~20% slower and more in some cases. Slightly more data types are needed to cover different weight shapes.
  2. Add an independent script to decide the parameters before model conversion script, as in Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method #10181, and add much more data types like W2G64_1_m512, W2G64_1_m256, W2G64_1_m320.
  3. Add an independent script to decide the parameters before model conversion script, as in PR-10181, and use a config file to tell the parameters to the runtime program. This is almost the way in PR-10181, and the data type can be reduced to only one or two. However, it needs major changes in ggml basic functions to support.

We think it better to keep the flexible layout for simpler implementation. The transformation does take some time. In my experiments on Intel i7-12700, now it's about 30sec for 3B model, 1min for 7B model, 2min for 14B model. What's your expected time on this?

@nigelzzz
Copy link

nigelzzz commented May 28, 2025

Hi @QingtaoLi1 , can we use test-backend-ops to test flops, i would like to show the result from paper image 6, image 7. thanks

and i checked the tmac repo, it can use tvm to generate optimized kernel code, how to do it in this pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.