Inject custom attention op into MultiHeadAttention#408
Merged
JRosenkranz merged 4 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom May 16, 2025
paged_attn_mock_api_change2foundation-model-stack/foundation-model-stack:paged_attn_mock_api_change2Copy head branch name to clipboard
Merged
Inject custom attention op into MultiHeadAttention#408JRosenkranz merged 4 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom paged_attn_mock_api_change2foundation-model-stack/foundation-model-stack:paged_attn_mock_api_change2Copy head branch name to clipboard
JRosenkranz merged 4 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom
paged_attn_mock_api_change2foundation-model-stack/foundation-model-stack:paged_attn_mock_api_change2Copy head branch name to clipboard
Conversation
…on SDPA Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com> Signed-off-by: Antoni Viros i Martin <aviros@ibm.com>
143bb1b to
a9f34a3
Compare
Signed-off-by: Antoni Viros i Martin <aviros@ibm.com>
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
| mask = torch.where(mask.logical_not(), -torch.inf, 0.0) | ||
|
|
||
| padding_kwargs["mask"] = mask | ||
| # FIXME: this method should be per attn type (for now default it) |
Collaborator
There was a problem hiding this comment.
seems like it is already dependent on attn type?
Collaborator
Author
There was a problem hiding this comment.
yes, I think my thoughts here was we could eventually add an attn_name param to pad_input_ids which will provide you with the right prefill padding metadata per attention type. Did not include that as part of this PR
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
ani300
reviewed
May 16, 2025
ani300
approved these changes
May 16, 2025
Collaborator
ani300
left a comment
There was a problem hiding this comment.
mostly lgtm, a few more comments documenting the API better and some variables renaming will bring it to a good place
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR removes the assumption that fms using SDPA as the backend for MultiHeadAttention by introducing an AttentionKwargs TypedDict construct to forward. It has complete backwards compatibility with the prior API as
maskandattn_algorithmare part of the SDPAAttentionKwargs (extended AttentionKwargs) which can be passed in the same by name.Within attention, the attention implementation is chosen based on a
get_attention_typecall (providing the AttentionKwargs input). Attention types can be registered through theregister_attention_opmethod (This will give other repos including fms the ability to introduce their own attention type)