Skip to content

Commit

Permalink
Structs as function params working!
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanBrouwer committed Nov 5, 2023
1 parent 81da3f2 commit c3b2f85
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 95 deletions.
29 changes: 12 additions & 17 deletions compiler/src/passes/atomize/atomize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,29 @@ fn atomize_expr(expr: RExpr) -> AExpr {
},
RExpr::Apply { fun, args, typ } => {
let (args, extras): (Vec<_>, Vec<_>) = args.into_iter().map(atomize_atom).unzip();
let fn_typ = fun.typ().clone();

let (fun, fun_expr) = atomize_atom(*fun);

fun_expr
.into_iter()
.chain(extras.into_iter().flatten())
.rfold(AExpr::Apply { fun, args, typ }, |bdy, (sym, bnd)| {
AExpr::Let {
.rfold(
AExpr::Apply {
fun,
args,
typ,
fn_typ,
},
|bdy, (sym, bnd)| AExpr::Let {
typ: bnd.typ().clone(),
sym,
bnd: Box::new(bnd),
bdy: Box::new(bdy),
}
})
}
RExpr::FunRef { sym, typ } => {
AExpr::FunRef { sym, typ }

// let tmp = gen_sym("tmp");
// AExpr::Let {
// typ,
// sym: tmp,
// bnd: Box::new(AExpr::FunRef { sym }),
// bdy: Box::new(AExpr::Atom {
// atm: Atom::Var { sym: tmp },
// }),
// }
},
)
}
RExpr::FunRef { sym, typ } => AExpr::FunRef { sym, typ },
RExpr::Loop { bdy, typ } => AExpr::Loop {
bdy: Box::new(atomize_expr(*bdy)),
typ,
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/passes/atomize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum AExpr<'p> {
Apply {
fun: Atom<'p>,
args: Vec<Atom<'p>>,
fn_typ: Type<UniqueSym<'p>>,
typ: Type<UniqueSym<'p>>,
},
FunRef {
Expand Down Expand Up @@ -175,7 +176,7 @@ impl<'p> From<AExpr<'p>> for TExpr<'p, UniqueSym<'p>> {
els: Box::new((*els).into()),
typ,
},
AExpr::Apply { fun, args, typ } => TExpr::Apply {
AExpr::Apply { fun, args, typ, .. } => TExpr::Apply {
fun: Box::new(fun.into()),
args: args.into_iter().map(Into::into).collect(),
typ,
Expand Down
146 changes: 73 additions & 73 deletions compiler/src/passes/eliminate_algebraic/eliminate.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::passes::atomize::Atom;
use crate::passes::eliminate_algebraic::eliminate_params::{eliminate_params, flatten_params};
use crate::passes::eliminate_algebraic::{EExpr, PrgEliminated};
use crate::passes::explicate::{CExpr, PrgExplicated, Tail};
use crate::passes::parse::types::Type;
use crate::passes::parse::{Param, TypeDef};
use crate::passes::parse::TypeDef;
use crate::utils::gen_sym::{gen_sym, UniqueSym};
use std::collections::HashMap;

// (Old variable name, field name) -> New variable name
type Ctx<'p> = HashMap<(UniqueSym<'p>, &'p str), UniqueSym<'p>>;
pub type Ctx<'p> = HashMap<(UniqueSym<'p>, &'p str), UniqueSym<'p>>;

impl<'p> PrgExplicated<'p> {
pub fn eliminate(self) -> PrgEliminated<'p> {
Expand All @@ -28,56 +29,6 @@ impl<'p> PrgExplicated<'p> {
}
}

fn eliminate_params<'p>(
fn_params: HashMap<UniqueSym<'p>, Vec<Param<UniqueSym<'p>>>>,
ctx: &mut Ctx<'p>,
defs: &HashMap<UniqueSym<'p>, TypeDef<'p, UniqueSym<'p>>>,
) -> HashMap<UniqueSym<'p>, Vec<Param<UniqueSym<'p>>>> {
fn_params
.into_iter()
.map(|(sym, params)| {
(
sym,
params
.into_iter()
.flat_map(|param| {
flatten_params(param.sym, &param.typ, param.mutable, ctx, defs)
})
.collect(),
)
})
.collect()
}

fn flatten_params<'p>(
param_sym: UniqueSym<'p>,
param_type: &Type<UniqueSym<'p>>,
mutable: bool,
ctx: &mut Ctx<'p>,
defs: &HashMap<UniqueSym<'p>, TypeDef<'p, UniqueSym<'p>>>,
) -> Vec<Param<UniqueSym<'p>>> {
match param_type {
Type::Int | Type::Bool | Type::Unit | Type::Never | Type::Fn { .. } => vec![Param {
sym: param_sym,
typ: param_type.clone(),
mutable,
}],
Type::Var { sym } => match &defs[&sym] {
TypeDef::Struct { fields } => fields
.iter()
.flat_map(|(field_name, field_type)| {
let new_sym = *ctx
.entry((param_sym, field_name))
.or_insert_with(|| gen_sym(param_sym.sym));

flatten_params(new_sym, field_type, mutable, ctx, defs).into_iter()
})
.collect(),
TypeDef::Enum { .. } => todo!(),
},
}
}

fn eliminate_tail<'p>(
tail: Tail<'p, CExpr<'p>>,
ctx: &mut Ctx<'p>,
Expand Down Expand Up @@ -106,27 +57,66 @@ fn eliminate_seq<'p>(
tail: Tail<'p, EExpr<'p>>,
defs: &HashMap<UniqueSym<'p>, TypeDef<'p, UniqueSym<'p>>>,
) -> Tail<'p, EExpr<'p>> {
if let CExpr::AccessField {
strct,
field,
typ: field_typ,
} = bnd
{
let strct = strct.var();
let new_sym = *ctx
.entry((strct, field))
.or_insert_with(|| gen_sym(sym.sym));
let bnd = match bnd {
CExpr::AccessField {
strct,
field,
typ: field_typ,
} => {
let strct = strct.var();
let new_sym = *ctx
.entry((strct, field))
.or_insert_with(|| gen_sym(sym.sym));

return eliminate_seq(
sym,
ctx,
CExpr::Atom {
atm: Atom::Var { sym: new_sym },
typ: field_typ,
},
tail,
defs,
);
return eliminate_seq(
sym,
ctx,
CExpr::Atom {
atm: Atom::Var { sym: new_sym },
typ: field_typ,
},
tail,
defs,
);
}
CExpr::Apply {
fun,
args,
fn_typ,
typ,
} => {
#[rustfmt::skip]
let Type::Fn { params, typ: rtrn_typ} = fn_typ else {
unreachable!("fn_type should be a function type")
};

let (args, params): (Vec<_>, Vec<_>) = args
.into_iter()
.zip(params.into_iter())
.flat_map(|(atom, typ)| {
match atom {
Atom::Val { val } => vec![(Atom::Val { val }, typ)],
Atom::Var { sym } => {
flatten_params(sym, &typ, ctx, defs)
.into_iter()
.map(|(sym, typ)| (Atom::Var { sym }, typ))
.collect()
}
}
})
.unzip();

CExpr::Apply {
fun,
args,
fn_typ: Type::Fn {
params,
typ: rtrn_typ.clone(),
},
typ,
}
}
_ => bnd,
};

match bnd.typ() {
Expand Down Expand Up @@ -179,7 +169,17 @@ fn map_expr(e: CExpr) -> EExpr {
match e {
CExpr::Atom { atm, typ } => EExpr::Atom { atm, typ },
CExpr::Prim { op, args, typ } => EExpr::Prim { op, args, typ },
CExpr::Apply { fun, args, typ } => EExpr::Apply { fun, args, typ },
CExpr::Apply {
fun,
args,
typ,
fn_typ,
} => EExpr::Apply {
fun,
args,
typ,
fn_typ,
},
CExpr::FunRef { sym, typ } => EExpr::FunRef { sym, typ },
CExpr::Struct { .. } | CExpr::AccessField { .. } => unreachable!(),
}
Expand Down
58 changes: 58 additions & 0 deletions compiler/src/passes/eliminate_algebraic/eliminate_params.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use crate::passes::eliminate_algebraic::eliminate::Ctx;
use crate::passes::parse::types::Type;
use crate::passes::parse::{Param, TypeDef};
use crate::utils::gen_sym::{gen_sym, UniqueSym};
use std::collections::HashMap;

pub fn eliminate_params<'p>(
fn_params: HashMap<UniqueSym<'p>, Vec<Param<UniqueSym<'p>>>>,
ctx: &mut Ctx<'p>,
defs: &HashMap<UniqueSym<'p>, TypeDef<'p, UniqueSym<'p>>>,
) -> HashMap<UniqueSym<'p>, Vec<Param<UniqueSym<'p>>>> {
fn_params
.into_iter()
.map(|(sym, params)| {
(
sym,
params
.into_iter()
.flat_map(|param| {
flatten_params(param.sym, &param.typ, ctx, defs)
.into_iter()
.map(move |(sym, typ)| Param {
sym,
typ,
mutable: param.mutable,
})
})
.collect(),
)
})
.collect()
}

pub fn flatten_params<'p>(
param_sym: UniqueSym<'p>,
param_type: &Type<UniqueSym<'p>>,
ctx: &mut Ctx<'p>,
defs: &HashMap<UniqueSym<'p>, TypeDef<'p, UniqueSym<'p>>>,
) -> Vec<(UniqueSym<'p>, Type<UniqueSym<'p>>)> {
match param_type {
Type::Int | Type::Bool | Type::Unit | Type::Never | Type::Fn { .. } => {
vec![(param_sym, param_type.clone())]
}
Type::Var { sym } => match &defs[&sym] {
TypeDef::Struct { fields } => fields
.iter()
.flat_map(|(field_name, field_type)| {
let new_sym = *ctx
.entry((param_sym, field_name))
.or_insert_with(|| gen_sym(param_sym.sym));

flatten_params(new_sym, field_type, ctx, defs).into_iter()
})
.collect(),
TypeDef::Enum { .. } => todo!(),
},
}
}
14 changes: 13 additions & 1 deletion compiler/src/passes/eliminate_algebraic/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod eliminate;
mod eliminate_params;

use crate::passes::atomize::Atom;
use crate::passes::explicate::{CExpr, PrgExplicated, Tail};
Expand Down Expand Up @@ -31,6 +32,7 @@ pub enum EExpr<'p> {
fun: Atom<'p>,
args: Vec<Atom<'p>>,
typ: Type<UniqueSym<'p>>,
fn_typ: Type<UniqueSym<'p>>,
},
FunRef {
sym: UniqueSym<'p>,
Expand Down Expand Up @@ -73,7 +75,17 @@ impl<'p> From<EExpr<'p>> for CExpr<'p> {
match value {
EExpr::Atom { atm, typ } => CExpr::Atom { atm, typ },
EExpr::Prim { op, args, typ } => CExpr::Prim { op, args, typ },
EExpr::Apply { fun, args, typ } => CExpr::Apply { fun, args, typ },
EExpr::Apply {
fun,
args,
typ,
fn_typ,
} => CExpr::Apply {
fun,
args,
typ,
fn_typ,
},
EExpr::FunRef { sym, typ } => CExpr::FunRef { sym, typ },
}
}
Expand Down
19 changes: 16 additions & 3 deletions compiler/src/passes/explicate/explicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,19 @@ fn explicate_assign<'p>(
};

match bnd {
AExpr::Apply { fun, args, typ } => Tail::Seq {
AExpr::Apply {
fun,
args,
typ,
fn_typ,
} => Tail::Seq {
sym,
bnd: CExpr::Apply { fun, args, typ },
bnd: CExpr::Apply {
fun,
args,
typ,
fn_typ,
},
tail: Box::new(tail),
},
AExpr::FunRef { sym: sym_fn, typ } => Tail::Seq {
Expand Down Expand Up @@ -320,14 +330,17 @@ fn explicate_pred<'p>(
env,
)
}
AExpr::Apply { fun, args, .. } => {
AExpr::Apply {
fun, args, fn_typ, ..
} => {
let tmp = gen_sym("tmp");
explicate_assign(
tmp,
AExpr::Apply {
fun,
args,
typ: Type::Bool,
fn_typ,
},
explicate_pred(
AExpr::Atom {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/passes/explicate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub enum CExpr<'p> {
fun: Atom<'p>,
args: Vec<Atom<'p>>,
typ: Type<UniqueSym<'p>>,
fn_typ: Type<UniqueSym<'p>>,
},
FunRef {
sym: UniqueSym<'p>,
Expand Down
Loading

0 comments on commit c3b2f85

Please sign in to comment.