Skip to content

Commit

Permalink
fix: fixed range_constant plonk
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasPiellard authored and ivokub committed Dec 22, 2021
1 parent 64f2e28 commit 4d8f2b1
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 107 deletions.
17 changes: 16 additions & 1 deletion frontend/cs/plonk/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ func (system *SparseR1CS) Xor(a, b cs.Variable) cs.Variable {
// Or returns a | b
// a and b must be 0 or 1
func (system *SparseR1CS) Or(a, b cs.Variable) cs.Variable {

var zero, one big.Int
one.SetUint64(1)

if system.IsConstant(a) && system.IsConstant(b) {
_a := utils.FromInterface(a)
_b := utils.FromInterface(b)
Expand All @@ -310,9 +314,16 @@ func (system *SparseR1CS) Or(a, b cs.Variable) cs.Variable {
a, b = b, a
}
if system.IsConstant(b) {
_b := utils.FromInterface(b)

l := a.(compiled.Term)
r := l
_b := utils.FromInterface(b)

if _b.Cmp(&one) != 0 && _b.Cmp(&zero) != 0 {
panic(fmt.Sprintf("%s should be 0 or 1", _b.String()))
}
system.AssertIsBoolean(a)

one := big.NewInt(1)
_b.Sub(&_b, one)
idl := system.CoeffID(&_b)
Expand All @@ -321,13 +332,17 @@ func (system *SparseR1CS) Or(a, b cs.Variable) cs.Variable {
}
l := a.(compiled.Term)
r := b.(compiled.Term)
system.AssertIsBoolean(l)
system.AssertIsBoolean(r)
system.addPlonkConstraint(l, r, res, compiled.CoeffIdMinusOne, compiled.CoeffIdMinusOne, compiled.CoeffIdOne, compiled.CoeffIdOne, compiled.CoeffIdOne, compiled.CoeffIdZero)
return res
}

// Or returns a & b
// a and b must be 0 or 1
func (system *SparseR1CS) And(a, b cs.Variable) cs.Variable {
system.AssertIsBoolean(a)
system.AssertIsBoolean(b)
return system.Mul(a, b)
}

Expand Down
113 changes: 62 additions & 51 deletions frontend/cs/plonk/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,53 @@ func (system *SparseR1CS) AssertIsLessOrEqual(v cs.Variable, bound cs.Variable)
}
}

func (system *SparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) {

debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound)

nbBits := system.BitLen()

aBits := system.toBinary(a, nbBits, true)
boundBits := system.ToBinary(bound, nbBits)

p := make([]cs.Variable, nbBits+1)
p[nbBits] = 1

for i := nbBits - 1; i >= 0; i-- {

// if bound[i] == 0
// p[i] = p[i+1]
// t = p[i+1]
// else
// p[i] = p[i+1] * a[i]
// t = 0
v := system.Mul(p[i+1], aBits[i])
p[i] = system.Select(boundBits[i], v, p[i+1])

t := system.Select(boundBits[i], 0, p[i+1])

// (1 - t - ai) * ai == 0
l := system.Sub(1, t, aBits[i])

// note if bound[i] == 1, this constraint is (1 - ai) * ai == 0
// --> this is a boolean constraint
// if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too
// system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint

system.addPlonkConstraint(
l.(compiled.Term),
aBits[i].(compiled.Term),
system.zero(),
compiled.CoeffIdZero,
compiled.CoeffIdZero,
compiled.CoeffIdOne,
compiled.CoeffIdOne,
compiled.CoeffIdZero,
compiled.CoeffIdZero, debug)
}

}

func (system *SparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) {

nbBits := system.BitLen()
Expand Down Expand Up @@ -127,63 +174,27 @@ func (system *SparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) {
}

for i := nbBits - 1; i >= 0; i-- {

if bound.Bit(i) == 0 {
// (1 - p(i+1) - ai) * ai == 0
l := system.Sub(1, p[i+1]).(compiled.Term)
l = system.Sub(l, aBits[i]).(compiled.Term)

system.addPlonkConstraint(l, aBits[i].(compiled.Term), system.zero(), compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdOne, compiled.CoeffIdZero, compiled.CoeffIdZero, debug)
l := system.Sub(1, p[i+1], aBits[i]).(compiled.Term)
//l = system.Sub(l, ).(compiled.Term)

system.addPlonkConstraint(
l,
aBits[i].(compiled.Term),
system.zero(),
compiled.CoeffIdZero,
compiled.CoeffIdZero,
compiled.CoeffIdOne,
compiled.CoeffIdOne,
compiled.CoeffIdZero,
compiled.CoeffIdZero,
debug)
// system.markBoolean(aBits[i].(compiled.Term))
} else {
system.AssertIsBoolean(aBits[i])
}
}

}

