Skip to content

Commit

Permalink
[IR][inliner] Added a customization option for inliner
Browse files Browse the repository at this point in the history
... to use it in a specific K/N pre-codegen inliner

 #KT-67480
  • Loading branch information
homuroll authored and Space Team committed Sep 16, 2024
1 parent 9d52924 commit 874c0a1
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ import org.jetbrains.kotlin.ir.backend.js.lower.inline.*
import org.jetbrains.kotlin.ir.backend.js.utils.compileSuspendAsJsGenerator
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import org.jetbrains.kotlin.ir.inline.DumpSyntheticAccessors
import org.jetbrains.kotlin.ir.inline.FunctionInlining
import org.jetbrains.kotlin.ir.inline.SyntheticAccessorLowering
import org.jetbrains.kotlin.ir.inline.isConsideredAsPrivateForInlining
import org.jetbrains.kotlin.ir.inline.*
import org.jetbrains.kotlin.ir.interpreter.IrInterpreterConfiguration
import org.jetbrains.kotlin.platform.js.JsPlatforms

Expand Down Expand Up @@ -264,7 +261,7 @@ private val inlineOnlyPrivateFunctionsPhase = makeIrModulePhase(
{ context: JsIrBackendContext ->
FunctionInlining(
context,
JsInlineFunctionResolver(context, inlineOnlyPrivateFunctions = true),
JsInlineFunctionResolver(context, inlineMode = InlineMode.PRIVATE_INLINE_FUNCTIONS),
produceOuterThisFields = false,
insertAdditionalImplicitCasts = true,
)
Expand Down Expand Up @@ -313,7 +310,7 @@ private val inlineAllFunctionsPhase = makeIrModulePhase(
{ context: JsIrBackendContext ->
FunctionInlining(
context,
JsInlineFunctionResolver(context, inlineOnlyPrivateFunctions = false),
JsInlineFunctionResolver(context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS),
produceOuterThisFields = false,
insertAdditionalImplicitCasts = true,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ package org.jetbrains.kotlin.ir.backend.js.lower
import org.jetbrains.kotlin.ir.backend.js.JsIrBackendContext
import org.jetbrains.kotlin.ir.backend.js.lower.inline.JsInlineFunctionResolver
import org.jetbrains.kotlin.ir.inline.CommonInlineCallableReferenceToLambdaPhase
import org.jetbrains.kotlin.ir.inline.InlineMode

internal class JsInlineCallableReferenceToLambdaPhase(context: JsIrBackendContext) : CommonInlineCallableReferenceToLambdaPhase(
context, JsInlineFunctionResolver(context, inlineOnlyPrivateFunctions = false)
context, JsInlineFunctionResolver(context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS)
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlin.ir.backend.js.JsIrBackendContext
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.inline.InlineFunctionResolverReplacingCoroutineIntrinsics
import org.jetbrains.kotlin.ir.inline.InlineMode
import org.jetbrains.kotlin.ir.inline.isConsideredAsPrivateForInlining
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.util.deepCopyWithSymbols
Expand All @@ -33,8 +34,8 @@ internal class SaveInlineFunctionsBeforeInlining(

internal class JsInlineFunctionResolver(
context: JsIrBackendContext,
override val inlineOnlyPrivateFunctions: Boolean
) : InlineFunctionResolverReplacingCoroutineIntrinsics<JsIrBackendContext>(context) {
inlineMode: InlineMode,
) : InlineFunctionResolverReplacingCoroutineIntrinsics<JsIrBackendContext>(context, inlineMode) {
private val enumEntriesIntrinsic = context.intrinsics.enumEntriesIntrinsic
private val inlineFunctionsBeforeInlining = context.mapping.inlineFunctionsBeforeInlining

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.inline.FunctionInlining
import org.jetbrains.kotlin.ir.inline.InlineFunctionResolver
import org.jetbrains.kotlin.ir.inline.InlineMode

@PhaseDescription(
name = "FunctionInliningPhase",
Expand All @@ -32,6 +33,6 @@ class JvmIrInliner(context: JvmBackendContext) : FunctionInlining(
}
}

class JvmInlineFunctionResolver(private val context: JvmBackendContext) : InlineFunctionResolver() {
class JvmInlineFunctionResolver(private val context: JvmBackendContext) : InlineFunctionResolver(InlineMode.ALL_INLINE_FUNCTIONS) {
override fun needsInlining(function: IrFunction): Boolean = function.isInlineFunctionCall(context)
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ package org.jetbrains.kotlin.backend.wasm.lower

import org.jetbrains.kotlin.backend.wasm.WasmBackendContext
import org.jetbrains.kotlin.ir.inline.InlineFunctionResolverReplacingCoroutineIntrinsics
import org.jetbrains.kotlin.ir.inline.InlineMode
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol

class WasmInlineFunctionResolver(
context: WasmBackendContext
) : InlineFunctionResolverReplacingCoroutineIntrinsics<WasmBackendContext>(context) {
) : InlineFunctionResolverReplacingCoroutineIntrinsics<WasmBackendContext>(context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS) {
private val enumEntriesIntrinsic = context.wasmSymbols.enumEntriesIntrinsic

override fun shouldExcludeFunctionFromInlining(symbol: IrFunctionSymbol): Boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ interface CallInlinerStrategy {
}
}

abstract class InlineFunctionResolver {
enum class InlineMode {
PRIVATE_INLINE_FUNCTIONS,
ALL_INLINE_FUNCTIONS,
ALL_FUNCTIONS,
}

abstract class InlineFunctionResolver(val inlineMode: InlineMode) {
open val callInlinerStrategy: CallInlinerStrategy
get() = CallInlinerStrategy.DEFAULT
open val allowExternalInlining: Boolean
Expand All @@ -80,11 +86,9 @@ abstract class InlineFunctionResolver {
}

abstract class InlineFunctionResolverReplacingCoroutineIntrinsics<Ctx : CommonBackendContext>(
protected val context: Ctx
) : InlineFunctionResolver() {
protected open val inlineOnlyPrivateFunctions: Boolean
get() = false

protected val context: Ctx,
inlineMode: InlineMode,
) : InlineFunctionResolver(inlineMode) {
override fun getFunctionDeclaration(symbol: IrFunctionSymbol): IrFunction? {
val function = super.getFunctionDeclaration(symbol) ?: return null
// TODO: Remove these hacks when coroutine intrinsics are fixed.
Expand All @@ -101,7 +105,7 @@ abstract class InlineFunctionResolverReplacingCoroutineIntrinsics<Ctx : CommonBa

override fun shouldExcludeFunctionFromInlining(symbol: IrFunctionSymbol): Boolean {
return super.shouldExcludeFunctionFromInlining(symbol) ||
(inlineOnlyPrivateFunctions && !symbol.owner.isConsideredAsPrivateForInlining())
(inlineMode == InlineMode.PRIVATE_INLINE_FUNCTIONS && !symbol.owner.isConsideredAsPrivateForInlining())
}
}

Expand Down Expand Up @@ -233,14 +237,17 @@ open class FunctionInlining(
*/
val irBuilder = context.createIrBuilder(irReturnableBlockSymbol, endOffset, endOffset)

// Investigate the difference (KT-71425).
val returnType = if (inlineFunctionResolver.inlineMode == InlineMode.ALL_FUNCTIONS) callee.returnType else callSite.type

val transformer = ParameterSubstitutor()
val newStatements = statements.map { it.transform(transformer, data = null) as IrStatement }

val inlinedBlock = IrInlinedFunctionBlockImpl(
startOffset = callSite.startOffset,
endOffset = callSite.endOffset,
type = callSite.type,
inlineFunction = callee.originalFunction,
type = returnType,
inlineFunction = if (inlineFunctionResolver.inlineMode == InlineMode.ALL_FUNCTIONS) callee else callee.originalFunction,
origin = null,
statements = evaluationStatements + newStatements
).apply {
Expand All @@ -256,7 +263,7 @@ open class FunctionInlining(
return IrReturnableBlockImpl(
startOffset = callSite.startOffset,
endOffset = callSite.endOffset,
type = callSite.type,
type = returnType,
symbol = irReturnableBlockSymbol,
origin = null,
statements = listOf(inlinedBlock),
Expand All @@ -266,10 +273,10 @@ open class FunctionInlining(
expression.transformChildrenVoid(this)

if (expression.returnTargetSymbol == copiedCallee.symbol) {
val expr = if (callSite.type.isUnit()) {
val expr = if (returnType.isUnit()) {
expression.value.coerceToUnit(context.irBuiltIns, context.typeSystem)
} else {
expression.value.doImplicitCastIfNeededTo(callSite.type)
expression.value.doImplicitCastIfNeededTo(returnType)
}
return irBuilder.at(expression).irReturn(expr)
}
Expand Down Expand Up @@ -702,7 +709,9 @@ open class FunctionInlining(
* For simplicity and to produce simpler IR we don't create temporaries for every immutable variable,
* not only for those referring to inlinable lambdas.
*/
if (argument.isInlinableLambdaArgument || argument.isInlinablePropertyReference) {
if ((argument.isInlinableLambdaArgument || argument.isInlinablePropertyReference)
&& inlineFunctionResolver.inlineMode != InlineMode.ALL_FUNCTIONS
) {
substituteMap[parameter] = argument.argumentExpression
val arg = argument.argumentExpression
when {
Expand Down Expand Up @@ -756,7 +765,8 @@ open class FunctionInlining(
}

private fun ParameterToArgument.doesNotNeedTemporaryVariable(): Boolean =
argumentExpression.isPure(false, context = context) && parameter.isInlineParameter()
argumentExpression.isPure(false, context = context)
&& (inlineFunctionResolver.inlineMode == InlineMode.ALL_FUNCTIONS || parameter.isInlineParameter())

private fun createTemporaryVariable(
parameter: IrValueParameter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.jetbrains.kotlin.ir.expressions.IrBody
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import org.jetbrains.kotlin.ir.expressions.IrSuspensionPoint
import org.jetbrains.kotlin.ir.inline.DumpSyntheticAccessors
import org.jetbrains.kotlin.ir.inline.InlineMode
import org.jetbrains.kotlin.ir.inline.SyntheticAccessorLowering
import org.jetbrains.kotlin.ir.inline.isConsideredAsPrivateForInlining
import org.jetbrains.kotlin.ir.interpreter.IrInterpreterConfiguration
Expand Down Expand Up @@ -407,7 +408,7 @@ private val cacheOnlyPrivateFunctionsPhase: SimpleNamedCompilerPhase<NativeGener

private val inlineOnlyPrivateFunctionsPhase = createFileLoweringPhase(
lowering = { context: NativeGenerationState ->
NativeIrInliner(context, inlineOnlyPrivateFunctions = true)
NativeIrInliner(context, inlineMode = InlineMode.PRIVATE_INLINE_FUNCTIONS)
},
name = "InlineOnlyPrivateFunctions",
description = "The first phase of inlining (inline only private functions)",
Expand All @@ -430,7 +431,7 @@ private val cacheAllFunctionsPhase: SimpleNamedCompilerPhase<NativeGenerationSta

internal val inlineAllFunctionsPhase = createFileLoweringPhase(
lowering = { context: NativeGenerationState ->
NativeIrInliner(context, inlineOnlyPrivateFunctions = false)
NativeIrInliner(context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS)
},
name = "InlineAllFunctions",
description = "The second phase of inlining (inline all functions)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ package org.jetbrains.kotlin.backend.konan.lower

import org.jetbrains.kotlin.backend.konan.NativeGenerationState
import org.jetbrains.kotlin.ir.inline.CommonInlineCallableReferenceToLambdaPhase
import org.jetbrains.kotlin.ir.inline.InlineMode

internal class NativeInlineCallableReferenceToLambdaPhase(
context: NativeGenerationState
) : CommonInlineCallableReferenceToLambdaPhase(
context.context, NativeInlineFunctionResolver(context, inlineOnlyPrivateFunctions = false)
context.context, NativeInlineFunctionResolver(context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS)
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.inline.CallInlinerStrategy
import org.jetbrains.kotlin.ir.inline.InlineFunctionResolverReplacingCoroutineIntrinsics
import org.jetbrains.kotlin.ir.inline.InlineMode
import org.jetbrains.kotlin.ir.inline.SyntheticAccessorLowering
import org.jetbrains.kotlin.ir.irAttribute
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
Expand All @@ -42,8 +43,8 @@ internal fun IrFunction.getOrSaveLoweredInlineFunction(): IrFunction =
// TODO: This is a bit hacky. Think about adopting persistent IR ideas.
internal class NativeInlineFunctionResolver(
private val generationState: NativeGenerationState,
override val inlineOnlyPrivateFunctions: Boolean
) : InlineFunctionResolverReplacingCoroutineIntrinsics<Context>(generationState.context) {
inlineMode: InlineMode,
) : InlineFunctionResolverReplacingCoroutineIntrinsics<Context>(generationState.context, inlineMode) {
override fun getFunctionDeclaration(symbol: IrFunctionSymbol): IrFunction? {
val function = super.getFunctionDeclaration(symbol) ?: return null

Expand Down Expand Up @@ -113,7 +114,7 @@ internal class NativeInlineFunctionResolver(
WrapInlineDeclarationsWithReifiedTypeParametersLowering(context).lower(body, function)

if (experimentalDoubleInlining) {
NativeIrInliner(generationState, inlineOnlyPrivateFunctions = true).lower(body, function)
NativeIrInliner(generationState, inlineMode = InlineMode.PRIVATE_INLINE_FUNCTIONS).lower(body, function)
SyntheticAccessorLowering(context).lowerWithoutAddingAccessorsToParents(function)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.inline.FunctionInlining
import org.jetbrains.kotlin.ir.inline.InlineMode
import org.jetbrains.kotlin.ir.inline.isConsideredAsPrivateForInlining
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid

internal class NativeIrInliner(
context: NativeGenerationState,
inlineOnlyPrivateFunctions: Boolean,
context: NativeGenerationState,
inlineMode: InlineMode,
) : FunctionInlining(
context = context.context,
NativeInlineFunctionResolver(context, inlineOnlyPrivateFunctions),
insertAdditionalImplicitCasts = true,
produceOuterThisFields = false,
context = context.context,
NativeInlineFunctionResolver(context, inlineMode),
insertAdditionalImplicitCasts = true,
produceOuterThisFields = false,
)

internal class CacheInlineFunctionsBeforeInlining(private val cacheOnlyPrivateFunctions: Boolean) : FileLoweringPass {
Expand Down

0 comments on commit 874c0a1

Please sign in to comment.