Skip to content

Commit 872d50b

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:nlp_seq] Create a main.py file to run tests with config files to match other examples. #jax-fixit
PiperOrigin-RevId: 840822537
1 parent 881b8ed commit 872d50b

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Default hyperparameters for NLP sequence tagging."""
16+
17+
import ml_collections
18+
19+
20+
def get_config():
21+
"""Get the default hyperparameter configuration."""
22+
config = ml_collections.ConfigDict()
23+
24+
# Model directory for checkpoints and logs
25+
config.model_dir = ''
26+
27+
# Experiment name
28+
config.experiment = 'xpos'
29+
30+
# Training hyperparameters
31+
config.batch_size = 64
32+
config.num_train_steps = 75000
33+
config.eval_frequency = 100
34+
35+
# Optimizer hyperparameters
36+
config.learning_rate = 0.05
37+
config.weight_decay = 1e-1
38+
39+
# Model hyperparameters
40+
config.max_length = 256
41+
42+
# Random seed
43+
config.random_seed = 0
44+
45+
# Data paths
46+
config.train = ''
47+
config.dev = ''
48+
49+
return config

examples/nlp_seq/main.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Main file for running the NLP sequence tagging example.
16+
17+
This file is intentionally kept short to allow config-based execution.
18+
"""
19+
20+
from absl import app
21+
from absl import flags
22+
import train
23+
from ml_collections import config_flags
24+
25+
26+
FLAGS = flags.FLAGS
27+
28+
config_flags.DEFINE_config_file(
29+
'config',
30+
'configs/default.py',
31+
'File path to the training hyperparameter configuration.',
32+
lock_config=True,
33+
)
34+
35+
36+
def main(argv):
37+
if len(argv) > 1:
38+
raise app.UsageError('Too many command-line arguments.')
39+
40+
# Convert config to FLAGS for train.py compatibility
41+
config = FLAGS.config
42+
43+
# Override FLAGS with config values
44+
FLAGS.model_dir = config.model_dir
45+
FLAGS.experiment = config.experiment
46+
FLAGS.batch_size = config.batch_size
47+
FLAGS.eval_frequency = config.eval_frequency
48+
FLAGS.num_train_steps = config.num_train_steps
49+
FLAGS.learning_rate = config.learning_rate
50+
FLAGS.weight_decay = config.weight_decay
51+
FLAGS.max_length = config.max_length
52+
FLAGS.random_seed = config.random_seed
53+
FLAGS.train = config.train
54+
FLAGS.dev = config.dev
55+
56+
# Run the training
57+
train.main(argv)
58+
59+
60+
if __name__ == '__main__':
61+
app.run(main)

0 commit comments

Comments
 (0)