Skip to content

Commit f8c6861

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS][BE] Introduce LookUpOrCreateCachedGraph (pytorch#99422)
A template that replaces following common pattern: ```cpp MPSGraphCache* cache_ = MPSGraphCache::getInstance(); CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key); if (!cachedGraph) { cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() { CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new PoolingCachedGraph(mpsGraph); ... } return newCachedGraph: ); } ``` with ```cpp auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { ... }); ``` Fixes memory leak in addmv_out_mps_impl, where new entires were added the cache without doing the lookup first. Pull Request resolved: pytorch#99422 Approved by: https://github.com/albanD, https://github.com/kulinseth
1 parent d29cf18 commit f8c6861

34 files changed

+4018
-5544
lines changed

aten/src/ATen/native/mps/OperationUtils.h

+19
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,25 @@ struct MPSGraphCache
281281

282282
};
283283

284+
// Common template for creating graph with a specified cache if missing
285+
template<typename T>
286+
inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
287+
auto cache_ = MPSGraphCache::getInstance();
288+
if (auto rc = cache_->LookUpAs<T>(key)) {
289+
return rc;
290+
}
291+
return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
292+
T* newCachedGraph = nil;
293+
@autoreleasepool {
294+
// Initialize graph
295+
auto mpsGraph = mps::make_mps_graph();
296+
newCachedGraph = new T(mpsGraph);
297+
instantiate(mpsGraph, newCachedGraph);
298+
}
299+
return newCachedGraph;
300+
});
301+
}
302+
284303
// Common math operations
285304
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
286305

aten/src/ATen/native/mps/operations/Activation.mm

+718-1,147
Large diffs are not rendered by default.

aten/src/ATen/native/mps/operations/BinaryOps.mm

+33-46
Original file line numberDiff line numberDiff line change
@@ -106,55 +106,42 @@ void binaryOpTensor(const Tensor& self,
106106
}
107107
}
108108

109-
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
110109
@autoreleasepool {
111110
string key = op_name + getTensorsStringKey({self, other, output_});
112-
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph*>(cache_->LookUp(key));
113-
114-
if (!cachedGraph) {
115-
MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
116-
BinaryOpCachedGraph* newCachedGraph = nil;
117-
@autoreleasepool {
118-
MPSGraph* mpsGraph = make_mps_graph();
119-
newCachedGraph = new BinaryOpCachedGraph(mpsGraph);
120-
newCachedGraph->primaryTensor =
121-
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
122-
newCachedGraph->secondaryTensor =
123-
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
124-
125-
MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor;
126-
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;
127-
128-
// this type inference is only required at the time of graph creation
129-
ScalarType common_dtype = c10::promoteTypes(inputDataType, otherDataType);
130-
if (isIntegralType(common_dtype, true)) {
131-
// integer inputs must be cast to float, if output is float
132-
if (isFloatingType(outputDataType)) {
133-
common_dtype = outputDataType;
134-
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
135-
} else if (outputDataType == ScalarType::Bool &&
136-
(inputDataType == ScalarType::Byte || otherDataType == ScalarType::Byte)) {
137-
common_dtype = ScalarType::Byte;
138-
}
139-
}
140-
if (inputDataType != common_dtype) {
141-
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
142-
}
143-
if (otherDataType != common_dtype) {
144-
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
145-
}
146-
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
147-
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to
148-
// int32 tensor Output tensor should have been promoted but it remains an int32 tensor
149-
if (outputDataType != common_dtype ||
150-
[newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
151-
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
152-
}
111+
auto cachedGraph = LookUpOrCreateCachedGraph<BinaryOpCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
112+
newCachedGraph->primaryTensor =
113+
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
114+
newCachedGraph->secondaryTensor =
115+
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other));
116+
117+
MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor;
118+
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;
119+
120+
// this type inference is only required at the time of graph creation
121+
ScalarType common_dtype = c10::promoteTypes(inputDataType, otherDataType);
122+
if (isIntegralType(common_dtype, true)) {
123+
// integer inputs must be cast to float, if output is float
124+
if (isFloatingType(outputDataType)) {
125+
common_dtype = outputDataType;
126+
// in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type
127+
} else if (outputDataType == ScalarType::Bool &&
128+
(inputDataType == ScalarType::Byte || otherDataType == ScalarType::Byte)) {
129+
common_dtype = ScalarType::Byte;
153130
}
154-
return newCachedGraph;
155-
});
156-
cachedGraph = static_cast<BinaryOpCachedGraph*>(tmpCachedGraph);
157-
}
131+
}
132+
if (inputDataType != common_dtype) {
133+
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
134+
}
135+
if (otherDataType != common_dtype) {
136+
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
137+
}
138+
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
139+
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to
140+
// int32 tensor Output tensor should have been promoted but it remains an int32 tensor
141+
if (outputDataType != common_dtype || [newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) {
142+
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType);
143+
}
144+
});
158145

