Skip to content

Commit 0bda69b

Browse files
ai-edge-botcopybara-github
authored andcommitted
Internal changes only
PiperOrigin-RevId: 743058543
1 parent 5f22c45 commit 0bda69b

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

ai_edge_torch/generative/layers/attention_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def build_causal_mask_cache(
6161
size: int,
6262
dtype: torch.dtype = torch.float32,
6363
device: torch.device = None,
64+
mask_value: float = float('-inf'),
6465
) -> torch.Tensor:
6566
"""Build a cache for causal attention mask.
6667
@@ -70,14 +71,16 @@ def build_causal_mask_cache(
7071
torch.float32.
7172
device (torch.device, optional): Output tensor's data type. Defaults to
7273
None in which case "cpu" is used.
74+
mask_value (float, optional): The value to set the mask to. Defaults to
75+
float('-inf').
7376
7477
Returns:
7578
torch.Tensor: Causal attention mask.
7679
"""
7780

7881
if device is None:
7982
device = torch.device('cpu')
80-
mask = torch.full((size, size), float('-inf'), dtype=dtype, device=device)
83+
mask = torch.full((size, size), mask_value, dtype=dtype, device=device)
8184
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
8285

8386

@@ -86,6 +89,7 @@ def build_sliding_window_mask_cache(
8689
window_size: int,
8790
dtype: torch.dtype = torch.float32,
8891
device: torch.device = None,
92+
mask_value: float = float('-inf'),
8993
) -> torch.Tensor:
9094
"""Build a cache for a sliding window mask.
9195
@@ -96,18 +100,20 @@ def build_sliding_window_mask_cache(
96100
torch.float32.
97101
device (torch.device, optional): Output tensor's data type. Defaults to
98102
None in which case "cpu" is used.
103+
mask_value (float, optional): The value to set the mask to. Defaults to
104+
float('-inf').
99105
100106
Returns:
101107
torch.Tensor: Causal attention mask.
102108
"""
103109

104-
mask = build_causal_mask_cache(size, dtype, device)
110+
mask = build_causal_mask_cache(size, dtype, device, mask_value)
105111
all_ones = torch.ones_like(mask)
106112
window_size = min(size, window_size)
107113
sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
108114
all_ones, window_size - 1
109115
)
110-
return torch.where(sliding_mask == 1, mask, float('-inf'))
116+
return torch.where(sliding_mask == 1, mask, mask_value)
111117

112118

113119
def relative_position_bucket(

0 commit comments

Comments
 (0)