Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fb38ab7

Browse filesBrowse files
Fix for MPS regression in #122016 and #123178 (#123385)
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds. Before the fix on MacOS 14.4: ``` python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)" tensor([0., 3., 3.], device='mps:0') ``` After the fix: ``` python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)" tensor([0., 1., 3.], device='mps:0') ``` This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+ Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: #123234 Approved by: https://github.com/malfet, https://github.com/kulinseth (cherry picked from commit 05289a2) Co-authored-by: Joona Havukainen <jhavukainen@apple.com>
1 parent 23961ce commit fb38ab7
Copy full SHA for fb38ab7

File tree

Expand file treeCollapse file tree

2 files changed

+9
-20
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+9
-20
lines changed

‎aten/src/ATen/native/mps/operations/ConstantOps.mm

Copy file name to clipboardExpand all lines: aten/src/ATen/native/mps/operations/ConstantOps.mm
+8-20Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,43 +28,31 @@
2828

2929
struct CachedGraph : public MPSCachedGraph {
3030
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
31+
MPSGraphTensor* inputTensor_ = nil;
3132
MPSGraphTensor* outputTensor_ = nil;
3233
};
3334

3435
@autoreleasepool {
3536
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
3637

3738
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
38-
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
39-
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
40-
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
41-
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
42-
// workaround by filing it as int8 tensor and than casting to bool
43-
// See https://github.com/pytorch/pytorch/issues/82427
44-
// constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview
45-
// workaround by filing it as uint32 tensor and than casting to uint8
46-
// See https://github.com/pytorch/pytorch/issues/83692
47-
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
48-
shape:getMPSShape(self)
49-
dataType:dataType];
39+
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
5040
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
51-
if (isBool) {
52-
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
53-
}
54-
if (isUInt8) {
55-
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
56-
}
57-
41+
newCachedGraph->inputTensor_ = inputTensor;
5842
newCachedGraph->outputTensor_ = outputTensor;
5943
});
6044

45+
auto mpsScalar = getMPSScalar(value, self.scalar_type());
46+
auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar);
47+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : mpsScalarData};
48+
6149
Placeholder outputPlaceholder =
6250
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);
6351

6452
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
6553
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
6654

67-
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results);
55+
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
6856

6957
if (needsCopyToOutput) {
7058
self.copy_(output);

‎test/test_mps.py

Copy file name to clipboardExpand all lines: test/test_mps.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def mps_ops_modifier(ops):
412412
'mean',
413413
'ne',
414414
'neg',
415+
'nn.functional.rms_norm',
415416
'nn.functional.padconstant',
416417
'nn.functional.padreflect',
417418
'nn.functional.padreplicate',

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.