Skip to content

Commit

Permalink
rewrite all load api calls to return RecordBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
KSDaemon committed Jan 14, 2025
1 parent d330d7f commit 5f40798
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 95 deletions.
79 changes: 74 additions & 5 deletions packages/cubejs-backend-native/src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::node_obj_deserializer::JsValueDeserializer;
use crate::transport::MapCubeErrExt;
use cubeorchestrator::query_message_parser::QueryResult;
use cubeorchestrator::query_result_transform::{
RequestResultData, RequestResultDataMulti, TransformedData,
DBResponsePrimitive, RequestResultData, RequestResultDataMulti, TransformedData,
};
use cubeorchestrator::transport::{JsRawData, TransformDataRequest};
use cubesql::compile::engine::df::scan::{FieldValue, ValueObject};
use cubesql::CubeError;
use neon::context::{Context, FunctionContext, ModuleContext};
use neon::handle::Handle;
Expand All @@ -15,6 +16,7 @@ use neon::prelude::{
};
use neon::types::buffer::TypedArray;
use serde::Deserialize;
use std::borrow::Cow;
use std::sync::Arc;

pub fn register_module(cx: &mut ModuleContext) -> NeonResult<()> {
Expand All @@ -39,6 +41,7 @@ pub fn register_module(cx: &mut ModuleContext) -> NeonResult<()> {
pub struct ResultWrapper {
transform_data: TransformDataRequest,
data: Arc<QueryResult>,
transformed_data: Option<TransformedData>,
}

impl ResultWrapper {
Expand Down Expand Up @@ -115,14 +118,80 @@ impl ResultWrapper {
Ok(Self {
transform_data: transform_request,
data: query_result,
transformed_data: None,
})
}

pub fn transform_result(&self) -> Result<TransformedData, CubeError> {
let transformed = TransformedData::transform(&self.transform_data, &self.data)
.map_cube_err("Can't prepare transformed data")?;
pub fn transform_result(&mut self) -> Result<(), CubeError> {
self.transformed_data = Some(
TransformedData::transform(&self.transform_data, &self.data)
.map_cube_err("Can't prepare transformed data")?,
);

Ok(transformed)
Ok(())
}
}

impl ValueObject for ResultWrapper {
fn len(&mut self) -> Result<usize, CubeError> {
if self.transformed_data.is_none() {
self.transform_result()?;
}

let data = self.transformed_data.as_ref().unwrap();

match data {
TransformedData::Compact {
members: _members,
dataset,
} => Ok(dataset.len()),
TransformedData::Vanilla(dataset) => Ok(dataset.len()),
}
}

fn get(&mut self, index: usize, field_name: &str) -> Result<FieldValue, CubeError> {
if self.transformed_data.is_none() {
self.transform_result()?;
}

let data = self.transformed_data.as_ref().unwrap();

let value = match data {
TransformedData::Compact { members, dataset } => {
let Some(row) = dataset.get(index) else {
return Err(CubeError::user(format!(
"Unexpected response from Cube, can't get {} row",
index
)));
};

let Some(member_index) = members.iter().position(|m| m == field_name) else {
return Err(CubeError::user(format!(
"Field name '{}' not found in members",
field_name
)));
};

row.get(member_index).unwrap_or(&DBResponsePrimitive::Null)
}
TransformedData::Vanilla(dataset) => {
let Some(row) = dataset.get(index) else {
return Err(CubeError::user(format!(
"Unexpected response from Cube, can't get {} row",
index
)));
};

row.get(field_name).unwrap_or(&DBResponsePrimitive::Null)
}
};

Ok(match value {
DBResponsePrimitive::String(s) => FieldValue::String(Cow::Borrowed(s)),
DBResponsePrimitive::Number(n) => FieldValue::Number(*n),
DBResponsePrimitive::Boolean(b) => FieldValue::Bool(*b),
DBResponsePrimitive::Null => FieldValue::Null,
})
}
}

Expand Down
35 changes: 20 additions & 15 deletions packages/cubejs-backend-native/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use crate::{
stream::call_js_with_stream_as_callback,
};
use async_trait::async_trait;
use cubesql::compile::engine::df::scan::{MemberField, SchemaRef};
use cubesql::compile::engine::df::scan::{
convert_transport_response, transform_response, MemberField, RecordBatch, SchemaRef,
};
use cubesql::compile::engine::df::wrapper::SqlQuery;
use cubesql::transport::{
SpanId, SqlGenerator, SqlResponse, TransportLoadRequestQuery, TransportLoadResponse,
Expand Down Expand Up @@ -334,9 +336,9 @@ impl TransportService for NodeBridgeTransport {
sql_query: Option<SqlQuery>,
ctx: AuthContextRef,
meta: LoadRequestMeta,
_schema: SchemaRef,
// ) -> Result<Vec<RecordBatch>, CubeError> {
) -> Result<TransportLoadResponse, CubeError> {
schema: SchemaRef,
member_fields: Vec<MemberField>,
) -> Result<Vec<RecordBatch>, CubeError> {
trace!("[transport] Request ->");

let native_auth = ctx
Expand Down Expand Up @@ -461,20 +463,23 @@ impl TransportService for NodeBridgeTransport {
}
};

break serde_json::from_value::<TransportLoadResponse>(response)
let response = match serde_json::from_value::<TransportLoadResponse>(response) {
Ok(v) => v,
Err(err) => {
return Err(CubeError::user(err.to_string()));
}
};

break convert_transport_response(response, schema.clone(), member_fields)
.map_err(|err| CubeError::user(err.to_string()));
}
ValueFromJs::ResultWrapper(result_wrappers) => {
let response = TransportLoadResponse {
pivot_query: None,
slow_query: None,
query_type: None,
results: result_wrappers
.into_iter()
.map(|v| v.transform_result().unwrap().into())
.collect(),
};
break Ok(response);
break result_wrappers
.into_iter()
.map(|mut wrapper| {
transform_response(&mut wrapper, schema.clone(), &member_fields)
})
.collect::<Result<Vec<_>, _>>();
}
}
}
Expand Down
134 changes: 71 additions & 63 deletions rust/cubesql/cubesql/src/compile/engine/df/scan.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use cubeclient::models::{V1LoadRequestQuery, V1LoadResult, V1LoadResultAnnotation};
use cubeclient::models::{V1LoadRequestQuery, V1LoadResponse};
pub use datafusion::{
arrow::{
array::{
Expand Down Expand Up @@ -52,7 +52,7 @@ use datafusion::{
logical_plan::JoinType,
scalar::ScalarValue,
};
use serde_json::{json, Value};
use serde_json::Value;

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum MemberField {
Expand Down Expand Up @@ -655,28 +655,22 @@ impl ExecutionPlan for CubeScanExecutionPlan {
)));
}

let mut response = JsonValueObject::new(
load_data(
self.span_id.clone(),
request,
self.auth_context.clone(),
self.transport.clone(),
meta.clone(),
self.schema.clone(),
self.options.clone(),
self.wrapped_sql.clone(),
)
.await?
.data,
);
one_shot_stream.data = Some(
transform_response(
&mut response,
one_shot_stream.schema.clone(),
&one_shot_stream.member_fields,
)
.map_err(|e| DataFusionError::Execution(e.message.to_string()))?,
);
let response = load_data(
self.span_id.clone(),
request,
self.auth_context.clone(),
self.transport.clone(),
meta.clone(),
self.schema.clone(),
self.member_fields.clone(),
self.options.clone(),
self.wrapped_sql.clone(),
)
.await?;

// For now execute method executes only one query at a time, so we
// take the first result
one_shot_stream.data = Some(response.first().unwrap().clone());

Ok(Box::pin(CubeScanStreamRouter::new(
None,
Expand Down Expand Up @@ -846,9 +840,10 @@ async fn load_data(
transport: Arc<dyn TransportService>,
meta: LoadRequestMeta,
schema: SchemaRef,
member_fields: Vec<MemberField>,
options: CubeScanOptions,
sql_query: Option<SqlQuery>,
) -> ArrowResult<V1LoadResult> {
) -> ArrowResult<Vec<RecordBatch>> {
let no_members_query = request.measures.as_ref().map(|v| v.len()).unwrap_or(0) == 0
&& request.dimensions.as_ref().map(|v| v.len()).unwrap_or(0) == 0
&& request
Expand All @@ -866,22 +861,27 @@ async fn load_data(
data.push(serde_json::Value::Null)
}

V1LoadResult::new(
V1LoadResultAnnotation {
measures: json!(Vec::<serde_json::Value>::new()),
dimensions: json!(Vec::<serde_json::Value>::new()),
segments: json!(Vec::<serde_json::Value>::new()),
time_dimensions: json!(Vec::<serde_json::Value>::new()),
},
data,
)
let mut response = JsonValueObject::new(data);
let rec = transform_response(&mut response, schema.clone(), &member_fields)
.map_err(|e| DataFusionError::Execution(e.message.to_string()))?;

rec
} else {
let result = transport
.load(span_id, request, sql_query, auth_context, meta, schema)
.await;
let mut response = result.map_err(|err| ArrowError::ComputeError(err.to_string()))?;
if let Some(data) = response.results.pop() {
match (options.max_records, data.data.len()) {
.load(
span_id,
request,
sql_query,
auth_context,
meta,
schema,
member_fields,
)
.await
.map_err(|err| ArrowError::ComputeError(err.to_string()))?;
let response = result.first();
if let Some(data) = response.cloned() {
match (options.max_records, data.num_rows()) {
(Some(max_records), len) if len >= max_records => {
return Err(ArrowError::ComputeError(format!("One of the Cube queries exceeded the maximum row limit ({}). JOIN/UNION is not possible as it will produce incorrect results. Try filtering the results more precisely or moving post-processing functions to an outer query.", max_records)));
}
Expand All @@ -896,7 +896,7 @@ async fn load_data(
}
};

Ok(result)
Ok(vec![result])
}

fn load_to_stream_sync(one_shot_stream: &mut CubeScanOneShotStream) -> Result<()> {
Expand All @@ -906,6 +906,7 @@ fn load_to_stream_sync(one_shot_stream: &mut CubeScanOneShotStream) -> Result<()
let transport = one_shot_stream.transport.clone();
let meta = one_shot_stream.meta.clone();
let schema = one_shot_stream.schema.clone();
let member_fields = one_shot_stream.member_fields.clone();
let options = one_shot_stream.options.clone();
let wrapped_sql = one_shot_stream.wrapped_sql.clone();

Expand All @@ -918,22 +919,16 @@ fn load_to_stream_sync(one_shot_stream: &mut CubeScanOneShotStream) -> Result<()
transport,
meta,
schema,
member_fields,
options,
wrapped_sql,
))
})
.join()
.map_err(|_| DataFusionError::Execution(format!("Can't load to stream")))?;

let mut response = JsonValueObject::new(res.unwrap().data);
one_shot_stream.data = Some(
transform_response(
&mut response,
one_shot_stream.schema.clone(),
&one_shot_stream.member_fields,
)
.map_err(|e| DataFusionError::Execution(e.message.to_string()))?,
);
.map_err(|_| DataFusionError::Execution(format!("Can't load to stream")))??;

let response = res.first();
one_shot_stream.data = Some(response.cloned().unwrap());

Ok(())
}
Expand Down Expand Up @@ -1339,6 +1334,21 @@ pub fn transform_response<V: ValueObject>(
Ok(RecordBatch::try_new(schema.clone(), columns)?)
}

pub fn convert_transport_response(
response: V1LoadResponse,
schema: SchemaRef,
member_fields: Vec<MemberField>,
) -> std::result::Result<Vec<RecordBatch>, CubeError> {
response
.results
.into_iter()
.map(|r| {
let mut response = JsonValueObject::new(r.data.clone());
transform_response(&mut response, schema.clone(), &member_fields)
})
.collect::<std::result::Result<Vec<RecordBatch>, CubeError>>()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1402,10 +1412,12 @@ mod tests {
_sql_query: Option<SqlQuery>,
_ctx: AuthContextRef,
_meta_fields: LoadRequestMeta,
_schema: SchemaRef,
) -> Result<V1LoadResponse, CubeError> {
schema: SchemaRef,
member_fields: Vec<MemberField>,
) -> Result<Vec<RecordBatch>, CubeError> {
let response = r#"
{
{
"results": [{
"annotation": {
"measures": [],
"dimensions": [],
Expand All @@ -1419,17 +1431,13 @@ mod tests {
{"KibanaSampleDataEcommerce.count": null, "KibanaSampleDataEcommerce.maxPrice": null, "KibanaSampleDataEcommerce.isBool": "true", "KibanaSampleDataEcommerce.orderDate": "9999-12-31 00:00:00.000", "KibanaSampleDataEcommerce.city": "City 4"},
{"KibanaSampleDataEcommerce.count": null, "KibanaSampleDataEcommerce.maxPrice": null, "KibanaSampleDataEcommerce.isBool": "false", "KibanaSampleDataEcommerce.orderDate": null, "KibanaSampleDataEcommerce.city": null}
]
}
}]
}
"#;

let result: V1LoadResult = serde_json::from_str(response).unwrap();

Ok(V1LoadResponse {
pivot_query: None,
slow_query: None,
query_type: None,
results: vec![result],
})
let result: V1LoadResponse = serde_json::from_str(response).unwrap();
convert_transport_response(result, schema.clone(), member_fields)
.map_err(|err| CubeError::user(err.to_string()))
}

async fn load_stream(
Expand Down
Loading

0 comments on commit 5f40798

Please sign in to comment.