diff --git a/claasp/components/theta_gaston_component.py b/claasp/components/theta_gaston_component.py index c4145006..6c91c5db 100644 --- a/claasp/components/theta_gaston_component.py +++ b/claasp/components/theta_gaston_component.py @@ -1,3 +1,6 @@ +import os +import pickle +from typing import Any # **************************************************************************** # Copyright 2023 Technology Innovation Institute @@ -22,11 +25,37 @@ from claasp.component import linear_layer_to_binary_matrix from claasp.components.linear_layer_component import LinearLayer +# Global matrix cache +_cached_matrices: dict[str, Any] = {} + +# File to persist the matrix cache +THIS_DIR = os.path.dirname(__file__) +ROOT_DIR = os.path.abspath(os.path.join(THIS_DIR, "..")) +CACHE_DIR = os.path.join(ROOT_DIR, "ciphers", "permutations") +os.makedirs(CACHE_DIR, exist_ok=True) + +def _matrix_cache_path(cipher_id): + return os.path.join(CACHE_DIR, f"gaston_theta_{cipher_id}.pkl") class ThetaGaston(LinearLayer): def __init__(self, current_round_number, current_round_number_of_components, input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): - binary_matrix = linear_layer_to_binary_matrix(THETA_GASTON, output_bit_size, output_bit_size, [rotation_amounts_parameter]) + + matrix_id = "_".join(str(p) for p in rotation_amounts_parameter) + if matrix_id in _cached_matrices: + binary_matrix = _cached_matrices[matrix_id] + else: + path = _matrix_cache_path(matrix_id) + if os.path.exists(path): + with open(path, "rb") as f: + binary_matrix = pickle.load(f) + else: + binary_matrix = linear_layer_to_binary_matrix( + THETA_GASTON, output_bit_size, output_bit_size, [rotation_amounts_parameter]) + with open(path, "wb") as f: + pickle.dump(binary_matrix, f) + + _cached_matrices[matrix_id] = binary_matrix description = list(binary_matrix.transpose()) super().__init__(current_round_number, current_round_number_of_components, input_id_links, input_bit_positions, output_bit_size, description)