Skip to content

[AutoDiff] Run AutoDiff closure spec pass for all VJPs #81548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
(function: Function, context: FunctionPassContext) in

guard !function.isDefinedExternally,
function.isAutodiffVJP,
function.blocks.singleElement != nil else {
function.isAutodiffVJP else {
return
}

Expand All @@ -132,26 +131,24 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
repeat {
// TODO: Names here are pretty misleading. We are looking for a place where
// the pullback closure is created (so for `partial_apply` instruction).
var callSites = gatherCallSites(in: function, context)
guard !callSites.isEmpty else {
return
let callSiteOpt = gatherCallSite(in: function, context)
if callSiteOpt == nil {
break
}

for callSite in callSites {
var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context)
let callSite = callSiteOpt!
var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context)

if !alreadyExists {
context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee)
}

rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)
if !alreadyExists {
context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee)
}

var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in
callSite.closureArgDescriptors
rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)

var deadClosures = InstructionWorklist(context)
callSite.closureArgDescriptors
.map { $0.closure }
.forEach { deadClosures.pushIfNotVisited($0) }
}

defer {
deadClosures.deinitialize()
Expand All @@ -176,7 +173,7 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special

private let specializationLevelLimit = 2

private func gatherCallSites(in caller: Function, _ context: FunctionPassContext) -> [CallSite] {
private func gatherCallSite(in caller: Function, _ context: FunctionPassContext) -> CallSite? {
/// __Root__ closures created via `partial_apply` or `thin_to_thick_function` may be converted and reabstracted
/// before finally being used at an apply site. We do not want to handle these intermediate closures separately
/// as they are handled and cloned into the specialized function as part of the root closures. Therefore, we keep
Expand Down Expand Up @@ -208,18 +205,18 @@ private func gatherCallSites(in caller: Function, _ context: FunctionPassContext
convertedAndReabstractedClosures.deinitialize()
}

var callSiteMap = CallSiteMap()
var callSiteOpt = CallSite?(nil)

for inst in caller.instructions {
if !convertedAndReabstractedClosures.contains(inst),
let rootClosure = inst.asSupportedClosure
{
updateCallSites(for: rootClosure, in: &callSiteMap,
convertedAndReabstractedClosures: &convertedAndReabstractedClosures, context)
updateCallSite(for: rootClosure, in: &callSiteOpt,
convertedAndReabstractedClosures: &convertedAndReabstractedClosures, context)
}
}

return callSiteMap.callSites
return callSiteOpt
}

private func getOrCreateSpecializedFunction(basedOn callSite: CallSite, _ context: FunctionPassContext)
Expand Down Expand Up @@ -307,8 +304,8 @@ private func rewriteApplyInstruction(using specializedCallee: Function, callSite

// ===================== Utility functions and extensions ===================== //

private func updateCallSites(for rootClosure: SingleValueInstruction, in callSiteMap: inout CallSiteMap,
convertedAndReabstractedClosures: inout InstructionSet, _ context: FunctionPassContext) {
private func updateCallSite(for rootClosure: SingleValueInstruction, in callSiteOpt: inout CallSite?,
convertedAndReabstractedClosures: inout InstructionSet, _ context: FunctionPassContext) {
var rootClosurePossibleLiveRange = InstructionRange(begin: rootClosure, context)
defer {
rootClosurePossibleLiveRange.deinitialize()
Expand Down Expand Up @@ -343,14 +340,18 @@ private func updateCallSites(for rootClosure: SingleValueInstruction, in callSit
}

let intermediateClosureArgDescriptorData =
handleApplies(for: rootClosure, callSiteMap: &callSiteMap, rootClosureApplies: &rootClosureApplies,
handleApplies(for: rootClosure, callSiteOpt: &callSiteOpt, rootClosureApplies: &rootClosureApplies,
rootClosurePossibleLiveRange: &rootClosurePossibleLiveRange,
convertedAndReabstractedClosures: &convertedAndReabstractedClosures,
haveUsedReabstraction: haveUsedReabstraction, context)

finalizeCallSites(for: rootClosure, in: &callSiteMap,
rootClosurePossibleLiveRange: rootClosurePossibleLiveRange,
intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
if callSiteOpt == nil {
return
}

finalizeCallSite(for: rootClosure, in: &callSiteOpt,
rootClosurePossibleLiveRange: rootClosurePossibleLiveRange,
intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
}

/// Handles all non-apply direct and transitive uses of `rootClosure`.
Expand Down Expand Up @@ -479,7 +480,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction,

private typealias IntermediateClosureArgDescriptorDatum = (applySite: SingleValueInstruction, closureArgIndex: Int, paramInfo: ParameterInfo)

private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap: inout CallSiteMap,
private func handleApplies(for rootClosure: SingleValueInstruction, callSiteOpt: inout CallSite?,
rootClosureApplies: inout OperandWorklist,
rootClosurePossibleLiveRange: inout InstructionRange,
convertedAndReabstractedClosures: inout InstructionSet, haveUsedReabstraction: Bool,
Expand Down Expand Up @@ -586,8 +587,10 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
convertedAndReabstractedClosures: &convertedAndReabstractedClosures)
}

if callSiteMap[pai] == nil {
callSiteMap.insert(key: pai, value: CallSite(applySite: pai))
if callSiteOpt == nil {
callSiteOpt = CallSite(applySite: pai)
} else {
assert(callSiteOpt!.applySite == pai)
}

intermediateClosureArgDescriptorData
Expand All @@ -599,21 +602,21 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:

/// Finalizes the call sites for a given root closure by adding a corresponding `ClosureArgDescriptor`
/// to all call sites where the closure is ultimately passed as an argument.
private func finalizeCallSites(for rootClosure: SingleValueInstruction, in callSiteMap: inout CallSiteMap,
rootClosurePossibleLiveRange: InstructionRange,
intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum],
_ context: FunctionPassContext)
{
private func finalizeCallSite(for rootClosure: SingleValueInstruction, in callSiteOpt: inout CallSite?,
rootClosurePossibleLiveRange: InstructionRange,
intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum],
_ context: FunctionPassContext) {
assert(callSiteOpt != nil)

let closureInfo = ClosureInfo(closure: rootClosure, lifetimeFrontier: Array(rootClosurePossibleLiveRange.ends))

for (applySite, closureArgumentIndex, parameterInfo) in intermediateClosureArgDescriptorData {
guard var callSite = callSiteMap[applySite] else {
if callSiteOpt!.applySite != applySite {
fatalError("While finalizing call sites, call site descriptor not found for call site: \(applySite)!")
}
let closureArgDesc = ClosureArgDescriptor(closureInfo: closureInfo, closureArgumentIndex: closureArgumentIndex,
parameterInfo: parameterInfo)
callSite.appendClosureArgDescriptor(closureArgDesc)
callSiteMap.update(key: applySite, value: callSite)
callSiteOpt!.appendClosureArgDescriptor(closureArgDesc)
}
}

Expand Down Expand Up @@ -1202,14 +1205,6 @@ private struct OrderedDict<Key: Hashable, Value> {
}
}

private typealias CallSiteMap = OrderedDict<SingleValueInstruction, CallSite>

private extension CallSiteMap {
var callSites: [CallSite] {
Array(self.values)
}
}

/// Represents all the information required to represent a closure in isolation, i.e., outside of a callsite context
/// where the closure may be getting passed as an argument.
///
Expand Down Expand Up @@ -1339,42 +1334,35 @@ private struct CallSite {

// ===================== Unit tests ===================== //

let gatherCallSitesTest = FunctionTest("closure_specialize_gather_call_sites") { function, arguments, context in
let gatherCallSiteTest = FunctionTest("closure_specialize_gather_call_site") { function, arguments, context in
print("Specializing closures in function: \(function.name)")
print("===============================================")
var callSites = gatherCallSites(in: function, context)

callSites.forEach { callSite in
print("PartialApply call site: \(callSite.applySite)")
print("Passed in closures: ")
for index in callSite.closureArgDescriptors.indices {
var closureArgDescriptor = callSite.closureArgDescriptors[index]
print("\(index+1). \(closureArgDescriptor.closureInfo.closure)")
}
let callSite = gatherCallSite(in: function, context)!
print("PartialApply call site: \(callSite.applySite)")
print("Passed in closures: ")
for index in callSite.closureArgDescriptors.indices {
var closureArgDescriptor = callSite.closureArgDescriptors[index]
print("\(index+1). \(closureArgDescriptor.closureInfo.closure)")
}
print("\n")
}

let specializedFunctionSignatureAndBodyTest = FunctionTest(
"closure_specialize_specialized_function_signature_and_body") { function, arguments, context in

var callSites = gatherCallSites(in: function, context)
let callSite = gatherCallSite(in: function, context)!

for callSite in callSites {
let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context)
print("Generated specialized function: \(specializedFunction.name)")
print("\(specializedFunction)\n")
}
let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context)
print("Generated specialized function: \(specializedFunction.name)")
print("\(specializedFunction)\n")
}

let rewrittenCallerBodyTest = FunctionTest("closure_specialize_rewritten_caller_body") { function, arguments, context in
var callSites = gatherCallSites(in: function, context)
let callSite = gatherCallSite(in: function, context)!

for callSite in callSites {
let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context)
rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)
let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context)
rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)

print("Rewritten caller body for: \(function.name):")
print("\(function)\n")
}
print("Rewritten caller body for: \(function.name):")
print("\(function)\n")
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public func registerOptimizerTests() {
enclosingValuesTest,
forwardingDefUseTest,
forwardingUseDefTest,
gatherCallSitesTest,
gatherCallSiteTest,
interiorLivenessTest,
lifetimeDependenceRootTest,
lifetimeDependenceScopeTest,
Expand Down
Loading