diff --git a/cedar-drt/src/lean_impl.rs b/cedar-drt/src/lean_impl.rs index 73c23f23f..2ccb16094 100644 --- a/cedar-drt/src/lean_impl.rs +++ b/cedar-drt/src/lean_impl.rs @@ -22,11 +22,13 @@ use core::panic; use std::collections::HashMap; -use std::{env, ffi::CString}; +use std::ffi::CString; +use std::sync::Once; use crate::cedar_test_impl::*; use crate::definitional_request_types::*; use cedar_policy::integration_testing::{CustomCedarImpl, IntegrationTestValidationResult}; +use cedar_policy::Decision; use cedar_policy_core::ast::{Expr, Value}; pub use cedar_policy_core::*; pub use cedar_policy_validator::{ValidationMode, ValidatorSchema}; @@ -34,12 +36,14 @@ pub use entities::Entities; pub use lean_sys::init::lean_initialize; pub use lean_sys::lean_object; pub use lean_sys::string::lean_mk_string; +use lean_sys::{lean_dec, lean_dec_ref, lean_io_result_is_ok, lean_io_result_show_error}; use lean_sys::{ - lean_initialize_runtime_module, lean_io_mark_end_initialization, lean_io_mk_world, + lean_initialize_runtime_module_locked, lean_io_mark_end_initialization, lean_io_mk_world, lean_string_cstr, }; use log::info; use serde::Deserialize; +use std::collections::HashSet; use std::ffi::CStr; use std::str::FromStr; @@ -50,13 +54,15 @@ extern "C" { fn isAuthorizedDRT(req: *mut lean_object) -> *mut lean_object; fn validateDRT(req: *mut lean_object) -> *mut lean_object; fn evaluateDRT(req: *mut lean_object) -> *mut lean_object; - fn initialize_DiffTest_Main(builtin: i8, ob: *mut lean_object) -> *mut lean_object; + fn initialize_DiffTest_Main(builtin: u8, ob: *mut lean_object) -> *mut lean_object; } pub const LEAN_AUTH_MSG: &str = "Lean authorization time (ns) : "; pub const LEAN_EVAL_MSG: &str = "Lean evaluation time (ns) : "; pub const LEAN_VAL_MSG: &str = "Lean validation time (ns) : "; +static START: Once = Once::new(); + #[derive(Debug, Deserialize)] struct ListDef { l: Vec, @@ -108,44 +114,32 @@ type ValidationResponse = ResultDef>; pub struct LeanDefinitionalEngine {} -fn lean_obj_to_string(o: *mut lean_object) -> String { - let lean_obj_p = unsafe { lean_string_cstr(o) }; - let lean_obj_cstr = unsafe { CStr::from_ptr(lean_obj_p as *const i8) }; - lean_obj_cstr - .to_str() - .expect("failed to convert Lean object to string") - .to_owned() -} - impl LeanDefinitionalEngine { pub fn new() -> Self { - if env::var("RUST_LEAN_INTERFACE_INIT").is_err() { - unsafe { lean_initialize_runtime_module() }; - unsafe { lean_initialize() }; - unsafe { initialize_DiffTest_Main(1, lean_io_mk_world()) }; - unsafe { lean_io_mark_end_initialization() }; - env::set_var("RUST_LEAN_INTERFACE_INIT", "1"); - } + // We run this once per thread: + unsafe { + lean_initialize_runtime_module_locked(); + }; + + START.call_once(|| { + unsafe { + // following: https://lean-lang.org/lean4/doc/dev/ffi.html + let builtin: u8 = 1; + let res = initialize_DiffTest_Main(builtin, lean_io_mk_world()); + if lean_io_result_is_ok(res) { + lean_dec_ref(res); + } else { + lean_io_result_show_error(res); + lean_dec(res); + panic!("Failed to initialize Lean"); + } + lean_io_mark_end_initialization(); + }; + }); Self {} } - fn serialize_authorization_request( - request: &ast::Request, - policies: &ast::PolicySet, - entities: &Entities, - ) -> *mut lean_object { - let request: String = serde_json::to_string(&AuthorizationRequest { - request, - policies, - entities, - }) - .expect("failed to serialize request, policies, or entities"); - let cstring = CString::new(request).expect("`CString::new` failed"); - unsafe { lean_mk_string(cstring.as_ptr() as *const u8) } - } - - fn deserialize_authorization_response(response: *mut lean_object) -> TestResult { - let response_string = lean_obj_to_string(response); + fn deserialize_authorization_response(response_string: String) -> TestResult { let resp: AuthorizationResponse = serde_json::from_str(&response_string).expect("could not deserialize json"); match resp { @@ -197,30 +191,31 @@ impl LeanDefinitionalEngine { policies: &ast::PolicySet, entities: &Entities, ) -> TestResult { - let req = Self::serialize_authorization_request(request, policies, entities); - let response = unsafe { isAuthorizedDRT(req) }; - Self::deserialize_authorization_response(response) - } - - fn serialize_evaluation_request( - request: &ast::Request, - entities: &Entities, - expr: &Expr, - expected: Option<&Expr>, - ) -> *mut lean_object { - let request: String = serde_json::to_string(&EvaluationRequest { + let request: String = serde_json::to_string(&AuthorizationRequest { request, + policies, entities, - expr, - expected, }) - .expect("failed to serialize request, expression, or entities"); + .expect("failed to serialize request, policies, or entities"); let cstring = CString::new(request).expect("`CString::new` failed"); - unsafe { lean_mk_string(cstring.as_ptr() as *const u8) } + // Lean with decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h + let req = unsafe { lean_mk_string(cstring.as_ptr() as *const u8) }; + let response = unsafe { isAuthorizedDRT(req) }; + // req can no longer be assumed to exist + + let lean_obj_p = unsafe { lean_string_cstr(response) }; + let lean_obj_cstr = unsafe { CStr::from_ptr(lean_obj_p as *const i8) }; + let response_string = lean_obj_cstr + .to_str() + .expect("failed to convert Lean object to string") + .to_owned(); + unsafe { + lean_dec(response); + }; + Self::deserialize_authorization_response(response_string) } - fn deserialize_evaluation_response(response: *mut lean_object) -> TestResult { - let response_string = lean_obj_to_string(response); + fn deserialize_evaluation_response(response_string: String) -> TestResult { let resp: EvaluationResponse = serde_json::from_str(&response_string).expect("could not deserialize json"); match resp { @@ -241,31 +236,35 @@ impl LeanDefinitionalEngine { expr: &Expr, expected: Option, ) -> TestResult { - let expected_as_expr = expected.map(|v| v.into()); - let req = - Self::serialize_evaluation_request(request, entities, expr, expected_as_expr.as_ref()); - let response = unsafe { evaluateDRT(req) }; - Self::deserialize_evaluation_response(response) - } - - fn serialize_validation_request( - schema: &ValidatorSchema, - policies: &ast::PolicySet, - ) -> *mut lean_object { - let request: String = serde_json::to_string(&ValidationRequest { - schema, - policies, - mode: cedar_policy_validator::ValidationMode::default(), // == Strict + let expected_as_expr: Option = expected.map(|v| v.into()); + let request: String = serde_json::to_string(&EvaluationRequest { + request, + entities, + expr, + expected: expected_as_expr.as_ref(), }) - .expect("failed to serialize schema or policies"); + .expect("failed to serialize request, expression, or entities"); let cstring = CString::new(request).expect("`CString::new` failed"); - unsafe { lean_mk_string(cstring.as_ptr() as *const u8) } + // Lean with decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h + let req = unsafe { lean_mk_string(cstring.as_ptr() as *const u8) }; + let response = unsafe { evaluateDRT(req) }; + // req can no longer be assumed to exist + + let lean_obj_p = unsafe { lean_string_cstr(response) }; + let lean_obj_cstr = unsafe { CStr::from_ptr(lean_obj_p as *const i8) }; + let response_string = lean_obj_cstr + .to_str() + .expect("failed to convert Lean object to string") + .to_owned(); + unsafe { + lean_dec(response); + }; + Self::deserialize_evaluation_response(response_string) } fn deserialize_validation_response( - response: *mut lean_object, + response_string: String, ) -> TestResult { - let response_string = lean_obj_to_string(response); let resp: ValidationResponse = serde_json::from_str(&response_string).expect("could not deserialize json"); match resp { @@ -291,9 +290,27 @@ impl LeanDefinitionalEngine { schema: &ValidatorSchema, policies: &ast::PolicySet, ) -> TestResult { - let req = Self::serialize_validation_request(schema, policies); + let request: String = serde_json::to_string(&ValidationRequest { + schema, + policies, + mode: cedar_policy_validator::ValidationMode::default(), // == Strict + }) + .expect("failed to serialize schema or policies"); + let cstring = CString::new(request).expect("`CString::new` failed"); + // Lean with decrement the reference count when we pass this object: https://github.com/leanprover/lean4/blob/master/src/include/lean/lean.h + let req = unsafe { lean_mk_string(cstring.as_ptr() as *const u8) }; let response = unsafe { validateDRT(req) }; - Self::deserialize_validation_response(response) + // req can no longer be assumed to exist + let lean_obj_p = unsafe { lean_string_cstr(response) }; + let lean_obj_cstr = unsafe { CStr::from_ptr(lean_obj_p as *const i8) }; + let response_string = lean_obj_cstr + .to_str() + .expect("failed to convert Lean object to string") + .to_owned(); + unsafe { + lean_dec(response); + }; + Self::deserialize_validation_response(response_string) } } diff --git a/cedar-drt/tests/benchmark.rs b/cedar-drt/tests/benchmark.rs index 5e24fbf3c..df099e973 100644 --- a/cedar-drt/tests/benchmark.rs +++ b/cedar-drt/tests/benchmark.rs @@ -165,7 +165,6 @@ fn print_summary(auth_times: HashMap<&str, Vec>, val_times: HashMap<&str, V #[test] // Currently, running this in conjunction with existing tests will cause an error (#227). // In order see the printed output from this test, run `cargo test -- --ignored --nocapture`. -#[ignore] fn run_all_tests() { let rust_impl = RustEngine::new(); let lean_impl = LeanDefinitionalEngine::new();