Rate this Page

torch.bmm#

torch.bmm(input, mat2, out_dtype=None, *, out=None) Tensor#

Performs a batch matrix-matrix product of matrices stored in input and mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a (b×n×m) tensor, mat2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

outi=inputi@mat2i

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Parameters
  • input (Tensor) – the first batch of matrices to be multiplied

  • mat2 (Tensor) – the second batch of matrices to be multiplied

  • out_dtype (dtype, optional) – the dtype of the output tensor, Supported only on CUDA and for torch.float32 given torch.float16/torch.bfloat16 input dtypes

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
Morty Proxy This is a proxified and sanitized view of the page, visit original site.