From 15a2f128331195baac9580c267724a66b140d522 Mon Sep 17 00:00:00 2001 From: Victor Date: Thu, 25 Jul 2024 15:13:11 +0800 Subject: [PATCH 1/3] feat: add exactly sampler for FATE 1.x --- python/fate_arch/computing/eggroll/_table.py | 76 ++++++++++++++------ python/fate_arch/computing/spark/_table.py | 5 +- 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/python/fate_arch/computing/eggroll/_table.py b/python/fate_arch/computing/eggroll/_table.py index a33af09302..49650df4cc 100644 --- a/python/fate_arch/computing/eggroll/_table.py +++ b/python/fate_arch/computing/eggroll/_table.py @@ -126,32 +126,17 @@ def glom(self, **kwargs): @computing_profile def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None): + if fraction is not None and num is not None: + raise ValueError("specify only one of `fraction` or `num`, not both.") + if fraction is not None: return Table(self._rp.sample(fraction=fraction, seed=seed)) if num is not None: - total = self._rp.count() - if num > total: - raise ValueError(f"not enough data to sample, own {total} but required {num}") - - frac = num / float(total) - while True: - sampled_table = self._rp.sample(fraction=frac, seed=seed) - sampled_count = sampled_table.count() - if sampled_count < num: - frac *= 1.1 - else: - break - - if sampled_count > num: - drops = sampled_table.take(sampled_count - num) - for k, v in drops: - sampled_table.delete(k) - - return Table(sampled_table) + return self._exactly_sample(num, seed) raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}") - + @computing_profile def subtractByKey(self, other: 'Table', **kwargs): return Table(self._rp.subtract_by_key(other._rp)) @@ -169,3 +154,54 @@ def flatMap(self, func, **kwargs): flat_map = self._rp.flat_map(func) shuffled = flat_map.map(lambda k, v: (k, v)) # trigger shuffle return Table(shuffled) + + def _exactly_sample(self, num: int, seed: int): + split_size = list(self._rp.map_partitions_with_index( + lambda s, it: [(s, sum(1 for _ in it))] + ).get_all()) + LOGGER.info(f"{split_size}") + + if not split_size: + raise ValueError("no data available to sample") + + total = sum(v for _, v in split_size) + if num > total: + raise ValueError(f"not enough data to sample, own {total} but required {num}") + + sampled_size = {} + for split, size in split_size: + if size <= 0: + sampled_size[split] = 0 + else: + if num == 0: + sampled_size[split] = 0 + else: + sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) + total -= size + num -= sampled_size[split] + + LOGGER.info(f"{sampled_size}") + + return self._rp.map_partitions_with_index(self._reservoir_sample_func(sampled_size, seed)) + + def _reservoir_sample_func(self, split_sample_size: dict, seed=None): + def func(split, iterator): + size = split_sample_size[split] + sample = [] + random_seed = seed + + if random_seed is None: + random_seed = random.randint(0, sys.maxsize) + random_state = random.Random(random_seed ^ split) + + for counter, obj in enumerate(iterator, start=1): + if len(sample) < size: + sample.append(obj) + else: + randint = random_state.randint(1, counter) + if randint <= size: + sample[randint - 1] = obj + + return iter(sample) + + return func diff --git a/python/fate_arch/computing/spark/_table.py b/python/fate_arch/computing/spark/_table.py index 0a25c34431..5b64f17ef6 100644 --- a/python/fate_arch/computing/spark/_table.py +++ b/python/fate_arch/computing/spark/_table.py @@ -314,7 +314,10 @@ def _exactly_sample(rdd, num: int, seed: int): # random the size of each split sampled_size = {} for split, size in split_size.items(): - sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) + if num == 0: + sampled_size[split] = 0 + else: + sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) total = total - size num = num - sampled_size[split] From f95660dfce94c4e5b4a88944f4255357ebe86e5d Mon Sep 17 00:00:00 2001 From: Surfacebuaa <408083193@qq.com> Date: Thu, 25 Jul 2024 16:19:09 +0800 Subject: [PATCH 2/3] feat: add OU homomorphic encryption for FATE 1.x --- python/federatedml/secureprotol/fate_ou.py | 374 +++++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 python/federatedml/secureprotol/fate_ou.py diff --git a/python/federatedml/secureprotol/fate_ou.py b/python/federatedml/secureprotol/fate_ou.py new file mode 100644 index 0000000000..2cd51b446a --- /dev/null +++ b/python/federatedml/secureprotol/fate_ou.py @@ -0,0 +1,374 @@ +"""OU encryption library for partially homomorphic encryption.""" + +import numpy as np +import random + +from federatedml.secureprotol import gmpy_math +from federatedml.secureprotol.fixedpoint import FixedPointNumber + + +# according to this paper +# << Accelerating Okamoto-Uchiyama’s Public-Key Cryptosystem >> +# and NIST's recommendation: +# https://www.keylength.com/en/4/ +# 160 bits for key size 1024 +# 224 bits for key size 2048 +# 256 bits for key size 3072 +kPrimeFactorSize1024 = 160 +kPrimeFactorSize2048 = 224 +kPrimeFactorSize3072 = 256 + +class OUKeypair(object): + def __init__(self): + pass + + @staticmethod + def random_monic_exact_bits(bits): + global last_generated + new_value = random.getrandbits(bits) + + if 'last_generated' not in globals(): + last_generated = new_value + else: + if new_value <= last_generated: + new_value = last_generated + 1 + + last_generated = new_value + return new_value + + def generate_keypair(self, n_length=1024): + """return a new :class:`OUPublicKey` and :class:`OUPrivateKey`. + """ + secret_size = (n_length + 2) // 3 + + prime_factor_size = kPrimeFactorSize1024 + if n_length >= 3072: + prime_factor_size = kPrimeFactorSize3072 + elif n_length >= 2048: + prime_factor_size = kPrimeFactorSize2048 + + assert prime_factor_size * 2 <= secret_size, \ + "Key size must be larger than {} bits".format(prime_factor_size * 2 * 3 - 2) + + # generate p + while True: + prime_factor = gmpy_math.getprimeover(prime_factor_size) + # bits_of(a * b) <= bits_of(a) + bits_of(b), + # So we add extra two bits to u: + # one bit for prime_factor * u; another one bit for p^2; + # Also, make sure that u > prime_factor + u = self.random_monic_exact_bits(secret_size - prime_factor_size + 2) # p - 1 has a large prime factor + p = prime_factor * u + 1 + + if gmpy_math.is_prime(p): + break + + # since bits_of(a * b) <= bits_of(a) + bits_of(b) + # add another 1 bit for q + q = gmpy_math.getprimeover(secret_size + 1) + p_square = p ** 2 + t = prime_factor + n = p_square * q + + # calculate g_p + while True: + while True: + g = random.randint(1, n-1) + gcd = np.gcd(g, p) + if gcd == 1: + break + + gp = gmpy_math.powmod(g % p_square, p - 1, p_square) + check = gmpy_math.powmod(gp, p, p_square) + + if check == 1: + break + + # calculate G + capital_g = gmpy_math.powmod(g, u, n) + + while True: + g = random.randint(1, n-1) + if g % p != 0: + break + + # calculate H + capital_h = gmpy_math.powmod(g, n * u, n) + + # max_plaintext_ must be a power of 2, for ease of use + max_plaintext = pow(10, prime_factor_size // 2) // 2 + + public_key = OUPublicKey(n, capital_g, capital_h, max_plaintext) + private_key = OUPrivateKey(public_key, p, q, t, gp, max_plaintext) + + return public_key, private_key + + +class OUPublicKey(object): + """Contains a public key and associated encryption methods. + """ + + def __init__(self, n, capital_g, capital_h, max_plaintext): + self.n = n # n = p^2 * q + self.capital_g = capital_g # G = g^u mod n for some random g \in [0, n) + self.capital_h = capital_h # H = g'^{n*u} mod n for some random g' \in [0, n) + self.max_plaintext = max_plaintext # always power of 2, e.g. max_plaintext_ == 2^681 + + def __repr__(self): + hashcode = hex(hash(self))[2:] + + return "".format(hashcode[:10]) + + def __eq__(self, other): + return self.n == other.n and self.capital_g == other.capital_g and self.capital_h == other.capital_h + + def __hash__(self): + return hash(self.n) + + # multi H^r + # r is a random number < n + # H and n is public key + def apply_obfuscator(self, ciphertext, random_value=None): + """ + """ + r = random_value or random.SystemRandom().randrange(1, self.n) + obfuscator = gmpy_math.powmod(self.capital_h, r, self.n) + + return (ciphertext * obfuscator) % self.n + + def raw_encrypt(self, plaintext, random_value=None): + """ + """ + if not isinstance(plaintext, int): + raise TypeError("plaintext should be int, but got: %s" % + type(plaintext)) + + if plaintext >= self.max_plaintext: + plaintext -= self.max_plaintext * 2 + + gm = gmpy_math.powmod(self.capital_g, plaintext, self.n) + + ciphertext = self.apply_obfuscator(gm, random_value) + + return ciphertext + + def encrypt(self, value, precision=None, random_value=None): + """Encode and OU encrypt a real number value. + """ + if isinstance(value, FixedPointNumber): + value = value.decode() + encoding = FixedPointNumber.encode(value, self.max_plaintext * 2, self.max_plaintext, precision) + obfuscator = random_value or 1 + ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator) + encryptednumber = OUEncryptedNumber(self, ciphertext, encoding.exponent) + + return encryptednumber + + +class OUPrivateKey(object): + """Contains a private key and associated decryption method. + """ + + def __init__(self, public_key, p, q, t, gp, max_plaintext): + self.public_key = public_key + self.p = p + self.q = q # primes such that log2(p), log2(q) ~ n_bits / 3 + self.t = t # a big prime factor of p - 1, i.e., p = t * u + 1 + self.gp = gp + self.gp_inv = gmpy_math.invert((self.gp - 1) // p, p) # L(g^{p-1} mod p^2))^{-1} mod p + self.p_square = p ** 2 + self.max_plaintext = max_plaintext + + def __repr__(self): + hashcode = hex(hash(self))[2:] + + return "".format(hashcode[:10]) + + def __eq__(self, other): + return self.p == other.p and self.q == other.q and self.t == other.t and self.gp_inv == other.gp_inv + + def __hash__(self): + return hash((self.p, self.q)) + + def raw_decrypt(self, ciphertext): + """return raw plaintext. + """ + if not isinstance(ciphertext, int): + raise TypeError("ciphertext should be an int, not: %s" % + type(ciphertext)) + + plaintext = 0 + + ct = gmpy_math.powmod(ciphertext % self.p_square, self.t, self.p_square) + + plaintext = ((ct // self.p) * self.gp_inv) % self.p + + if plaintext >= self.p / 2: + plaintext -= self.p + if plaintext >= self.max_plaintext: + plaintext = plaintext % (self.max_plaintext * 2) + + return plaintext + + def decrypt(self, encrypted_number): + """return the decrypted & decoded plaintext of encrypted_number. + """ + if not isinstance(encrypted_number, OUEncryptedNumber): + raise TypeError("encrypted_number should be an OUEncryptedNumber, \ + not: %s" % type(encrypted_number)) + + if self.public_key != encrypted_number.public_key: + raise ValueError("encrypted_number was encrypted against a different key!") + + encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False)) + encoded = FixedPointNumber(encoded, + encrypted_number.exponent, + self.public_key.max_plaintext * 2, + self.public_key.max_plaintext) + decrypt_value = encoded.decode() + + return decrypt_value + + +class OUEncryptedNumber(object): + """Represents the OU encryption of a float or int. + """ + + def __init__(self, public_key, ciphertext, exponent=0): + self.public_key = public_key + self.__ciphertext = ciphertext + self.exponent = exponent + self.__is_obfuscator = False + + if not isinstance(self.__ciphertext, int): + raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext)) + + if not isinstance(self.public_key, OUPublicKey): + raise TypeError("public_key should be a OUPublicKey, not: %s" % type(self.public_key)) + + def ciphertext(self, be_secure=True): + """return the ciphertext of the OUEncryptedNumber. + """ + if be_secure and not self.__is_obfuscator: + self.apply_obfuscator() + + return self.__ciphertext + + def apply_obfuscator(self): + """ciphertext by multiplying by H ** r with random r + """ + self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext) + self.__is_obfuscator = True + + def __add__(self, other): + if isinstance(other, OUEncryptedNumber): + return self.__add_encryptednumber(other) + else: + return self.__add_scalar(other) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + + return self + (other * -1) + + def __rsub__(self, other): + return other + (self * -1) + + def __rmul__(self, scalar): + return self.__mul__(scalar) + + def __truediv__(self, scalar): + return self.__mul__(1 / scalar) + + def __mul__(self, scalar): + """return Multiply by an scalar(such as int, float) + """ + if isinstance(scalar, FixedPointNumber): + scalar = scalar.decode() + encode = FixedPointNumber.encode(scalar, self.public_key.max_plaintext * 2, self.public_key.max_plaintext) + plaintext = encode.encoding + + if plaintext < 0 or plaintext >= (self.public_key.max_plaintext * 2): + raise ValueError("Scalar out of bounds: %i" % plaintext) + + if plaintext > self.public_key.max_plaintext: + # Very large plaintext, play a sneaky trick using inverses + plaintext -= self.public_key.max_plaintext * 2 + + ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.n) + + exponent = self.exponent + encode.exponent + + return OUEncryptedNumber(self.public_key, ciphertext, exponent) + + def increase_exponent_to(self, new_exponent): + """return OUEncryptedNumber: + new OUEncryptedNumber with same value but having great exponent. + """ + if new_exponent < self.exponent: + raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent)) + + factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent) + new_encryptednumber = self.__mul__(factor) + new_encryptednumber.exponent = new_exponent + + return new_encryptednumber + + def __align_exponent(self, x, y): + """return x,y with same exponet + """ + if x.exponent < y.exponent: + x = x.increase_exponent_to(y.exponent) + elif x.exponent > y.exponent: + y = y.increase_exponent_to(x.exponent) + + return x, y + + def __add_scalar(self, scalar): + """return OUEncryptedNumber: z = E(x) + y + """ + if isinstance(scalar, FixedPointNumber): + scalar = scalar.decode() + + encoded = FixedPointNumber.encode(scalar, + self.public_key.max_plaintext * 2, + self.public_key.max_plaintext, + max_exponent=self.exponent) + + return self.__add_fixpointnumber(encoded) + + def __add_fixpointnumber(self, encoded): + """return OUEncryptedNumber: z = E(x) + FixedPointNumber(y) + # """ + if self.public_key.max_plaintext != encoded.max_int: + raise ValueError("Attempted to add numbers encoded against different public keys!") + + # their exponents must match, and align. + x, y = self.__align_exponent(self, encoded) + + encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1) + encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent) + + return encryptednumber + + def __add_encryptednumber(self, other): + """return OUEncryptedNumber: z = E(x) + E(y) + """ + if self.public_key != other.public_key: + raise ValueError("add two numbers have different public key!") + + # their exponents must match, and align. + x, y = self.__align_exponent(self, other) + + encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent) + + return encryptednumber + + def __raw_add(self, e_x, e_y, exponent): + """return the integer E(x + y) given ints E(x) and E(y). + """ + ciphertext = gmpy_math.mpz(e_x) * gmpy_math.mpz(e_y) % self.public_key.n + + return OUEncryptedNumber(self.public_key, int(ciphertext), exponent) From 86c9c8e213bcde2ff15208a2a2cc34efc6736570 Mon Sep 17 00:00:00 2001 From: Surfacebuaa <408083193@qq.com> Date: Fri, 26 Jul 2024 09:54:53 +0800 Subject: [PATCH 3/3] Revert "feat: add OU homomorphic encryption for FATE 1.x" This reverts commit f95660dfce94c4e5b4a88944f4255357ebe86e5d. --- python/federatedml/secureprotol/fate_ou.py | 374 --------------------- 1 file changed, 374 deletions(-) delete mode 100644 python/federatedml/secureprotol/fate_ou.py diff --git a/python/federatedml/secureprotol/fate_ou.py b/python/federatedml/secureprotol/fate_ou.py deleted file mode 100644 index 2cd51b446a..0000000000 --- a/python/federatedml/secureprotol/fate_ou.py +++ /dev/null @@ -1,374 +0,0 @@ -"""OU encryption library for partially homomorphic encryption.""" - -import numpy as np -import random - -from federatedml.secureprotol import gmpy_math -from federatedml.secureprotol.fixedpoint import FixedPointNumber - - -# according to this paper -# << Accelerating Okamoto-Uchiyama’s Public-Key Cryptosystem >> -# and NIST's recommendation: -# https://www.keylength.com/en/4/ -# 160 bits for key size 1024 -# 224 bits for key size 2048 -# 256 bits for key size 3072 -kPrimeFactorSize1024 = 160 -kPrimeFactorSize2048 = 224 -kPrimeFactorSize3072 = 256 - -class OUKeypair(object): - def __init__(self): - pass - - @staticmethod - def random_monic_exact_bits(bits): - global last_generated - new_value = random.getrandbits(bits) - - if 'last_generated' not in globals(): - last_generated = new_value - else: - if new_value <= last_generated: - new_value = last_generated + 1 - - last_generated = new_value - return new_value - - def generate_keypair(self, n_length=1024): - """return a new :class:`OUPublicKey` and :class:`OUPrivateKey`. - """ - secret_size = (n_length + 2) // 3 - - prime_factor_size = kPrimeFactorSize1024 - if n_length >= 3072: - prime_factor_size = kPrimeFactorSize3072 - elif n_length >= 2048: - prime_factor_size = kPrimeFactorSize2048 - - assert prime_factor_size * 2 <= secret_size, \ - "Key size must be larger than {} bits".format(prime_factor_size * 2 * 3 - 2) - - # generate p - while True: - prime_factor = gmpy_math.getprimeover(prime_factor_size) - # bits_of(a * b) <= bits_of(a) + bits_of(b), - # So we add extra two bits to u: - # one bit for prime_factor * u; another one bit for p^2; - # Also, make sure that u > prime_factor - u = self.random_monic_exact_bits(secret_size - prime_factor_size + 2) # p - 1 has a large prime factor - p = prime_factor * u + 1 - - if gmpy_math.is_prime(p): - break - - # since bits_of(a * b) <= bits_of(a) + bits_of(b) - # add another 1 bit for q - q = gmpy_math.getprimeover(secret_size + 1) - p_square = p ** 2 - t = prime_factor - n = p_square * q - - # calculate g_p - while True: - while True: - g = random.randint(1, n-1) - gcd = np.gcd(g, p) - if gcd == 1: - break - - gp = gmpy_math.powmod(g % p_square, p - 1, p_square) - check = gmpy_math.powmod(gp, p, p_square) - - if check == 1: - break - - # calculate G - capital_g = gmpy_math.powmod(g, u, n) - - while True: - g = random.randint(1, n-1) - if g % p != 0: - break - - # calculate H - capital_h = gmpy_math.powmod(g, n * u, n) - - # max_plaintext_ must be a power of 2, for ease of use - max_plaintext = pow(10, prime_factor_size // 2) // 2 - - public_key = OUPublicKey(n, capital_g, capital_h, max_plaintext) - private_key = OUPrivateKey(public_key, p, q, t, gp, max_plaintext) - - return public_key, private_key - - -class OUPublicKey(object): - """Contains a public key and associated encryption methods. - """ - - def __init__(self, n, capital_g, capital_h, max_plaintext): - self.n = n # n = p^2 * q - self.capital_g = capital_g # G = g^u mod n for some random g \in [0, n) - self.capital_h = capital_h # H = g'^{n*u} mod n for some random g' \in [0, n) - self.max_plaintext = max_plaintext # always power of 2, e.g. max_plaintext_ == 2^681 - - def __repr__(self): - hashcode = hex(hash(self))[2:] - - return "".format(hashcode[:10]) - - def __eq__(self, other): - return self.n == other.n and self.capital_g == other.capital_g and self.capital_h == other.capital_h - - def __hash__(self): - return hash(self.n) - - # multi H^r - # r is a random number < n - # H and n is public key - def apply_obfuscator(self, ciphertext, random_value=None): - """ - """ - r = random_value or random.SystemRandom().randrange(1, self.n) - obfuscator = gmpy_math.powmod(self.capital_h, r, self.n) - - return (ciphertext * obfuscator) % self.n - - def raw_encrypt(self, plaintext, random_value=None): - """ - """ - if not isinstance(plaintext, int): - raise TypeError("plaintext should be int, but got: %s" % - type(plaintext)) - - if plaintext >= self.max_plaintext: - plaintext -= self.max_plaintext * 2 - - gm = gmpy_math.powmod(self.capital_g, plaintext, self.n) - - ciphertext = self.apply_obfuscator(gm, random_value) - - return ciphertext - - def encrypt(self, value, precision=None, random_value=None): - """Encode and OU encrypt a real number value. - """ - if isinstance(value, FixedPointNumber): - value = value.decode() - encoding = FixedPointNumber.encode(value, self.max_plaintext * 2, self.max_plaintext, precision) - obfuscator = random_value or 1 - ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator) - encryptednumber = OUEncryptedNumber(self, ciphertext, encoding.exponent) - - return encryptednumber - - -class OUPrivateKey(object): - """Contains a private key and associated decryption method. - """ - - def __init__(self, public_key, p, q, t, gp, max_plaintext): - self.public_key = public_key - self.p = p - self.q = q # primes such that log2(p), log2(q) ~ n_bits / 3 - self.t = t # a big prime factor of p - 1, i.e., p = t * u + 1 - self.gp = gp - self.gp_inv = gmpy_math.invert((self.gp - 1) // p, p) # L(g^{p-1} mod p^2))^{-1} mod p - self.p_square = p ** 2 - self.max_plaintext = max_plaintext - - def __repr__(self): - hashcode = hex(hash(self))[2:] - - return "".format(hashcode[:10]) - - def __eq__(self, other): - return self.p == other.p and self.q == other.q and self.t == other.t and self.gp_inv == other.gp_inv - - def __hash__(self): - return hash((self.p, self.q)) - - def raw_decrypt(self, ciphertext): - """return raw plaintext. - """ - if not isinstance(ciphertext, int): - raise TypeError("ciphertext should be an int, not: %s" % - type(ciphertext)) - - plaintext = 0 - - ct = gmpy_math.powmod(ciphertext % self.p_square, self.t, self.p_square) - - plaintext = ((ct // self.p) * self.gp_inv) % self.p - - if plaintext >= self.p / 2: - plaintext -= self.p - if plaintext >= self.max_plaintext: - plaintext = plaintext % (self.max_plaintext * 2) - - return plaintext - - def decrypt(self, encrypted_number): - """return the decrypted & decoded plaintext of encrypted_number. - """ - if not isinstance(encrypted_number, OUEncryptedNumber): - raise TypeError("encrypted_number should be an OUEncryptedNumber, \ - not: %s" % type(encrypted_number)) - - if self.public_key != encrypted_number.public_key: - raise ValueError("encrypted_number was encrypted against a different key!") - - encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False)) - encoded = FixedPointNumber(encoded, - encrypted_number.exponent, - self.public_key.max_plaintext * 2, - self.public_key.max_plaintext) - decrypt_value = encoded.decode() - - return decrypt_value - - -class OUEncryptedNumber(object): - """Represents the OU encryption of a float or int. - """ - - def __init__(self, public_key, ciphertext, exponent=0): - self.public_key = public_key - self.__ciphertext = ciphertext - self.exponent = exponent - self.__is_obfuscator = False - - if not isinstance(self.__ciphertext, int): - raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext)) - - if not isinstance(self.public_key, OUPublicKey): - raise TypeError("public_key should be a OUPublicKey, not: %s" % type(self.public_key)) - - def ciphertext(self, be_secure=True): - """return the ciphertext of the OUEncryptedNumber. - """ - if be_secure and not self.__is_obfuscator: - self.apply_obfuscator() - - return self.__ciphertext - - def apply_obfuscator(self): - """ciphertext by multiplying by H ** r with random r - """ - self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext) - self.__is_obfuscator = True - - def __add__(self, other): - if isinstance(other, OUEncryptedNumber): - return self.__add_encryptednumber(other) - else: - return self.__add_scalar(other) - - def __radd__(self, other): - return self.__add__(other) - - def __sub__(self, other): - - return self + (other * -1) - - def __rsub__(self, other): - return other + (self * -1) - - def __rmul__(self, scalar): - return self.__mul__(scalar) - - def __truediv__(self, scalar): - return self.__mul__(1 / scalar) - - def __mul__(self, scalar): - """return Multiply by an scalar(such as int, float) - """ - if isinstance(scalar, FixedPointNumber): - scalar = scalar.decode() - encode = FixedPointNumber.encode(scalar, self.public_key.max_plaintext * 2, self.public_key.max_plaintext) - plaintext = encode.encoding - - if plaintext < 0 or plaintext >= (self.public_key.max_plaintext * 2): - raise ValueError("Scalar out of bounds: %i" % plaintext) - - if plaintext > self.public_key.max_plaintext: - # Very large plaintext, play a sneaky trick using inverses - plaintext -= self.public_key.max_plaintext * 2 - - ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.n) - - exponent = self.exponent + encode.exponent - - return OUEncryptedNumber(self.public_key, ciphertext, exponent) - - def increase_exponent_to(self, new_exponent): - """return OUEncryptedNumber: - new OUEncryptedNumber with same value but having great exponent. - """ - if new_exponent < self.exponent: - raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent)) - - factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent) - new_encryptednumber = self.__mul__(factor) - new_encryptednumber.exponent = new_exponent - - return new_encryptednumber - - def __align_exponent(self, x, y): - """return x,y with same exponet - """ - if x.exponent < y.exponent: - x = x.increase_exponent_to(y.exponent) - elif x.exponent > y.exponent: - y = y.increase_exponent_to(x.exponent) - - return x, y - - def __add_scalar(self, scalar): - """return OUEncryptedNumber: z = E(x) + y - """ - if isinstance(scalar, FixedPointNumber): - scalar = scalar.decode() - - encoded = FixedPointNumber.encode(scalar, - self.public_key.max_plaintext * 2, - self.public_key.max_plaintext, - max_exponent=self.exponent) - - return self.__add_fixpointnumber(encoded) - - def __add_fixpointnumber(self, encoded): - """return OUEncryptedNumber: z = E(x) + FixedPointNumber(y) - # """ - if self.public_key.max_plaintext != encoded.max_int: - raise ValueError("Attempted to add numbers encoded against different public keys!") - - # their exponents must match, and align. - x, y = self.__align_exponent(self, encoded) - - encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1) - encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent) - - return encryptednumber - - def __add_encryptednumber(self, other): - """return OUEncryptedNumber: z = E(x) + E(y) - """ - if self.public_key != other.public_key: - raise ValueError("add two numbers have different public key!") - - # their exponents must match, and align. - x, y = self.__align_exponent(self, other) - - encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent) - - return encryptednumber - - def __raw_add(self, e_x, e_y, exponent): - """return the integer E(x + y) given ints E(x) and E(y). - """ - ciphertext = gmpy_math.mpz(e_x) * gmpy_math.mpz(e_y) % self.public_key.n - - return OUEncryptedNumber(self.public_key, int(ciphertext), exponent)