Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

ckvermaAI
Copy link
Contributor

Issue: If the input to the forward of Linear4bit is torch.float32 dtype and compute_dtype is set to torch.bfloat16 dtype, then the matmul operation executes in torch.float32 dtype. This issue reproduces on CPU and HPU (Intel Gaudi).

Fix: During initialization, compute_type_is_set is set to False. In the forward pass, compute_dtype is set as per the input of the forward pass. Initializing compute_type_is_set as updated in this PR resolves this issue (and we can get rid of unnecessary casting operations)

Details:

Case I: No change

  1. In the end of Linear4bit forward, the x dtype is float32 and weight dtype is uint8
  2. From here, control goes to MatMul4bit forward,
    a) First, we are dequantizing the weights, output is bfloat16 dtype
    b) Then we are casting the dequantized weights as per input (which is in float32)
    c) and now, we use torch.nn.functional.linear

Case II: Using this change

  1. Because compute_type_is_set is True, we'll not update the compute_dtype as per the input to the forward
  2. Now, input (x) will be typecasted to compute_dtype (which is bfloat16).
  3. In the end of Linear4bit forward, the x dtype is bfloat16 and weight dtype is uint8
  4. From here, control goes to MatMul4bit forward,
    a) First, we are dequantizing the weights, output is bfloat16 dtype
    b) Then we are casting the dequantized weights as per input (which is in bfloat16)
    c) and now, we use torch.nn.functional.linear and both the inputs and weights are in bfloat16 dtype

@matthewdouglas matthewdouglas added Bug Something isn't working Cross Platform labels May 5, 2025
@matthewdouglas matthewdouglas self-assigned this May 5, 2025
Copy link

github-actions bot commented May 5, 2025

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas
Copy link
Member

Seems reasonable to me, thanks! We'll ignore the lint failures on this PR as that's unrelated.

@matthewdouglas matthewdouglas merged commit 5e267f5 into bitsandbytes-foundation:multi-backend-refactor May 5, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Bug Something isn't working Cross Platform

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.