159146
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
160147
Placeholder selfPlaceholder;

aten/src/ATen/native/mps/operations/Blas.mm

+66-93
Original file line numberDiff line numberDiff line change
@@ -24,57 +24,43 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
2424
using CachedGraph = MPSBinaryCachedGraph;
2525
auto output = at::empty({}, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
2626

27-
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
28-
2927
MPSStream* stream = at::mps::getCurrentMPSStream();
3028

3129
@autoreleasepool {
3230
string key = "dot_mps" + getTensorsStringKey({self, other});
3331

34-
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
35-
if (!cachedGraph) {
36-
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
37-
CachedGraph* newCachedGraph = nil;
38-
39-
@autoreleasepool {
40-
MPSGraph* mpsGraph = mps::make_mps_graph();
41-
newCachedGraph = new CachedGraph(mpsGraph);
42-
43-
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
44-
MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
45-
46-
MPSGraphTensor* castSelf = nil;
47-
MPSGraphTensor* castOther = nil;
48-
49-
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
50-
self.scalar_type() == ScalarType::Char) {
51-
castSelf = [mpsGraph castTensor:selfTensor toType:MPSDataTypeInt32 name:@"castSelfTensor"];
52-
castOther = [mpsGraph castTensor:otherTensor toType:MPSDataTypeInt32 name:@"castOtherTensor"];
53-
} else {
54-
castSelf = selfTensor;
55-
castOther = otherTensor;
56-
}
57-
58-
MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf
59-
secondaryTensor:castOther
60-
name:@"multiplication"];
61-
62-
MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor:dot axes:nil name:@"dotProduct"];
63-
64-
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
65-
self.scalar_type() == ScalarType::Char)
66-
dotProductTensor = [mpsGraph castTensor:dotProductTensor
67-
toType:getMPSDataType(self)
68-
name:@"castDotProductTensor"];
69-
70-
newCachedGraph->inputTensor_ = selfTensor;
71-
newCachedGraph->otherTensor_ = otherTensor;
72-
newCachedGraph->outputTensor_ = dotProductTensor;
73-
}
74-
return newCachedGraph;
75-
});
76-
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
77-
}
32+
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
33+
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
34+
MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, other);
35+
36+
MPSGraphTensor* castSelf = nil;
37+
MPSGraphTensor* castOther = nil;
38+
39+
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
40+
self.scalar_type() == ScalarType::Char) {
41+
castSelf = [mpsGraph castTensor:selfTensor toType:MPSDataTypeInt32 name:@"castSelfTensor"];
42+
castOther = [mpsGraph castTensor:otherTensor toType:MPSDataTypeInt32 name:@"castOtherTensor"];
43+
} else {
44+
castSelf = selfTensor;
45+
castOther = otherTensor;
46+
}
47+
48+
MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf
49+
secondaryTensor:castOther
50+
name:@"multiplication"];
51+
52+
MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor:dot axes:nil name:@"dotProduct"];
53+
54+
if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte ||
55+
self.scalar_type() == ScalarType::Char)
56+
dotProductTensor = [mpsGraph castTensor:dotProductTensor
57+
toType:getMPSDataType(self)
58+
name:@"castDotProductTensor"];
59+
60+
newCachedGraph->inputTensor_ = selfTensor;
61+
newCachedGraph->otherTensor_ = otherTensor;
62+
newCachedGraph->outputTensor_ = dotProductTensor;
63+
});
7864

7965
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
8066
Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
@@ -110,64 +96,51 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
11096
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
11197
auto betaval = beta_.toComplexDouble();
11298

