Skip to content

Commit

Permalink
feat: implement support for externally-implemented foreign functions (g…
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante authored Mar 13, 2024
1 parent 29dcd72 commit 57df119
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 21 deletions.
7 changes: 5 additions & 2 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ default = [
"language-parsers",
"non_wasm",
]
external_functions = ["dep:marzano-externals"]
external_functions_common = []
external_functions = ["external_functions_common", "dep:marzano-externals"]
# Use external functions via FFI
external_functions_ffi = ["external_functions_common"]
embeddings = ["dep:embeddings"]
test_ci = ["external_functions"]
test_all = ["embeddings", "external_functions"]
Expand All @@ -71,7 +74,7 @@ grit_alpha = ["external_functions", "embeddings"]
network_requests_common = ["marzano-util/network_requests_common"]
network_requests = ["reqwest", "tokio", "tokio/rt", "network_requests_common", "marzano-util/network_requests"]
network_requests_external = ["network_requests_common", "marzano-util/network_requests_external"]
wasm_core = ["getrandom/js", "web-sys", "network_requests_external"]
wasm_core = ["getrandom/js", "web-sys", "network_requests_external", "external_functions_common", "external_functions_ffi", "marzano-util/external_functions_ffi"]
grit_tracing = ["dep:tracing-opentelemetry"]
language-parsers = ["marzano-language/builtin-parser"]
grit-parser = ["tree-sitter-gritql"]
Expand Down
14 changes: 12 additions & 2 deletions crates/core/src/pattern/function_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl ForeignFunctionDefinition {
}

impl FunctionDefinition for ForeignFunctionDefinition {
#[cfg(not(feature = "external_functions"))]
#[cfg(not(feature = "external_functions_common"))]
fn call<'a>(
&'a self,
_state: &mut State<'a>,
Expand All @@ -210,7 +210,7 @@ impl FunctionDefinition for ForeignFunctionDefinition {
) -> Result<FuncEvaluation> {
bail!("External functions are not enabled in your environment")
}
#[cfg(feature = "external_functions")]
#[cfg(feature = "external_functions_common")]
fn call<'a>(
&'a self,
state: &mut State<'a>,
Expand Down Expand Up @@ -239,13 +239,23 @@ impl FunctionDefinition for ForeignFunctionDefinition {

let resolved_str: Vec<&str> = cow_resolved.iter().map(Cow::as_ref).collect();

// START Simple externalized version
#[cfg(feature = "external_functions_ffi")]
let result = (context.runtime.exec_external)(&self.code, param_names, &resolved_str)?;

// END Simple externalized version

// START embedded version
// Really, we should compile ahead of time and then call the compiled function
// But, the WebAssembly function model is currently *mutable* so state would be contaminated
#[cfg(feature = "external_functions")]
let mut function = ExternalFunction::new_js(&self.code, param_names)?;

#[cfg(feature = "external_functions")]
let result = function
.call(&resolved_str)
.or_else(|e| bail!("failed to call function {}: {}", self.name, e))?;
// END embedded version

let string = String::from_utf8(result).or_else(|_| {
bail!(
Expand Down
2 changes: 2 additions & 0 deletions crates/util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ finder = ["log", "ignore"]
network_requests_common = []
network_requests = ["reqwest", "tokio", "tokio/rt", "network_requests_common"]
network_requests_external = ["network_requests_common"]

external_functions_ffi = []
13 changes: 13 additions & 0 deletions crates/util/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ pub struct ExecutionContext {
/// It is particularly useful for the WebAssembly variant of Marzano.
#[cfg(all(
feature = "network_requests_external",
feature = "external_functions_ffi",
not(feature = "network_requests")
))]
#[derive(Clone, Debug)]
pub struct ExecutionContext {
llm_api: Option<LanguageModelAPI>,
fetch: fn(url: &str, headers: &HeaderMap, json: &serde_json::Value) -> Result<String>,
pub exec_external:
fn(code: &[u8], param_names: Vec<String>, input_bindings: &[&str]) -> Result<Vec<u8>>,
pub ignore_limit_pattern: bool,
}

Expand Down Expand Up @@ -79,14 +82,21 @@ impl ExecutionContext {

#[cfg(all(
feature = "network_requests_external",
feature = "external_functions_ffi",
not(feature = "network_requests")
))]
pub fn new(
fetch: fn(url: &str, headers: &HeaderMap, json: &serde_json::Value) -> Result<String>,
exec_external: fn(
code: &[u8],
param_names: Vec<String>,
input_bindings: &[&str],
) -> Result<Vec<u8>>,
) -> ExecutionContext {
Self {
llm_api: None,
fetch,
exec_external,
ignore_limit_pattern: false,
}
}
Expand Down Expand Up @@ -186,6 +196,9 @@ impl Default for ExecutionContext {
fetch: |_url: &str, _headers: &HeaderMap, _json: &serde_json::Value| {
Err(anyhow::anyhow!("Network requests are disabled"))
},
exec_external: |_code: &[u8], _param_names: Vec<String>, _input_bindings: &[&str]| {
Err(anyhow::anyhow!("External functions are disabled"))
},
ignore_limit_pattern: false,
}
}
Expand Down
75 changes: 58 additions & 17 deletions crates/wasm-bindings/src/match_pattern.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use marzano_core::{
pattern::{
api::{AnalysisLog, InputFile, MatchResult, PatternInfo}, built_in_functions::BuiltIns, compiler::{src_to_problem_libs_for_language, CompilationResult}
api::{AnalysisLog, InputFile, MatchResult, PatternInfo},
built_in_functions::BuiltIns,
compiler::{src_to_problem_libs_for_language, CompilationResult},
},
tree_sitter_serde::tree_sitter_node_to_json,
};
use marzano_util::runtime::{ExecutionContext, LanguageModelAPI};
use marzano_language::target_language::{PatternLanguage, TargetLanguage};
use marzano_util::runtime::{ExecutionContext, LanguageModelAPI};
use marzano_util::{position::Position, rich_path::RichFile};
use std::{
collections::{BTreeMap, HashMap},
Expand Down Expand Up @@ -47,6 +49,12 @@ pub async fn initialize_tree_sitter() -> Result<(), JsError> {
extern "C" {
#[wasm_bindgen(catch)]
pub(crate) fn gritApiRequest(url: &str, headers: &str, body: &str) -> Result<String, JsValue>;
#[wasm_bindgen(catch)]
pub(crate) fn gritExternalFunctionCall(
code: &str,
arg_names: Vec<String>,
arg_values: Vec<String>,
) -> Result<String, JsValue>;
}

#[wasm_bindgen(js_name = parseInputFiles)]
Expand Down Expand Up @@ -89,7 +97,15 @@ pub async fn parse_input_files(
let injected_builtins: Option<BuiltIns> = None;
#[cfg(feature = "ai_builtins")]
let injected_builtins = Some(ai_builtins::ai_builtins::get_ai_built_in_functions());
match src_to_problem_libs_for_language(pattern.clone(), &libs, lang, None, None, parser, injected_builtins) {
match src_to_problem_libs_for_language(
pattern.clone(),
&libs,
lang,
None,
None,
parser,
injected_builtins,
) {
Ok(c) => {
let warning_logs = c
.compilation_warnings
Expand Down Expand Up @@ -152,19 +168,36 @@ pub async fn match_pattern(
let ParsedPattern { libs, lang, .. } =
get_parsed_pattern(&pattern, lib_paths, lib_contents, parser).await?;

let context = ExecutionContext::new(|url, headers, json| {
let body = serde_json::to_string(json)?;
let mut header_map = HashMap::<&str, &str>::new();
for (k, v) in headers.iter() {
header_map.insert(k.as_str(), v.to_str()?);
}
let headers_str = serde_json::to_string(&header_map)?;
let result = gritApiRequest(url, &headers_str, &body);
match result {
Ok(s) => Ok(s),
Err(_e) => Err(anyhow::anyhow!("HTTP error when making AI request, likely due to a network error. Please make sure you are logged in, or try again later.")),
}
});
let context = ExecutionContext::new(
|url, headers, json| {
let body = serde_json::to_string(json)?;
let mut header_map = HashMap::<&str, &str>::new();
for (k, v) in headers.iter() {
header_map.insert(k.as_str(), v.to_str()?);
}
let headers_str = serde_json::to_string(&header_map)?;
let result = gritApiRequest(url, &headers_str, &body);
match result {
Ok(s) => Ok(s),
Err(_e) => Err(anyhow::anyhow!("HTTP error when making AI request, likely due to a network error. Please make sure you are logged in, or try again later.")),
}
},
|code: &[u8], param_names: Vec<String>, input_bindings: &[&str]| {
let result = gritExternalFunctionCall(
&String::from_utf8_lossy(code),
param_names,
input_bindings.iter().map(|s| s.to_string()).collect(),
);
match result {
Ok(s) => Ok(s.into()),
Err(e) => {
// TODO: figure out why we don't get the real error here
let unwrapped = e.as_string().unwrap_or_else(|| "unknown error, check console for details".to_string());
Err(anyhow::anyhow!("Error calling external function: {}", unwrapped))
}
}
},
);

let context = if !llm_api_base.is_empty() {
let llm_api = LanguageModelAPI {
Expand All @@ -182,7 +215,15 @@ pub async fn match_pattern(
let injected_builtins = Some(ai_builtins::ai_builtins::get_ai_built_in_functions());
let CompilationResult {
problem: pattern, ..
} = match src_to_problem_libs_for_language(pattern, &libs, lang, None, None, parser, injected_builtins) {
} = match src_to_problem_libs_for_language(
pattern,
&libs,
lang,
None,
None,
parser,
injected_builtins,
) {
Ok(c) => c,
Err(e) => {
let log = match e.downcast::<marzano_util::analysis_logs::AnalysisLog>() {
Expand Down

0 comments on commit 57df119

Please sign in to comment.