Skip to content

Commit 5432f60

Browse files
committed
fix(security): harden checkpoint loading and fix LoRA initialization bugs
- loader: use precise .pt detection (glob: *.pt, suffix: .pt) - loader: enforce torch.load on CPU with weights_only for safety - lora: use torch.rand for float dtype in random() to avoid runtime error - code cleanup: remove unused code and simplify operations
1 parent 46ff6d0 commit 5432f60

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

ai_edge_torch/generative/layers/lora.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,7 @@ def random(
268268
LoRA weights with random values.
269269
"""
270270
return cls._from_tensor_generator(
271-
tensor_generator=lambda shape, dtype: torch.randint(
272-
low=0, high=128, size=shape, dtype=dtype
273-
),
271+
tensor_generator=lambda shape, dtype: torch.rand(shape, dtype=dtype),
274272
rank=rank,
275273
config=config,
276274
dtype=dtype,

ai_edge_torch/generative/utilities/loader.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,15 @@ def get_custom_loader(
4848
if checkpoint_format == "safetensors":
4949
return load_file
5050
if checkpoint_format == "pt":
51-
return lambda path: torch.load(path, weights_only=True)
51+
return lambda path: torch.load(
52+
path, weights_only=True, map_location=torch.device("cpu")
53+
)
5254
raise ValueError(f"Unsupported checkpoint format: {checkpoint_format}")
5355

5456
if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
55-
return lambda path: torch.load(path, weights_only=True)
57+
return lambda path: torch.load(
58+
path, weights_only=True, map_location=torch.device("cpu")
59+
)
5660
if checkpoint_path.endswith(".safetensors"):
5761
return load_file
5862
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
@@ -126,7 +130,7 @@ def load_pytorch_statedict(full_path: str):
126130
patterns = []
127131
if os.path.isdir(full_path):
128132
patterns.append(os.path.join(full_path, "*.bin"))
129-
patterns.append(os.path.join(full_path, "*pt"))
133+
patterns.append(os.path.join(full_path, "*.pt"))
130134
else:
131135
patterns.append(full_path)
132136
for pattern in patterns:
@@ -135,7 +139,9 @@ def load_pytorch_statedict(full_path: str):
135139

136140
tensors = {}
137141
for file in files:
138-
this_file_tensors = torch.load(file)
142+
this_file_tensors = torch.load(
143+
file, map_location=torch.device("cpu"), weights_only=True
144+
)
139145
for k in this_file_tensors:
140146
assert k not in tensors
141147
tensors.update(this_file_tensors)
@@ -279,14 +285,14 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
279285
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
280286
return load_safetensors
281287
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
282-
os.path.join(self._file_name, "*pt")
288+
os.path.join(self._file_name, "*.pt")
283289
):
284290
return load_pytorch_statedict
285291

286292
if self._file_name.endswith(".safetensors"):
287293
return load_safetensors
288294

289-
if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
295+
if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
290296
return load_pytorch_statedict
291297

292298
raise ValueError("File format not supported.")

0 commit comments

Comments
 (0)