Skip to content

Commit 0443734

Browse files
authored
ENH Validation for task_type in PEFT config (#2210)
Raises an error when invalid task type is provided.
1 parent 0155fa8 commit 0443734

File tree

23 files changed

+48
-3
lines changed

23 files changed

+48
-3
lines changed

src/peft/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,19 @@ class PeftConfigMixin(PushToHubMixin):
5555
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
5656
"""
5757

58+
task_type: Optional[TaskType] = field(default=None, metadata={"help": "The type of task."})
5859
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
5960
auto_mapping: Optional[dict] = field(
6061
default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
6162
)
6263

64+
def __post_init__(self):
65+
# check for invalid task type
66+
if (self.task_type is not None) and (self.task_type not in list(TaskType)):
67+
raise ValueError(
68+
f"Invalid task type: '{self.task_type}'. Must be one of the following task types: {', '.join(TaskType)}."
69+
)
70+
6371
def to_dict(self) -> Dict:
6472
r"""
6573
Returns the configuration for your adapter model as a dictionary.

src/peft/tuners/adalora/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class AdaLoraConfig(LoraConfig):
5050
rank_pattern: Optional[dict] = field(default=None, metadata={"help": "The saved rank pattern."})
5151

5252
def __post_init__(self):
53+
super().__post_init__()
5354
self.peft_type = PeftType.ADALORA
5455

5556
if self.use_dora:

src/peft/tuners/adaption_prompt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AdaptionPromptConfig(PeftConfig):
3232
adapter_layers: int = field(default=None, metadata={"help": "Number of adapter layers (from the top)"})
3333

3434
def __post_init__(self):
35+
super().__post_init__()
3536
self.peft_type = PeftType.ADAPTION_PROMPT
3637

3738
@property

src/peft/tuners/boft/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class BOFTConfig(PeftConfig):
139139
)
140140

141141
def __post_init__(self):
142+
super().__post_init__()
142143
self.peft_type = PeftType.BOFT
143144
self.target_modules = (
144145
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

src/peft/tuners/bone/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class BoneConfig(PeftConfig):
108108
)
109109

110110
def __post_init__(self):
111+
super().__post_init__()
111112
self.peft_type = PeftType.BONE
112113
self.target_modules = (
113114
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

src/peft/tuners/fourierft/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class FourierFTConfig(PeftConfig):
185185
)
186186

187187
def __post_init__(self):
188+
super().__post_init__()
188189
self.peft_type = PeftType.FOURIERFT
189190
self.target_modules = (
190191
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

src/peft/tuners/hra/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class HRAConfig(PeftConfig):
115115
)
116116

117117
def __post_init__(self):
118+
super().__post_init__()
118119
self.peft_type = PeftType.HRA
119120
self.target_modules = (
120121
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

src/peft/tuners/ia3/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class IA3Config(PeftConfig):
9494
)
9595

9696
def __post_init__(self):
97+
super().__post_init__()
9798
self.peft_type = PeftType.IA3
9899
self.target_modules = (
99100
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

src/peft/tuners/ln_tuning/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,5 @@ class LNTuningConfig(PeftConfig):
6666
)
6767

6868
def __post_init__(self):
69+
super().__post_init__()
6970
self.peft_type = PeftType.LN_TUNING

src/peft/tuners/loha/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class LoHaConfig(LycorisConfig):
126126
)
127127

128128
def __post_init__(self):
129+
super().__post_init__()
129130
self.peft_type = PeftType.LOHA
130131
self.target_modules = (
131132
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

0 commit comments

Comments
 (0)