diff --git a/compiler/ast/src/types/array.rs b/compiler/ast/src/types/array.rs index 132f6cd928..6e0d644c3d 100644 --- a/compiler/ast/src/types/array.rs +++ b/compiler/ast/src/types/array.rs @@ -41,6 +41,14 @@ impl ArrayType { pub fn length(&self) -> usize { self.length.to_usize() } + + /// Returns the base element type of the array. + pub fn base_element_type(&self) -> &Type { + match self.element_type.as_ref() { + Type::Array(array_type) => array_type.base_element_type(), + type_ => type_, + } + } } impl fmt::Display for ArrayType { diff --git a/compiler/parser/src/parser/expression.rs b/compiler/parser/src/parser/expression.rs index b01cd78166..bd62fdaad9 100644 --- a/compiler/parser/src/parser/expression.rs +++ b/compiler/parser/src/parser/expression.rs @@ -512,7 +512,7 @@ impl ParserContext<'_> { }); } // Check if next token is a dot to see if we are calling recursive method. - if !self.check(&Token::Dot) { + if !(self.check(&Token::Dot) || self.check(&Token::LeftSquare)) { break; } } diff --git a/compiler/passes/src/type_checking/check_expressions.rs b/compiler/passes/src/type_checking/check_expressions.rs index 317068d8fb..e5e1f7e029 100644 --- a/compiler/passes/src/type_checking/check_expressions.rs +++ b/compiler/passes/src/type_checking/check_expressions.rs @@ -20,6 +20,8 @@ use leo_ast::*; use leo_errors::{emitter::Handler, TypeCheckerError}; use leo_span::{sym, Span}; +use itertools::Itertools; +use snarkvm_console::network::{Network, Testnet3}; use std::str::FromStr; fn return_incorrect_type(t1: Option, t2: Option, expected: &Option) -> Option { @@ -42,7 +44,29 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { fn visit_access(&mut self, input: &'a AccessExpression, expected: &Self::AdditionalInput) -> Self::Output { match input { - AccessExpression::Array(array) => todo!(), + AccessExpression::Array(access) => { + // Check that the expression is an array. + let array_type = self.visit_expression(&access.array, &None); + self.assert_array_type(&array_type, access.array.span()); + + // Check that the index is an integer type. + let index_type = self.visit_expression(&access.index, &None); + self.assert_int_type(&index_type, access.index.span()); + + // Get the element type of the array. + let element_type = match array_type { + Some(Type::Array(array_type)) => Some(array_type.element_type().clone()), + _ => None, + }; + + // If the expected type is known, then check that the element type is the same as the expected type. + if let Some(expected) = expected { + self.assert_type(&element_type, expected, input.span()); + } + + // Return the element type of the array. + return element_type; + } AccessExpression::AssociatedFunction(access) => { // Check core struct name and function. if let Some(core_instruction) = self.get_core_function_call(&access.ty, &access.name) { @@ -209,6 +233,51 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { None } + fn visit_array(&mut self, input: &'a ArrayExpression, additional: &Self::AdditionalInput) -> Self::Output { + // Get the types of each element expression. + let element_types = + input.elements.iter().map(|element| self.visit_expression(element, &None)).collect::>(); + + // Construct the array type. + let return_type = match element_types.len() { + // The array cannot be empty. + 0 => { + self.emit_err(TypeCheckerError::array_empty(input.span())); + None + } + // Check that the element types match. + 1..=Testnet3::MAX_ARRAY_ELEMENTS => { + let mut element_types = element_types.into_iter(); + // Note that this unwrap is safe because we already checked that the array is not empty. + element_types.next().unwrap().map(|first_type| { + // Check that all elements have the same type. + for (element_type, element) in element_types.zip_eq(input.elements.iter().skip(1)) { + self.assert_type(&element_type, &first_type, element.span()); + } + // Return the array type. + Type::Array(ArrayType::new(first_type, PositiveNumber { value: input.elements.len().to_string() })) + }) + } + // The array cannot have more than `MAX_ARRAY_ELEMENTS` elements. + num_elements => { + self.emit_err(TypeCheckerError::array_too_large( + num_elements, + Testnet3::MAX_ARRAY_ELEMENTS, + input.span(), + )); + None + } + }; + + // If the expected type is known, then check that the array type is the same as the expected type. + if let Some(expected) = additional { + self.assert_type(&return_type, expected, input.span()); + } + + // Return the array type. + return_type + } + fn visit_binary(&mut self, input: &'a BinaryExpression, destination: &Self::AdditionalInput) -> Self::Output { match input.op { BinaryOperation::And | BinaryOperation::Or | BinaryOperation::Nand | BinaryOperation::Nor => { diff --git a/compiler/passes/src/type_checking/check_program.rs b/compiler/passes/src/type_checking/check_program.rs index 2de2e765b3..d4cb6eefc7 100644 --- a/compiler/passes/src/type_checking/check_program.rs +++ b/compiler/passes/src/type_checking/check_program.rs @@ -109,7 +109,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { // TODO: Better span to target duplicate member. if !input.members.iter().all(|Member { identifier, type_, span, .. }| { // Check that the member types are defined. - self.assert_type_is_defined(type_, *span); + self.assert_type_is_valid(type_, *span); used.insert(identifier.name) }) { self.emit_err(if input.is_record { @@ -146,11 +146,20 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { } // Ensure that there are no record members. self.assert_member_is_not_record(identifier.span, input.identifier.name, type_); + // If the member is a struct, add it to the struct dependency graph. // Note that we have already checked that each member is defined and valid. if let Type::Identifier(member_type) = type_ { self.struct_graph.add_edge(input.identifier.name, member_type.name); + } else if let Type::Array(array_type) = type_ { + // Get the base element type. + let base_element_type = array_type.base_element_type(); + // If the base element type is a struct, then add it to the struct dependency graph. + if let Type::Identifier(member_type) = base_element_type { + self.struct_graph.add_edge(input.identifier.name, member_type.name); + } } + // If the input is a struct, then check that the member does not have a mode. if !input.is_record && !matches!(mode, Mode::None) { self.emit_err(TypeCheckerError::struct_cannot_have_member_mode(*span)); @@ -160,7 +169,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { fn visit_mapping(&mut self, input: &'a Mapping) { // Check that a mapping's key type is valid. - self.assert_type_is_defined(&input.key_type, input.span); + self.assert_type_is_valid(&input.key_type, input.span); // Check that a mapping's key type is not a tuple, record, or mapping. match input.key_type { Type::Tuple(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("key", "tuple", input.span)), @@ -177,7 +186,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { } // Check that a mapping's value type is valid. - self.assert_type_is_defined(&input.value_type, input.span); + self.assert_type_is_valid(&input.value_type, input.span); // Check that a mapping's value type is not a tuple, record or mapping. match input.value_type { Type::Tuple(_) => self.emit_err(TypeCheckerError::invalid_mapping_type("value", "tuple", input.span)), @@ -226,7 +235,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { // Type check the function's parameters. function.input.iter().for_each(|input_var| { // Check that the type of input parameter is defined. - self.assert_type_is_defined(&input_var.type_(), input_var.span()); + self.assert_type_is_valid(&input_var.type_(), input_var.span()); // Check that the type of the input parameter is not a tuple. if matches!(input_var.type_(), Type::Tuple(_)) { self.emit_err(TypeCheckerError::function_cannot_take_tuple_as_input(input_var.span())) @@ -272,7 +281,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { } Output::Internal(function_output) => { // Check that the type of output is defined. - if self.assert_type_is_defined(&function_output.type_, function_output.span) { + if self.assert_type_is_valid(&function_output.type_, function_output.span) { // If the function is not a transition function, then it cannot output a record. if let Type::Identifier(identifier) = function_output.type_ { if !matches!(function.variant, Variant::Transition) @@ -337,7 +346,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { finalize.input.iter().for_each(|input_var| { // Check that the type of input parameter is defined. - if self.assert_type_is_defined(&input_var.type_(), input_var.span()) { + if self.assert_type_is_valid(&input_var.type_(), input_var.span()) { // Check that the input parameter is not a tuple. if matches!(input_var.type_(), Type::Tuple(_)) { self.emit_err(TypeCheckerError::finalize_cannot_take_tuple_as_input(input_var.span())) @@ -378,7 +387,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { // Note that checking that each of the component types are defined is sufficient to guarantee that the `output_type` is defined. finalize.output.iter().for_each(|output_type| { // Check that the type of output is defined. - if self.assert_type_is_defined(&output_type.type_(), output_type.span()) { + if self.assert_type_is_valid(&output_type.type_(), output_type.span()) { // Check that the output is not a tuple. This is necessary to forbid nested tuples. if matches!(&output_type.type_(), Type::Tuple(_)) { self.emit_err(TypeCheckerError::nested_tuple_type(output_type.span())) @@ -408,7 +417,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { self.visit_block(&finalize.block); // Check that the return type is defined. Note that the component types are already checked. - self.assert_type_is_defined(&finalize.output_type, finalize.span); + self.assert_type_is_valid(&finalize.output_type, finalize.span); // If the function has a return type, then check that it has a return. if finalize.output_type != Type::Unit && !self.has_return { diff --git a/compiler/passes/src/type_checking/check_statements.rs b/compiler/passes/src/type_checking/check_statements.rs index a33364080c..bda7770dcc 100644 --- a/compiler/passes/src/type_checking/check_statements.rs +++ b/compiler/passes/src/type_checking/check_statements.rs @@ -203,7 +203,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { fn visit_definition(&mut self, input: &'a DefinitionStatement) { // Check that the type of the definition is defined. - self.assert_type_is_defined(&input.type_, input.span); + self.assert_type_is_valid(&input.type_, input.span); // Check that the type of the definition is not a unit type, singleton tuple type, or nested tuple type. match &input.type_ { diff --git a/compiler/passes/src/type_checking/checker.rs b/compiler/passes/src/type_checking/checker.rs index f0c5c1364c..69873b0331 100644 --- a/compiler/passes/src/type_checking/checker.rs +++ b/compiler/passes/src/type_checking/checker.rs @@ -16,11 +16,12 @@ use crate::{CallGraph, StructGraph, SymbolTable}; -use leo_ast::{CoreConstant, CoreFunction, Identifier, IntegerType, MappingType, Node, Type, Variant}; +use leo_ast::{ArrayType, CoreConstant, CoreFunction, Identifier, IntegerType, MappingType, Node, Type, Variant}; use leo_errors::{emitter::Handler, TypeCheckerError}; use leo_span::{Span, Symbol}; use itertools::Itertools; +use snarkvm_console::network::{Network, Testnet3}; use std::cell::RefCell; pub struct TypeChecker<'a> { @@ -1054,8 +1055,8 @@ impl<'a> TypeChecker<'a> { } } - /// Emits an error if the type or its constituent types are not defined. - pub(crate) fn assert_type_is_defined(&self, type_: &Type, span: Span) -> bool { + /// Emits an error if the type or its constituent types is not valid. + pub(crate) fn assert_type_is_valid(&self, type_: &Type, span: Span) -> bool { let mut is_defined = true; match type_ { // String types are temporarily disabled. @@ -1071,13 +1072,25 @@ impl<'a> TypeChecker<'a> { // Check that the constituent types of the tuple are valid. Type::Tuple(tuple_type) => { for type_ in tuple_type.iter() { - is_defined &= self.assert_type_is_defined(type_, span) + is_defined &= self.assert_type_is_valid(type_, span) } } // Check that the constituent types of mapping are valid. Type::Mapping(mapping_type) => { - is_defined &= self.assert_type_is_defined(&mapping_type.key, span); - is_defined &= self.assert_type_is_defined(&mapping_type.value, span); + is_defined &= self.assert_type_is_valid(&mapping_type.key, span); + is_defined &= self.assert_type_is_valid(&mapping_type.value, span); + } + // Check that the array element types are valid. + Type::Array(array_type) => { + // Check that the array length is valid. + match array_type.length() { + 0 => self.emit_err(TypeCheckerError::array_empty(span)), + 1..=Testnet3::MAX_ARRAY_ELEMENTS => {} + length => { + self.emit_err(TypeCheckerError::array_too_large(length, Testnet3::MAX_ARRAY_ELEMENTS, span)) + } + } + is_defined &= self.assert_type_is_valid(array_type.element_type(), span) } _ => {} // Do nothing. } @@ -1092,6 +1105,11 @@ impl<'a> TypeChecker<'a> { _ => None, } } + + /// Emits an error if the type is not an array. + pub(crate) fn assert_array_type(&self, type_: &Option, span: Span) { + self.check_type(|type_| matches!(type_, Type::Array(_)), "array".to_string(), type_, span); + } } fn types_to_string(types: &[Type]) -> String { diff --git a/errors/src/errors/type_checker/type_checker_error.rs b/errors/src/errors/type_checker/type_checker_error.rs index 2b020a5016..c013b597d7 100644 --- a/errors/src/errors/type_checker/type_checker_error.rs +++ b/errors/src/errors/type_checker/type_checker_error.rs @@ -684,4 +684,18 @@ create_messages!( msg: format!("A constant declaration statement can only bind a single value"), help: None, } + + @formatted + array_empty { + args: (), + msg: format!("An array cannot be empty"), + help: None, + } + + @formatted + array_too_large { + args: (size: impl Display, max: impl Display), + msg: format!("An array cannot have more than {max} elements, found one with {size} elements"), + help: None, + } ); diff --git a/tests/tests/compiler/array/access_array_with_loop_counter.leo b/tests/tests/compiler/array/access_array_with_loop_counter.leo index 534ab6d605..23e38e4bde 100644 --- a/tests/tests/compiler/array/access_array_with_loop_counter.leo +++ b/tests/tests/compiler/array/access_array_with_loop_counter.leo @@ -6,7 +6,7 @@ expectation: Pass program test.aleo { transition foo(a: [bool; 4]) { - for i: i32 in 0u32..4u32 { + for i: u32 in 0u32..4u32 { assert(a[i]); } } diff --git a/tests/tests/compiler/array/array_initialization_fail.leo b/tests/tests/compiler/array/array_initialization_fail.leo index 99e1ce2315..46d07e37a2 100644 --- a/tests/tests/compiler/array/array_initialization_fail.leo +++ b/tests/tests/compiler/array/array_initialization_fail.leo @@ -1,6 +1,6 @@ /* namespace: Compile -expectation: Pass +expectation: Fail */ program test.aleo { diff --git a/tests/tests/compiler/array/array_too_large_fail.leo b/tests/tests/compiler/array/array_too_large_fail.leo index 6c8696d830..46b600adac 100644 --- a/tests/tests/compiler/array/array_too_large_fail.leo +++ b/tests/tests/compiler/array/array_too_large_fail.leo @@ -1,6 +1,6 @@ /* namespace: Compile -expectation: Pass +expectation: Fail */ program test.aleo { diff --git a/tests/tests/compiler/array/array_with_units_fail.leo b/tests/tests/compiler/array/array_with_units_fail.leo index bbd887b8d7..77e0f38c28 100644 --- a/tests/tests/compiler/array/array_with_units_fail.leo +++ b/tests/tests/compiler/array/array_with_units_fail.leo @@ -1,6 +1,6 @@ /* namespace: Compile -expectation: Pass +expectation: Fail */ program test.aleo {