@@ -61,6 +61,7 @@ def build_causal_mask_cache(
61
61
size : int ,
62
62
dtype : torch .dtype = torch .float32 ,
63
63
device : torch .device = None ,
64
+ mask_value : float = float ('-inf' ),
64
65
) -> torch .Tensor :
65
66
"""Build a cache for causal attention mask.
66
67
@@ -70,14 +71,16 @@ def build_causal_mask_cache(
70
71
torch.float32.
71
72
device (torch.device, optional): Output tensor's data type. Defaults to
72
73
None in which case "cpu" is used.
74
+ mask_value (float, optional): The value to set the mask to. Defaults to
75
+ float('-inf').
73
76
74
77
Returns:
75
78
torch.Tensor: Causal attention mask.
76
79
"""
77
80
78
81
if device is None :
79
82
device = torch .device ('cpu' )
80
- mask = torch .full ((size , size ), float ( '-inf' ) , dtype = dtype , device = device )
83
+ mask = torch .full ((size , size ), mask_value , dtype = dtype , device = device )
81
84
return torch .triu (mask , diagonal = 1 ).unsqueeze (0 ).unsqueeze (0 )
82
85
83
86
@@ -86,6 +89,7 @@ def build_sliding_window_mask_cache(
86
89
window_size : int ,
87
90
dtype : torch .dtype = torch .float32 ,
88
91
device : torch .device = None ,
92
+ mask_value : float = float ('-inf' ),
89
93
) -> torch .Tensor :
90
94
"""Build a cache for a sliding window mask.
91
95
@@ -96,18 +100,20 @@ def build_sliding_window_mask_cache(
96
100
torch.float32.
97
101
device (torch.device, optional): Output tensor's data type. Defaults to
98
102
None in which case "cpu" is used.
103
+ mask_value (float, optional): The value to set the mask to. Defaults to
104
+ float('-inf').
99
105
100
106
Returns:
101
107
torch.Tensor: Causal attention mask.
102
108
"""
103
109
104
- mask = build_causal_mask_cache (size , dtype , device )
110
+ mask = build_causal_mask_cache (size , dtype , device , mask_value )
105
111
all_ones = torch .ones_like (mask )
106
112
window_size = min (size , window_size )
107
113
sliding_mask = torch .triu (all_ones , - 1 * window_size + 1 ) * torch .tril (
108
114
all_ones , window_size - 1
109
115
)
110
- return torch .where (sliding_mask == 1 , mask , float ( '-inf' ) )
116
+ return torch .where (sliding_mask == 1 , mask , mask_value )
111
117
112
118
113
119
def relative_position_bucket (
0 commit comments