|
28 | 28 |
|
29 | 29 | struct CachedGraph : public MPSCachedGraph {
|
30 | 30 | CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
| 31 | + MPSGraphTensor* inputTensor_ = nil; |
31 | 32 | MPSGraphTensor* outputTensor_ = nil;
|
32 | 33 | };
|
33 | 34 |
|
34 | 35 | @autoreleasepool {
|
35 | 36 | string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
|
36 | 37 |
|
37 | 38 | auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
38 |
| - auto isBool = self.scalar_type() == c10::ScalarType::Bool; |
39 |
| - auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte; |
40 |
| - auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32; |
41 |
| - // constantWithScalar does not work for boolTypes on MacOS-12.[34] |
42 |
| - // workaround by filing it as int8 tensor and than casting to bool |
43 |
| - // See https://github.com/pytorch/pytorch/issues/82427 |
44 |
| - // constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview |
45 |
| - // workaround by filing it as uint32 tensor and than casting to uint8 |
46 |
| - // See https://github.com/pytorch/pytorch/issues/83692 |
47 |
| - MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble() |
48 |
| - shape:getMPSShape(self) |
49 |
| - dataType:dataType]; |
| 39 | + MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); |
50 | 40 | MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
|
51 |
| - if (isBool) { |
52 |
| - outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"]; |
53 |
| - } |
54 |
| - if (isUInt8) { |
55 |
| - outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"]; |
56 |
| - } |
57 |
| - |
| 41 | + newCachedGraph->inputTensor_ = inputTensor; |
58 | 42 | newCachedGraph->outputTensor_ = outputTensor;
|
59 | 43 | });
|
60 | 44 |
|
| 45 | + auto mpsScalar = getMPSScalar(value, self.scalar_type()); |
| 46 | + auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar); |
| 47 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : mpsScalarData}; |
| 48 | + |
61 | 49 | Placeholder outputPlaceholder =
|
62 | 50 | Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);
|
63 | 51 |
|
64 | 52 | NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
|
65 | 53 | @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
|
66 | 54 |
|
67 |
| - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results); |
| 55 | + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); |
68 | 56 |
|
69 | 57 | if (needsCopyToOutput) {
|
70 | 58 | self.copy_(output);
|
|
0 commit comments