Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

MPS CI Support#1278

Merged
jlarson4 merged 29 commits intoTransformerLensOrg:devTransformerLensOrg/TransformerLens:devfrom
huseyincavusbi:feat/mps-ci-supporthuseyincavusbi/TransformerLens:feat/mps-ci-supportCopy head branch name to clipboard
May 7, 2026
Merged

MPS CI Support#1278
jlarson4 merged 29 commits intoTransformerLensOrg:devTransformerLensOrg/TransformerLens:devfrom
huseyincavusbi:feat/mps-ci-supporthuseyincavusbi/TransformerLens:feat/mps-ci-supportCopy head branch name to clipboard

Conversation

@huseyincavusbi
Copy link
Copy Markdown
Contributor

@huseyincavusbi huseyincavusbi commented May 2, 2026

Hi @jlarson4,

This PR implements MPS (Metal Performance Shaders) CI Runner Support as proposed in #1264.

The goal is to provide automated testing for the Apple Silicon research community while working within the limits of GitHub's Mac runners.

Key Changes:

  • New Test Suite: Added tests/mps/test_mps_basic.py with 11 smoke tests covering device detection, core tensor ops on Metal, and HookedTransformer forward passes/caching with small models (TinyStories-1M).
  • CI Automation: Introduced the mps-checks job in .github/workflows/checks.yml. It uses macos-latest and runs only on PRs/pushes to main.
  • Memory Management:
    • Updated tests/conftest.py to proactively clear the MPS cache after every test using torch.mps.empty_cache().
    • Configured the CI to ignore memory-intensive modules (e.g., model_bridge) to ensure stability.
  • Opt-in Mechanism: Respects TRANSFORMERLENS_ALLOW_MPS=1 to ensure safe defaults for Mac users.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

brendanlong and others added 10 commits April 20, 2026 14:50
* Fix type of HookedTransformerConfig.device

This is typed as `Optional[str]` but sometimes returns `torch.device`.
Updated the code to just return the `str` instead of wrapping with a
device.

I'm not confident that every function which takes a device will
always be passed a string, so I didn't change functions like
warn_if_mps.

Found while working on TransformerLensOrg#1219

* more cleanup

* 3.0 CI Bugs (TransformerLensOrg#1261)

* Fixing `utils` imports

* skip gated notebooks on PR from forks

* Updating notebooks

* Ensure LLaMA only runs when HF_TOKEN is available

---------

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
Copilot AI review requested due to automatic review settings May 2, 2026 16:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Apple Silicon MPS coverage to CI by introducing an MPS-specific test suite and a macOS GitHub Actions job, alongside device-selection tweaks to make MPS opt-in by default.

Changes:

  • Added a new tests/mps smoke-test suite that validates basic tensor ops and a small HookedTransformer run on MPS.
  • Added an mps-checks GitHub Actions job on macos-latest to run unit/integration tests plus the new MPS smoke tests on PRs to main and pushes to main.
  • Updated device utilities and configs to better support MPS opt-in behavior, plus proactive torch.mps.empty_cache() cleanup in pytest fixtures.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
transformer_lens/utilities/devices.py Adjusts device selection behavior/signatures to support MPS opt-in and updated warning typing.
transformer_lens/train.py Updates training config typing and default device assignment.
transformer_lens/config/HookedTransformerConfig.py Uses get_device() directly when defaulting cfg.device.
tests/unit/utilities/test_devices.py Updates device utility unit tests for the new get_device() return type.
tests/mps/test_mps_basic.py Adds MPS-only smoke tests covering device detection, core ops, and small-model forward/cache paths.
tests/mps/init.py Declares the MPS test package.
tests/conftest.py Adds MPS cache clearing after tests/classes/session to reduce CI OOM risk.
pyproject.toml Registers a no_mps pytest marker.
.github/workflows/checks.yml Adds the mps-checks CI job that runs on macOS and executes the MPS tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread transformer_lens/utilities/devices.py
Comment thread transformer_lens/utilities/devices.py
Comment thread transformer_lens/train.py
Comment thread tests/mps/test_mps_basic.py Outdated
@huseyincavusbi
Copy link
Copy Markdown
Contributor Author

Hi @jlarson4, I've updated the PR to address the automated feedback:

  • API Stability: Reverted get_device() to return torch.device objects.
  • Type Checks: Updated type hints across model classes to resolve mypy failures.
  • CI Trigger: Strictly restricted mps-checks to the main branch

@jlarson4 jlarson4 mentioned this pull request May 4, 2026
7 tasks
Comment thread .github/workflows/checks.yml
Comment thread .github/workflows/checks.yml Outdated
Comment thread .github/workflows/checks.yml
Comment thread pyproject.toml Outdated
@huseyincavusbi
Copy link
Copy Markdown
Contributor Author

Hi @jlarson4, here are the latest changes:

  • Passed ${{ secrets.HF_TOKEN }} to all MPS job steps.
  • Switched to targeted ignores for the modules you listed. I also added 3 more skips based on CI failures: generation tests (test_bridge_generation.py, test_bridge_integration.py) hit NaNs on MPS. For GQA, create_hooked_encoder runs and passes, but I kept test_grouped_query_attention ignored. It fails because Metal float arithmetic differs slightly from CPU, breaking strict torch.equal(). All other model_bridge tests now run and pass.
  • Standardized to str and cleaned up stale imports.
  • Dropped the unused no_mps marker.

Thanks for the review!

@jlarson4
Copy link
Copy Markdown
Collaborator

jlarson4 commented May 6, 2026

Great work @huseyincavusbi! This looks good to me. One followup thought on the GQA test(s) that are still not working as expected:

GQA precision tests are a test-strictness issue. The diffs are last-bit fp32 noise (33226.9414 vs 33226.9453). torch.equal requires bit-exact equality. It should be safe to update the tests to torch.allclose(..., atol=1e-4, rtol=1e-4) in test_grouped_query_attention.py. Then re-include the test in MPS CI - should pass on CPU and MPS both.

@huseyincavusbi
Copy link
Copy Markdown
Contributor Author

Good point @jlarson4 ! I was too focused on getting them to pass without failure, but this is definitely a better approach.

On a related note, I noticed some notebook tests also fail intermittently due to similar floating point noise across different runs, even without code changes. For example, this run.
Would you like me to make those more flexible as well, or do you prefer keeping them strict?

Also, I'm seeing Node.js 20 deprecation warnings. I can update the CI workflows to use the latest versions.
What do you think about these?

@jlarson4
Copy link
Copy Markdown
Collaborator

jlarson4 commented May 7, 2026

@huseyincavusbi If you want to tackle those in a follow up PR, feel free!

The Notebook Checks that fail intermittently have given me pause. I want to keep some level of numerical comparison, but I haven't searched very deeply for a possible solution. Let me know if you find anything, I am going to merge this work as done! Thanks for taking this on, it looks great!

@jlarson4 jlarson4 merged commit 6a8f470 into TransformerLensOrg:dev May 7, 2026
23 checks passed
@huseyincavusbi
Copy link
Copy Markdown
Contributor Author

Thanks for the merge, @jlarson4! I really enjoyed working on this.
I'll dig into the notebook noise issue for a follow-up PR. I suspect giving them a little flexibility (similar to the allclose approach we used for GQA precision) would be the right balance to stabilize them while still keeping meaningful numerical checks. I'll also update the Node.js versions to clear deprecation warnings.
I'll be back soon with a PR to fix these!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.