Skip to content

Commit 95566b3

Browse files
committed
fix tfb bug
1 parent e274406 commit 95566b3

File tree

4 files changed

+253
-9
lines changed

4 files changed

+253
-9
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ To run the script, one should download the Taxi data following the above instruc
218218

219219
```shell
220220
cd examples
221-
python benchmark_script.py
221+
python run_retweet.py
222222
```
223223

224224

easy_tpp/config_factory/runner_config.py

-3
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def ensure_valid_config(self):
9494
# during testing we dont do shuffle by default
9595
self.trainer_config.shuffle = False
9696

97-
# during testing we dont apply tfb by default
98-
self.trainer_config.use_tfb = False
99-
10097
return
10198

10299
def update_config(self):

easy_tpp/torch_wrapper.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,10 @@ def __init__(self, model, base_config, model_config, trainer_config):
3333
self.opt = set_optimizer(optimizer, self.model.parameters(), self.learning_rate)
3434

3535
# set up tensorboard
36-
self.use_tfb = self.trainer_config.use_tfb
3736
self.train_summary_writer, self.valid_summary_writer = None, None
38-
if self.use_tfb:
39-
self.train_summary_writer = SummaryWriter(log_dir=self.base_config.spec['tfb_train_dir'])
40-
self.valid_summary_writer = SummaryWriter(log_dir=self.base_config.spec['tfb_valid_dir'])
37+
if self.trainer_config.use_tfb:
38+
self.train_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_train_dir'])
39+
self.valid_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_valid_dir'])
4140

4241
def restore(self, ckpt_dir):
4342
"""Load the checkpoint to restore the model.
@@ -64,7 +63,7 @@ def write_summary(self, epoch, kv_pairs, phase):
6463
kv_pairs (dict): metrics dict.
6564
phase (RunnerPhase): a const that defines the stage of model runner.
6665
"""
67-
if self.use_tfb:
66+
if self.trainer_config.use_tfb:
6867
summary_writer = None
6968
if phase == RunnerPhase.TRAIN:
7069
summary_writer = self.train_summary_writer

notebooks/easytpp_2_tfb_wb.ipynb

