Skip to content

Commit

Permalink
dev: feat: add compile-time if and recursive function (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
Myriad-Dreamin authored Oct 8, 2024
1 parent 8ef3ef4 commit 4ca32ad
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 89 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
block {
none
def fib@49(n@50: u32): DeclExpr(u32@52) = if ("<="(n@50, 1)) block {
def fib@49(n@50: u32): DeclExpr(u32@52) = if (BinInst(Int(u32,Le),n@50,Int64(1))) block {
n@50
} else block {
"+"(fib@49("-"(n@50, 1)), fib@49("-"(n@50, 2)))
"+"(fib@49(BinInst(Int(u32,Sub),n@50,Int64(1))), fib@49(BinInst(Int(u32,Sub),n@50,Int64(2))))
}
}
186 changes: 111 additions & 75 deletions packages/cosmo/src/main/scala/cosmo/Eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,70 +189,82 @@ class Env(val source: Source, val pacMgr: cosmo.PackageManager)
}
}

def term(lhs: Term)(implicit rhs: Term): ir.Term = lhs match {
// todo: lift
case Region(stmts, semi) => Region(stmts.map(term), semi)
case Break() => Break()
case Continue() => Continue()
case Return(value) => Return(term(value))
case If(cond, x, y) => If(term(cond), term(x), y.map(term))
case Loop(body) => Loop(term(body))
case While(cond, body) => While(term(cond), term(body))
case For(name, iter, body) =>
For(term(name), term(iter), term(body))
case UnOp(op, lhs) => UnOp(op, term(lhs))
case BinOp(op, lhs, rhs) => binOp(op, term(lhs), term(rhs))
case BinInst(op, lhs, rhs) => binInst(op, term(lhs), term(rhs))
case Apply(lhs, rhs) => $apply(term(lhs), rhs.map(term))
case Select(lhs, rhs) => select(term(lhs), rhs)
case As(lhs, rhs) => As(term(lhs), term(rhs))
case KeyedArg(k, v) => KeyedArg(term(k), term(v))
case i: InferVar => i
case t: TmplApply =>
tmplApply(t.lhs, t.strings, t.rhs.map { case (a, b) => (term(a), b) })
case t: SelectExpr => select(t.lhs, t.rhs)
case t: Name => byRef(t.id)
case t: Hole => Hole(t.id)
// todo: fold these expressions
case t: DestructExpr => t
case t: MatchExpr => t
case t: CaseRegion => t
case t: CaseExpr => t
case v: BoundField => v
case item: CIdent if rhs == UniverseTy =>
CIdent(item.name, item.ns, UniverseTy)
case item: CppInsType if rhs == UniverseTy =>
CppInsType(
lift(item.target).asInstanceOf[CIdent],
item.arguments.map(lift),
)
case SelfVal if rhs == UniverseTy => SelfTy
// todo: fold these instances
case c: CIdent => c
case c: CppInsType => c
case v: Value => v
case d: Def => d
case t: SimpleType => t
case o: Opaque => o
case t: ValueMatch => t
case t: TypeMatch => t
case t: ClassExpr => t
case t: Impl => t
case t: Var => t
case t: Param => t
case t: CModule => t
case t: NativeModule => t
case t: Class => t
case t: ClassInstance => t
case t: ClassDestruct => t
case t: EnumDestruct => t
case t: HKTInstance => t
case r @ Ref(id, v) if rhs == UniverseTy =>
items.get(id.id).getOrElse(r)
case r: Ref => r
case RefItem(lhs, isMut) => RefItem(term(lhs), isMut)
case TupleLit(elems) => TupleLit(elems.map(term))
}
def term(lhs: Term)(implicit rhs: Term): ir.Term =
debugln(s"term $lhs: $rhs")
lhs match {
// todo: lift
case Region(stmts, true) if rhs == UniverseTy => NoneItem
case Region(stmts, false) if rhs == UniverseTy =>
stmts.lastOption.map(term).getOrElse(NoneItem)
case Region(stmts, semi) => Region(stmts.map(term), semi)
case Break() => Break()
case Continue() => Continue()
case Return(value) => Return(term(value))
case If(cond, x, y) =>
term(cond) match {
case Bool(true) => term(x)
case Bool(false) => y.map(term).getOrElse(NoneItem)
case c => If(c, x, y)
}
case Loop(body) => Loop(term(body))
case While(cond, body) => While(term(cond), term(body))
case For(name, iter, body) =>
For(term(name), term(iter), term(body))
case UnOp(op, lhs) => UnOp(op, term(lhs))
case BinOp(op, lhs, rhs) => binOp(op, term(lhs), term(rhs))
case BinInst(op, lhs, rhs) => binInst(op, term(lhs), term(rhs))
case Apply(lhs, rhs) => $apply(term(lhs), rhs.map(term))
case Select(lhs, rhs) => select(term(lhs), rhs)
case As(lhs, rhs) => As(term(lhs), term(rhs))
case KeyedArg(k, v) => KeyedArg(term(k), term(v))
case i: InferVar => i
case t: TmplApply =>
tmplApply(t.lhs, t.strings, t.rhs.map { case (a, b) => (term(a), b) })
case t: SelectExpr => select(t.lhs, t.rhs)
case t: Name => byRef(t.id)
case t: Hole => Hole(t.id)
// todo: fold these expressions
case t: DestructExpr => t
case t: MatchExpr => t
case t: CaseRegion => t
case t: CaseExpr => t
case v: BoundField => v
case item: CIdent if rhs == UniverseTy =>
CIdent(item.name, item.ns, UniverseTy)
case item: CppInsType if rhs == UniverseTy =>
CppInsType(
lift(item.target).asInstanceOf[CIdent],
item.arguments.map(lift),
)
case SelfVal if rhs == UniverseTy => SelfTy
// todo: fold these instances
case c: CIdent => c
case c: CppInsType => c
case v: Value => v
case d: Def => d
case t: SimpleType => t
case o: Opaque => o
case t: ValueMatch => t
case t: TypeMatch => t
case t: ClassExpr => t
case t: Impl => t
case t: Param => t
case t: CModule => t
case t: NativeModule => t
case t: Class => t
case t: ClassInstance => t
case t: ClassDestruct => t
case t: EnumDestruct => t
case t: HKTInstance => t
case t @ Var(id, _, _) if rhs == UniverseTy =>
items.get(id.id).getOrElse(t)
case t: Var => t
case r @ Ref(id, v) if rhs == UniverseTy =>
items.get(id.id).getOrElse(r)
case r: Ref => r
case RefItem(lhs, isMut) => RefItem(term(lhs), isMut)
case TupleLit(elems) => TupleLit(elems.map(term))
}

