Skip to content

Commit 2998b61

Browse files
authored
Removed unused mask_params return value (#1626)
1 parent be69a74 commit 2998b61

File tree

2 files changed

+3
-15
lines changed

2 files changed

+3
-15
lines changed

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def _build_stream_data_input(
379379
stream_data.source_is_spoof = rdata.is_spoof
380380

381381
# preprocess data for model input
382-
(source_cells, source_cells_lens, mask_state) = self.tokenizer.get_source(
382+
(source_cells, source_cells_lens) = self.tokenizer.get_source(
383383
stream_info,
384384
rdata,
385385
token_data,

src/weathergen/datasets/tokenizer_masking.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,7 @@ def get_source(
126126
if is_diagnostic or rdata.data.shape[1] == 0 or len(rdata.data) < 2:
127127
source_tokens_cells = [torch.tensor([])]
128128
source_tokens_lens = torch.zeros([self.num_healpix_cells_source], dtype=torch.int32)
129-
mask_state = {
130-
"strategy": self.masker.current_strategy,
131-
"mask_tokens": None,
132-
"mask_channels": None,
133-
}
134-
return (source_tokens_cells, source_tokens_lens, mask_state)
129+
return (source_tokens_cells, source_tokens_lens)
135130

136131
# create tokenization index
137132
(idxs_cells, idxs_cells_lens) = idxs_cells_data
@@ -154,14 +149,7 @@ def get_source(
154149
encode_times_source,
155150
)
156151

157-
# capture per-view mask state to later produce consistent targets
158-
mask_state = {
159-
"strategy": None, # self.masker.current_strategy,
160-
"mask_tokens": mask_tokens,
161-
"mask_channels": mask_channels,
162-
}
163-
164-
return (source_tokens_cells, source_tokens_lens, mask_state)
152+
return (source_tokens_cells, source_tokens_lens)
165153

166154
# batchify_target_for_view now unified into batchify_target via optional mask_state
167155

0 commit comments

Comments
 (0)