Skip to content

Commit

Permalink
ATLAS.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Aug 6, 2021
1 parent e6dbb4c commit c597554
Show file tree
Hide file tree
Showing 159 changed files with 3,661 additions and 1,372 deletions.
2 changes: 2 additions & 0 deletions BMR/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ class ProgramRegister : public Phase, public Register
// only true for evaluation
static const bool actual_inputs = false;

static int threshold(int) { throw not_implemented(); }

static Register new_reg();
static Register tmp_reg() { return new_reg(); }
static Register and_reg() { return new_reg(); }
Expand Down
2 changes: 1 addition & 1 deletion BMR/Register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void EvalRegister::store(GC::Memory<U>& mem,
//cout << "ext:" << ext << "/" << (int)reg.get_external() << " " << endl;
tmp = spdz_wire.mask + U::constant(ext, (int)party.get_id() - 1, party.get_mac_key());
S.push_back(tmp);
tmp *= gf2n_long(1) << i;
tmp <<= i;
dest += tmp;
const Key& key = reg.external_key(party.get_id());
Key& expected_key = spdz_wire.my_keys[(int)reg.get_external()];
Expand Down
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.2.6 (Aug 6, 2021)

- [ATLAS](https://eprint.iacr.org/2021/833)
- Keras-like interface
- Iterative linear solution approximation
- Binary output
- HighGear/LowGear key generation for wider range of parameters by default
- Dabit generation for smaller primes and malicious security
- More consistent type model
- Improved local computation
- Optimized GF(2^8) for CCD
- NTL only needed for computation with GF(2^40)
- Virtual machines suggest compile-time optimizations
- Improved documentation of types

## 0.2.5 (Jul 2, 2021)

- Training of convolutional neural networks
Expand Down
1 change: 1 addition & 0 deletions CONFIG
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS += -lboost_system -lssl -lcrypto

ifeq ($(USE_NTL),1)
CFLAGS += -DUSE_NTL
LDLIBS := -lntl $(LDLIBS)
endif

Expand Down
31 changes: 26 additions & 5 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
@classmethod
def new(cls, value=None, n=None):
if util.is_constant(value):
n = value.bit_length()
return cls.get_type(n)(value)
def __init__(self, value=None, n=None, size=None):
assert n == self.n or n is None
Expand Down Expand Up @@ -152,7 +154,7 @@ def load_other(self, other):
and self.n == other.n:
for i in range(math.ceil(self.n / self.unit)):
self.mov(self[i], other[i])
elif isinstance(other, sint):
elif isinstance(other, sint) and isinstance(self, sbits):
self.mov(self, sbitvec(other, self.n).elements()[0])
else:
try:
Expand Down Expand Up @@ -214,7 +216,13 @@ def conv_regint_by_bit(cls, n, res, other):
cls.conv_cint_vec(cint(other, size=other.size), res)
types = {}
def load_int(self, value):
self.load_other(regint(value))
if self.n <= 64:
tmp = regint(value)
elif value == self.long_one():
tmp = cint(1, size=self.n)
else:
raise CompilerError('loading long integers to cbits not supported')
self.load_other(tmp)
def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op):
Expand All @@ -227,7 +235,7 @@ def clear_op(self, other, c_inst, ci_inst, op):
else:
if util.is_constant(other):
if other >= 2**31 or other < -2**31:
return op(self, cbits(other))
return op(self, cbits.new(other))
res = cbits.get_type(max(self.n, len(bin(other)) - 2))()
ci_inst(res, self, other)
return res
Expand Down Expand Up @@ -269,6 +277,8 @@ def __lshift__(self, other):
res = cbits.get_type(self.n+other)()
inst.shlcbi(res, self, other)
return res
def __invert__(self):
return self ^ self.long_one()
def print_reg(self, desc=''):
inst.print_regb(self, desc)
def print_reg_plain(self):
Expand Down Expand Up @@ -527,6 +537,14 @@ def trans(cls, rows):
return res
@staticmethod
def bit_adder(*args, **kwargs):
""" Binary adder in binary circuits.
:param a: summand (list of 0/1 in compatible type)
:param b: summand (list of 0/1 in compatible type)
:param carry_in: input carry (default 0)
:param get_carry: add final carry to output
:returns: list of 0/1 in relevant type
"""
return sbitint.bit_adder(*args, **kwargs)
@staticmethod
def ripple_carry_adder(*args, **kwargs):
Expand Down Expand Up @@ -889,7 +907,7 @@ def _store(self, value, address):
cbits.dynamic_array = Array

def _complement_two_extend(bits, k):
return bits + [bits[-1]] * (k - len(bits))
return bits[:k] + [bits[-1]] * (k - len(bits))

class _sbitintbase:
def extend(self, n):
Expand Down Expand Up @@ -1096,7 +1114,6 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False):
raise CompilerError('round to nearest not implemented')
if not isinstance(other, sbitintvec):
other = sbitintvec(other)
assert len(self.v) + len(other.v) <= k
a = self.get_type(k).from_vec(_complement_two_extend(self.v, k))
b = self.get_type(k).from_vec(_complement_two_extend(other.v, k))
tmp = a * b
Expand Down Expand Up @@ -1148,6 +1165,10 @@ class sbitfix(_fix):
sub: 0.199997
lt: 0
Note that the default precision (16 bits after the dot, 31 bits in
total) only allows numbers up to :math:`2^{31-16-1} \\approx
16000`. You can increase this using :py:func:`set_precision`.
"""
float_type = type(None)
clear_type = cbitfix
Expand Down
5 changes: 3 additions & 2 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def process(self, program, alloc_pool):

def finalize(self, options):
for reg in self.alloc:
for x in reg.vector:
if x not in self.dealloc and reg not in self.dealloc:
for x in reg.get_all():
if x not in self.dealloc and reg not in self.dealloc \
and len(x.duplicates) == 1:
print('Warning: read before write at register', x)
print('\tregister trace: %s' % format_trace(x.caller,
'\t\t'))
Expand Down
41 changes: 13 additions & 28 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
parameter.
Most of these routines were implemented before the cint/sint classes, so use
the old-fasioned Register class and assembly instructions instead of operator
the old-fashioned Register class and assembly instructions instead of operator
overloading.
The PreMulC function has a few variants, depending on whether
Expand Down Expand Up @@ -61,18 +61,13 @@ def ld2i(c, n):
t1 = t2
movc(c, t1)

inverse_of_two = {}

def divide_by_two(res, x, m=1):
""" Faster clear division by two using a cached value of 2^-1 mod p """
tmp = program.curr_block.new_reg('c')
inv2m(tmp, m)
mulc(res, x, tmp)

def require_ring_size(k, op):
if int(program.options.ring) < k:
raise CompilerError('ring size too small for %s, compile '
'with \'-R %d\' or more' % (op, k))
msg = 'ring size too small for %s, compile ' \
'with \'-R %d\' or more' % (op, k)
if k > 64 and k < 128:
msg += ' (maybe \'-R 128\' as it is supported by default)'
raise CompilerError(msg)
program.curr_tape.require_bit_length(k)

@instructions_base.cisc
Expand Down Expand Up @@ -122,20 +117,11 @@ def Trunc(d, a, k, m, kappa, signed):
m: compile-time integer
signed: True/False, describes a
"""
t = program.curr_block.new_reg('s')
c = [program.curr_block.new_reg('c') for i in range(3)]
c2m = program.curr_block.new_reg('c')
if m == 0:
movs(d, a)
return
elif program.options.ring:
return TruncRing(d, a, k, m, signed)
else:
a_prime = program.non_linear.mod2m(a, k, m, signed)
subs(t, a, a_prime)
ldi(c[1], 1)
divide_by_two(c[2], c[1], m)
mulm(d, t, c[2])
movs(d, program.non_linear.trunc(a, k, m, kappa, signed))

def TruncRing(d, a, k, m, signed):
program.curr_tape.require_bit_length(1)
Expand Down Expand Up @@ -489,13 +475,12 @@ def BitLTL(res, a, b, kappa):
"""
k = len(b)
a_bits = b[0].bit_decompose_clear(a, k)
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
t = [program.curr_block.new_reg('s') for i in range(1)]
for i in range(len(b)):
s[0][i] = b[0].long_one() - b[i]
CarryOut(t[0], a_bits[::-1], s[0][::-1], b[0].long_one(), kappa)
subsfi(res, t[0], 1)
return a_bits, s[0]
from .types import sint
movs(res, sint.conv(BitLTL_raw(a_bits, b)))

def BitLTL_raw(a_bits, b):
s = [x.bit_not() for x in b]
return CarryOutRaw(a_bits[::-1], s[::-1], b[0].long_one()).bit_not()

def PreMulC_with_inverses_and_vectors(p, a):
"""
Expand Down
4 changes: 4 additions & 0 deletions Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def run(args, options):
if options.binary:
VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary))
VARS['sfix'] = GC_types.sbitfixvec
for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \
'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \
'squant':
del VARS[i]

print('Compiling file', prog.infile)

Expand Down
9 changes: 5 additions & 4 deletions Compiler/floatingpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def BitDec(a, k, m, kappa, bits_to_compute=None):
return program.Program.prog.non_linear.bit_dec(a, k, m)

def BitDecRingRaw(a, k, m):
comparison.require_ring_size(m, 'bit decomposition')
n_shift = int(program.Program.prog.options.ring) - m
assert(n_shift >= 0)
if program.Program.prog.use_split():
x = a.split_to_two_summands(m)
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
Expand Down Expand Up @@ -504,7 +504,8 @@ def TruncPrRing(a, k, m, signed=True):
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
else:
from .types import sint
if signed:
prog = program.Program.prog
if signed and prog.use_trunc_pr != -1:
a += (1 << (k - 1))
if program.Program.prog.use_trunc_pr:
res = sint()
Expand All @@ -530,7 +531,7 @@ def TruncPrRing(a, k, m, signed=True):
overflow = msb.bit_xor(masked >> (n_ring - 1))
res = shifted - upper + \
(overflow << (k - m))
if signed:
if signed and prog.use_trunc_pr != -1:
res -= (1 << (k - m - 1))
return res

Expand Down Expand Up @@ -672,7 +673,7 @@ def _():
t = bbits[0].bit_decompose_clear(p - c, bit_length)
c = longint(c, bit_length)
czero = (c==0)
q = bbits[0].long_one() - BITLT(bbits, t, bit_length)
q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t)
fbar = [bbits[0].clear_type.conv(cint(x))
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
Expand Down
64 changes: 64 additions & 0 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,19 @@ class divc(base.InvertInstruction):
code = base.opcodes['DIVC']
arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class floordivc(base.Instruction):
""" Clear integer floor division.
:param: result (cint)
:param: dividend (cint)
:param: divisor (cint)
"""
__slots__ = []
code = base.opcodes['FLOORDIVC']
arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class modc(base.Instruction):
Expand Down Expand Up @@ -1406,6 +1419,32 @@ def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', player), \
self.get_size())

class inputpersonal(base.Instruction, base.Mergeable):
""" Private input from cint.
:param: vector size (int)
:param: player (int)
:param: destination (sint)
:param: source (cint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['INPUTPERSONAL']
arg_format = tools.cycle(['int','p','sw','c'])
field_type = 'modp'

def __init__(self, *args):
super(inputpersonal, self).__init__(*args)
for i in range(0, len(args), 4):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]

def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])

@base.gf2n
@base.vectorize
class print_reg(base.IOInstruction):
Expand Down Expand Up @@ -1677,6 +1716,31 @@ class raw_output(base.PublicFileIOInstruction):
code = base.opcodes['RAWOUTPUT']
arg_format = ['c']

@base.vectorize
class intoutput(base.PublicFileIOInstruction):
""" Binary integer output.
:param: player (int)
:param: regint
"""
__slots__ = []
code = base.opcodes['INTOUTPUT']
arg_format = ['p','ci']

@base.vectorize
class floatoutput(base.PublicFileIOInstruction):
""" Binary floating-point output.
:param: player (int)
:param: significand (cint)
:param: exponent (cint)
:param: zero bit (cint)
:param: sign bit (cint)
"""
__slots__ = []
code = base.opcodes['FLOATOUTPUT']
arg_format = ['p','c','c','c','c']

@base.gf2n
@base.vectorize
class startprivateoutput(base.Instruction):
Expand Down
Loading

0 comments on commit c597554

Please sign in to comment.