Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Comments

Close side panel

head_dim expansion managed at get_model() level#476

Merged
rzbhatti merged 8 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom
override_head_dimfoundation-model-stack/foundation-model-stack:override_head_dimCopy head branch name to clipboard
Oct 27, 2025
Merged

head_dim expansion managed at get_model() level#476
rzbhatti merged 8 commits intomainfoundation-model-stack/foundation-model-stack:mainfrom
override_head_dimfoundation-model-stack/foundation-model-stack:override_head_dimCopy head branch name to clipboard

Conversation

@rzbhatti
Copy link
Contributor

@rzbhatti rzbhatti commented Oct 14, 2025

This PR does the following:

  1. Moves the weights expansion adapter function _weight_expansion_for_mismatched_head_dim to serialization.py, so that it can be registered with other models like llama 1b, gpt_oss, etc too.

  2. Removes the weights expansion adapter registration from granite.py. This allows the application layer (at get_mode() level) to decide when the weights expansion should be done. e.g.

if args.device_type == "aiu" and args.head_dim is not None:
    serialization.extend_adapter("granite", "hf", ["weight_expansion_for_mismatched_head_dim"])
  1. Adds an optional argument override_hf_pretrained_config to the get_mode(), which allows overriding model config parameters, like head_dim=128 as a kwarg, when architecture = hf_pretrained. e.g. in inference.py :
    model = get_model(
        args.architecture,
        args.variant,
        model_path=args.model_path,
        device_type="cpu" if is_aiu_backend else args.device_type,
        data_type=default_dtype,
        source=args.model_source,
        distributed_strategy=distr_param,
        group=dist.group.WORLD,
        linear_config=linear_config,
        fused_weights=fused_weights,
        force_override_config=True,
        override_hf_pretrained_config=True if args.device_type == "aiu" and args.head_dim is not None else False, 
        head_dim=args.head_dim,
    )

This is how inference is run from the command line:

python3 ./aiu-fms-testing-utils/scripts/inference.py \
  --architecture=hf_pretrained \
  --variant=ibm-granite/granite-3.3-2b-instruct \
  --tokenizer=ibm-granite/granite-3.3-2b-instruct \
  --device_type=aiu \
  --unfuse_weights \
  --compile_dynamic \
  --compile \
  --default_dtype=fp16 \
  --fixed_prompt_length=64 \
  --max_new_tokens=20 \
  --timing=per-token \
  --batch_size=1 \
  --head_dim=128

…ake head_dim as kwarg

Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
@rzbhatti rzbhatti requested a review from ani300 October 14, 2025 18:30
Rashed Z. Bhatti, PhD added 2 commits October 14, 2025 18:32
Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
…` to `serialization.py`

Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
@rzbhatti rzbhatti changed the title override_hf_pretrained_config to allow get_model() take head_dim as kwarg head_dim expansion managed at get_model() Oct 15, 2025
…step

Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
@rzbhatti rzbhatti changed the title head_dim expansion managed at get_model() head_dim expansion managed at get_model() level Oct 15, 2025
…step

Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
@rzbhatti rzbhatti requested a review from JRosenkranz October 15, 2025 01:25
Comment on lines 711 to 717
if "attn.in_proj.query" in layer:
expansion_factor = (
model_config.head_dim
* model_config.nheads
// input_sd[layer].size(0)
)
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cannot assume that this will happen first before key or value. It might, but it's not guaranteed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might have to pick layer_dim first, and then have a second dictionary with whether you need to multiply by model_config.nheads or model_config.kvheads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the expansion factor calculation based on all QKV and Dense. When you get a chance, please take a look at it to resolve this conversation.

Rashed Z. Bhatti, PhD added 2 commits October 15, 2025 21:49
Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
# When emb_dim // nheads < head_dim, expand QKV and Dense Weights
def _weight_expansion_for_mismatched_head_dim(
input_sd: Mapping[str, Any], model_config
) -> Mapping[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may want to assert here that head_dim exists in the config. I don't believe all of the models have a head_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point, I can add an assertion.
The application that would register this adapter extenstion must make sure that the head_dim is either part of the model config, or passed as kwarg override.

Signed-off-by: Rashed Z. Bhatti, PhD <rzbhatti@us.ibm.com>
Copy link
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@rzbhatti rzbhatti merged commit c9bc7ee into main Oct 27, 2025
4 checks passed
@rzbhatti rzbhatti deleted the override_head_dim branch October 27, 2025 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.