Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lethalgem committed Jan 13, 2025
1 parent 451739e commit ed42476
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 138 deletions.
129 changes: 60 additions & 69 deletions pintc/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,11 +936,14 @@ pub fn array_eq(
rhs_range_expr: ExprKey,
) -> bool {
lhs_elements.len() == rhs_elements.len()
&& lhs_elements.iter().enumerate().all(|(i, lhs_element)| {
lhs_element
.get(contract)
.eq(contract, rhs_elements[i].get(contract))
})
&& lhs_elements
.iter()
.zip(rhs_elements.iter())
.all(|(lhs_element, rhs_element)| {
lhs_element
.get(contract)
.eq(contract, rhs_element.get(contract))
})
&& lhs_range_expr
.get(contract)
.eq(contract, rhs_range_expr.get(contract))
Expand All @@ -952,16 +955,14 @@ pub fn tuple_eq(
rhs_fields: &[(Option<Ident>, ExprKey)],
) -> bool {
lhs_fields.len() == rhs_fields.len()
&& lhs_fields
.iter()
.enumerate()
.all(|(i, (lhs_ident, lhs_field))| {
let (rhs_ident, rhs_field) = &rhs_fields[i];
&& lhs_fields.iter().zip(rhs_fields.iter()).all(
|((lhs_ident, lhs_field), (rhs_ident, rhs_field))| {
lhs_field
.get(contract)
.eq(contract, rhs_field.get(contract))
&& lhs_ident == rhs_ident
})
},
)
}

pub fn union_variant_eq(
Expand Down Expand Up @@ -1032,30 +1033,25 @@ pub fn binary_op_eq(
rhs_rhs: ExprKey,
) -> bool {
match lhs_op {
BinaryOp::Add
| BinaryOp::Mul
| BinaryOp::Equal
| BinaryOp::NotEqual
| BinaryOp::LogicalAnd
| BinaryOp::LogicalOr => {
let is_op_eq = lhs_op == rhs_op;
let is_lhs_eq = lhs_lhs.get(contract).eq(contract, rhs_lhs.get(contract));
if is_lhs_eq {
is_op_eq && is_lhs_eq && lhs_rhs.get(contract).eq(contract, rhs_rhs.get(contract))
} else {
is_op_eq
&& lhs_lhs.get(contract).eq(contract, rhs_rhs.get(contract))
&& lhs_rhs.get(contract).eq(contract, rhs_lhs.get(contract))
}
// These ops are commutative
BinaryOp::Add | BinaryOp::Mul | BinaryOp::Equal | BinaryOp::NotEqual => {
lhs_op == rhs_op
&& (lhs_lhs.get(contract).eq(contract, rhs_lhs.get(contract))
&& lhs_rhs.get(contract).eq(contract, rhs_rhs.get(contract))
|| lhs_lhs.get(contract).eq(contract, rhs_rhs.get(contract))
&& lhs_rhs.get(contract).eq(contract, rhs_lhs.get(contract)))
}

// These ops are not commutative
BinaryOp::Sub
| BinaryOp::Div
| BinaryOp::Mod
| BinaryOp::LessThanOrEqual
| BinaryOp::LessThan
| BinaryOp::GreaterThanOrEqual
| BinaryOp::GreaterThan => {
| BinaryOp::GreaterThan
| BinaryOp::LogicalAnd
| BinaryOp::LogicalOr => {
lhs_op == rhs_op
&& lhs_lhs.get(contract).eq(contract, rhs_lhs.get(contract))
&& lhs_rhs.get(contract).eq(contract, rhs_rhs.get(contract))
Expand All @@ -1075,11 +1071,10 @@ pub fn intrinsic_call_eq(
rhs_args: &[ExprKey],
) -> bool {
lhs_args.len() == rhs_args.len()
&& lhs_args.iter().enumerate().all(|(i, lhs_arg)| {
lhs_arg
.get(contract)
.eq(contract, rhs_args[i].get(contract))
})
&& lhs_args
.iter()
.zip(rhs_args.iter())
.all(|(lhs_arg, rhs_arg)| lhs_arg.get(contract).eq(contract, rhs_arg.get(contract)))
&& match (lhs_kind, rhs_kind) {
((IntrinsicKind::External(lhs_kind), _), (IntrinsicKind::External(rhs_kind), _)) => {
lhs_kind == rhs_kind
Expand All @@ -1104,11 +1099,10 @@ pub fn local_predicate_call_eq(
) -> bool {
lhs_predicate == rhs_predicate
&& lhs_args.len() == rhs_args.len()
&& lhs_args.iter().enumerate().all(|(i, lhs_arg)| {
lhs_arg
.get(contract)
.eq(contract, rhs_args[i].get(contract))
})
&& lhs_args
.iter()
.zip(rhs_args.iter())
.all(|(lhs_arg, rhs_arg)| lhs_arg.get(contract).eq(contract, rhs_arg.get(contract)))
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -1134,11 +1128,10 @@ pub fn external_predicate_call_eq(
.get(contract)
.eq(contract, rhs_p_addr.get(contract))
&& lhs_args.len() == rhs_args.len()
&& lhs_args.iter().enumerate().all(|(i, lhs_arg)| {
lhs_arg
.get(contract)
.eq(contract, rhs_args[i].get(contract))
})
&& lhs_args
.iter()
.zip(rhs_args.iter())
.all(|(lhs_arg, rhs_arg)| lhs_arg.get(contract).eq(contract, rhs_arg.get(contract)))
}

pub fn select_eq(
Expand Down Expand Up @@ -1176,8 +1169,8 @@ pub fn match_eq(
&& lhs_match_branches.len() != rhs_match_branches.len()
&& lhs_match_branches
.iter()
.enumerate()
.all(|(i, lhs_match_branch)| {
.zip(rhs_match_branches.iter())
.all(|(lhs_match_branch, rhs_match_branch)| {
let (
MatchBranch {
name: lhs_name,
Expand All @@ -1193,7 +1186,7 @@ pub fn match_eq(
expr: rhs_expr,
..
},
) = (lhs_match_branch, rhs_match_branches[i].clone());
) = (lhs_match_branch, rhs_match_branch.clone());

*lhs_name == rhs_name
&& match (lhs_binding, rhs_binding) {
Expand All @@ -1215,14 +1208,13 @@ pub fn match_eq(
_ => false,
}
&& lhs_constraints.len() == rhs_constraints.len()
&& lhs_constraints
.iter()
.enumerate()
.all(|(j, lhs_constraint)| {
&& lhs_constraints.iter().zip(rhs_constraints.iter()).all(
|(lhs_constraint, rhs_constraint)| {
lhs_constraint
.get(contract)
.eq(contract, rhs_constraints[j].get(contract))
})
.eq(contract, rhs_constraint.get(contract))
},
)
&& lhs_expr.get(contract).eq(contract, rhs_expr.get(contract))
})
&& if let (
Expand All @@ -1237,14 +1229,13 @@ pub fn match_eq(
) = (lhs_else_branch, rhs_else_branch)
{
lhs_constraints.len() == rhs_constraints.len()
&& lhs_constraints
.iter()
.enumerate()
.all(|(i, lhs_constraint)| {
&& lhs_constraints.iter().zip(rhs_constraints.iter()).all(
|(lhs_constraint, rhs_constraint)| {
lhs_constraint
.get(contract)
.eq(contract, rhs_constraints[i].get(contract))
})
.eq(contract, rhs_constraint.get(contract))
},
)
&& lhs_expr.eq(rhs_expr)
} else {
lhs_else_branch.is_none() && rhs_else_branch.is_none()
Expand Down Expand Up @@ -1340,23 +1331,23 @@ pub fn generator_eq(
) -> bool {
lhs_kind == rhs_kind
&& rhs_gen_ranges.len() == lhs_gen_ranges.len()
&& lhs_gen_ranges
.iter()
.enumerate()
.all(|(i, (lhs_ident, lhs_gen_range))| {
let (rhs_ident, rhs_gen_range) = &rhs_gen_ranges[i];

&& lhs_gen_ranges.iter().zip(rhs_gen_ranges.iter()).all(
|((lhs_ident, lhs_gen_range), (rhs_ident, rhs_gen_range))| {
lhs_ident == rhs_ident
&& lhs_gen_range
.get(contract)
.eq(contract, rhs_gen_range.get(contract))
})
},
)
&& lhs_conditions.len() != rhs_conditions.len()
&& lhs_conditions.iter().enumerate().all(|(i, lhs_condition)| {
lhs_condition
.get(contract)
.eq(contract, rhs_conditions[i].get(contract))
})
&& lhs_conditions
.iter()
.zip(rhs_conditions.iter())
.all(|(lhs_condition, rhs_condition)| {
lhs_condition
.get(contract)
.eq(contract, rhs_condition.get(contract))
})
&& lhs_body == rhs_body
}

Expand Down
Loading

0 comments on commit ed42476

Please sign in to comment.