Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Discussion options

Hi @rasbt,

Please checkout this MHSA implementation. If you like it, I can add it to the repo. If you love it and decide to include it in the book then I am willing to make as many changes as necessary to get contributor credits 😎

import torch
import torch.nn as nn


class MultiHeadedSelfAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        ctx_len,
        attn_drop=0.0,
        proj_drop=0.0,
        qkv_bias=False,
    ) -> None:
        super().__init__()

        assert embed_dim % num_heads == 0, "embed_dim is indivisible by num_heads"

        self.num_heads = num_heads
        self.ctx_len = ctx_len
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim**-0.5

        self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=qkv_bias)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.register_buffer(
            "mask", torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1)
        )

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # B, S, E --> B, S, 3 * E
        qkv = self.qkv(x)
        # B, S, 3 * E --> B, S, 3, num_heads, head_dim
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        # B, S, 3, num_heads, head_dim --> 3, B, num_heads, S, head_dim
        qkv = qkv.permute(2, 0, 3, 1, 4)
        # B, num_head, S, head_dim
        q, k, v = qkv.unbind(0)

        # B, num_head, S, head_dim --> B, num_heads, S, S
        attn_scores = q @ k.transpose(-2, -1)
        attn_scores = attn_scores.masked_fill(
            self.mask.bool()[:seq_len, :seq_len], -torch.inf
        )

        attn_weights = torch.softmax(attn_scores / self.scale, dim=-1)
        attn_weights = self.attn_drop(attn_weights)

        # B, num_heads, S, S --> B, num_heads, S, head_dim
        context_vec = attn_weights @ v
        # B, num_heads, S, head_dim --> B, S, num_heads, head_dim
        context_vec = context_vec.transpose(1, 2)
        # B, S, num_heads, head_dim --> B, S, E
        context_vec = context_vec.reshape(batch_size, seq_len, embed_dim)

        x = self.proj(context_vec)
        x = self.proj_drop(x)

        return x
You must be logged in to vote

Replies: 1 comment · 9 replies

Comment options

Thanks for this! I do like your implementation. My original implementation also used some similar more compact approaches ... but after trying to explain it in the text, and to make it a bit more accessible for beginners, I evolved that more into the verbose version I currently have.

However, perhaps this could be an optional alternative implementation. Perhaps this could be inside a ch03/02_alt_mha_implementations folder. And I could add some more variants (incl. flash attention) later.

You must be logged in to vote
9 replies
@d-kleine
Comment options

There is also a very nice step-by-step implementation from Harvard NLP researchers, but it's very detailed.


Could you please add CUDA support for the MHA implementations please too?

It won't change the relative time deltas between the models, just speeds up the process when running the notebook with a supported GPU/TPU/NPU 😃

Like for this for instance:

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(123)

batch_size = 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_len, embed_dim))

embeddings = embeddings.to(device)

...

mha_alt = MultiHeadAttentionAlt(
    d_in=embed_dim,
    d_out=embed_dim,
    block_size=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
)

mha_alt = mha_alt.to(device)

out = mha_alt(embeddings)
print(out.shape)

BTW: The Pytorch MHA implementation with Flash Attn 2 is more than 2x faster, that is incredible!

@rasbt
Comment options

rasbt Mar 8, 2024
Maintainer

Good call, I actually already had the file prepared like this anyway 😅. Uploaded it

@rasbt
Comment options

rasbt Mar 8, 2024
Maintainer

I also added the MHA class from PyTorch (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).

On GPU it falls between our two implementations:

Screenshot 2024-03-08 at 9 32 16 AM

On CPU, it's a tad slower:

Screenshot 2024-03-08 at 9 32 48 AM
@d-kleine
Comment options

Interesting, it looks like torch.nn.MultiheadAttention() does not (yet) use Flash Attention, right?

@rasbt
Comment options

rasbt Mar 8, 2024
Maintainer

Yeah, I think it doesn't. In fact, I've also never seen anyone use this torch.nn.MultiheadAttention()class in practice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.