Skip to content

Commit 0a9aa70

Browse files
authored
[AutoDiff] Fix @differentiating derivative registration. (swiftlang#23418)
When type-checking `@differentiating` attribute and registering a derivative in a `@differentiable` attribute on the original function, emit an error if the `@differentiable` attribute already has either a JVP/VJP name or function. Add test, update affected existing tests.
1 parent c8ba90a commit 0a9aa70

File tree

2 files changed

+27
-35
lines changed

2 files changed

+27
-35
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,7 +2802,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
28022802

28032803
auto insertion =
28042804
ctx.DifferentiableAttrs.try_emplace({D, checkedWrtParamIndices}, attr);
2805-
// Differentiable attributes are uniqued by their parameter indices.
2805+
// `@differentiable` attributes are uniqued by their parameter indices.
28062806
// Reject duplicate attributes for the same decl and parameter indices pair.
28072807
if (!insertion.second && insertion.first->getSecond() != attr) {
28082808
diagnoseAndRemoveAttr(attr, diag::differentiable_attr_duplicate);
@@ -3249,29 +3249,32 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
32493249
None, None, derivativeRequirements);
32503250
auto insertion = ctx.DifferentiableAttrs.try_emplace(
32513251
{originalFn, checkedWrtParamIndices}, da);
3252-
// Differentiable attributes are uniqued by their parameter indices.
3252+
// `@differentiable` attributes are uniqued by their parameter indices.
32533253
// Reject duplicate attributes for the same decl and parameter indices pair.
32543254
if (!insertion.second && insertion.first->getSecond() != da) {
32553255
diagnoseAndRemoveAttr(da, diag::differentiable_attr_duplicate);
32563256
return;
32573257
}
32583258
originalFn->getAttrs().add(da);
32593259
}
3260+
// Check if the `@differentiable` attribute already has a registered
3261+
// derivative. If so, emit an error on the `@differentiating` attribute.
3262+
// Otherwise, register the derivative in the `@differentiable` attribute.
32603263
switch (kind) {
32613264
case AutoDiffAssociatedFunctionKind::JVP:
3262-
if (auto jvp = da->getJVP()) {
3265+
if (da->getJVP() || da->getJVPFunction()) {
32633266
diagnoseAndRemoveAttr(
32643267
attr, diag::differentiating_attr_original_already_has_derivative,
3265-
jvp->Name);
3268+
originalFn->getFullName());
32663269
return;
32673270
}
32683271
da->setJVPFunction(derivative);
32693272
break;
32703273
case AutoDiffAssociatedFunctionKind::VJP:
3271-
if (auto vjp = da->getVJP()) {
3274+
if (da->getVJP() || da->getVJPFunction()) {
32723275
diagnoseAndRemoveAttr(
32733276
attr, diag::differentiating_attr_original_already_has_derivative,
3274-
vjp->Name);
3277+
originalFn->getFullName());
32753278
return;
32763279
}
32773280
da->setVJPFunction(derivative);

test/AutoDiff/differentiating_attr_type_checking.swift

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@
66
func sin(_ x: Float) -> Float {
77
return x // dummy implementation
88
}
9-
109
@differentiating(sin) // ok
11-
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
10+
func jvpSin(x: @nondiff Float) -> (value: Float, differential: (Float) -> (Float)) {
1211
return (x, { $0 })
1312
}
1413
@differentiating(sin, wrt: x) // ok
1514
func vjpSinExplicitWrt(x: Float) -> (value: Float, pullback: (Float) -> Float) {
1615
return (x, { $0 })
1716
}
18-
@differentiating(sin) // ok
19-
func jvpSin(x: @nondiff Float) -> (value: Float, differential: (Float) -> (Float)) {
17+
18+
// expected-error @+1 {{a derivative already exists for 'sin'}}
19+
@differentiating(sin)
20+
func vjpDuplicate(x: Float) -> (value: Float, pullback: (Float) -> Float) {
2021
return (x, { $0 })
2122
}
22-
2323
// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.CotangentVector) -> T.CotangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}}
2424
@differentiating(sin)
2525
func jvpSinResultInvalid(x: @nondiff Float) -> Float {
@@ -46,10 +46,6 @@ func generic<T : Differentiable>(_ x: T, _ y: T) -> T {
4646
return x
4747
}
4848
@differentiating(generic) // ok
49-
func vjpGeneric<T : Differentiable>(x: T, y: T) -> (value: T, pullback: (T.CotangentVector) -> (T.CotangentVector, T.CotangentVector)) {
50-
return (x, { ($0, $0) })
51-
}
52-
@differentiating(generic) // ok
5349
func jvpGeneric<T : Differentiable>(x: T, y: T) -> (value: T, differential: (T.TangentVector, T.TangentVector) -> T.TangentVector) {
5450
return (x, { $0 + $1 })
5551
}
@@ -130,10 +126,6 @@ func vjpFoo<T : AdditiveArithmetic & Differentiable>(_ x: T) -> (value: T, pullb
130126
return (x, { $0 })
131127
}
132128
@differentiating(foo)
133-
func vjpFoo<T : FloatingPoint & Differentiable>(_ x: T) -> (value: T, pullback: (T.CotangentVector) -> (T.CotangentVector)) {
134-
return (x, { $0 })
135-
}
136-
@differentiating(foo)
137129
func vjpFooExtraGenericRequirements<T : FloatingPoint & Differentiable & BinaryInteger>(_ x: T) -> (value: T, pullback: (T) -> (T)) where T == T.CotangentVector {
138130
return (x, { $0 })
139131
}
@@ -177,26 +169,28 @@ extension AdditiveArithmetic where Self : Differentiable, Self == Self.Cotangent
177169
protocol InstanceMethod : Differentiable {
178170
// expected-note @+1 {{'foo' defined here}}
179171
func foo(_ x: Self) -> Self
172+
func foo2(_ x: Self) -> Self
180173
// expected-note @+1 {{'bar' defined here}}
181174
func bar<T : Differentiable>(_ x: T) -> Self
175+
func bar2<T : Differentiable>(_ x: T) -> Self
182176
}
183177

184178
extension InstanceMethod {
185179
// If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter.
186180
// expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
187181
// expected-note @+2 {{'pullback' does not have expected type '(Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)'}}
188182
@differentiating(foo)
189-
func vjpFoo(x: Self) -> (value: Self, pullback: (Self.CotangentVector) -> Self.CotangentVector) {
183+
func vjpFoo(x: Self) -> (value: Self, pullback: (CotangentVector) -> CotangentVector) {
190184
return (x, { $0 })
191185
}
192186

193187
@differentiating(foo)
194-
func vjpFoo(x: Self) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) {
195-
return (x, { ($0, $0) })
188+
func jvpFoo(x: Self) -> (value: Self, differential: (TangentVector, TangentVector) -> (TangentVector)) {
189+
return (x, { $0 + $1 })
196190
}
197191

198192
@differentiating(foo, wrt: (self, x))
199-
func vjpFooWrt(x: Self) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) {
193+
func vjpFooWrt(x: Self) -> (value: Self, pullback: (CotangentVector) -> (CotangentVector, CotangentVector)) {
200194
return (x, { ($0, $0) })
201195
}
202196
}
@@ -205,43 +199,38 @@ extension InstanceMethod {
205199
// expected-error @+2 {{function result's 'pullback' type does not match 'bar'}}
206200
// expected-note @+2 {{'pullback' does not have expected type '(Self.CotangentVector) -> (Self.CotangentVector, T.CotangentVector)'}}
207201
@differentiating(bar)
208-
func vjpBar<T : Differentiable>(_ x: T) -> (value: Self, pullback: (Self.CotangentVector) -> T.CotangentVector) {
202+
func vjpBar<T : Differentiable>(_ x: T) -> (value: Self, pullback: (CotangentVector) -> T.CotangentVector) {
209203
return (self, { _ in .zero })
210204
}
211205

212206
@differentiating(bar)
213-
func vjpBar<T : Differentiable>(_ x: T) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, T.CotangentVector)) {
207+
func vjpBar<T : Differentiable>(_ x: T) -> (value: Self, pullback: (CotangentVector) -> (CotangentVector, T.CotangentVector)) {
214208
return (self, { ($0, .zero) })
215209
}
216210

217-
@differentiating(bar)
218-
func jvpBar<T : Differentiable>(_ x: T) -> (value: Self, differential: (Self.TangentVector, T.TangentVector) -> Self.TangentVector) {
219-
return (self, { dself, dx in dself })
220-
}
221-
222211
@differentiating(bar, wrt: (self, x))
223-
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (Self.TangentVector, T.TangentVector) -> Self.TangentVector) {
212+
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T.TangentVector) -> TangentVector) {
224213
return (self, { dself, dx in dself })
225214
}
226215
}
227216

228217
extension InstanceMethod where Self == Self.TangentVector, Self == Self.CotangentVector {
229-
@differentiating(foo)
218+
@differentiating(foo2)
230219
func vjpFooExtraRequirements(x: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
231220
return (x, { ($0, $0) })
232221
}
233222

234-
@differentiating(foo)
223+
@differentiating(foo2)
235224
func jvpFooExtraRequirements(x: Self) -> (value: Self, differential: (Self, Self) -> (Self)) {
236225
return (x, { $0 + $1 })
237226
}
238227

239-
@differentiating(bar)
228+
@differentiating(bar2)
240229
func vjpBarExtraRequirements<T : Differentiable>(x: T) -> (value: Self, pullback: (Self) -> (Self, T.CotangentVector)) {
241230
return (self, { ($0, .zero) })
242231
}
243232

244-
@differentiating(bar)
233+
@differentiating(bar2)
245234
func jvpBarExtraRequirements<T : Differentiable>(_ x: T) -> (value: Self, differential: (Self, T.TangentVector) -> Self) {
246235
return (self, { dself, dx in dself })
247236
}

0 commit comments

Comments
 (0)