Skip to content

Commit

Permalink
Add functions for comparisons, remove desugaring hint
Browse files Browse the repository at this point in the history
  • Loading branch information
ap29600 committed Sep 23, 2022
1 parent 41cec35 commit 5a96d16
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 14 deletions.
144 changes: 144 additions & 0 deletions src/funcs/arithmetic.c
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,150 @@ Eval_Node func_equal (Eval_Node *left, Eval_Node *right) {
return (Eval_Node){.type = Node_Array, .as.array = result};
}

Eval_Node func_greater_equal (Eval_Node *left, Eval_Node *right) {
assert(left ->type == Node_Array);
assert(right->type == Node_Array);
Array *left_ = borrow_array(left ->as.array);
Array *right_ = borrow_array(right->as.array);
assert(left_->shape == right_->shape);

Element_Type natural_type = supertype(left_->type, right_->type);

left_ = array_cast(left_, natural_type);
right_ = array_cast(right_, natural_type);

Array *result = make_array(NULL, left_->shape, Type_Bool);

u64 n = result->shape;
switch (natural_type) {
case Type_Char:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((char*)left_->data)[i] >= ((char*)right_->data)[i];
break;
case Type_Int:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((i64*)left_->data)[i] >= ((i64*)right_->data)[i];
break;
case Type_Float:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((f64*)left_->data)[i] >= ((f64*)right_->data)[i];
break;
case Type_Bool:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((bool*)left_->data)[i] >= ((bool*)right_->data)[i];
break;
case Types_Count: assert(false);
}

release_array(left_);
release_array(right_);
return (Eval_Node){.type = Node_Array, .as.array = result};
}

Eval_Node func_greater (Eval_Node *left, Eval_Node *right) {
assert(left ->type == Node_Array);
assert(right->type == Node_Array);
Array *left_ = borrow_array(left ->as.array);
Array *right_ = borrow_array(right->as.array);
assert(left_->shape == right_->shape);

Element_Type natural_type = supertype(left_->type, right_->type);

left_ = array_cast(left_, natural_type);
right_ = array_cast(right_, natural_type);

Array *result = make_array(NULL, left_->shape, Type_Bool);

u64 n = result->shape;
switch (natural_type) {
case Type_Char:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((char*)left_->data)[i] > ((char*)right_->data)[i];
break;
case Type_Int:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((i64*)left_->data)[i] > ((i64*)right_->data)[i];
break;
case Type_Float:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((f64*)left_->data)[i] > ((f64*)right_->data)[i];
break;
case Type_Bool:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((bool*)left_->data)[i] > ((bool*)right_->data)[i];
break;
case Types_Count: assert(false);
}

release_array(left_);
release_array(right_);
return (Eval_Node){.type = Node_Array, .as.array = result};
}

Eval_Node func_less (Eval_Node *left, Eval_Node *right) {
assert(left ->type == Node_Array);
assert(right->type == Node_Array);
Array *left_ = borrow_array(left ->as.array);
Array *right_ = borrow_array(right->as.array);
assert(left_->shape == right_->shape);

Element_Type natural_type = supertype(left_->type, right_->type);

left_ = array_cast(left_, natural_type);
right_ = array_cast(right_, natural_type);

Array *result = make_array(NULL, left_->shape, Type_Bool);

u64 n = result->shape;
switch (natural_type) {
case Type_Char:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((char*)left_->data)[i] < ((char*)right_->data)[i];
break;
case Type_Int:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((i64*)left_->data)[i] < ((i64*)right_->data)[i];
break;
case Type_Float:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((f64*)left_->data)[i] < ((f64*)right_->data)[i];
break;
case Type_Bool:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((bool*)left_->data)[i] < ((bool*)right_->data)[i];
break;
case Types_Count: assert(false);
}

release_array(left_);
release_array(right_);
return (Eval_Node){.type = Node_Array, .as.array = result};
}