func (system *SparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) {

debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound)

nbBits := system.BitLen()

aBits := system.toBinary(a, nbBits, true)
boundBits := system.ToBinary(bound, nbBits)

p := make([]cs.Variable, nbBits+1)
p[nbBits] = 1

for i := nbBits - 1; i >= 0; i-- {

// if bound[i] == 0
// p[i] = p[i+1]
// t = p[i+1]
// else
// p[i] = p[i+1] * a[i]
// t = 0
v := system.Mul(p[i+1], aBits[i])
p[i] = system.Select(boundBits[i], v, p[i+1])

t := system.Select(boundBits[i], 0, p[i+1])

// (1 - t - ai) * ai == 0
l := system.Sub(1, t, aBits[i])

// note if bound[i] == 1, this constraint is (1 - ai) * ai == 0
// --> this is a boolean constraint
// if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too
// system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint

system.addPlonkConstraint(
l.(compiled.Term),
aBits[i].(compiled.Term),
system.zero(),
compiled.CoeffIdZero,
compiled.CoeffIdZero,
compiled.CoeffIdOne,
compiled.CoeffIdOne,
compiled.CoeffIdZero,
compiled.CoeffIdZero, debug)
}

}
52 changes: 0 additions & 52 deletions frontend/cs/r1cs/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,56 +182,4 @@ func (system *R1CSRefactor) mustBeLessOrEqCst(a compiled.Variable, bound big.Int
}
}

// nbBits := system.BitLen()

// // ensure the bound is positive, it's bit-len doesn't matter
// if bound.Sign() == -1 {
// panic("AssertIsLessOrEqual: bound must be positive")
// }
// if bound.BitLen() > nbBits {
// panic("AssertIsLessOrEqual: bound is too large, constraint will never be satisfied")
// }

// // debug info
// debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.constant(bound))

// // note that at this stage, we didn't boolean-constraint these new variables yet
// // (as opposed to ToBinary)
// aBits := system.toBinary(a, nbBits, true)

// // t trailing bits in the bound
// t := 0
// for i := 0; i < nbBits; i++ {
// if bound.Bit(i) == 0 {
// break
// }
// t++
// }

// p := make([]cs.Variable, nbBits+1)
// // p[i] == 1 --> a[j] == c[j] for all j >= i
// p[nbBits] = system.constant(1)

// for i := nbBits - 1; i >= t; i-- {
// if bound.Bit(i) == 0 {
// p[i] = p[i+1]
// } else {
// p[i] = system.Mul(p[i+1], aBits[i])
// }
// }

// for i := nbBits - 1; i >= 0; i-- {
// if bound.Bit(i) == 0 {
// // (1 - p(i+1) - ai) * ai == 0
// l := system.one()
// l = system.Sub(l, p[i+1]).(compiled.Variable)
// l = system.Sub(l, aBits[i]).(compiled.Variable)

// system.addConstraint(newR1C(l, aBits[i], system.constant(0)), debug)
// system.markBoolean(aBits[i].(compiled.Variable))
// } else {
// system.AssertIsBoolean(aBits[i])
// }
// }

}
1 change: 1 addition & 0 deletions frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func IgnoreUnconstrainedInputs(opt *CompileOption) error {

var tVariable reflect.Type

// TODO @tpiellard change struct{ A Variable } to struct{ A cs.Variable}, and refactor ...
func init() {
tVariable = reflect.ValueOf(struct{ A Variable }{}).FieldByName("A").Type()
}
13 changes: 10 additions & 3 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,24 @@ func TestIntegrationAPI(t *testing.T) {
t.Log(k)
for _, w := range tData.ValidWitnesses {
// assert.ProverSucceeded(tData.Circuit, w, test.WithProverOpts(backend.WithHints(tData.HintFunctions...)))
assert.ProverSucceeded(tData.Circuit, w, test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), test.WithBackends(backend.GROTH16))
assert.ProverSucceeded(
tData.Circuit,
w,
test.WithProverOpts(backend.WithHints(tData.HintFunctions...)),
test.WithBackends(backend.PLONK))
}

for _, w := range tData.InvalidWitnesses {
assert.ProverFailed(tData.Circuit, w, test.WithProverOpts(backend.WithHints(tData.HintFunctions...)))
assert.ProverFailed(
tData.Circuit,
w,
test.WithProverOpts(backend.WithHints(tData.HintFunctions...)),
test.WithBackends(backend.PLONK))
}

// we put that here now, but will be into a proper fuzz target with go1.18
const fuzzCount = 30
assert.Fuzz(tData.Circuit, fuzzCount, test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), test.WithBackends(backend.GROTH16))

}

}

0 comments on commit 4d8f2b1

Please sign in to comment.