Skip to content

Commit 1bc7bbb

Browse files
authored
Merge pull request #129 from NoahSchiro/fp16-fix
BUG FIX: fp16 "TypeError: unflatten_dense_tensors(): argument 'tensors' (position 2) must be tuple of Tensors, not generator"
2 parents 783b674 + 7bac936 commit 1bc7bbb

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

improved_diffusion/fp16_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def unflatten_master_params(model_params, master_params):
6565
"""
6666
Unflatten the master parameters to look like model_params.
6767
"""
68-
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
68+
return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params))
6969

7070

7171
def zero_grad(model_params):

0 commit comments

Comments
 (0)