-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add side_effect and backend_config to jit_call #425
base: main
Are you sure you want to change the base?
Conversation
test/lit_tests/lowering/cpujit.mlir
Outdated
@@ -36,6 +36,6 @@ module { | |||
// CHECK-LABEL: @main | |||
// CHECK-SAME: (%[[ARG0:.+]]: tensor<64xi64>) -> tensor<64xi64> { | |||
// CHECK-NEXT: %[[CALL:.+]] = stablehlo.custom_call @enzymexla_compile_cpu(%arg0) | |||
// CHECK-SAME: {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", | |||
// CHECK-SAME: {api_version = 4 : i32, backend_config = {attr = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wsmoses this PR does make this change and I am not sure it affects the kernel stuff in any way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, this actually does need to stay the same (see above)
/* api_version*/ | ||
CustomCallApiVersionAttr::get( | ||
rewriter.getContext(), | ||
mlir::stablehlo::CustomCallApiVersion:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh no this is actually an issue, this needs to keep its API as it was before (it is the calling convention for the setup)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API_VERSION_STATUS_RETURNING_UNIFIED
doesn't allow for a DictAttr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, but we shouldn't change the attributes passed in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we pass backend_config to jit_call those need to be forwarded to the custom_call as well right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so currently the backend config out of jitcall is the function pointer of the jitt'd function, which then is directly executed
Just to make it clear, tests here are failing, independently of the failure introduced by #426, so please don't merge this PR before both issues are fixed. |
The 3 "Build Enzyme-JAX" jobs are all expected to pass, like the code formatting ones. I see there are 3 expected jobs which don't exist anymore, those can be removed. |
|
||
// CHECK-LABEL: func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { | ||
func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { | ||
// CHECK: stablehlo.custom_call @enzymexla_compile_cpu() {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", has_side_effect = true} : () -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't the backend_config
remain the same and get passed to the underlying func?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs fixing, the backend config needs to be compiled as part of CompileCall and this backend config should just be the attribute (at least that is what I figured out from the discussion above)
No description provided.