def tyckValO(node: Option[Expr]): Term =
node.map(tyckVal).getOrElse(NoneItem)
Expand Down Expand Up @@ -404,29 +416,48 @@ class Env(val source: Source, val pacMgr: cosmo.PackageManager)
def binOp(op: String, lhs: Term, rhs: Term): Term = {
val lhsTy = tyOf(lhs)
val rhsTy = tyOf(rhs)
debugln(s"binOp $lhsTy $op $rhsTy")
debugln(s"binOp $lhs ($lhsTy) $op $rhs ($rhsTy)")

(lhsTy, op) match {
case (_, "<:") => Bool(isSubtype(lhs, rhs))
case (l: IntegerTy, s @ ("+" | "-" | "*" | "/")) =>
if (rhsTy != l) {
case (
l: IntegerTy,
s @ ("+" | "-" | "*" | "/" | "%" | "&" | "|" | "^" | "<<" | ">>" |
">>>" | "==" | "!=" | "<" | "<=" | ">" | ">="),
) =>
if (rhsTy != l && !rhs.isInstanceOf[Int64]) {
err(s"mismatched types $lhsTy $s $rhsTy")
return BinOp(s, lhs, rhs)
}
// todo: move to parser
val op = s match {
case "+" => BinInstIntOp.Add
case "-" => BinInstIntOp.Sub
case "*" => BinInstIntOp.Mul
case "/" => BinInstIntOp.Div
case "+" => BinInstIntOp.Add
case "-" => BinInstIntOp.Sub
case "*" => BinInstIntOp.Mul
case "/" => BinInstIntOp.Div
case "%" => BinInstIntOp.Rem
case "&" => BinInstIntOp.And
case "|" => BinInstIntOp.Or
case "^" => BinInstIntOp.Xor
case "<<" => BinInstIntOp.Shl
case ">>" => BinInstIntOp.Shr
case ">>>" => BinInstIntOp.Sar
case "==" => BinInstIntOp.Eq
case "!=" => BinInstIntOp.Ne
case "<" => BinInstIntOp.Lt
case "<=" => BinInstIntOp.Le
case ">" => BinInstIntOp.Gt
case ">=" => BinInstIntOp.Ge
}
binInst(BinInstOp.Int(l, op), lhs, rhs)
case _ => BinOp(op, lhs, rhs)
}
}

def binInst(op: BinInstOp, lhs: Term, rhs: Term): Term = {
debugln(s"binInst $lhs $op $rhs")
(op, lhs, rhs) match {
case (BinInstOp.Int(ir.IntegerTy(64, _), op), Int64(l), Int64(r)) =>
case (BinInstOp.Int(ir.IntegerTy(_, _), op), Int64(l), Int64(r)) =>
op match {
case BinInstIntOp.Add => Int64(l + r)
case BinInstIntOp.Sub => Int64(l - r)
Expand Down Expand Up @@ -635,18 +666,23 @@ class Env(val source: Source, val pacMgr: cosmo.PackageManager)
if (isValLevel(fn) && anno != UniverseTy) then
return Apply(fn, castArgs(fn.params, args))
val f = fn.checked
return scopes.withScope {
val itemsStack = items
val res = scopes.withScope {
val castedArgs = f.params.zip(args).map { case (p, a) =>
val info = p.id
scopes.set(info.name, info)
val casted = castTo(a, info.ty)
items += (info.id -> casted)
val casted = castTo(term(a), info.ty)
// todo: constrain types
casted
}
f.params.zip(castedArgs).foreach { case (p, a) =>
items += (p.id.id -> a)
}
val value = f.body.map(lift).map(eval).getOrElse(NoneItem)
hktRef(Some(f), castedArgs.toList, value)
}
items = itemsStack
res
}

def applyC(node: Class, args: List[Term]): Term = {
Expand Down
4 changes: 2 additions & 2 deletions packages/cosmo/src/test/scala/cosmo/SampleTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class SampleTest extends TestBase:
test("CompileTime/addFn") {
compilePath("samples/CompileTime/addFn.cos")
}
test("CompileTime/loop".ignore) {
compilePath("samples/CompileTime/loop.cos")
test("CompileTime/recursive") {
compilePath("samples/CompileTime/recursive.cos")
}
test("Reflect/name") {
compilePath("samples/Reflect/name.cos")
Expand Down
10 changes: 0 additions & 10 deletions samples/CompileTime/loop.cos

This file was deleted.

17 changes: 17 additions & 0 deletions samples/CompileTime/recursive.cos
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

@noCore();

def fib(n: u32): u32 = if (n <= 1) { n } else {
fib(n - 1) + fib(n - 2)
}

def main() = {
println(Type(fib(1)));
println(Type(fib(2)));
println(Type(fib(3)));
println(Type(fib(4)));
println(Type(fib(5)));
println(Type(fib(6)));
println(Type(fib(10)));
println(Type(fib(22)));
}

0 comments on commit 4ca32ad

Please sign in to comment.