@@ -51,22 +51,29 @@ const (
51
51
CustomCardPrefix = "custom"
52
52
ShutdownURI = "/Shutdown"
53
53
GetStatusURI = "/GetStatus"
54
+ DefaultGitBranch = "main"
54
55
)
55
56
56
57
type DriverOption struct {
57
- Context context.Context
58
- OutputPath string
59
- DetectDevice bool
60
- TaskRecipesPath string
61
- TaskRecipes []string
62
- CatalogPath string
63
- CustomCards []string
64
- CustomTemplates []string
65
- CustomSystemPrompt []string
66
- Logger logr.Logger
67
- Args []string
68
- CommPort int
69
- DownloadAssetsS3 bool
58
+ Context context.Context
59
+ OutputPath string
60
+ DetectDevice bool
61
+ TaskRecipesPath string
62
+ TaskRecipes []string
63
+ CatalogPath string
64
+ CustomCards []string
65
+ CustomTemplates []string
66
+ CustomSystemPrompt []string
67
+ Logger logr.Logger
68
+ Args []string
69
+ CommPort int
70
+ DownloadAssetsS3 bool
71
+ CustomTaskGitURL string
72
+ CustomTaskGitBranch string
73
+ CustomTaskGitCommit string
74
+ CustomTaskGitPath string
75
+ TaskNames []string
76
+ AllowOnline bool
70
77
}
71
78
72
79
type Driver interface {
@@ -332,6 +339,10 @@ func (d *driverImpl) exec() error {
332
339
return fmt .Errorf ("failed to create custom cards: %v" , err )
333
340
}
334
341
342
+ if err := d .fetchGitCustomTasks (); err != nil {
343
+ return fmt .Errorf ("failed to set up custom tasks: %v" , err )
344
+ }
345
+
335
346
// Copy S3 assets if needed
336
347
if err := d .downloadS3Assets (); err != nil {
337
348
return err
@@ -377,9 +388,10 @@ func (d *driverImpl) exec() error {
377
388
}
378
389
executor .Stdout = stdout
379
390
executor .Stderr = mwriter
380
- executor .Env = append (os .Environ (),
381
- "UNITXT_ALLOW_UNVERIFIED_CODE=True" ,
382
- )
391
+
392
+ env := append (os .Environ (), "UNITXT_ALLOW_UNVERIFIED_CODE=True" )
393
+
394
+ executor .Env = env
383
395
384
396
var freeRes = func () {
385
397
stdin .Close ()
@@ -508,7 +520,7 @@ func (d *driverImpl) prepDir4CustomArtifacts() error {
508
520
subDirs := []string {"cards" , "templates" , "system_prompts" }
509
521
var errs []error
510
522
for _ , dir := range subDirs {
511
- errs = append (errs , mkdirIfNotExist (filepath .Join (d .Option .CatalogPath , dir )))
523
+ errs = append (errs , createDirectory (filepath .Join (d .Option .CatalogPath , dir )))
512
524
}
513
525
return errors .Join (errs ... )
514
526
}
@@ -557,7 +569,7 @@ func (d *driverImpl) createCustomSystemPrompts() error {
557
569
return nil
558
570
}
559
571
560
- func mkdirIfNotExist (path string ) error {
572
+ func createDirectory (path string ) error {
561
573
fi , err := os .Stat (path )
562
574
if err == nil && ! fi .IsDir () {
563
575
return fmt .Errorf ("%s is a file. can not create a directory" , path )
@@ -567,3 +579,74 @@ func mkdirIfNotExist(path string) error {
567
579
}
568
580
return nil
569
581
}
582
+
583
+ func (d * driverImpl ) fetchGitCustomTasks () error {
584
+ // No-op if git url not set
585
+ if d .Option .CustomTaskGitURL == "" {
586
+ return nil
587
+ }
588
+
589
+ // If online is disable, also disable fetching external tasks
590
+ if ! d .Option .AllowOnline {
591
+ return fmt .Errorf ("fetching external git tasks is not allowed when allowOnline is false" )
592
+ }
593
+
594
+ repositoryDestination := filepath .Join ("/tmp" , "custom_tasks" )
595
+ if err := createDirectory (repositoryDestination ); err != nil {
596
+ return err
597
+ }
598
+
599
+ cloneCommand := exec .Command ("git" , "clone" , d .Option .CustomTaskGitURL , repositoryDestination )
600
+ if output , err := cloneCommand .CombinedOutput (); err != nil {
601
+ return fmt .Errorf ("failed to clone git repository: %v, output: %s" , err , string (output ))
602
+ }
603
+
604
+ clonedDirectory := fmt .Sprintf ("--git-dir=%s" , filepath .Join (repositoryDestination , ".git" ))
605
+ workTree := fmt .Sprintf ("--work-tree=%s" , repositoryDestination )
606
+
607
+ // Checkout a specific branch, if specified
608
+ if d .Option .CustomTaskGitBranch != "" {
609
+ checkoutCommand := exec .Command ("git" , clonedDirectory , workTree , "checkout" , d .Option .CustomTaskGitBranch )
610
+ if output , err := checkoutCommand .CombinedOutput (); err != nil {
611
+ return fmt .Errorf ("failed to checkout branch %s: %v, output: %s" ,
612
+ d .Option .CustomTaskGitBranch , err , string (output ))
613
+ }
614
+ } else {
615
+ checkoutCmd := exec .Command ("git" , clonedDirectory , workTree , "checkout" , DefaultGitBranch )
616
+ if output , err := checkoutCmd .CombinedOutput (); err != nil {
617
+ d .Option .Logger .Info ("failed to checkout main branch, using default branch from clone" ,
618
+ "error" , err , "output" , string (output ))
619
+ }
620
+ }
621
+
622
+ // Checkout a specific commit, if specified
623
+ if d .Option .CustomTaskGitCommit != "" {
624
+ checkoutCommand := exec .Command ("git" , clonedDirectory , workTree , "checkout" , d .Option .CustomTaskGitCommit )
625
+ if output , err := checkoutCommand .CombinedOutput (); err != nil {
626
+ return fmt .Errorf ("failed to checkout commit %s: %v, output: %s" ,
627
+ d .Option .CustomTaskGitCommit , err , string (output ))
628
+ }
629
+ }
630
+
631
+ // Use the specified repository path for copying
632
+ taskPath := repositoryDestination
633
+ if d .Option .CustomTaskGitPath != "" {
634
+ taskPath = filepath .Join (repositoryDestination , d .Option .CustomTaskGitPath )
635
+ if _ , err := os .Stat (taskPath ); os .IsNotExist (err ) {
636
+ return fmt .Errorf ("specified path '%s' does not exist in the repository" , d .Option .CustomTaskGitPath )
637
+ }
638
+ }
639
+
640
+ // Create destination path for copy
641
+ if err := createDirectory (d .Option .TaskRecipesPath ); err != nil {
642
+ return err
643
+ }
644
+
645
+ copyCmd := exec .Command ("cp" , "-r" , taskPath + "/." , d .Option .TaskRecipesPath )
646
+ output , err := copyCmd .CombinedOutput ()
647
+ if err != nil {
648
+ return fmt .Errorf ("failed to copy tasks to %s: %v, output: %s" , d .Option .TaskRecipesPath , err , string (output ))
649
+ }
650
+
651
+ return nil
652
+ }
0 commit comments