Skip to content

Commit 3756bf1

Browse files
Add support for specifying revisions when pushing to Hub via internal Trainer call (#36852)
* Update training_args.py * Update trainer.py * fixes * fix * remove extraneous comments * explicit revision arg * add msg * fixup * fix field name * rename field revision to hub_revision * restore gradient_checkpointing doc * fix ws --------- Co-authored-by: Arthur <[email protected]>
1 parent 458e0b3 commit 3756bf1

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/transformers/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3938,7 +3938,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
39383938

39393939
# Push to the Hub when `save_model` is called by the user.
39403940
if self.args.push_to_hub and not _internal_call:
3941-
self.push_to_hub(commit_message="Model save")
3941+
self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
39423942

39433943
def _save_tpu(self, output_dir: Optional[str] = None):
39443944
output_dir = output_dir if output_dir is not None else self.args.output_dir
@@ -4788,6 +4788,7 @@ def _push_from_checkpoint(self, checkpoint_folder):
47884788
token=self.args.hub_token,
47894789
run_as_future=True,
47904790
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
4791+
revision=self.args.hub_revision,
47914792
)
47924793

47934794
push_jobs = [model_push_job]
@@ -4803,6 +4804,7 @@ def _push_from_checkpoint(self, checkpoint_folder):
48034804
commit_message=commit_message + ", checkpoint",
48044805
token=self.args.hub_token,
48054806
run_as_future=True,
4807+
revision=self.args.hub_revision,
48064808
)
48074809
push_jobs.append(checkpoint_push)
48084810

@@ -4882,8 +4884,12 @@ def push_to_hub(
48824884

48834885
self.create_model_card(model_name=model_name, **kwargs)
48844886

4887+
if revision is None:
4888+
revision = self.args.hub_revision
4889+
48854890
# Wait for the current upload to be finished.
48864891
self._finish_current_push()
4892+
48874893
return upload_folder(
48884894
repo_id=self.hub_model_id,
48894895
folder_path=self.args.output_dir,

src/transformers/training_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,8 @@ class TrainingArguments:
693693
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
694694
hub_always_push (`bool`, *optional*, defaults to `False`):
695695
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
696+
hub_revision (`str`, *optional*):
697+
The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
696698
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
697699
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
698700
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
@@ -1361,6 +1363,12 @@ class TrainingArguments:
13611363
default=False,
13621364
metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
13631365
)
1366+
hub_revision: Optional[str] = field(
1367+
default=None,
1368+
metadata={
1369+
"help": "The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash."
1370+
},
1371+
)
13641372
gradient_checkpointing: bool = field(
13651373
default=False,
13661374
metadata={
@@ -2861,6 +2869,7 @@ def set_push_to_hub(
28612869
token: Optional[str] = None,
28622870
private_repo: Optional[bool] = None,
28632871
always_push: bool = False,
2872+
revision: Optional[str] = None,
28642873
):
28652874
"""
28662875
A method that regroups all arguments linked to synchronizing checkpoints with the Hub.
@@ -2904,6 +2913,8 @@ def set_push_to_hub(
29042913
always_push (`bool`, *optional*, defaults to `False`):
29052914
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not
29062915
finished.
2916+
revision (`str`, *optional*):
2917+
The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
29072918
29082919
Example:
29092920
@@ -2922,6 +2933,7 @@ def set_push_to_hub(
29222933
self.hub_token = token
29232934
self.hub_private_repo = private_repo
29242935
self.hub_always_push = always_push
2936+
self.hub_revision = revision
29252937
return self
29262938

29272939
def set_optimizer(

0 commit comments

Comments
 (0)