Skip to content

Commit

Permalink
Add type checking for arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
d0cd committed Oct 27, 2023
1 parent 22766a4 commit 7e471b7
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 21 deletions.
8 changes: 8 additions & 0 deletions compiler/ast/src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ impl ArrayType {
pub fn length(&self) -> usize {
self.length.to_usize()
}

/// Returns the base element type of the array.
pub fn base_element_type(&self) -> &Type {
match self.element_type.as_ref() {
Type::Array(array_type) => array_type.base_element_type(),
type_ => type_,
}
}
}

impl fmt::Display for ArrayType {
Expand Down
2 changes: 1 addition & 1 deletion compiler/parser/src/parser/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ impl ParserContext<'_> {
});
}
// Check if next token is a dot to see if we are calling recursive method.
if !self.check(&Token::Dot) {
if !(self.check(&Token::Dot) || self.check(&Token::LeftSquare)) {
break;
}
}
Expand Down
71 changes: 70 additions & 1 deletion compiler/passes/src/type_checking/check_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use leo_ast::*;
use leo_errors::{emitter::Handler, TypeCheckerError};
use leo_span::{sym, Span};

use itertools::Itertools;
use snarkvm_console::network::{Network, Testnet3};
use std::str::FromStr;

fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: &Option<Type>) -> Option<Type> {
Expand All @@ -42,7 +44,29 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {

fn visit_access(&mut self, input: &'a AccessExpression, expected: &Self::AdditionalInput) -> Self::Output {
match input {
AccessExpression::Array(array) => todo!(),
AccessExpression::Array(access) => {
// Check that the expression is an array.
let array_type = self.visit_expression(&access.array, &None);
self.assert_array_type(&array_type, access.array.span());

// Check that the index is an integer type.
let index_type = self.visit_expression(&access.index, &None);
self.assert_int_type(&index_type, access.index.span());

// Get the element type of the array.
let element_type = match array_type {
Some(Type::Array(array_type)) => Some(array_type.element_type().clone()),
_ => None,
};

// If the expected type is known, then check that the element type is the same as the expected type.
if let Some(expected) = expected {
self.assert_type(&element_type, expected, input.span());
}

// Return the element type of the array.
return element_type;
}
AccessExpression::AssociatedFunction(access) => {
// Check core struct name and function.
if let Some(core_instruction) = self.get_core_function_call(&access.ty, &access.name) {
Expand Down Expand Up @@ -209,6 +233,51 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
None
}

fn visit_array(&mut self, input: &'a ArrayExpression, additional: &Self::AdditionalInput) -> Self::Output {
// Get the types of each element expression.
let element_types =
input.elements.iter().map(|element| self.visit_expression(element, &None)).collect::<Vec<_>>();

// Construct the array type.
let return_type = match element_types.len() {
// The array cannot be empty.
0 => {
self.emit_err(TypeCheckerError::array_empty(input.span()));
None
}
// Check that the element types match.
1..=Testnet3::MAX_ARRAY_ELEMENTS => {
let mut element_types = element_types.into_iter();
// Note that this unwrap is safe because we already checked that the array is not empty.
element_types.next().unwrap().map(|first_type| {
// Check that all elements have the same type.
for (element_type, element) in element_types.zip_eq(input.elements.iter().skip(1)) {
self.assert_type(&element_type, &first_type, element.span());
}
// Return the array type.
Type::Array(ArrayType::new(first_type, PositiveNumber { value: input.elements.len().to_string() }))
})
}
// The array cannot have more than `MAX_ARRAY_ELEMENTS` elements.
num_elements => {
self.emit_err(TypeCheckerError::array_too_large(
num_elements,
Testnet3::MAX_ARRAY_ELEMENTS,
input.span(),
));
None
}
};

// If the expected type is known, then check that the array type is the same as the expected type.
if let Some(expected) = additional {
self.assert_type(&return_type, expected, input.span());
}

// Return the array type.
return_type
}

fn visit_binary(&mut self, input: &'a BinaryExpression, destination: &Self::AdditionalInput) -> Self::Output {
match input.op {
BinaryOperation::And | BinaryOperation::Or | BinaryOperation::Nand | BinaryOperation::Nor => {
Expand Down
25 changes: 17 additions & 8 deletions compiler/passes/src/type_checking/check_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// TODO: Better span to target duplicate member.
if !input.members.iter().all(|Member { identifier, type_, span, .. }| {
// Check that the member types are defined.
self.assert_type_is_defined(type_, *span);
self.assert_type_is_valid(type_, *span);
used.insert(identifier.name)
}) {
self.emit_err(if input.is_record {
Expand Down Expand Up @@ -146,11 +146,20 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
}
// Ensure that there are no record members.
self.assert_member_is_not_record(identifier.span, input.identifier.name, type_);

// If the member is a struct, add it to the struct dependency graph.
// Note that we have already checked that each member is defined and valid.
if let Type::Identifier(member_type) = type_ {
self.struct_graph.add_edge(input.identifier.name, member_type.name);
} else if let Type::Array(array_type) = type_ {
// Get the base element type.
let base_element_type = array_type.base_element_type();
// If the base element type is a struct, then add it to the struct dependency graph.
if let Type::Identifier(member_type) = base_element_type {
self.struct_graph.add_edge(input.identifier.name, member_type.name);
}
}

// If the input is a struct, then check that the member does not have a mode.
if !input.is_record && !matches!(mode, Mode::None) {
self.emit_err(TypeCheckerError::struct_cannot_have_member_mode(*span));
Expand All @@ -160,7 +169,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {

fn visit_mapping(&mut self, input: &'a Mapping) {
// Check that a mapping's key type is valid.
self.assert_type_is_defined(&input.key_type, input.span);
self.assert_type_is_valid(&input.key_type, input.span);
// Check that a mapping's key type is not a tuple, record, or mapping.
match input.key_type {
Type::Tuple(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("key", "tuple", input.span)),
Expand All @@ -177,7 +186,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
}

// Check that a mapping's value type is valid.
self.assert_type_is_defined(&input.value_type, input.span);
self.assert_type_is_valid(&input.value_type, input.span);
// Check that a mapping's value type is not a tuple, record or mapping.
match input.value_type {
Type::Tuple(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("value", "tuple", input.span)),
Expand Down Expand Up @@ -226,7 +235,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Type check the function's parameters.
function.input.iter().for_each(|input_var| {
// Check that the type of input parameter is defined.
self.assert_type_is_defined(&input_var.type_(), input_var.span());
self.assert_type_is_valid(&input_var.type_(), input_var.span());
// Check that the type of the input parameter is not a tuple.
if matches!(input_var.type_(), Type::Tuple(_)) {
self.emit_err(TypeCheckerError::function_cannot_take_tuple_as_input(input_var.span()))
Expand Down Expand Up @@ -272,7 +281,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
}
Output::Internal(function_output) => {
// Check that the type of output is defined.
if self.assert_type_is_defined(&function_output.type_, function_output.span) {
if self.assert_type_is_valid(&function_output.type_, function_output.span) {
// If the function is not a transition function, then it cannot output a record.
if let Type::Identifier(identifier) = function_output.type_ {
if !matches!(function.variant, Variant::Transition)
Expand Down Expand Up @@ -337,7 +346,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {

finalize.input.iter().for_each(|input_var| {
// Check that the type of input parameter is defined.
if self.assert_type_is_defined(&input_var.type_(), input_var.span()) {
if self.assert_type_is_valid(&input_var.type_(), input_var.span()) {
// Check that the input parameter is not a tuple.
if matches!(input_var.type_(), Type::Tuple(_)) {
self.emit_err(TypeCheckerError::finalize_cannot_take_tuple_as_input(input_var.span()))
Expand Down Expand Up @@ -378,7 +387,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Note that checking that each of the component types are defined is sufficient to guarantee that the `output_type` is defined.
finalize.output.iter().for_each(|output_type| {
// Check that the type of output is defined.
if self.assert_type_is_defined(&output_type.type_(), output_type.span()) {
if self.assert_type_is_valid(&output_type.type_(), output_type.span()) {
// Check that the output is not a tuple. This is necessary to forbid nested tuples.
if matches!(&output_type.type_(), Type::Tuple(_)) {
self.emit_err(TypeCheckerError::nested_tuple_type(output_type.span()))
Expand Down Expand Up @@ -408,7 +417,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
self.visit_block(&finalize.block);

// Check that the return type is defined. Note that the component types are already checked.
self.assert_type_is_defined(&finalize.output_type, finalize.span);
self.assert_type_is_valid(&finalize.output_type, finalize.span);

// If the function has a return type, then check that it has a return.
if finalize.output_type != Type::Unit && !self.has_return {
Expand Down
2 changes: 1 addition & 1 deletion compiler/passes/src/type_checking/check_statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {

fn visit_definition(&mut self, input: &'a DefinitionStatement) {
// Check that the type of the definition is defined.
self.assert_type_is_defined(&input.type_, input.span);
self.assert_type_is_valid(&input.type_, input.span);

// Check that the type of the definition is not a unit type, singleton tuple type, or nested tuple type.
match &input.type_ {
Expand Down
30 changes: 24 additions & 6 deletions compiler/passes/src/type_checking/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

use crate::{CallGraph, StructGraph, SymbolTable};

use leo_ast::{CoreConstant, CoreFunction, Identifier, IntegerType, MappingType, Node, Type, Variant};
use leo_ast::{ArrayType, CoreConstant, CoreFunction, Identifier, IntegerType, MappingType, Node, Type, Variant};
use leo_errors::{emitter::Handler, TypeCheckerError};
use leo_span::{Span, Symbol};

use itertools::Itertools;
use snarkvm_console::network::{Network, Testnet3};
use std::cell::RefCell;

pub struct TypeChecker<'a> {
Expand Down Expand Up @@ -1054,8 +1055,8 @@ impl<'a> TypeChecker<'a> {
}
}

/// Emits an error if the type or its constituent types are not defined.
pub(crate) fn assert_type_is_defined(&self, type_: &Type, span: Span) -> bool {
/// Emits an error if the type or its constituent types is not valid.
pub(crate) fn assert_type_is_valid(&self, type_: &Type, span: Span) -> bool {
let mut is_defined = true;
match type_ {
// String types are temporarily disabled.
Expand All @@ -1071,13 +1072,25 @@ impl<'a> TypeChecker<'a> {
// Check that the constituent types of the tuple are valid.
Type::Tuple(tuple_type) => {
for type_ in tuple_type.iter() {
is_defined &= self.assert_type_is_defined(type_, span)
is_defined &= self.assert_type_is_valid(type_, span)
}
}
// Check that the constituent types of mapping are valid.
Type::Mapping(mapping_type) => {
is_defined &= self.assert_type_is_defined(&mapping_type.key, span);
is_defined &= self.assert_type_is_defined(&mapping_type.value, span);
is_defined &= self.assert_type_is_valid(&mapping_type.key, span);
is_defined &= self.assert_type_is_valid(&mapping_type.value, span);
}
// Check that the array element types are valid.
Type::Array(array_type) => {
// Check that the array length is valid.
match array_type.length() {
0 => self.emit_err(TypeCheckerError::array_empty(span)),
1..=Testnet3::MAX_ARRAY_ELEMENTS => {}
length => {
self.emit_err(TypeCheckerError::array_too_large(length, Testnet3::MAX_ARRAY_ELEMENTS, span))
}
}
is_defined &= self.assert_type_is_valid(array_type.element_type(), span)
}
_ => {} // Do nothing.
}
Expand All @@ -1092,6 +1105,11 @@ impl<'a> TypeChecker<'a> {
_ => None,
}
}

/// Emits an error if the type is not an array.
pub(crate) fn assert_array_type(&self, type_: &Option<Type>, span: Span) {
self.check_type(|type_| matches!(type_, Type::Array(_)), "array".to_string(), type_, span);
}
}

fn types_to_string(types: &[Type]) -> String {
Expand Down
14 changes: 14 additions & 0 deletions errors/src/errors/type_checker/type_checker_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,4 +684,18 @@ create_messages!(
msg: format!("A constant declaration statement can only bind a single value"),
help: None,
}

@formatted
array_empty {
args: (),
msg: format!("An array cannot be empty"),
help: None,
}

@formatted
array_too_large {
args: (size: impl Display, max: impl Display),
msg: format!("An array cannot have more than {max} elements, found one with {size} elements"),
help: None,
}
);
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ expectation: Pass

program test.aleo {
transition foo(a: [bool; 4]) {
for i: i32 in 0u32..4u32 {
for i: u32 in 0u32..4u32 {
assert(a[i]);
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/compiler/array/array_initialization_fail.leo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
namespace: Compile
expectation: Pass
expectation: Fail
*/

program test.aleo {
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/compiler/array/array_too_large_fail.leo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
namespace: Compile
expectation: Pass
expectation: Fail
*/

program test.aleo {
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/compiler/array/array_with_units_fail.leo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
namespace: Compile
expectation: Pass
expectation: Fail
*/

program test.aleo {
Expand Down

0 comments on commit 7e471b7

Please sign in to comment.