MPS CI Support#1278
MPS CI Support#1278jlarson4 merged 29 commits intoTransformerLensOrg:devTransformerLensOrg/TransformerLens:devfrom huseyincavusbi:feat/mps-ci-supporthuseyincavusbi/TransformerLens:feat/mps-ci-supportCopy head branch name to clipboard
Conversation
* Fix type of HookedTransformerConfig.device This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on TransformerLensOrg#1219 * more cleanup * 3.0 CI Bugs (TransformerLensOrg#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net>
TransformerLens 3.1.0
There was a problem hiding this comment.
Pull request overview
Adds Apple Silicon MPS coverage to CI by introducing an MPS-specific test suite and a macOS GitHub Actions job, alongside device-selection tweaks to make MPS opt-in by default.
Changes:
- Added a new
tests/mpssmoke-test suite that validates basic tensor ops and a smallHookedTransformerrun on MPS. - Added an
mps-checksGitHub Actions job onmacos-latestto run unit/integration tests plus the new MPS smoke tests on PRs tomainand pushes tomain. - Updated device utilities and configs to better support MPS opt-in behavior, plus proactive
torch.mps.empty_cache()cleanup in pytest fixtures.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| transformer_lens/utilities/devices.py | Adjusts device selection behavior/signatures to support MPS opt-in and updated warning typing. |
| transformer_lens/train.py | Updates training config typing and default device assignment. |
| transformer_lens/config/HookedTransformerConfig.py | Uses get_device() directly when defaulting cfg.device. |
| tests/unit/utilities/test_devices.py | Updates device utility unit tests for the new get_device() return type. |
| tests/mps/test_mps_basic.py | Adds MPS-only smoke tests covering device detection, core ops, and small-model forward/cache paths. |
| tests/mps/init.py | Declares the MPS test package. |
| tests/conftest.py | Adds MPS cache clearing after tests/classes/session to reduce CI OOM risk. |
| pyproject.toml | Registers a no_mps pytest marker. |
| .github/workflows/checks.yml | Adds the mps-checks CI job that runs on macOS and executes the MPS tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi @jlarson4, I've updated the PR to address the automated feedback:
|
|
Hi @jlarson4, here are the latest changes:
Thanks for the review! |
|
Great work @huseyincavusbi! This looks good to me. One followup thought on the GQA test(s) that are still not working as expected: GQA precision tests are a test-strictness issue. The diffs are last-bit fp32 noise (33226.9414 vs 33226.9453). |
|
Good point @jlarson4 ! I was too focused on getting them to pass without failure, but this is definitely a better approach. On a related note, I noticed some notebook tests also fail intermittently due to similar floating point noise across different runs, even without code changes. For example, this run. Also, I'm seeing Node.js 20 deprecation warnings. I can update the CI workflows to use the latest versions. |
|
@huseyincavusbi If you want to tackle those in a follow up PR, feel free! The Notebook Checks that fail intermittently have given me pause. I want to keep some level of numerical comparison, but I haven't searched very deeply for a possible solution. Let me know if you find anything, I am going to merge this work as done! Thanks for taking this on, it looks great! |
|
Thanks for the merge, @jlarson4! I really enjoyed working on this. |
Hi @jlarson4,
This PR implements MPS (Metal Performance Shaders) CI Runner Support as proposed in #1264.
The goal is to provide automated testing for the Apple Silicon research community while working within the limits of GitHub's Mac runners.
Key Changes:
tests/mps/test_mps_basic.pywith 11 smoke tests covering device detection, core tensor ops on Metal, andHookedTransformerforward passes/caching with small models (TinyStories-1M).mps-checksjob in.github/workflows/checks.yml. It usesmacos-latestand runs only on PRs/pushes to main.tests/conftest.pyto proactively clear the MPS cache after every test usingtorch.mps.empty_cache().model_bridge) to ensure stability.TRANSFORMERLENS_ALLOW_MPS=1to ensure safe defaults for Mac users.Type of change
Checklist: