Skip to content

Commit 8d6a4e0

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:seq2seq] Create main and default config based on seq2seq.ipynb.
PiperOrigin-RevId: 838872307
1 parent 697f4e5 commit 8d6a4e0

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2025 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Default Hyperparameter configuration."""
16+
17+
import ml_collections
18+
19+
20+
def get_config():
21+
"""Get the default hyperparameter configuration."""
22+
config = ml_collections.ConfigDict()
23+
24+
config.workdir = '/tmp/seq2seq'
25+
config.learning_rate = 0.003
26+
config.batch_size = 128
27+
config.hidden_size = 512
28+
config.num_train_steps = 10000
29+
config.decode_frequency = 200
30+
config.max_len_query_digit = 3
31+
32+
return config

examples/seq2seq/main.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Main script for seq2seq example."""
16+
17+
from absl import app
18+
from absl import flags
19+
from absl import logging
20+
import train
21+
from ml_collections import config_flags
22+
23+
FLAGS = flags.FLAGS
24+
25+
config_flags.DEFINE_config_file(
26+
'config',
27+
None,
28+
'File path to the training hyperparameter configuration.',
29+
lock_config=True,
30+
)
31+
32+
33+
def main(argv):
34+
del argv
35+
36+
config = FLAGS.config
37+
38+
# Set train.FLAGS values from config
39+
train.FLAGS.workdir = config.workdir
40+
train.FLAGS.learning_rate = config.learning_rate
41+
train.FLAGS.batch_size = config.batch_size
42+
train.FLAGS.hidden_size = config.hidden_size
43+
train.FLAGS.num_train_steps = config.num_train_steps
44+
train.FLAGS.decode_frequency = config.decode_frequency
45+
train.FLAGS.max_len_query_digit = config.max_len_query_digit
46+
47+
logging.info('Starting training with config: %s', config)
48+
_ = train.train_and_evaluate(train.FLAGS.workdir)
49+
50+
51+
if __name__ == '__main__':
52+
app.run(main)

0 commit comments

Comments
 (0)