Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

implement faster RoPE embedding#238

Merged
danielhanchen merged 1 commit intounslothai:mainunslothai/unsloth:mainfrom
HuyNguyen-hust:faster-ropeHuyNguyen-hust/unsloth:faster-ropeCopy head branch name to clipboard
Mar 15, 2024
Merged

implement faster RoPE embedding#238
danielhanchen merged 1 commit intounslothai:mainunslothai/unsloth:mainfrom
HuyNguyen-hust:faster-ropeHuyNguyen-hust/unsloth:faster-ropeCopy head branch name to clipboard

Conversation

@HuyNguyen-hust
Copy link
Copy Markdown
Contributor

@HuyNguyen-hust HuyNguyen-hust commented Mar 12, 2024

PR proposes a bit change to the current RoPE embedding kernel:

  • The current implementation launches 1 block for 1 head on axis 1. Each block has to reload the same sin/cos which is inefficient.
  • Reorganize grid that on axis 1, instead of launching a block for a head, I launch a block for a group of heads (4-8 heads). That enables loading sin/cos only once and reuse it to compute all the heads inside that block.

Benchmark with batch_size=4, head_dim=128, n_heads=32 (// 2 means BLOCK_SIZE=head_dim // 2. If not BLOCK_SIZE=head_dim):
image

The figure indicates that mine is more sensitive to BLOCK_SIZE.

@danielhanchen
Copy link
Copy Markdown
Member

Thanks @HuyNguyen-hust a lot! As per our discussion on Discord - I just want to say thank you again - super apprecitate this! Will do some tests on my end and I'll expedite this PR!

@danielhanchen
Copy link
Copy Markdown
Member

@HuyNguyen-hust I tested the kernel! Can confirm RoPE itself should be faster. The effect on a full training run though is less pronounced sadly, since through Pytorch's Profiler, RoPE itself now takes around 1% of the total runtime, with matrix multiplications taking the bulk of the time. DPO for eg - with your RoPE fix: 1553 seconds. Original: 1542 seconds. So within the margin of error. This was on Colab T4, so I'm pretty sure A100s get more noticeable effects.

However, your kernel works absolute wonders when long sequence lengths come into play! The RoPE kernel does creep up to around 2-3% of the total runtime, which means savings are well deserved!

Thanks so much for wonderful contribution - added this in! :)

I'll probably play around with the group size - it seems like this might be an auto-tunable number!!!

@danielhanchen danielhanchen merged commit 809bdbe into unslothai:main Mar 15, 2024
@chiennv2000
Copy link
Copy Markdown

awesome @HuyNguyen-hust, congrats on your great work!

1 similar comment
@hieule88
Copy link
Copy Markdown

hieule88 commented Apr 9, 2024

awesome @HuyNguyen-hust, congrats on your great work!

@mohsen202
Copy link
Copy Markdown

thanks

@namng194
Copy link
Copy Markdown

cool :O

@ngocbh
Copy link
Copy Markdown

ngocbh commented Sep 21, 2024

Congrats @HuyNguyen-hust! Great contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.