Eval_Node func_less_equal (Eval_Node *left, Eval_Node *right) {
assert(left ->type == Node_Array);
assert(right->type == Node_Array);
Array *left_ = borrow_array(left ->as.array);
Array *right_ = borrow_array(right->as.array);
assert(left_->shape == right_->shape);

Element_Type natural_type = supertype(left_->type, right_->type);

left_ = array_cast(left_, natural_type);
right_ = array_cast(right_, natural_type);

Array *result = make_array(NULL, left_->shape, Type_Bool);

u64 n = result->shape;
switch (natural_type) {
case Type_Char:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((char*)left_->data)[i] <= ((char*)right_->data)[i];
break;
case Type_Int:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((i64*)left_->data)[i] <= ((i64*)right_->data)[i];
break;
case Type_Float:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((f64*)left_->data)[i] <= ((f64*)right_->data)[i];
break;
case Type_Bool:
for(i64 i = 0; i < n; i++) ((bool*)result->data)[i] = ((bool*)left_->data)[i] <= ((bool*)right_->data)[i];
break;
case Types_Count: assert(false);
}

release_array(left_);
release_array(right_);
return (Eval_Node){.type = Node_Array, .as.array = result};
}

Eval_Node func_negate (Eval_Node *left, Eval_Node *right) {
assert(right->type == Node_Array);
Array *right_ = borrow_array(right->as.array);
Expand Down
4 changes: 4 additions & 0 deletions src/funcs/funcs.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ void init_default_scope() {
scope_insert(&default_scope, (Lookup_Entry){ .name = "/", .as_monadic = func_square_root, .as_dyadic = func_divide });
scope_insert(&default_scope, (Lookup_Entry){ .name = "~", .as_monadic = func_complement, .as_dyadic = func_mismatch });
scope_insert(&default_scope, (Lookup_Entry){ .name = "=", .as_monadic = NULL, .as_dyadic = func_equal });
scope_insert(&default_scope, (Lookup_Entry){ .name = ">=", .as_monadic = NULL, .as_dyadic = func_greater_equal });
scope_insert(&default_scope, (Lookup_Entry){ .name = ">", .as_monadic = NULL, .as_dyadic = func_greater });
scope_insert(&default_scope, (Lookup_Entry){ .name = "<=", .as_monadic = NULL, .as_dyadic = func_less_equal });
scope_insert(&default_scope, (Lookup_Entry){ .name = "<", .as_monadic = NULL, .as_dyadic = func_less });
scope_insert(&default_scope, (Lookup_Entry){ .name = "?", .as_monadic = NULL, .as_dyadic = func_filter });
scope_insert(&default_scope, (Lookup_Entry){ .name = "$", .as_monadic = func_shape, .as_dyadic = func_reshape });
}
Expand Down
28 changes: 16 additions & 12 deletions src/funcs/funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
void init_default_scope();
extern Lookup_Scope default_scope;

Eval_Node func_plus (Eval_Node *left, Eval_Node *right);
Eval_Node func_minus (Eval_Node *left, Eval_Node *right);
Eval_Node func_multiply (Eval_Node *left, Eval_Node *right);
Eval_Node func_divide (Eval_Node *left, Eval_Node *right);
Eval_Node func_mismatch (Eval_Node *left, Eval_Node *right);
Eval_Node func_equal (Eval_Node *left, Eval_Node *right);
Eval_Node func_negate (Eval_Node *left, Eval_Node *right);
Eval_Node func_complement (Eval_Node *left, Eval_Node *right);
Eval_Node func_square_root(Eval_Node *left, Eval_Node *right);
Eval_Node func_plus ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_minus ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_multiply ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_divide ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_mismatch ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_equal ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_greater_equal ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_greater ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_less_equal ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_less ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_negate ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_complement ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_square_root ( Eval_Node *left, Eval_Node *right ) ;

Eval_Node func_filter (Eval_Node *left, Eval_Node *right);
Eval_Node func_shape (Eval_Node *left, Eval_Node *right);
Eval_Node func_reshape (Eval_Node *left, Eval_Node *right);
Eval_Node func_filter ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_shape ( Eval_Node *left, Eval_Node *right ) ;
Eval_Node func_reshape ( Eval_Node *left, Eval_Node *right ) ;

#endif // FUNCS_H
4 changes: 2 additions & 2 deletions src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ i32 main () {

Scanner scanner = {.source = src, .location.fname = "stdin"};
Ast ast = parse_expressions(&scanner);
set_format_user_ptr(ast.nodes);
format_println("desugars to: {expr}", ast.nodes[ast.parent]);
// set_format_user_ptr(ast.nodes);
// format_println("desugars to: {expr}", ast.nodes[ast.parent]);

Eval_Context ctx = {.scope = &default_scope};
Node_Handle expr = apply(&ctx, ast.nodes, ast.parent);
Expand Down

0 comments on commit 5a96d16

Please sign in to comment.