diff --git a/src/trustyai/language/detoxify/tmarco.py b/src/trustyai/language/detoxify/tmarco.py index 01a9840..73904fa 100644 --- a/src/trustyai/language/detoxify/tmarco.py +++ b/src/trustyai/language/detoxify/tmarco.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # pylint: disable = invalid-name, line-too-long, use-dict-literal, consider-using-f-string, too-many-nested-blocks, self-assigning-variable """TMaRCo detoxification.""" +import os import math import itertools import numpy as np @@ -17,6 +18,7 @@ TrainingArguments, pipeline, AutoModelForCausalLM, + AutoModelForMaskedLM, BartForConditionalGeneration, BartTokenizer, AutoModelForSeq2SeqLM, @@ -102,18 +104,33 @@ def __init__( "cuda" if torch.cuda.is_available() else "cpu" ) - def load_models(self, experts: list, expert_weights: list = None): + def load_models(self, experts: list[str] = None, expert_weights: list = None): """Load expert models.""" if expert_weights is not None: self.expert_weights = expert_weights expert_models = [] for expert in experts: - if isinstance(expert, str): + # Load TMaRCO models + if (expert == "trustyai/gplus" or expert == "trustyai/gminus"): expert = BartForConditionalGeneration.from_pretrained( expert, forced_bos_token_id=self.tokenizer.bos_token_id, device_map="auto", ) + # Load local models + elif os.path.exists(os.path.dirname(expert)): + expert = AutoModelForMaskedLM.from_pretrained( + expert, + forced_bos_token_id=self.tokenizer.bos_token_id, + device_map = "auto" + ) + # Load HuggingFace models + else: + expert = AutoModelForCausalLM.from_pretrained( + expert, + forced_bos_token_id=self.tokenizer.bos_token_id, + device_map = "auto" + ) expert_models.append(expert) self.experts = expert_models