@@ -24,57 +24,43 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
24
24
using CachedGraph = MPSBinaryCachedGraph;
25
25
auto output = at::empty ({}, self.scalar_type (), c10::nullopt, kMPS , c10::nullopt, c10::nullopt);
26
26
27
- MPSGraphCache* cache_ = MPSGraphCache::getInstance ();
28
-
29
27
MPSStream* stream = at::mps::getCurrentMPSStream ();
30
28
31
29
@autoreleasepool {
32
30
string key = " dot_mps" + getTensorsStringKey ({self, other});
33
31
34
- CachedGraph* cachedGraph = static_cast <CachedGraph*>(cache_->LookUp (key));
35
- if (!cachedGraph) {
36
- mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph (key, ^mps::MPSCachedGraph*() {
37
- CachedGraph* newCachedGraph = nil ;
38
-
39
- @autoreleasepool {
40
- MPSGraph* mpsGraph = mps::make_mps_graph ();
41
- newCachedGraph = new CachedGraph (mpsGraph);
42
-
43
- MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder (mpsGraph, self);
44
- MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder (mpsGraph, other);
45
-
46
- MPSGraphTensor* castSelf = nil ;
47
- MPSGraphTensor* castOther = nil ;
48
-
49
- if (self.scalar_type () == ScalarType::Short || self.scalar_type () == ScalarType::Byte ||
50
- self.scalar_type () == ScalarType::Char) {
51
- castSelf = [mpsGraph castTensor: selfTensor toType: MPSDataTypeInt32 name: @" castSelfTensor" ];
52
- castOther = [mpsGraph castTensor: otherTensor toType: MPSDataTypeInt32 name: @" castOtherTensor" ];
53
- } else {
54
- castSelf = selfTensor;
55
- castOther = otherTensor;
56
- }
57
-
58
- MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor: castSelf
59
- secondaryTensor: castOther
60
- name: @" multiplication" ];
61
-
62
- MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor: dot axes: nil name: @" dotProduct" ];
63
-
64
- if (self.scalar_type () == ScalarType::Short || self.scalar_type () == ScalarType::Byte ||
65
- self.scalar_type () == ScalarType::Char)
66
- dotProductTensor = [mpsGraph castTensor: dotProductTensor
67
- toType: getMPSDataType (self )
68
- name: @" castDotProductTensor" ];
69
-
70
- newCachedGraph->inputTensor_ = selfTensor;
71
- newCachedGraph->otherTensor_ = otherTensor;
72
- newCachedGraph->outputTensor_ = dotProductTensor;
73
- }
74
- return newCachedGraph;
75
- });
76
- cachedGraph = static_cast <CachedGraph*>(tmpCachedGraph);
77
- }
32
+ auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
33
+ MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder (mpsGraph, self);
34
+ MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder (mpsGraph, other);
35
+
36
+ MPSGraphTensor* castSelf = nil ;
37
+ MPSGraphTensor* castOther = nil ;
38
+
39
+ if (self.scalar_type () == ScalarType::Short || self.scalar_type () == ScalarType::Byte ||
40
+ self.scalar_type () == ScalarType::Char) {
41
+ castSelf = [mpsGraph castTensor: selfTensor toType: MPSDataTypeInt32 name: @" castSelfTensor" ];
42
+ castOther = [mpsGraph castTensor: otherTensor toType: MPSDataTypeInt32 name: @" castOtherTensor" ];
43
+ } else {
44
+ castSelf = selfTensor;
45
+ castOther = otherTensor;
46
+ }
47
+
48
+ MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor: castSelf
49
+ secondaryTensor: castOther
50
+ name: @" multiplication" ];
51
+
52
+ MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor: dot axes: nil name: @" dotProduct" ];
53
+
54
+ if (self.scalar_type () == ScalarType::Short || self.scalar_type () == ScalarType::Byte ||
55
+ self.scalar_type () == ScalarType::Char)
56
+ dotProductTensor = [mpsGraph castTensor: dotProductTensor
57
+ toType: getMPSDataType (self )
58
+ name: @" castDotProductTensor" ];
59
+
60
+ newCachedGraph->inputTensor_ = selfTensor;
61
+ newCachedGraph->otherTensor_ = otherTensor;
62
+ newCachedGraph->outputTensor_ = dotProductTensor;
63
+ });
78
64
79
65
Placeholder selfPlaceholder = Placeholder (cachedGraph->inputTensor_ , self);
80
66
Placeholder otherPlaceholder = Placeholder (cachedGraph->otherTensor_ , other);
@@ -110,64 +96,51 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
110
96
c10::MaybeOwned<Tensor> self_ = expand_size (self, {mat.size (0 )});
111
97
auto betaval = beta_.toComplexDouble ();
112
98
113
- struct CachedGraph : public mps :: MPSCachedGraph {
99
+ struct CachedGraph : public MPSCachedGraph {
114
100
CachedGraph (MPSGraph* graph) : MPSCachedGraph(graph) {}
115
101
MPSGraphTensor* selfTensor_ = nil ;
116
102
MPSGraphTensor* matMulVecTensor_ = nil ;
117
103
MPSGraphTensor* outputTensor_ = nil ;
118
104
};
119
- mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance ();
120
105
121
106
MPSStream* stream = at::mps::getCurrentMPSStream ();
122
107
Tensor matMulVec = at::mm (mat, vec.unsqueeze (1 )).squeeze (1 );
123
108
124
109
@autoreleasepool {
125
110
string key = " addmv_out_mps_impl" + getTensorsStringKey ({self, matMulVec}) + " :" + to_string (beta_.toDouble ()) +
126
111
" :" + to_string (alpha_.toDouble ());
127
- CachedGraph* cachedGraph = nil ;
128
- if (!cachedGraph) {
129
- mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph (key, ^mps::MPSCachedGraph*() {
130
- CachedGraph* newCachedGraph = nil ;
131
-
132
- @autoreleasepool {
133
- MPSGraph* mpsGraph = mps::make_mps_graph ();
134
- newCachedGraph = new CachedGraph (mpsGraph);
135
-
136
- MPSGraphTensor* matMulVecTensor = mps::mpsGraphRankedPlaceHolder (mpsGraph, matMulVec);
137
- MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder (mpsGraph, self);
138
-
139
- // Intermediates for beta and alpha
140
- MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha_.toDouble ()
141
- dataType: getMPSScalarType (mat.scalar_type ())];
142
-
143
- // Intermediates for multiplying by beta and alpha
144
- MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: matMulVecTensor
145
- secondaryTensor: alphaTensor
146
- name: @" MM/alpha*(mat@vec)" ];
147
- newCachedGraph->outputTensor_ = productTimesAlphaTensor;
148
-
149
- if (betaval != 0.0 ) {
150
- MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta_.toDouble ()
151
- dataType: getMPSScalarType (self .scalar_type ())];
152
-
153
- MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: selfTensor
154
- secondaryTensor: betaTensor
155
- name: @" MM/beta*input" ];
156
-
157
- MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor: productTimesAlphaTensor
158
- secondaryTensor: selfTimesBetaTensor
159
- name: @" MM/beta*input + alpha*(mat@vec)" ];
160
-
161
- newCachedGraph->outputTensor_ = outputTensor;
162
- }
163
-
164
- newCachedGraph->selfTensor_ = selfTensor;
165
- newCachedGraph->matMulVecTensor_ = matMulVecTensor;
166
- }
167
- return newCachedGraph;
168
- });
169
- cachedGraph = static_cast <CachedGraph*>(tmpCachedGraph);
170
- }
112
+ auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
113
+ MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder (mpsGraph, matMulVec);
114
+ MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder (mpsGraph, self);
115
+
116
+ // Intermediates for beta and alpha
117
+ MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha_.toDouble ()
118
+ dataType: getMPSScalarType (mat.scalar_type ())];
119
+
120
+ // Intermediates for multiplying by beta and alpha
121
+ MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: matMulVecTensor
122
+ secondaryTensor: alphaTensor
123
+ name: @" MM/alpha*(mat@vec)" ];
124
+ newCachedGraph->outputTensor_ = productTimesAlphaTensor;
125
+
126
+ if (betaval != 0.0 ) {
127
+ MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta_.toDouble ()
128
+ dataType: getMPSScalarType (self .scalar_type ())];
129
+
130
+ MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: selfTensor
131
+ secondaryTensor: betaTensor
132
+ name: @" MM/beta*input" ];
133
+
134
+ MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor: productTimesAlphaTensor
135
+ secondaryTensor: selfTimesBetaTensor
136
+ name: @" MM/beta*input + alpha*(mat@vec)" ];
137
+
138
+ newCachedGraph->outputTensor_ = outputTensor;
139
+ }
140
+
141
+ newCachedGraph->selfTensor_ = selfTensor;
142
+ newCachedGraph->matMulVecTensor_ = matMulVecTensor;
143
+ });
171
144
172
145
Placeholder matMulVecPlaceholder = Placeholder (cachedGraph->matMulVecTensor_ , matMulVec);
173
146
Placeholder outputPlaceholder = Placeholder (cachedGraph->outputTensor_ , result);
@@ -182,7 +155,7 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
182
155
NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results =
183
156
@{outputPlaceholder.getMPSGraphTensor () : outputPlaceholder.getMPSGraphTensorData ()};
184
157
185
- mps:: runMPSGraph (stream, cachedGraph->graph (), feeds, results);
158
+ runMPSGraph (stream, cachedGraph->graph (), feeds, results);
186
159
}
187
160
188
161
return result;
0 commit comments