Skip to content

Commit

Permalink
Changing the semantics of binding filters (tensorlakeai#295)
Browse files Browse the repository at this point in the history
* Changing the semantics of binding filters

* fixed param deserialization
  • Loading branch information
diptanu authored Feb 2, 2024
1 parent a9a7680 commit 5aba436
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 31 deletions.
2 changes: 1 addition & 1 deletion indexify_extractor_sdk/base_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class EmbeddingInputParams:
overlap: int = 0
chunk_size: int = 0
text_splitter: Literal["char", "recursive"] = "recursive"
text_splitter: str = "recursive"


class BaseEmbeddingExtractor(Extractor):
Expand Down
8 changes: 4 additions & 4 deletions indexify_extractor_sdk/base_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Extractor(ABC):

@abstractmethod
def extract(
self, content: Content, params) -> List[Content]:
self, content: Content, params: None) -> List[Content]:
"""
Extracts information from the content.
"""
Expand All @@ -119,9 +119,9 @@ def __init__(self, module_name: str, class_name: str):

def extract(self, content: List[Content], params: str) -> List[List[Content]]:
params_dict = json.loads(params)
param_instance = (
self._param_cls.from_dict(params_dict) if self._param_cls else None
)
param_instance = None
if self._param_cls is not type(None):
param_instance = self._param_cls.from_dict(params_dict)

# This is because the rust side does batching and on python we don't batch
out = []
Expand Down
2 changes: 1 addition & 1 deletion indexify_extractor_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.14"
__version__ = "0.0.15"
17 changes: 4 additions & 13 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,18 @@ use crate::{api_utils, metadata_index, vectordbs};
pub struct ExtractorBinding {
pub extractor: String,
pub name: String,
#[serde(default)]
pub filters: HashMap<String, serde_json::Value>,
#[serde(default, deserialize_with = "api_utils::deserialize_labels_eq_filter")]
pub filters_eq: Option<HashMap<String, String>>,
pub input_params: Option<serde_json::Value>,
pub content_source: Option<String>,
}

impl From<ExtractorBinding> for indexify_coordinator::ExtractorBinding {
fn from(value: ExtractorBinding) -> Self {
let mut filters = HashMap::new();
for filter in value.filters {
filters.insert(filter.0, filter.1.to_string());
}

Self {
extractor: value.extractor,
name: value.name,
filters,
filters: value.filters_eq.unwrap_or_default(),
input_params: value
.input_params
.map(|v| v.to_string())
Expand All @@ -60,11 +55,7 @@ impl TryFrom<indexify_coordinator::Repository> for DataRepository {
extractor_bindings.push(ExtractorBinding {
extractor: binding.extractor,
name: binding.name,
filters: binding
.filters
.into_iter()
.map(|(k, v)| (k, serde_json::from_str(&v).unwrap()))
.collect(),
filters_eq: Some(binding.filters),
input_params: Some(serde_json::from_str(&binding.input_params)?),
content_source: Some(binding.content_source),
});
Expand Down
19 changes: 10 additions & 9 deletions src/api_utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::HashMap;

use anyhow::{anyhow, Result};
use serde::Deserialize;

pub fn validate_label_key(key: &str) -> Result<(), String> {
pub fn validate_label_key(key: &str) -> Result<()> {
let validations = [
(key.is_ascii(), "must be ASCII"),
(key.len() <= 63, "must be 63 characters or less"),
Expand Down Expand Up @@ -35,15 +36,15 @@ pub fn validate_label_key(key: &str) -> Result<(), String> {
if err_msgs.is_empty() {
Ok(())
} else {
Err(format!(
Err(anyhow!(
"label key invalid - {} - found key : \"{}\"",
err_msgs.join(", "),
key
))
}
}

pub fn validate_label_value(value: &str) -> Result<(), String> {
pub fn validate_label_value(value: &str) -> Result<()> {
// empty string is ok
if value.is_empty() {
return Ok(());
Expand Down Expand Up @@ -83,15 +84,15 @@ pub fn validate_label_value(value: &str) -> Result<(), String> {
if err_msgs.is_empty() {
Ok(())
} else {
Err(format!(
Err(anyhow!(
"label value invalid - {} - found value : \"{}\"",
err_msgs.join(", "),
value
))
}
}

pub fn parse_validate_label_raw(raw: &str) -> Result<(String, String), String> {
pub fn parse_validate_label_raw(raw: &str) -> Result<(String, String)> {
let mut split = raw.split(':');

let mut err_msgs = vec![];
Expand All @@ -114,7 +115,7 @@ pub fn parse_validate_label_raw(raw: &str) -> Result<(String, String), String> {
let value = split.next().unwrap_or("").to_string();
Ok((key, value))
} else {
Err(format!(
Err(anyhow!(
"query invalid - {} - raw : \"{}\"",
err_msgs.join(", "),
raw
Expand Down Expand Up @@ -258,7 +259,7 @@ where
let mut labels_eq = HashMap::new();
for label in labels {
let (key, value) = parse_validate_label_raw(label)
.map_err(|e| err_formatter("query invalid".to_string(), e))?;
.map_err(|e| err_formatter("query invalid".to_string(), e.to_string()))?;

// if the key already exists, then it's a duplicate
if labels_eq.contains_key(&key) {
Expand All @@ -268,9 +269,9 @@ where
));
}
validate_label_key(key.as_str())
.map_err(|e| err_formatter("key invalid".to_string(), e))?;
.map_err(|e| err_formatter("key invalid".to_string(), e.to_string()))?;
validate_label_value(value.as_str())
.map_err(|e| err_formatter("value invalid".to_string(), e))?;
.map_err(|e| err_formatter("value invalid".to_string(), e.to_string()))?;

labels_eq.insert(key, value);
}
Expand Down
9 changes: 7 additions & 2 deletions src/data_repository_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,14 @@ impl DataRepositoryManager {
.await?;
new_content_metadata.push(content_metadata.clone());
for feature in content.features {
let index_table_name = extracted_content.output_to_index_table_mapping.get(&feature.name);
let index_table_name = extracted_content
.output_to_index_table_mapping
.get(&feature.name);
if index_table_name.is_none() {
error!("unable to find index table name for feature {}", feature.name);
error!(
"unable to find index table name for feature {}",
feature.name
);
continue;
}
let index_table_name = index_table_name.unwrap();
Expand Down
3 changes: 2 additions & 1 deletion src/executor_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use anyhow::Result;
use axum::{
extract::{DefaultBodyLimit, State},
routing::{get, post},
Json, Router,
Json,
Router,
};
use axum_otel_metrics::HttpMetricsLayerBuilder;
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
Expand Down

0 comments on commit 5aba436

Please sign in to comment.