Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions #5866

Open
wants to merge 20 commits into
base: master
Choose a base branch
from

Conversation

saipraveenb25
Copy link
Collaborator

@saipraveenb25 saipraveenb25 commented Dec 13, 2024

This is a large overhaul PR containing several fixes and overhauls:

Overhauled differentiable type lowering:

  • IRDifferentiableTypeDictionaryDecoration is deprecated in favor of IRDifferentiableTypeAnnotation instructions that can appear anywhere in the code (module/generic/function scopes) depending on the scope of the type. This makes it much simpler to lower differentiable run-time types.
  • IRDifferentiableTypeAnnotation(type, witness) instructions are generated by registering a 'type lowering hook' callback during IR lowering that checks the type against the list of differentiable types & generates an annotation for that type.
  • IRDifferentiableTypeAnnotation is hoistable, so it is emitted at the same parent as the type itself.
  • DifferentiableTypeConformanceContext, used by the auto-diff passes, scans the function being differentiated, as well as all parent scopes, to build the full list of differentiable types in the context of the function.

Overhaul differentiation of dynamic types

  • Previously, a limited version of dynamic-dispatch lowering was performed before auto-diff to make things simpler for auto-diff, but this meant that advanced uses (associated types) were not supported properly.
  • The overhauled system moves all dynamic-dispatch lowering to after the auto-diff passes.
  • Under the new system, differential pair of an interface typeIInterface is a pair of two interface types:IInterface & IDifferentiable, while the differential pairs of existential types & associated types are replaced with lookups of a synthesized differential pair type.
  • E.g. for a system like this:
interface IFoo : IDifferentiable
{
     associatedtype Bar : IDifferentiable;
     [Differentiable] Bar doThing(This);
};

struct A : IFoo 
{ 
     float x; 
     /* ... */
};

the derivative of doThing() is represented using associated types that are synthesized to represent the pairs.
In the above example, IFoo is automatically extended to add the new type requirements (in addition to the derivative method requirements, which is already an existing feature)

interface IFoo : IDifferentiable
{
     associatedtype Bar : IDifferentiable;
     
     // Synthesized type requirements.
    associatedtype DiffPair_This;
    associatedtype DiffPair_Bar;

    // Synthesized makePair requirements.
    DiffPair_This This_makePair(This, This.DIfferential);
    DiffPair_Bar Bar_makePair(Bar, Bar.DIfferential);

    // Synthesized getPrimal requirements.
    This This_getPrimal(DiffPair_This);
    Bar Bar_getPrimal(DiffPair_Bar);
    
    // Synthesized getDiff requirements.
    This.Differential This_getDiff(DiffPair_This);
    Bar.Differential Bar_getDiff(DiffPair_Bar);
     
     float doThing(This, Bar);
     
     // Synthesized diff requrirements (logic for this is already present
     // before this PR.
     void doThing_bwd(inout DiffPair_This, inout DiffPair_Bar, float);
     DiffPair<float> doThing_fwd(DiffPair_This, DiffPair_Bar);
};
  • After the interface is modified, all available concrete witness tables are extended with synthesized methods to meet these new requirements.

  • Note that this type & method synthesis happens on the IR side during auto-diff pair lowering for now.
    During auto-diff, these types are still represented as DifferentialPair<LookupWitness(table, assoc_type_key)> though it is not a proper type.

  • Ideally, the next step is to perform this on the AST side when creating the method requirements for the differentiable methods.

Additional fixes

There are several other minor fixes to bugs that appeared once the auto-diff step was moved to after the dynamic-dispatch lowering step.

  • UseGraph: During the primal availability step, primal instructions that are used in differential contexts are stored into an intermediate variables & loaded from later. Many instructions (such as ExtractExistentialType) cannot be stores, and so is considered a 'passthrough' inst that should be cloned in rather than stored. The previous system used a set of such 'use paths' called UseChains to facilitate this, but it fails in cases where there is a common node in these paths. This PR introduces UseGraph to represent & clone a full graph of instruction between a primal inst and its differential uses.
  • Support for lowering differential pairs of type-packs: We were previously lowering these as a struct of two type-packs, but this is not a valid use & will break specialization. We now lower these as type-packs of lowered pair types (with similar logic for value packs)

Fixes: ##5829

@saipraveenb25 saipraveenb25 requested a review from a team as a code owner December 13, 2024 23:02
@saipraveenb25 saipraveenb25 added the pr: non-breaking PRs without breaking changes label Dec 13, 2024
@saipraveenb25 saipraveenb25 changed the title [Auto-diff] [Do-Not-Merge] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions Dec 16, 2024
{
SpecializationOptions specOptions;
specOptions.lowerWitnessLookups = true;
specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performing lowerWitnessLookups can open up new opportunities for specializations, so we need to run that specialization-optimization loop again.

This is really calling for cleaning up the specialization pass to be just a peephole optimization pass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe specializeModule() already runs the witness lowering in a loop with the rest of the specialization pass?

@@ -827,17 +860,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
}
}

{
// --WORKAROUND--
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put all these in a function, and do early returns instead of nested if.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pr: non-breaking PRs without breaking changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants