1313from Compiler .types import vectorized_classmethod
1414from Compiler .program import Tape , Program
1515from Compiler .exceptions import *
16- from Compiler import util , oram , floatingpoint , library
16+ from Compiler import util , oram , floatingpoint , library , comparison
1717from Compiler import instructions_base
1818import Compiler .GC .instructions as inst
1919import operator
2020import math
2121import itertools
2222from functools import reduce
2323
24+ class _binary :
25+ def reveal_to (self , * args , ** kwargs ):
26+ raise CompilerError (
27+ '%s does not support revealing to indivual players' % type (self ))
28+
2429class bits (Tape .Register , _structure , _bit ):
2530 n = 40
2631 unit = 64
@@ -149,6 +154,12 @@ def set_length(self, n):
149154 self .n = n
150155 def set_size (self , size ):
151156 pass
157+ def load_int (self , value ):
158+ n_limbs = math .ceil (self .n / self .unit )
159+ for i in range (n_limbs ):
160+ self .conv_regint (min (self .unit , self .n - i * self .unit ),
161+ self [i ], regint (value % 2 ** self .unit ))
162+ value >>= self .unit
152163 def load_other (self , other ):
153164 if isinstance (other , cint ):
154165 assert (self .n == other .size )
@@ -236,12 +247,14 @@ def _new_by_number(self, i, size=1):
236247 return res
237248 def if_else (self , x , y ):
238249 """
239- Vectorized oblivious selection::
250+ Bit-wise oblivious selection::
240251
241252 sb32 = sbits.get_type(32)
242253 print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
243254
244- This will output 1.
255+ This will output 1 because it selects the two least
256+ significant bits from 5 and the rest of the bits from 2.
257+
245258 """
246259 return result_conv (x , y )(self & (x ^ y ) ^ y )
247260 def zero_if_not (self , condition ):
@@ -268,6 +281,9 @@ def copy_from_part(self, source, base, size):
268281 self .bit_compose (source .bit_decompose ()[base :base + size ]))
269282 def vector_size (self ):
270283 return self .n
284+ @staticmethod
285+ def size_for_mem ():
286+ return 1
271287
272288class cbits (bits ):
273289 """ Clear bits register. Helper type with limited functionality. """
@@ -302,13 +318,6 @@ def conv(cls, other):
302318 else :
303319 return super (cbits , cls ).conv (other )
304320 types = {}
305- def load_int (self , value ):
306- n_limbs = math .ceil (self .n / self .unit )
307- tmp = regint (size = n_limbs )
308- for i in range (n_limbs ):
309- tmp [i ].load_int (value % 2 ** self .unit )
310- value >>= self .unit
311- self .load_other (tmp )
312321 def store_in_dynamic_mem (self , address ):
313322 inst .stmsdci (self , cbits .conv (address ))
314323 def clear_op (self , other , c_inst , ci_inst , op ):
@@ -502,11 +511,7 @@ def load_int(self, value):
502511 if self .n <= 32 :
503512 inst .ldbits (self , self .n , value )
504513 else :
505- size = math .ceil (self .n / self .unit )
506- tmp = regint (size = size )
507- for i in range (size ):
508- tmp [i ].load_int ((value >> (i * 64 )) % 2 ** 64 )
509- self .load_other (tmp )
514+ bits .load_int (self , value )
510515 def load_other (self , other ):
511516 if isinstance (other , cbits ) and self .n == other .n :
512517 inst .convcbit2s (self .n , self , other )
@@ -675,7 +680,7 @@ def bit_adder(*args, **kwargs):
675680 def ripple_carry_adder (* args , ** kwargs ):
676681 return sbitint .ripple_carry_adder (* args , ** kwargs )
677682
678- class sbitvec (_vec , _bit ):
683+ class sbitvec (_vec , _bit , _binary ):
679684 """ Vector of registers of secret bits, effectively a matrix of secret bits.
680685 This facilitates parallel arithmetic operations in binary circuits.
681686 Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -732,15 +737,16 @@ def get_type(cls, n):
732737 :py:obj:`v` and the columns by calling :py:obj:`elements`.
733738 """
734739 class sbitvecn (cls , _structure ):
735- @staticmethod
736- def malloc (size , creator_tape = None ):
737- return sbit .malloc (size * n , creator_tape = creator_tape )
740+ @classmethod
741+ def malloc (cls , size , creator_tape = None ):
742+ return sbit .malloc (
743+ size * cls .mem_size (), creator_tape = creator_tape )
738744 @staticmethod
739745 def n_elements ():
740746 return 1
741747 @staticmethod
742748 def mem_size ():
743- return n
749+ return sbits . get_type ( n ). mem_size ()
744750 @classmethod
745751 def get_input_from (cls , player , size = 1 , f = 0 ):
746752 """ Secret input from :py:obj:`player`. The input is decomposed
@@ -780,38 +786,28 @@ def __init__(self, other=None, size=None):
780786 self .v = sbits .get_type (n )(other ).bit_decompose ()
781787 assert len (self .v ) == n
782788 assert size is None or size == self .v [0 ].n
783- @vectorized_classmethod
784- def load_mem (cls , address ):
785- size = instructions_base .get_global_vector_size ()
786- if size not in (None , 1 ):
787- assert isinstance (address , int ) or len (address ) == 1
788- sb = sbits .get_type (size )
789- return cls .from_vec (sb .bit_compose (
790- sbit .load_mem (address + i + j * n ) for j in range (size ))
791- for i in range (n ))
792- if not isinstance (address , int ):
793- v = [sbit .load_mem (x , size = n ).v [0 ] for x in address ]
794- return cls (v )
789+ @classmethod
790+ def load_mem (cls , address , size = None ):
791+ if isinstance (address , int ) or len (address ) == 1 :
792+ address = [address + i for i in range (size or 1 )]
795793 else :
796- return cls .from_vec (sbit .load_mem (address + i )
797- for i in range (n ))
794+ assert size == None
795+ return cls (
796+ [sbits .get_type (n ).load_mem (x ) for x in address ])
798797 def store_in_mem (self , address ):
799798 size = 1
800799 for x in self .v :
801800 if not util .is_constant (x ):
802801 size = max (size , x .n )
803- v = [sbits .get_type (size ).conv (x ) for x in self .v ]
804- if not isinstance (address , int ) and len (address ) != 1 :
805- v = self .elements ()
806- assert len (v ) == len (address )
807- for x , y in zip (v , address ):
808- for i , xx in enumerate (x .bit_decompose (n )):
809- xx .store_in_mem (y + i )
802+ if isinstance (address , int ):
803+ address = range (address , address + size )
804+ elif len (address ) == 1 :
805+ address = [address + i * self .mem_size ()
806+ for i in range (size )]
810807 else :
811- assert isinstance (address , int ) or len (address ) == 1
812- for i in range (n ):
813- for j , x in enumerate (v [i ].bit_decompose ()):
814- x .store_in_mem (address + i + j * n )
808+ assert size == len (address )
809+ for x , dest in zip (self .elements (), address ):
810+ x .store_in_mem (dest )
815811 @classmethod
816812 def two_power (cls , nn , size = 1 ):
817813 return cls .from_vec (
@@ -864,7 +860,7 @@ def __init__(self, elements=None, length=None, input_length=None):
864860 assert isinstance (elements , sint )
865861 if Program .prog .use_split ():
866862 x = elements .split_to_two_summands (length )
867- v = sbitint .carry_lookahead_adder (x [0 ], x [1 ], fewer_inv = True )
863+ v = sbitint .bit_adder (x [0 ], x [1 ])
868864 else :
869865 prog = Program .prog
870866 if not prog .options .ring :
@@ -877,6 +873,7 @@ def __init__(self, elements=None, length=None, input_length=None):
877873 length , prog .security )
878874 prog .use_edabit (backup )
879875 return
876+ comparison .require_ring_size (length , 'A2B conversion' )
880877 l = int (Program .prog .options .ring )
881878 r , r_bits = sint .get_edabit (length , size = elements .size )
882879 c = ((elements - r ) << (l - length )).reveal ()
@@ -885,6 +882,8 @@ def __init__(self, elements=None, length=None, input_length=None):
885882 x = sbitintvec .from_vec (r_bits ) + sbitintvec .from_vec (cb )
886883 v = x .v
887884 self .v = v [:length ]
885+ elif isinstance (elements , sbitvec ):
886+ self .v = elements .v
888887 elif elements is not None and not (util .is_constant (elements ) and \
889888 elements == 0 ):
890889 self .v = sbits .trans (elements )
@@ -1347,13 +1346,19 @@ def elements(self):
13471346 def __add__ (self , other ):
13481347 if util .is_zero (other ):
13491348 return self
1350- a , b = self .expand (other )
1349+ try :
1350+ a , b = self .expand (other )
1351+ except :
1352+ return NotImplemented
13511353 v = sbitint .bit_adder (a , b )
13521354 return self .get_type (len (v )).from_vec (v )
13531355 __radd__ = __add__
13541356 __sub__ = _bitint .__sub__
13551357 def __rsub__ (self , other ):
1356- a , b = self .expand (other )
1358+ try :
1359+ a , b = self .expand (other )
1360+ except :
1361+ return NotImplemented
13571362 return self .from_vec (b ) - self .from_vec (a )
13581363 def __mul__ (self , other ):
13591364 if isinstance (other , sbits ):
@@ -1447,7 +1452,7 @@ def output(self):
14471452 inst .print_float_plainb (v , cbits .get_type (32 )(- self .f ), cbits (0 ),
14481453 cbits (0 ), cbits (0 ))
14491454
1450- class sbitfix (_fix ):
1455+ class sbitfix (_fix , _binary ):
14511456 """ Secret signed fixed-point number in one binary register.
14521457 Use :py:obj:`set_precision()` to change the precision.
14531458
@@ -1515,7 +1520,7 @@ class cls(_fix):
15151520 cls .set_precision (f , k )
15161521 return cls ._new (cls .int_type (other ), k , f )
15171522
1518- class sbitfixvec (_fix , _vec ):
1523+ class sbitfixvec (_fix , _vec , _binary ):
15191524 """ Vector of fixed-point numbers for parallel binary computation.
15201525
15211526 Use :py:obj:`set_precision()` to change the precision.
0 commit comments