6
6
func sin( _ x: Float ) -> Float {
7
7
return x // dummy implementation
8
8
}
9
-
10
9
@differentiating ( sin) // ok
11
- func vjpSin ( x: Float ) -> ( value: Float , pullback : ( Float ) -> Float ) {
10
+ func jvpSin ( x: @ nondiff Float ) -> ( value: Float , differential : ( Float ) -> ( Float ) ) {
12
11
return ( x, { $0 } )
13
12
}
14
13
@differentiating ( sin, wrt: x) // ok
15
14
func vjpSinExplicitWrt( x: Float ) -> ( value: Float , pullback: ( Float ) -> Float ) {
16
15
return ( x, { $0 } )
17
16
}
18
- @differentiating ( sin) // ok
19
- func jvpSin( x: @nondiff Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
17
+
18
+ // expected-error @+1 {{a derivative already exists for 'sin'}}
19
+ @differentiating ( sin)
20
+ func vjpDuplicate( x: Float ) -> ( value: Float , pullback: ( Float ) -> Float ) {
20
21
return ( x, { $0 } )
21
22
}
22
-
23
23
// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.CotangentVector) -> T.CotangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}}
24
24
@differentiating ( sin)
25
25
func jvpSinResultInvalid( x: @nondiff Float ) -> Float {
@@ -46,10 +46,6 @@ func generic<T : Differentiable>(_ x: T, _ y: T) -> T {
46
46
return x
47
47
}
48
48
@differentiating ( generic) // ok
49
- func vjpGeneric< T : Differentiable > ( x: T , y: T ) -> ( value: T , pullback: ( T . CotangentVector ) -> ( T . CotangentVector , T . CotangentVector ) ) {
50
- return ( x, { ( $0, $0) } )
51
- }
52
- @differentiating ( generic) // ok
53
49
func jvpGeneric< T : Differentiable > ( x: T , y: T ) -> ( value: T , differential: ( T . TangentVector , T . TangentVector ) -> T . TangentVector ) {
54
50
return ( x, { $0 + $1 } )
55
51
}
@@ -130,10 +126,6 @@ func vjpFoo<T : AdditiveArithmetic & Differentiable>(_ x: T) -> (value: T, pullb
130
126
return ( x, { $0 } )
131
127
}
132
128
@differentiating ( foo)
133
- func vjpFoo< T : FloatingPoint & Differentiable > ( _ x: T ) -> ( value: T , pullback: ( T . CotangentVector ) -> ( T . CotangentVector ) ) {
134
- return ( x, { $0 } )
135
- }
136
- @differentiating ( foo)
137
129
func vjpFooExtraGenericRequirements< T : FloatingPoint & Differentiable & BinaryInteger > ( _ x: T ) -> ( value: T , pullback: ( T ) -> ( T ) ) where T == T . CotangentVector {
138
130
return ( x, { $0 } )
139
131
}
@@ -177,26 +169,28 @@ extension AdditiveArithmetic where Self : Differentiable, Self == Self.Cotangent
177
169
protocol InstanceMethod : Differentiable {
178
170
// expected-note @+1 {{'foo' defined here}}
179
171
func foo( _ x: Self ) -> Self
172
+ func foo2( _ x: Self ) -> Self
180
173
// expected-note @+1 {{'bar' defined here}}
181
174
func bar< T : Differentiable > ( _ x: T ) -> Self
175
+ func bar2< T : Differentiable > ( _ x: T ) -> Self
182
176
}
183
177
184
178
extension InstanceMethod {
185
179
// If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter.
186
180
// expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
187
181
// expected-note @+2 {{'pullback' does not have expected type '(Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)'}}
188
182
@differentiating ( foo)
189
- func vjpFoo( x: Self ) -> ( value: Self , pullback: ( Self . CotangentVector ) -> Self . CotangentVector ) {
183
+ func vjpFoo( x: Self ) -> ( value: Self , pullback: ( CotangentVector ) -> CotangentVector ) {
190
184
return ( x, { $0 } )
191
185
}
192
186
193
187
@differentiating ( foo)
194
- func vjpFoo ( x: Self ) -> ( value: Self , pullback : ( Self . CotangentVector ) -> ( Self . CotangentVector , Self . CotangentVector ) ) {
195
- return ( x, { ( $0 , $0 ) } )
188
+ func jvpFoo ( x: Self ) -> ( value: Self , differential : ( TangentVector , TangentVector ) -> ( TangentVector ) ) {
189
+ return ( x, { $0 + $1 } )
196
190
}
197
191
198
192
@differentiating ( foo, wrt: ( self , x) )
199
- func vjpFooWrt( x: Self ) -> ( value: Self , pullback: ( Self . CotangentVector ) -> ( Self . CotangentVector , Self . CotangentVector ) ) {
193
+ func vjpFooWrt( x: Self ) -> ( value: Self , pullback: ( CotangentVector ) -> ( CotangentVector , CotangentVector ) ) {
200
194
return ( x, { ( $0, $0) } )
201
195
}
202
196
}
@@ -205,43 +199,38 @@ extension InstanceMethod {
205
199
// expected-error @+2 {{function result's 'pullback' type does not match 'bar'}}
206
200
// expected-note @+2 {{'pullback' does not have expected type '(Self.CotangentVector) -> (Self.CotangentVector, T.CotangentVector)'}}
207
201
@differentiating ( bar)
208
- func vjpBar< T : Differentiable > ( _ x: T ) -> ( value: Self , pullback: ( Self . CotangentVector ) -> T . CotangentVector ) {
202
+ func vjpBar< T : Differentiable > ( _ x: T ) -> ( value: Self , pullback: ( CotangentVector ) -> T . CotangentVector ) {
209
203
return ( self , { _ in . zero } )
210
204
}
211
205
212
206
@differentiating ( bar)
213
- func vjpBar< T : Differentiable > ( _ x: T ) -> ( value: Self , pullback: ( Self . CotangentVector ) -> ( Self . CotangentVector , T . CotangentVector ) ) {
207
+ func vjpBar< T : Differentiable > ( _ x: T ) -> ( value: Self , pullback: ( CotangentVector ) -> ( CotangentVector , T . CotangentVector ) ) {
214
208
return ( self , { ( $0, . zero) } )
215
209
}
216
210
217
- @differentiating ( bar)
218
- func jvpBar< T : Differentiable > ( _ x: T ) -> ( value: Self , differential: ( Self . TangentVector , T . TangentVector ) -> Self . TangentVector ) {
219
- return ( self , { dself, dx in dself } )
220
- }
221
-
222
211
@differentiating ( bar, wrt: ( self , x) )
223
- func jvpBarWrt< T : Differentiable > ( _ x: T ) -> ( value: Self , differential: ( Self . TangentVector , T . TangentVector ) -> Self . TangentVector ) {
212
+ func jvpBarWrt< T : Differentiable > ( _ x: T ) -> ( value: Self , differential: ( TangentVector , T . TangentVector ) -> TangentVector ) {
224
213
return ( self , { dself, dx in dself } )
225
214
}
226
215
}
227
216
228
217
extension InstanceMethod where Self == Self . TangentVector , Self == Self . CotangentVector {
229
- @differentiating ( foo )
218
+ @differentiating ( foo2 )
230
219
func vjpFooExtraRequirements( x: Self ) -> ( value: Self , pullback: ( Self ) -> ( Self , Self ) ) {
231
220
return ( x, { ( $0, $0) } )
232
221
}
233
222
234
- @differentiating ( foo )
223
+ @differentiating ( foo2 )
235
224
func jvpFooExtraRequirements( x: Self ) -> ( value: Self , differential: ( Self , Self ) -> ( Self ) ) {
236
225
return ( x, { $0 + $1 } )
237
226
}
238
227
239
- @differentiating ( bar )
228
+ @differentiating ( bar2 )
240
229
func vjpBarExtraRequirements< T : Differentiable > ( x: T ) -> ( value: Self , pullback: ( Self ) -> ( Self , T . CotangentVector ) ) {
241
230
return ( self , { ( $0, . zero) } )
242
231
}
243
232
244
- @differentiating ( bar )
233
+ @differentiating ( bar2 )
245
234
func jvpBarExtraRequirements< T : Differentiable > ( _ x: T ) -> ( value: Self , differential: ( Self , T . TangentVector ) -> Self ) {
246
235
return ( self , { dself, dx in dself } )
247
236
}
0 commit comments