Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Comments

Close side panel

Inject custom attention op into MultiHeadAttention#408

Merged
JRosenkranz merged 4 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom
paged_attn_mock_api_change2foundation-model-stack/foundation-model-stack:paged_attn_mock_api_change2Copy head branch name to clipboard
May 16, 2025
Merged

Inject custom attention op into MultiHeadAttention#408
JRosenkranz merged 4 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom
paged_attn_mock_api_change2foundation-model-stack/foundation-model-stack:paged_attn_mock_api_change2Copy head branch name to clipboard

Conversation

@JRosenkranz
Copy link
Collaborator

@JRosenkranz JRosenkranz commented May 12, 2025

This PR removes the assumption that fms using SDPA as the backend for MultiHeadAttention by introducing an AttentionKwargs TypedDict construct to forward. It has complete backwards compatibility with the prior API as mask and attn_algorithm are part of the SDPAAttentionKwargs (extended AttentionKwargs) which can be passed in the same by name.

Within attention, the attention implementation is chosen based on a get_attention_type call (providing the AttentionKwargs input). Attention types can be registered through the register_attention_op method (This will give other repos including fms the ability to introduce their own attention type)

…on SDPA

Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
Signed-off-by: Antoni Viros i Martin <aviros@ibm.com>
@ani300 ani300 force-pushed the paged_attn_mock_api_change2 branch from 143bb1b to a9f34a3 Compare May 16, 2025 13:53
fms/models/roberta.py Show resolved Hide resolved
mask = torch.where(mask.logical_not(), -torch.inf, 0.0)

padding_kwargs["mask"] = mask
# FIXME: this method should be per attn type (for now default it)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like it is already dependent on attn type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think my thoughts here was we could eventually add an attn_name param to pad_input_ids which will provide you with the right prefill padding metadata per attention type. Did not include that as part of this PR

fms/modules/attention.py Show resolved Hide resolved
fms/modules/attention.py Show resolved Hide resolved
fms/modules/attention.py Outdated Show resolved Hide resolved
fms/modules/attention.py Show resolved Hide resolved
fms/modules/attention.py Show resolved Hide resolved
fms/modules/attention.py Outdated Show resolved Hide resolved
fms/modules/attention.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, a few more comments documenting the API better and some variables renaming will bring it to a good place

Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
@JRosenkranz JRosenkranz merged commit 47569ee into main May 16, 2025
4 checks passed
@ani300 ani300 deleted the paged_attn_mock_api_change2 branch September 9, 2025 20:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.