Skip to content

Commit

Permalink
Boolean constants.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Feb 20, 2020
1 parent a79f998 commit d5b3684
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 35 deletions.
1 change: 0 additions & 1 deletion apps/garbled/max.mpcl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// -*- go -*-
//

package main

Expand Down
6 changes: 3 additions & 3 deletions apps/garbled/millionaire.mpcl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

package main

func main(a, b int64) int {
func main(a, b int64) bool {
if a > b {
return 0
return true
} else {
return 1
return false
}
}
40 changes: 28 additions & 12 deletions compiler/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,36 +309,52 @@ func (ast *VariableRef) Fprint(w io.Writer, ind int) {
}

type Constant struct {
Loc utils.Point
UintVal *uint64
Loc utils.Point
Value interface{}
}

func (ast *Constant) String() string {
return fmt.Sprintf("$%d", *ast.UintVal)
switch val := ast.Value.(type) {
case uint64:
return fmt.Sprintf("$%d", val)
case bool:
return fmt.Sprintf("$%v", val)
default:
return fmt.Sprintf("{undefined constant %v}", ast.Value)
}
}

func (ast *Constant) Variable() (ssa.Variable, error) {
v := ssa.Variable{
Const: true,
Const: true,
ConstValue: ast.Value,
}
if ast.UintVal != nil {
switch val := ast.Value.(type) {
case uint64:
var bits int
// Count minimum bits needed to represent the value.
for bits = 2; bits < 64; bits++ {
if (0xffffffffffffffff<<bits)&*ast.UintVal == 0 {
if (0xffffffffffffffff<<bits)&val == 0 {
bits--
break
}
}

v.Name = fmt.Sprintf("$%d", *ast.UintVal)
v.Name = fmt.Sprintf("$%d", val)
v.Type = types.Info{
Type: types.Uint,
Bits: bits,
}
v.ConstUint = ast.UintVal
} else {
return v, fmt.Errorf("constant %v not implemented yet", ast)

case bool:
v.Name = fmt.Sprintf("$%v", val)
v.Type = types.Info{
Type: types.Bool,
Bits: 1,
}

default:
return v, fmt.Errorf("ast.Constant.Variable(): %v not implemented yet",
ast)
}
return v, nil
}
Expand All @@ -349,5 +365,5 @@ func (ast *Constant) Location() utils.Point {

func (ast *Constant) Fprint(w io.Writer, ind int) {
indent(w, ind)
fmt.Fprintf(w, "%d", ast.UintVal)
fmt.Fprintf(w, "%v", ast.Value)
}
4 changes: 2 additions & 2 deletions compiler/ast/ssagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (ast *Assign) SSA(block *ssa.Block, ctx *Codegen,

b, err := block.Bindings.Get(ast.Name)
if err != nil {
return nil, err
return nil, ctx.logger.Errorf(ast.Loc, "%s", err.Error())
}
v, err := gen.NewVar(b.Name, b.Type, ctx.Scope())
if err != nil {
Expand Down Expand Up @@ -343,7 +343,7 @@ func (ast *VariableRef) SSA(block *ssa.Block, ctx *Codegen,

b, err := block.Bindings.Get(ast.Name)
if err != nil {
return nil, err
return nil, ctx.logger.Errorf(ast.Loc, "%s", err.Error())
}

v := b.Value(block, gen)
Expand Down
19 changes: 14 additions & 5 deletions compiler/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,23 @@ type Token struct {
From utils.Point
To utils.Point
StrVal string
UintVal *uint64
ConstVal interface{}
TypeInfo types.Info
}

func (t *Token) String() string {
var str string
if len(t.StrVal) > 0 {
str = t.StrVal
} else if t.UintVal != nil {
str = strconv.FormatUint(*t.UintVal, 10)
} else {
str = t.Type.String()
switch val := t.ConstVal.(type) {
case uint64:
str = strconv.FormatUint(val, 10)
case bool:
str = fmt.Sprintf("%v", val)
default:
str = t.Type.String()
}
}
return str
}
Expand Down Expand Up @@ -487,6 +492,10 @@ func (l *Lexer) Get() (*Token, error) {
}
token.StrVal = symbol
return token, nil
} else if symbol == "true" || symbol == "false" {
token := l.Token(T_Constant)
token.ConstVal = symbol == "true"
return token, nil
}

token := l.Token(T_Identifier)
Expand Down Expand Up @@ -516,7 +525,7 @@ func (l *Lexer) Get() (*Token, error) {
return nil, err
}
token := l.Token(T_Constant)
token.UintVal = &u
token.ConstVal = u
return token, nil
}
l.UnreadRune()
Expand Down
4 changes: 2 additions & 2 deletions compiler/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ func (p *Parser) parseExprPrimary() (ast.AST, error) {

case T_Constant:
return &ast.Constant{
Loc: t.From,
UintVal: t.UintVal,
Loc: t.From,
Value: t.ConstVal,
}, nil
}
p.lexer.Unget(t)
Expand Down
28 changes: 18 additions & 10 deletions compiler/ssa/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ func (i Instr) PP(out io.Writer) {
}

type Variable struct {
Name string
Scope int
Version int
Type types.Info
Const bool
ConstUint *uint64
Name string
Scope int
Version int
Type types.Info
Const bool
ConstValue interface{}
}

func (v Variable) String() string {
Expand Down Expand Up @@ -295,9 +295,17 @@ func (v *Variable) Value(block *Block, gen *Generator) Variable {
}

func (v *Variable) Bit(bit int) bool {
if v.ConstUint != nil {
return (*v.ConstUint & (1 << bit)) != 0
} else {
panic(fmt.Sprintf("BitSet called for a non variable %v", v))
switch val := v.ConstValue.(type) {
case bool:
if bit == 0 {
return val
}
return false

case uint64:
return (val & (1 << bit)) != 0

default:
panic(fmt.Sprintf("Variable.Bit called for a non variable %v", v))
}
}
13 changes: 13 additions & 0 deletions compiler/ssagen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ package main
func main(a, b int) int {
return a * b
}
`,
},
SSAGenTest{
Enabled: true,
Name: "Bool",
Code: `
package main
func main(a, b int) bool {
if a > b {
return true
}
return false
}
`,
},
}
Expand Down

0 comments on commit d5b3684

Please sign in to comment.