Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

(perf) generate_mask functions optimizations #3203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2025

Conversation

tafia
Copy link
Contributor

@tafia tafia commented May 16, 2025

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.

There were 8 failing tests on my system (Asus PN50-E1, Ryzen 4700U) which seem unrelated to these changes

run-checks errors
failures:
                                                                                                                                                                                                                    
---- tests::cube::kernel::conv2d::tests::nchw_to_nhwc_should_match_into_contiguous stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube::kernel::conv2d::tests::nchw_to_nhwc_should_match_into_contiguous' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not approx eq:
  => Position 0: 0.8401348 != 0.4808194
     diff (rel = +2.72e-1, abs = +3.59e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 1: 0.2639603 != 0.8517316
     diff (rel = +5.27e-1, abs = +5.88e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 2: 0.09077653 != 0.08293253
     diff (rel = +4.52e-2, abs = +7.84e-3), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 3: 0.64059836 != 0.18346767
     diff (rel = +5.55e-1, abs = +4.57e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 4: 0.3280799 != 0.057279117
     diff (rel = +7.03e-1, abs = +2.71e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
474865 more errors...
                                                                                                                                                                                                                    
---- tests::cube::kernel::conv2d::tests::nchw_to_nhwc_should_match_into_contiguous_regression stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube::kernel::conv2d::tests::nchw_to_nhwc_should_match_into_contiguous_regression' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not approx eq:
  => Position 288: 0.16456887 != 0.020139458
     diff (rel = +7.82e-1, abs = +1.44e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 289: 0.15798609 != 0.0511629
     diff (rel = +5.11e-1, abs = +1.07e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 290: 0.15135513 != 0.91991514
     diff (rel = +7.17e-1, abs = +7.69e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 291: 0 != 0.95116186
     diff (rel = +1.00e0, abs = +9.51e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
  => Position 292: 0.165027 != 0.054931425
     diff (rel = +5.01e-1, abs = +1.10e-1), tol (rel = +7.63e-6, abs = +1.88e-37)
13 more errors...
                                                                                                                                                                                                                    
---- tests::cube::tensor::f32_ty::remainder::tests::should_support_remainder_basic stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube::tensor::f32_ty::remainder::tests::should_support_remainder_basic' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not approx eq:
  => Position 5: 3 != 0
     diff (rel = +1.00e0, abs = +3.00e0), tol (rel = +7.63e-6, abs = +1.88e-37)
                                                                                                                                                                                                                    
---- tests::cube::tensor::f32_ty::remainder::tests::should_support_remainder_op stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube::tensor::f32_ty::remainder::tests::should_support_remainder_op' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not approx eq:
  => Position 5: 3 != 0
     diff (rel = +1.00e0, abs = +3.00e0), tol (rel = +7.63e-6, abs = +1.88e-37)
                                                                                                                                                                                                                    
---- tests::cube::tensor::f32_ty::vector_norm::tests::test_normalize stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube::tensor::f32_ty::vector_norm::tests::test_normalize' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not eq:
  => Position 1: 0.3333333 != 0.33333334
  => Position 3: 0.6666666 != 0.6666667
                                                                                                                                                                                                                    
---- tests::cube_fusion::tensor::f32_ty::remainder::tests::should_support_remainder_op stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube_fusion::tensor::f32_ty::remainder::tests::should_support_remainder_op' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not approx eq:
  => Position 5: 3 != 0
     diff (rel = +1.00e0, abs = +3.00e0), tol (rel = +7.63e-6, abs = +1.88e-37)
                                                                                                                                                                                                                    
---- tests::cube_fusion::tensor::f32_ty::remainder::tests::should_support_remainder_basic stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube_fusion::tensor::f32_ty::remainder::tests::should_support_remainder_basic' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not approx eq:
  => Position 5: 3 != 0
     diff (rel = +1.00e0, abs = +3.00e0), tol (rel = +7.63e-6, abs = +1.88e-37)
                                                                                                                                                                                                                    
---- tests::cube_fusion::tensor::f32_ty::vector_norm::tests::test_normalize stdout ----
                                                                                                                                                                                                                    
thread 'tests::cube_fusion::tensor::f32_ty::vector_norm::tests::test_normalize' panicked at crates/burn-wgpu/src/lib.rs:129:5:
Tensors are not eq:
  => Position 1: 0.3333333 != 0.33333334
  => Position 3: 0.6666666 != 0.6666667
                                                                                                                                                                                                                    
                                                                                                                                                                                                                    
failures:
    tests::cube::kernel::conv2d::tests::nchw_to_nhwc_should_match_into_contiguous
    tests::cube::kernel::conv2d::tests::nchw_to_nhwc_should_match_into_contiguous_regression
    tests::cube::tensor::f32_ty::remainder::tests::should_support_remainder_basic
    tests::cube::tensor::f32_ty::remainder::tests::should_support_remainder_op
    tests::cube::tensor::f32_ty::vector_norm::tests::test_normalize
    tests::cube_fusion::tensor::f32_ty::remainder::tests::should_support_remainder_basic
    tests::cube_fusion::tensor::f32_ty::remainder::tests::should_support_remainder_op
    tests::cube_fusion::tensor::f32_ty::vector_norm::tests::test_normalize
                                                                                                                                                                                                                    
test result: FAILED. 2967 passed; 8 failed; 16 ignored; 0 measured; 0 filtered out; finished in 60.44s
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

No particular issue, this as it is a simple performance related PR

Changes

There are 2 perf changes:

  • generate_autoregressive_mask: as per the TODO, use triangular tensors and expand instead of manually populating the mask
  • generate_padding_mask: avoid split_off calls and instead only take up to max_size elements when iterating

As explained in the commit messages, I am no expert in terms of code generation (for the autoregressive change). I don't really know if it ends up being better or not.

Testing

I have mostly generated a simple script to confirm that the tensors stay are the same.

use burn::backend::Wgpu;
use burn::nn::attention::generate_autoregressive_mask;
use burn::tensor::{Bool, Tensor};

fn main() {
    let ssz = 5;
    let bsz = 4;
    type MyBackend = Wgpu<f32, i32>;
    let device = burn::backend::wgpu::WgpuDevice::default();
    let t: Tensor<MyBackend, 2, _> = Tensor::tril_mask([ssz, ssz], 0, &device);
    let t: Tensor<MyBackend, 3, Bool> = t.expand([bsz, ssz, ssz]);
    let t2: Tensor<MyBackend, 3, Bool> = generate_autoregressive_mask(bsz, ssz, &device);
    println!("tensor: {t}, {t2}");
}

tafia added 2 commits May 16, 2025 15:49
This is merely doing what the TODO was suggesting (using tril instead of
triu).

I am not an expert on the generated code and I don't really know
how to validate that this is indeed more efficient.
@tafia tafia changed the title Autoregressive mask (perf) generate_mask function optimizations May 16, 2025
@tafia tafia changed the title (perf) generate_mask function optimizations (perf) generate_mask functions optimizations May 16, 2025
Copy link

codecov bot commented May 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.22%. Comparing base (db3fb49) to head (a3aaaad).
Report is 12 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3203      +/-   ##
==========================================
+ Coverage   82.19%   82.22%   +0.02%     
==========================================
  Files         962      962              
  Lines      122542   122528      -14     
==========================================
+ Hits       100727   100751      +24     
+ Misses      21815    21777      -38     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@antimora antimora requested a review from nathanielsimard May 16, 2025 18:32
@tafia
Copy link
Contributor Author

tafia commented May 19, 2025

The failing test is unrelated to the pr.

@nathanielsimard nathanielsimard merged commit d12aa69 into tracel-ai:main May 20, 2025
12 of 13 checks passed
@tafia tafia deleted the autoregressive_mask branch May 21, 2025 03:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.