113-
struct CachedGraph : public mps::MPSCachedGraph {
99+
struct CachedGraph : public MPSCachedGraph {
114100
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
115101
MPSGraphTensor* selfTensor_ = nil;
116102
MPSGraphTensor* matMulVecTensor_ = nil;
117103
MPSGraphTensor* outputTensor_ = nil;
118104
};
119-
mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance();
120105

121106
MPSStream* stream = at::mps::getCurrentMPSStream();
122107
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
123108

124109
@autoreleasepool {
125110
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
126111
":" + to_string(alpha_.toDouble());
127-
CachedGraph* cachedGraph = nil;
128-
if (!cachedGraph) {
129-
mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() {
130-
CachedGraph* newCachedGraph = nil;
131-
132-
@autoreleasepool {
133-
MPSGraph* mpsGraph = mps::make_mps_graph();
134-
newCachedGraph = new CachedGraph(mpsGraph);
135-
136-
MPSGraphTensor* matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
137-
MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
138-
139-
// Intermediates for beta and alpha
140-
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha_.toDouble()
141-
dataType:getMPSScalarType(mat.scalar_type())];
142-
143-
// Intermediates for multiplying by beta and alpha
144-
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:matMulVecTensor
145-
secondaryTensor:alphaTensor
146-
name:@"MM/alpha*(mat@vec)"];
147-
newCachedGraph->outputTensor_ = productTimesAlphaTensor;
148-
149-
if (betaval != 0.0) {
150-
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta_.toDouble()
151-
dataType:getMPSScalarType(self.scalar_type())];
152-
153-
MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
154-
secondaryTensor:betaTensor
155-
name:@"MM/beta*input"];
156-
157-
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
158-
secondaryTensor:selfTimesBetaTensor
159-
name:@"MM/beta*input + alpha*(mat@vec)"];
160-
161-
newCachedGraph->outputTensor_ = outputTensor;
162-
}
163-
164-
newCachedGraph->selfTensor_ = selfTensor;
165-
newCachedGraph->matMulVecTensor_ = matMulVecTensor;
166-
}
167-
return newCachedGraph;
168-
});
169-
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
170-
}
112+
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
113+
MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
114+
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
115+
116+
// Intermediates for beta and alpha
117+
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha_.toDouble()
118+
dataType:getMPSScalarType(mat.scalar_type())];
119+
120+
// Intermediates for multiplying by beta and alpha
121+
MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:matMulVecTensor
122+
secondaryTensor:alphaTensor
123+
name:@"MM/alpha*(mat@vec)"];
124+
newCachedGraph->outputTensor_ = productTimesAlphaTensor;
125+
126+
if (betaval != 0.0) {
127+
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta_.toDouble()
128+
dataType:getMPSScalarType(self.scalar_type())];
129+
130+
MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
131+
secondaryTensor:betaTensor
132+
name:@"MM/beta*input"];
133+
134+
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
135+
secondaryTensor:selfTimesBetaTensor
136+
name:@"MM/beta*input + alpha*(mat@vec)"];
137+
138+
newCachedGraph->outputTensor_ = outputTensor;
139+
}
140+
141+
newCachedGraph->selfTensor_ = selfTensor;
142+
newCachedGraph->matMulVecTensor_ = matMulVecTensor;
143+
});
171144

172145
Placeholder matMulVecPlaceholder = Placeholder(cachedGraph->matMulVecTensor_, matMulVec);
173146
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
@@ -182,7 +155,7 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
182155
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
183156
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
184157

185-
mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
158+
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
186159
}
187160

188161
return result;

aten/src/ATen/native/mps/operations/ConstantOps.mm

+23-36
Original file line numberDiff line numberDiff line change
@@ -22,45 +22,32 @@
2222
MPSGraphTensor* outputTensor_ = nil;
2323
};
2424

25-
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
26-
2725
@autoreleasepool {
2826
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
2927

30-
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
31-
if (!cachedGraph) {
32-
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
33-
CachedGraph* newCachedGraph = nil;
34-
35-
@autoreleasepool {
36-
MPSGraph* mpsGraph = make_mps_graph();
37-
newCachedGraph = new CachedGraph(mpsGraph);
38-
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
39-
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
40-
auto dataType =
41-
!isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
42-
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
43-
// workaround by filing it as int8 tensor and than casting to bool
44-
// See https://github.com/pytorch/pytorch/issues/82427
45-
// constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview
46-
// workaround by filing it as uint32 tensor and than casting to uint8
47-
// See https://github.com/pytorch/pytorch/issues/83692
48-
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
49-
shape:getMPSShape(self)
50-
dataType:dataType];
51-
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
52-
if (isBool) {
53-
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
54-
}
55-
if (isUInt8) {
56-
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
57-
}
58-
59-
newCachedGraph->outputTensor_ = outputTensor;
60-
}
61-
return newCachedGraph;
62-
});
63-
}
28+
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
29+
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
30+
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
31+
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
32+
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
33+
// workaround by filing it as int8 tensor and than casting to bool
34+
// See https://github.com/pytorch/pytorch/issues/82427
35+
// constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview
36+
// workaround by filing it as uint32 tensor and than casting to uint8
37+
// See https://github.com/pytorch/pytorch/issues/83692
38+
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
39+
shape:getMPSShape(self)
40+
dataType:dataType];
41+
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
42+
if (isBool) {
43+
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
44+
}
45+
if (isUInt8) {
46+
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
47+
}
48+
49+
newCachedGraph->outputTensor_ = outputTensor;
50+
});
6451

6552
Placeholder outputPlaceholder =
6653
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);

0 commit comments

Comments
 (0)