Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Layer-wise Inference that reduce greatly reduce memory usage #4310

sorasoras started this conversation in Ideas
Discussion options

AirLLM optimizes inference memory usage, allowing 70B large language models to run inference on a single 4GB GPU card. No quantization, distillation, pruning or other model compression techniques that would result in degraded model performance are needed.
https://github.com/lyogavin/Anima/tree/main/air_llm
Kind of interesting method for inference reducing memory usage
https://huggingface.co/blog/lyogavin/airllm
I was wondering if this can be implement in llama.cpp in some way so that make small vram GPU usable.

You must be logged in to vote

Replies: 3 comments · 5 replies

Comment options

This has been supported forever in llama.cpp.
Build with CUBLAS and set -ngl 0 - am I missing something?

You must be logged in to vote
5 replies
@cmp-nct
Comment options

Yep that still works, though it's hardly optimized for the use case.
Though I'm not sure if there is any practical use in it, even if optimized that's so slow you are better off doing it on CPU.

I just tried it, 2.77 tokens/sec on GPU (4090) and 2.72 tokens on CPU (I9)

For modern cards with fp8e4 cublas (that support wasn't added yet in this rep) I suppose it would beat CPU, given it's 2-3 times faster in processing than fp16. It's still a strange use case

@ggerganov
Comment options

For modern cards with fp8e4 cublas (that support wasn't added yet in this rep)

How much work is it to add support for fp8 cuBLAS GEMM?

@cmp-nct
Comment options

It wasn't that bad, though the code base now is different than it was half a year ago.
CUDA supports two native types of fp8 with different precision and range, I chose the FP8 E4 M3 variant as likely the better suited one (the other one is FP8 E5 M2):
typedef unsigned char __fp8e4;

You need to use their newer API, it's included in cublas:

#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cublasLt.h>

The usual conversion function in ggml_cuda is
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
So I added
typedef void (*to_fp8e4_cuda_t)(const void * x, __fp8e4 * y, int k, cudaStream_t stream);

Then I modified the dequantize kernels to return fp8e4, it's very similar to supporting half/fp16 with this small conversion routine:

#define __FLOAT2FP8E4(x) __nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3)
//CUDA_R_8F_E4M3 - saturation to maxnorm instead of nan/inf
static __global__ void float_to_fp8e4(const float* src,  __fp8e4 * dst, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        dst[idx] = __nv_cvt_float_to_fp8((src[idx]), __NV_SATFINITE, __NV_E4M3);
    }
}

Example of a dequantize kernel modification:

static __device__ void dequantize_q4_1_f8e4(const void * __restrict__ vx, const int ib, const int iqs, __fp8e4 & v0, __fp8e4 & v1){
    const block_q4_1 * x = (const block_q4_1 *) vx;

    const __fp8e4 d = __FLOAT2FP8E4(x[ib].dm.x);
    const __fp8e4 m = __FLOAT2FP8E4(x[ib].dm.y);
    const dfloat d_float = x[ib].dm.x;
    const dfloat m_float = x[ib].dm.y;

    const uint8_t vui = x[ib].qs[iqs];

    const int8_t vi0 = vui & 0xF;
    const int8_t vi1 = vui >> 4;

    v0 = __FLOAT2FP8E4(vi0*d_float + m_float);
    v1 = __FLOAT2FP8E4(vi1*d_float + m_float);
}

Just like using __float2half()
The same on all other kernels, it's a few lines so they internally work the same but return the FP8.

Then I created a wrapper to cublas that takes 1:1 the same arguments as usual, but uses the new API function internally
ggml_cuda_op_mul_mat_cublas_f8e4_f32(src0, src1, dst, src0_ddq_i, src0_ddf_i_half, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i1, cudaStream_main);

Internally in the wrapper I did a quick conversion of src1 into fp8 using a on-the-fly kernel:

    __fp8e4* src1_ddf_i_fp8e4;
    size_t src1_size = ne10 * ne11 * sizeof(__fp8e4);
    size_t actual_size = 0;
    src1_ddf_i_fp8e4 = (__fp8e4 *)ggml_cuda_pool_malloc(src1_size,&actual_size); // 16 bit alignment is required (malloc guarantees it)
    float_to_fp8e4<<<(ne10 * ne11 + 255) / 256, 256, 0, cudaStream_main>>>(src1_ddf_i, src1_ddf_i_fp8e4, ne10 * ne11);
    CUDA_CHECK(cudaStreamSynchronize(cudaStream_main));

The new API looks like this:

 ggllm_lggemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, // transposed a, transposed b
            i01_diff, ne11, ne10, // m, n, k
            &alpha, 
            src0_ddf_i, CUDA_R_8F_E4M3, ne00, // alpha, A, Atype, lda
            src1_ddf_i_fp8e4, CUDA_R_8F_E4M3, ne10, // B, Btype, ldb
            &beta,  
            dst_ddf_i,  CUDA_R_32F, ldc, // beta, C, Ctype
            CUBLAS_COMPUTE_32F, // computeType
            CUBLAS_GEMM_DEFAULT,cudaStream_main); 
    (void) dst;
    (void) src0_ddq_i;
    (void) i02;
    (void) i1;

I'm not a cuda developer, so it took me a few days but in the end the result was good.
The memory footprint of using cublas went down 4 times (32 -> 8) and the performance increased significantly.
I also had the same game for fp16/half, so in my 4090/3090 tensor split situation I could run 16 bit on 3090 and 8 bit on 4090

all kernels (dequantize, code are still here but the codebase is outdated since july:
https://github.com/cmp-nct/ggllm.cpp/blob/ggfalcon_dev/ggml-cuda.cu

I couldn't keep up with the massive speed of llama.cpp as new projects knocked my door and I had a vacation, though quite a few parts of ggllm.cpp are probably still a bit ahead. Would be nice to see something of it being useful.
The kernels were all tested and worked fine

@ggerganov
Comment options

Thank you very much for the detailed response! Much appreciate!

This looks like something that we should try integrating at some point in llama.cpp.

@rudiservo
Comment options

Hi, is this in the roadmap or is there a better alternative?

Comment options

Yep that still works, though it's hardly optimized for the use case. Though I'm not sure if there is any practical use in it, even if optimized that's so slow you are better off doing it on CPU.

I just tried it, 2.77 tokens/sec on GPU (4090) and 2.72 tokens on CPU (I9)

For modern cards with fp8e4 cublas (that support wasn't added yet in this rep) I suppose it would beat CPU, given it's 2-3 times faster in processing than fp16. It's still a strange use case
It's not that strange use case. I have few people who want to run Q4 13B qwen model on GPU but 8G is not gonna be enough. Partials load layer on GPU is just too slow.
I also tried to run 70B llama2 2bit on a single 24GB gpu but 3token at most with partial load. I think layer wise inference is gonna be better than partial run on CPU. Anyway, Thanks for the help.

You must be logged in to vote
0 replies
Comment options

Doesn't airllm benefit a lot from Apple unified memory in particular? This seems to lower memory use and speeds up inference.

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
💡
Ideas
Labels
None yet
5 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.