Skip to content

[AutoDiff] Run AutoDiff closure spec pass for all VJPs #81548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public func registerOptimizerTests() {
enclosingValuesTest,
forwardingDefUseTest,
forwardingUseDefTest,
gatherCallSitesTest,
getPullbackClosureInfoTest,
interiorLivenessTest,
lifetimeDependenceRootTest,
lifetimeDependenceScopeTest,
Expand Down
509 changes: 0 additions & 509 deletions test/AutoDiff/SILOptimizer/closure_specialization.sil

This file was deleted.

689 changes: 689 additions & 0 deletions test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil

Large diffs are not rendered by default.

181 changes: 181 additions & 0 deletions test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte1.sil
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/// Multi basic block VJP, pullback not accepting branch tracing enum argument.

// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK
// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK

// REQUIRES: swift_in_compiler

sil_stage canonical

import Builtin
import Swift
import SwiftShims

import _Differentiation

/// This SIL corresponds to the following Swift:
///
/// struct Class: Differentiable {
/// var stored: Float
/// var optional: Float?
///
/// init(stored: Float, optional: Float?) {
/// self.stored = stored
/// self.optional = optional
/// }
///
/// @differentiable(reverse)
/// func method() -> Float {
/// let c: Class
/// do {
/// let tmp = Class(stored: 1 * stored, optional: optional)
/// let tuple = (tmp, tmp)
/// c = tuple.0
/// }
/// if let x = c.optional {
/// return x * c.stored
/// }
/// return 1 * c.stored
/// }
/// }
///
/// @differentiable(reverse)
/// func methodWrapper(_ x: Class) -> Float {
/// x.method()
/// }

struct Class : Differentiable {
@_hasStorage var stored: Float { get set }
@_hasStorage @_hasInitialValue var optional: Float? { get set }
init(stored: Float, optional: Float?)
@differentiable(reverse, wrt: self)
func method() -> Float
struct TangentVector : AdditiveArithmetic, Differentiable {
@_hasStorage var stored: Float { get set }
@_hasStorage var optional: Optional<Float>.TangentVector { get set }
static func + (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector
static func - (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector
typealias TangentVector = Class.TangentVector
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Class.TangentVector, _ b: Class.TangentVector) -> Bool
init(stored: Float, optional: Optional<Float>.TangentVector)
static var zero: Class.TangentVector { get }
}
mutating func move(by offset: Class.TangentVector)
}

enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 {
case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)))
}

enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 {
case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)))
}

enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 {
case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float))
case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float)))
}

sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
sil [transparent] [thunk] @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
sil @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector

// pullback of methodWrapper(_:)
sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Class.TangentVector):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Class.TangentVector
strong_release %1
return %2
} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJpSpSr'

// reverse-mode derivative of methodWrapper(_:)
sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
bb0(%0 : $Class):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test13methodWrapperySfAA5ClassVFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A36:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
// TRUNNER-NEXT: Passed in closures:
// TRUNNER-NEXT: 1. %[[#A36]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER-EMPTY:

//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector {
// CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
// CHECK: %[[#B2:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER: %[[#B4:]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> Class.TangentVector
// COMBINE-NOT: partial_apply
// COMBINE: %[[#B4:]] = apply %[[#B2]](%0, %1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER: strong_release %[[#B3]] : $@callee_guaranteed (Float) -> Class.TangentVector
// CHECK: return %[[#B4]]

//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test13methodWrapperySfAA5ClassVFTJrSpSr:
// CHECK: sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
// CHECK: bb3(%[[#C33:]] : $Float, %[[#C34:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
// TRUNNER: %[[#C35:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER: %[[#C37:]] = partial_apply [callee_guaranteed] %[[#C35]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER: %[[#C38:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
// COMBINE-NOT: function_ref @$s4test5ClassV6methodSfyFTJpSpSr
// COMBINE-NOT: partial_apply
// COMBINE-NOT: function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr
// CHECK: %[[#C39:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// CHECK: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// TRUNNER: release_value %[[#C37]] : $@callee_guaranteed (Float) -> Class.TangentVector
// CHECK: %[[#C42:]] = tuple (%[[#C33]] : $Float, %[[#C40]] : $@callee_guaranteed (Float) -> Class.TangentVector)
// CHECK: return %[[#C42]]

%3 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1
%4 = struct $Float (%3)
%5 = struct_extract %0, #Class.stored
%6 = struct_extract %5, #Float._value
%7 = builtin "fmul_FPIEEE32"(%3, %6) : $Builtin.FPIEEE32
%8 = struct $Float (%7)
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%9 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%10 = partial_apply [callee_guaranteed] %9(%5, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%11 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%12 = partial_apply [callee_guaranteed] %11(%10) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%13 = struct_extract %0, #Class.optional
// function_ref pullback of Class.init(stored:optional:)
%26 = function_ref @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
%27 = thin_to_thick_function %26 to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
%28 = tuple (%12, %27)
switch_enum %13, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2

bb1(%30 : $Float):
%31 = enum $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %28
%33 = struct_extract %30, #Float._value
%34 = builtin "fmul_FPIEEE32"(%33, %7) : $Builtin.FPIEEE32
%35 = struct $Float (%34)
%36 = partial_apply [callee_guaranteed] %9(%8, %30) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%37 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%31, %36)
%38 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %37
br bb3(%35, %38)

bb2:
%40 = enum $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %28
%41 = builtin "fmul_FPIEEE32"(%3, %7) : $Builtin.FPIEEE32
%42 = struct $Float (%41)
%43 = partial_apply [callee_guaranteed] %9(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%44 = partial_apply [callee_guaranteed] %11(%43) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%45 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%40, %44)
%46 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %45
br bb3(%42, %46)

bb3(%48 : $Float, %49 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
// function_ref pullback of Class.method()
%50 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
%51 = partial_apply [callee_guaranteed] %50(%49) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
// function_ref pullback of methodWrapper(_:)
%52 = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
%53 = partial_apply [callee_guaranteed] %52(%51) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
%54 = tuple (%48, %53)
return %54
} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJrSpSr'
Loading