Skip to content

Commit

Permalink
Add ability to load local and HF models (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
christinaexyou authored Sep 6, 2024
1 parent fa3e66a commit a4a3da8
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/trustyai/language/detoxify/tmarco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable = invalid-name, line-too-long, use-dict-literal, consider-using-f-string, too-many-nested-blocks, self-assigning-variable
"""TMaRCo detoxification."""
import os
import math
import itertools
import numpy as np
Expand All @@ -17,6 +18,7 @@
TrainingArguments,
pipeline,
AutoModelForCausalLM,
AutoModelForMaskedLM,
BartForConditionalGeneration,
BartTokenizer,
AutoModelForSeq2SeqLM,
Expand Down Expand Up @@ -102,18 +104,33 @@ def __init__(
"cuda" if torch.cuda.is_available() else "cpu"
)

def load_models(self, experts: list, expert_weights: list = None):
def load_models(self, experts: list[str] = None, expert_weights: list = None):
"""Load expert models."""
if expert_weights is not None:
self.expert_weights = expert_weights
expert_models = []
for expert in experts:
if isinstance(expert, str):
# Load TMaRCO models
if (expert == "trustyai/gplus" or expert == "trustyai/gminus"):
expert = BartForConditionalGeneration.from_pretrained(
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
# Load local models
elif os.path.exists(os.path.dirname(expert)):
expert = AutoModelForMaskedLM.from_pretrained(
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map = "auto"
)
# Load HuggingFace models
else:
expert = AutoModelForCausalLM.from_pretrained(
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map = "auto"
)
expert_models.append(expert)
self.experts = expert_models

Expand Down

0 comments on commit a4a3da8

Please sign in to comment.