+248
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_2_tfb_wb.ipynb)\n",
8+
"\n",
9+
"\n",
10+
"# Tutorial 2: Tensorboard and Weights & Biases in EasyTPP\n",
11+
"\n",
12+
"EasyTPP provides built-in support for both Tensorboard and Weights & Biases (W&B) to help you track and visualize your model training. These tools allow you to monitor metrics, compare experiments, and debug your models effectively.\n",
13+
"\n",
14+
"\n",
15+
"## Example of using Tensorboard"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": 4,
21+
"metadata": {
22+
"ExecuteTime": {
23+
"end_time": "2025-02-03T02:24:56.584850Z",
24+
"start_time": "2025-02-03T02:24:56.580600Z"
25+
}
26+
},
27+
"outputs": [],
28+
"source": [
29+
"# As an illustrative example, we write the YAML content to a file\n",
30+
"yaml_content = \"\"\"\n",
31+
"pipeline_config_id: runner_config\n",
32+
"\n",
33+
"data:\n",
34+
" taxi:\n",
35+
" data_format: json\n",
36+
" train_dir: easytpp/taxi # ./data/taxi/train.json\n",
37+
" valid_dir: easytpp/taxi # ./data/taxi/dev.json\n",
38+
" test_dir: easytpp/taxi # ./data/taxi/test.json\n",
39+
" data_specs:\n",
40+
" num_event_types: 10\n",
41+
" pad_token_id: 10\n",
42+
" padding_side: right\n",
43+
"\n",
44+
"\n",
45+
"NHP_train:\n",
46+
" base_config:\n",
47+
" stage: train\n",
48+
" backend: torch\n",
49+
" dataset_id: taxi\n",
50+
" runner_id: std_tpp\n",
51+
" model_id: NHP # model name\n",
52+
" base_dir: './checkpoints/'\n",
53+
" trainer_config:\n",
54+
" batch_size: 256\n",
55+
" max_epoch: 2\n",
56+
" shuffle: False\n",
57+
" optimizer: adam\n",
58+
" learning_rate: 1.e-3\n",
59+
" valid_freq: 1\n",
60+
" use_tfb: True\n",
61+
" metrics: [ 'acc', 'rmse' ]\n",
62+
" seed: 2019\n",
63+
" gpu: -1\n",
64+
" model_config:\n",
65+
" hidden_size: 32\n",
66+
" loss_integral_num_sample_per_step: 20\n",
67+
" thinning:\n",
68+
" num_seq: 10\n",
69+
" num_sample: 1\n",
70+
" num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm\n",
71+
" look_ahead_time: 10\n",
72+
" patience_counter: 5 # the maximum iteration used in adaptive thinning\n",
73+
" over_sample_rate: 5\n",
74+
" num_samples_boundary: 5\n",
75+
" dtime_max: 5\n",
76+
" num_step_gen: 1\n",
77+
"\"\"\"\n",
78+
"\n",
79+
"# Save the content to a file named config.yaml\n",
80+
"with open(\"config.yaml\", \"w\") as file:\n",
81+
" file.write(yaml_content)"
82+
]
83+
},
84+
{
85+
"cell_type": "markdown",
86+
"metadata": {},
87+
"source": [
88+
"Then we run the following command to train the model:"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": 5,
94+
"metadata": {},
95+
"outputs": [
96+
{
97+
"name": "stdout",
98+
"output_type": "stream",
99+
"text": [
100+
"\u001b[31;1m2025-02-03 10:32:32,085 - config.py[pid:91053;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig\u001b[0m\n",
101+
"\u001b[31;1m2025-02-03 10:32:32,089 - runner_config.py[pid:91053;line:161:update_config] - CRITICAL: train model NHP using CPU with torch backend\u001b[0m\n",
102+
"\u001b[38;20m2025-02-03 10:32:32,098 - runner_config.py[pid:91053;line:36:__init__] - INFO: Save the config to ./checkpoints/91053_8345177088_250203-103232/NHP_train_output.yaml\u001b[0m\n",
103+
"\u001b[38;20m2025-02-03 10:32:32,099 - base_runner.py[pid:91053;line:176:save_log] - INFO: Save the log to ./checkpoints/91053_8345177088_250203-103232/log\u001b[0m\n"
104+
]
105+
},
106+
{
107+
"name": "stderr",
108+
"output_type": "stream",
109+
"text": [
110+
"/opt/miniconda3/envs/llm/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
111+
" from .autonotebook import tqdm as notebook_tqdm\n",
112+
"Downloading readme: 100%|██████████| 28.0/28.0 [00:00<00:00, 119B/s]\n"
113+
]
114+
},
115+
{
116+
"name": "stdout",
117+
"output_type": "stream",
118+
"text": [
119+
"0.2244252199397379 0.29228809611195583\n",
120+
"min_dt: 0.000277777777777\n",
121+
"max_dt: 5.721388888888889\n",
122+
"\u001b[38;20m2025-02-03 10:32:38,267 - tpp_runner.py[pid:91053;line:60:_init_model] - INFO: Num of model parameters 15252\u001b[0m\n",
123+
"\u001b[38;20m2025-02-03 10:32:45,909 - base_runner.py[pid:91053;line:98:train] - INFO: Data 'taxi' loaded...\u001b[0m\n",
124+
"\u001b[38;20m2025-02-03 10:32:45,910 - base_runner.py[pid:91053;line:103:train] - INFO: Start NHP training...\u001b[0m\n",
125+
"\u001b[38;20m2025-02-03 10:32:46,425 - tpp_runner.py[pid:91053;line:96:_train_model] - INFO: [ Epoch 0 (train) ]: train loglike is -1.7553733776992408, num_events is 50454\u001b[0m\n",
126+
"\u001b[38;20m2025-02-03 10:32:47,128 - tpp_runner.py[pid:91053;line:107:_train_model] - INFO: [ Epoch 0 (valid) ]: valid loglike is -1.6691416010202664, num_events is 7204, acc is 0.4414214325374792, rmse is 0.3327808472052436\u001b[0m\n",
127+
"\u001b[38;20m2025-02-03 10:32:48,150 - tpp_runner.py[pid:91053;line:122:_train_model] - INFO: [ Epoch 0 (test) ]: test loglike is -1.6577474861303745, num_events is 14420, acc is 0.44667128987517335, rmse is 0.3408341129976238\u001b[0m\n",
128+
"\u001b[31;1m2025-02-03 10:32:48,150 - tpp_runner.py[pid:91053;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.6691 (updated at epoch-0), best updated at this epoch\u001b[0m\n",
129+
"\u001b[38;20m2025-02-03 10:32:48,487 - tpp_runner.py[pid:91053;line:96:_train_model] - INFO: [ Epoch 1 (train) ]: train loglike is -1.6284447180538213, num_events is 50454\u001b[0m\n",
130+
"\u001b[38;20m2025-02-03 10:32:48,995 - tpp_runner.py[pid:91053;line:107:_train_model] - INFO: [ Epoch 1 (valid) ]: valid loglike is -1.5259201159945863, num_events is 7204, acc is 0.4582176568573015, rmse is 0.33537458414488913\u001b[0m\n",
131+
"\u001b[38;20m2025-02-03 10:32:49,999 - tpp_runner.py[pid:91053;line:122:_train_model] - INFO: [ Epoch 1 (test) ]: test loglike is -1.5121817706527392, num_events is 14420, acc is 0.45977808599167824, rmse is 0.34166548827945314\u001b[0m\n",
132+
"\u001b[31;1m2025-02-03 10:32:50,000 - tpp_runner.py[pid:91053;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.5259 (updated at epoch-1), best updated at this epoch\u001b[0m\n",
133+
"\u001b[38;20m2025-02-03 10:32:50,000 - base_runner.py[pid:91053;line:110:train] - INFO: End NHP train! Cost time: 0.068m\u001b[0m\n"
134+
]
135+
}
136+
],
137+
"source": [
138+
"from easy_tpp.config_factory import Config\n",
139+
"from easy_tpp.runner import Runner\n",
140+
"\n",
141+
"config = Config.build_from_yaml_file('./config.yaml', experiment_id='NHP_train')\n",
142+
"\n",
143+
"model_runner = Runner.build_from_config(config)\n",
144+
"\n",
145+
"model_runner.run()"
146+
]
147+
},
148+
{
149+
"cell_type": "markdown",
150+
"metadata": {
151+
"vscode": {
152+
"languageId": "plaintext"
153+
}
154+
},
155+
"source": [
156+
"After the training is done, we can see the tensorboard files in the `./checkpoints/` directory. "
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": 9,
162+
"metadata": {},
163+
"outputs": [
164+
{
165+
"name": "stdout",
166+
"output_type": "stream",
167+
"text": [
168+
"\u001b[34mcheckpoints\u001b[m\u001b[m easytpp_1_dataset.ipynb\n",
169+
"config.yaml easytpp_2_tfb_wb.ipynb\n",
170+
"\n",
171+
"./checkpoints:\n",
172+
"\u001b[34m91053_8345177088_250203-103232\u001b[m\u001b[m\n",
173+
"\n",
174+
"./checkpoints/91053_8345177088_250203-103232:\n",
175+
"NHP_train_output.yaml \u001b[34mmodels\u001b[m\u001b[m \u001b[34mtfb_valid\u001b[m\u001b[m\n",
176+
"log \u001b[34mtfb_train\u001b[m\u001b[m\n",
177+
"\n",
178+
"./checkpoints/91053_8345177088_250203-103232/models:\n",
179+
"saved_model\n",
180+
"\n",
181+
"./checkpoints/91053_8345177088_250203-103232/tfb_train:\n",
182+
"events.out.tfevents.1738549958.siqiaodeMacBook-Pro.local.91053.0\n",
183+
"\n",
184+
"./checkpoints/91053_8345177088_250203-103232/tfb_valid:\n",
185+
"events.out.tfevents.1738549958.siqiaodeMacBook-Pro.local.91053.1\n"
186+
]
187+
}
188+
],
189+
"source": [
190+
"!ls -R"
191+
]
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"metadata": {},
196+
"source": [
197+
"Then we can use the following script to visualize the training process:"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"metadata": {},
204+
"outputs": [
205+
{
206+
"name": "stdout",
207+
"output_type": "stream",
208+
"text": [
209+
"TensorFlow installation not found - running with reduced feature set.\n",
210+
"Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all\n",
211+
"TensorBoard 2.17.1 at http://localhost:6006/ (Press CTRL+C to quit)\n"
212+
]
213+
}
214+
],
215+
"source": [
216+
"! tensorboard --logdir \"./checkpoints/91053_8345177088_250203-103232/tfb_train/\""
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": null,
222+
"metadata": {},
223+
"outputs": [],
224+
"source": []
225+
}
226+
],
227+
"metadata": {
228+
"kernelspec": {
229+
"display_name": "Python 3 (ipykernel)",
230+
"language": "python",
231+
"name": "python3"
232+
},
233+
"language_info": {
234+
"codemirror_mode": {
235+
"name": "ipython",
236+
"version": 3
237+
},
238+
"file_extension": ".py",
239+
"mimetype": "text/x-python",
240+
"name": "python",
241+
"nbconvert_exporter": "python",
242+
"pygments_lexer": "ipython3",
243+
"version": "3.10.14"
244+
}
245+
},
246+
"nbformat": 4,
247+
"nbformat_minor": 4
248+
}

0 commit comments

Comments
 (0)