Skip to content

Commit 6cfc32d

Browse files
committed
feat(RHOAIENG-22430): Add custom YAML task support to LMEval
1 parent 7042777 commit 6cfc32d

File tree

9 files changed

+625
-58
lines changed

9 files changed

+625
-58
lines changed

api/lmes/v1alpha1/lmevaljob_types.go

+39
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,52 @@ type TaskRecipe struct {
159159
DemosPoolSize *int `json:"demosPoolSize,omitempty"`
160160
}
161161

162+
// GitSource specifies the git location of external tasks
163+
type GitSource struct {
164+
// URL specifies the git repository URL
165+
URL string `json:"url,omitempty"`
166+
// Branch specifies the git branch to use
167+
// +optional
168+
Branch *string `json:"branch,omitempty"`
169+
// Commit specifies the git commit to use
170+
// +optional
171+
Commit *string `json:"commit,omitempty"`
172+
// Path specifies the path to the task file
173+
// +optional
174+
Path string `json:"path,omitempty"`
175+
}
176+
177+
// CustomTaskSource specifies the source of custom tasks
178+
type CustomTaskSource struct {
179+
// GitSource specifies the git location of external tasks
180+
GitSource GitSource `json:"git,omitempty"`
181+
}
182+
183+
// CustomTasks specifies the custom (external) tasks to use
184+
type CustomTasks struct {
185+
// Source specifies the source location of custom tasks
186+
Source CustomTaskSource `json:"source,omitempty"`
187+
// TaskNames specifies the names of the external tasks to use
188+
TaskNames []string `json:"taskNames,omitempty"`
189+
}
190+
162191
type TaskList struct {
163192
// TaskNames from lm-eval's task list
164193
TaskNames []string `json:"taskNames,omitempty"`
165194
// Task Recipes specifically for Unitxt
166195
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
167196
// Custom Unitxt artifacts that can be used in a TaskRecipe
168197
CustomArtifacts *CustomArtifacts `json:"custom,omitempty"`
198+
// CustomTasks is a list of external tasks
199+
CustomTasks *CustomTasks `json:"customTasks,omitempty"`
200+
}
201+
202+
func (t *TaskList) HasCustomTasks() bool {
203+
return t.CustomTasks != nil
204+
}
205+
206+
func (t *TaskList) HasCustomTasksWithGit() bool {
207+
return t.CustomTasks != nil && t.CustomTasks.Source.GitSource.URL != ""
169208
}
170209

171210
// Use the tp_idx and sp_idx to point to the corresponding custom template

api/lmes/v1alpha1/zz_generated.deepcopy.go

+67
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/lmes_driver/main.go

+24-11
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,19 @@ var (
5454
customCards strArrayArg
5555
customTemplates strArrayArg
5656
customSystemPrompts strArrayArg
57+
taskNames strArrayArg
5758
copy = flag.String("copy", "", "copy this binary to specified destination path")
5859
getStatus = flag.Bool("get-status", false, "Get current status")
5960
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
6061
outputPath = flag.String("output-path", OutputPath, "output path")
6162
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
6263
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
6364
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
65+
customTaskGitURL = flag.String("custom-task-git-url", "", "Git repository URL for custom tasks")
66+
customTaskGitBranch = flag.String("custom-task-git-branch", "", "Git repository branch for custom tasks")
67+
customTaskGitCommit = flag.String("custom-task-git-commit", "", "Git commit for custom tasks")
68+
customTaskGitPath = flag.String("custom-task-git-path", "", "Custom task path")
69+
allowOnline = flag.Bool("allow-online", false, "Allow LMEval online access")
6470
driverLog = ctrl.Log.WithName("driver")
6571
)
6672

@@ -69,6 +75,7 @@ func init() {
6975
flag.Var(&customCards, "custom-card", "A JSON string represents a custom card")
7076
flag.Var(&customTemplates, "custom-template", "A JSON string represents a custom template")
7177
flag.Var(&customSystemPrompts, "custom-prompt", "A string represents a custom system_prompt")
78+
flag.Var(&taskNames, "task-name", "A task name for custom tasks")
7279
}
7380

7481
func main() {
@@ -111,17 +118,23 @@ func main() {
111118
}
112119

113120
driverOpt := driver.DriverOption{
114-
Context: ctx,
115-
OutputPath: *outputPath,
116-
DetectDevice: *detectDevice,
117-
Logger: driverLog,
118-
TaskRecipes: taskRecipes,
119-
CustomCards: customCards,
120-
CustomTemplates: customTemplates,
121-
CustomSystemPrompt: customSystemPrompts,
122-
Args: args,
123-
CommPort: *commPort,
124-
DownloadAssetsS3: *downloadAssetsS3,
121+
Context: ctx,
122+
OutputPath: *outputPath,
123+
DetectDevice: *detectDevice,
124+
Logger: driverLog,
125+
TaskRecipes: taskRecipes,
126+
CustomCards: customCards,
127+
CustomTemplates: customTemplates,
128+
CustomSystemPrompt: customSystemPrompts,
129+
Args: args,
130+
CommPort: *commPort,
131+
DownloadAssetsS3: *downloadAssetsS3,
132+
CustomTaskGitURL: *customTaskGitURL,
133+
CustomTaskGitBranch: *customTaskGitBranch,
134+
CustomTaskGitCommit: *customTaskGitCommit,
135+
CustomTaskGitPath: *customTaskGitPath,
136+
TaskNames: taskNames,
137+
AllowOnline: *allowOnline,
125138
}
126139

127140
driver, err := driver.NewDriver(&driverOpt)

config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml

+32
Original file line numberDiff line numberDiff line change
@@ -4757,6 +4757,38 @@ spec:
47574757
type: object
47584758
type: array
47594759
type: object
4760+
customTasks:
4761+
description: CustomTasks is a list of external tasks
4762+
properties:
4763+
source:
4764+
description: Source specifies the source location of custom
4765+
tasks
4766+
properties:
4767+
git:
4768+
description: GitSource specifies the git location of external
4769+
tasks
4770+
properties:
4771+
branch:
4772+
description: Branch specifies the git branch to use
4773+
type: string
4774+
commit:
4775+
description: Commit specifies the git commit to use
4776+
type: string
4777+
path:
4778+
description: Path specifies the path to the task file
4779+
type: string
4780+
url:
4781+
description: URL specifies the git repository URL
4782+
type: string
4783+
type: object
4784+
type: object
4785+
taskNames:
4786+
description: TaskNames specifies the names of the external
4787+
tasks to use
4788+
items:
4789+
type: string
4790+
type: array
4791+
type: object
47604792
taskNames:
47614793
description: TaskNames from lm-eval's task list
47624794
items:

controllers/lmes/driver/driver.go

+101-18
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,29 @@ const (
5151
CustomCardPrefix = "custom"
5252
ShutdownURI = "/Shutdown"
5353
GetStatusURI = "/GetStatus"
54+
DefaultGitBranch = "main"
5455
)
5556

5657
type DriverOption struct {
57-
Context context.Context
58-
OutputPath string
59-
DetectDevice bool
60-
TaskRecipesPath string
61-
TaskRecipes []string
62-
CatalogPath string
63-
CustomCards []string
64-
CustomTemplates []string
65-
CustomSystemPrompt []string
66-
Logger logr.Logger
67-
Args []string
68-
CommPort int
69-
DownloadAssetsS3 bool
58+
Context context.Context
59+
OutputPath string
60+
DetectDevice bool
61+
TaskRecipesPath string
62+
TaskRecipes []string
63+
CatalogPath string
64+
CustomCards []string
65+
CustomTemplates []string
66+
CustomSystemPrompt []string
67+
Logger logr.Logger
68+
Args []string
69+
CommPort int
70+
DownloadAssetsS3 bool
71+
CustomTaskGitURL string
72+
CustomTaskGitBranch string
73+
CustomTaskGitCommit string
74+
CustomTaskGitPath string
75+
TaskNames []string
76+
AllowOnline bool
7077
}
7178

7279
type Driver interface {
@@ -332,6 +339,10 @@ func (d *driverImpl) exec() error {
332339
return fmt.Errorf("failed to create custom cards: %v", err)
333340
}
334341

342+
if err := d.fetchGitCustomTasks(); err != nil {
343+
return fmt.Errorf("failed to set up custom tasks: %v", err)
344+
}
345+
335346
// Copy S3 assets if needed
336347
if err := d.downloadS3Assets(); err != nil {
337348
return err
@@ -377,9 +388,10 @@ func (d *driverImpl) exec() error {
377388
}
378389
executor.Stdout = stdout
379390
executor.Stderr = mwriter
380-
executor.Env = append(os.Environ(),
381-
"UNITXT_ALLOW_UNVERIFIED_CODE=True",
382-
)
391+
392+
env := append(os.Environ(), "UNITXT_ALLOW_UNVERIFIED_CODE=True")
393+
394+
executor.Env = env
383395

384396
var freeRes = func() {
385397
stdin.Close()
@@ -508,7 +520,7 @@ func (d *driverImpl) prepDir4CustomArtifacts() error {
508520
subDirs := []string{"cards", "templates", "system_prompts"}
509521
var errs []error
510522
for _, dir := range subDirs {
511-
errs = append(errs, mkdirIfNotExist(filepath.Join(d.Option.CatalogPath, dir)))
523+
errs = append(errs, createDirectory(filepath.Join(d.Option.CatalogPath, dir)))
512524
}
513525
return errors.Join(errs...)
514526
}
@@ -557,7 +569,7 @@ func (d *driverImpl) createCustomSystemPrompts() error {
557569
return nil
558570
}
559571

560-
func mkdirIfNotExist(path string) error {
572+
func createDirectory(path string) error {
561573
fi, err := os.Stat(path)
562574
if err == nil && !fi.IsDir() {
563575
return fmt.Errorf("%s is a file. can not create a directory", path)
@@ -567,3 +579,74 @@ func mkdirIfNotExist(path string) error {
567579
}
568580
return nil
569581
}
582+
583+
func (d *driverImpl) fetchGitCustomTasks() error {
584+
// No-op if git url not set
585+
if d.Option.CustomTaskGitURL == "" {
586+
return nil
587+
}
588+
589+
// If online is disable, also disable fetching external tasks
590+
if !d.Option.AllowOnline {
591+
return fmt.Errorf("fetching external git tasks is not allowed when allowOnline is false")
592+
}
593+
594+
repositoryDestination := filepath.Join("/tmp", "custom_tasks")
595+
if err := createDirectory(repositoryDestination); err != nil {
596+
return err
597+
}
598+
599+
cloneCommand := exec.Command("git", "clone", d.Option.CustomTaskGitURL, repositoryDestination)
600+
if output, err := cloneCommand.CombinedOutput(); err != nil {
601+
return fmt.Errorf("failed to clone git repository: %v, output: %s", err, string(output))
602+
}
603+
604+
clonedDirectory := fmt.Sprintf("--git-dir=%s", filepath.Join(repositoryDestination, ".git"))
605+
workTree := fmt.Sprintf("--work-tree=%s", repositoryDestination)
606+
607+
// Checkout a specific branch, if specified
608+
if d.Option.CustomTaskGitBranch != "" {
609+
checkoutCommand := exec.Command("git", clonedDirectory, workTree, "checkout", d.Option.CustomTaskGitBranch)
610+
if output, err := checkoutCommand.CombinedOutput(); err != nil {
611+
return fmt.Errorf("failed to checkout branch %s: %v, output: %s",
612+
d.Option.CustomTaskGitBranch, err, string(output))
613+
}
614+
} else {
615+
checkoutCmd := exec.Command("git", clonedDirectory, workTree, "checkout", DefaultGitBranch)
616+
if output, err := checkoutCmd.CombinedOutput(); err != nil {
617+
d.Option.Logger.Info("failed to checkout main branch, using default branch from clone",
618+
"error", err, "output", string(output))
619+
}
620+
}
621+
622+
// Checkout a specific commit, if specified
623+
if d.Option.CustomTaskGitCommit != "" {
624+
checkoutCommand := exec.Command("git", clonedDirectory, workTree, "checkout", d.Option.CustomTaskGitCommit)
625+
if output, err := checkoutCommand.CombinedOutput(); err != nil {
626+
return fmt.Errorf("failed to checkout commit %s: %v, output: %s",
627+
d.Option.CustomTaskGitCommit, err, string(output))
628+
}
629+
}
630+
631+
// Use the specified repository path for copying
632+
taskPath := repositoryDestination
633+
if d.Option.CustomTaskGitPath != "" {
634+
taskPath = filepath.Join(repositoryDestination, d.Option.CustomTaskGitPath)
635+
if _, err := os.Stat(taskPath); os.IsNotExist(err) {
636+
return fmt.Errorf("specified path '%s' does not exist in the repository", d.Option.CustomTaskGitPath)
637+
}
638+
}
639+
640+
// Create destination path for copy
641+
if err := createDirectory(d.Option.TaskRecipesPath); err != nil {
642+
return err
643+
}
644+
645+
copyCmd := exec.Command("cp", "-r", taskPath+"/.", d.Option.TaskRecipesPath)
646+
output, err := copyCmd.CombinedOutput()
647+
if err != nil {
648+
return fmt.Errorf("failed to copy tasks to %s: %v, output: %s", d.Option.TaskRecipesPath, err, string(output))
649+
}
650+
651+
return nil
652+
}

0 commit comments

Comments
 (0)