Skip to content

Commit b2bec12

Browse files
authored
feat: support custom dataset (trustyai-explainability#309)
Updated the CRD data struct to allow users to specify a custom Unitxt card in JSON format. The custom Unitxt card is equivalent to a custom dataset definition. Also restructured and updated the CRD to support Volumes, VolumeMounts, Env, Resources, Labels, and Annotations. Signed-off-by: Yihong Wang <[email protected]>
1 parent 159842f commit b2bec12

File tree

10 files changed

+3903
-252
lines changed

10 files changed

+3903
-252
lines changed

Dockerfile.lmes-job

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ WORKDIR /opt/app-root/src
88
RUN mkdir /opt/app-root/src/hf_home && chmod g+rwx /opt/app-root/src/hf_home
99
RUN mkdir /opt/app-root/src/output && chmod g+rwx /opt/app-root/src/output
1010
RUN mkdir /opt/app-root/src/my_tasks && chmod g+rwx /opt/app-root/src/my_tasks
11+
RUN mkdir -p /opt/app-root/src/my_catalogs/cards && chmod -R g+rwx /opt/app-root/src/my_catalogs
1112
RUN mkdir /opt/app-root/src/.cache
1213
ENV PATH="/opt/app-root/bin:/opt/app-root/src/.local/bin/:/opt/app-root/src/bin:/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
1314

@@ -23,6 +24,7 @@ RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("cla
2324

2425
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
2526
ENV HF_HOME=/opt/app-root/src/hf_home
27+
ENV UNITXT_ARTIFACTORIES=/opt/app-root/src/my_catalogs
2628

2729
CMD ["/opt/app-root/bin/python"]
2830

api/lmes/v1alpha1/lmevaljob_types.go

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,30 +63,23 @@ type Arg struct {
6363
Value string `json:"value,omitempty"`
6464
}
6565

66-
type EnvSecret struct {
67-
// Environment's name
68-
Env string `json:"env"`
69-
// The secret is from a secret object
66+
type Card struct {
67+
// Unitxt card's ID
7068
// +optional
71-
SecretRef *corev1.SecretKeySelector `json:"secretRef,omitempty"`
72-
// The secret is from a plain text
69+
Name string `json:"name,omitempty"`
70+
// A JSON string for a custom unitxt card which contains the custom dataset.
71+
// Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_dataset.html#adding-to-the-catalog
72+
// to compose a custom card, store it as a JSON file, and use the JSON content as the value here.
7373
// +optional
74-
Secret *string `json:"secret,omitempty"`
75-
}
76-
77-
type FileSecret struct {
78-
// The secret object
79-
SecretRef corev1.SecretVolumeSource `json:"secretRef,omitempty"`
80-
// The path to mount the secret
81-
MountPath string `json:"mountPath"`
74+
Custom string `json:"custom,omitempty"`
8275
}
8376

8477
// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
8578
// Find details of the Unitxt Recipe here:
8679
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
8780
type TaskRecipe struct {
8881
// The Unitxt dataset card
89-
Card string `json:"card"`
82+
Card Card `json:"card"`
9083
// The Unitxt template
9184
Template string `json:"template"`
9285
// The Unitxt Task
@@ -118,7 +111,7 @@ type TaskList struct {
118111

119112
func (t *TaskRecipe) String() string {
120113
var b strings.Builder
121-
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card, t.Template))
114+
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card.Name, t.Template))
122115
if t.Task != nil {
123116
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
124117
}
@@ -140,6 +133,76 @@ func (t *TaskRecipe) String() string {
140133
return b.String()
141134
}
142135

136+
type LMEvalContainer struct {
137+
// Define Env information for the main container
138+
// +optional
139+
Env []corev1.EnvVar `json:"env,omitempty"`
140+
// Define the volume mount information
141+
// +optional
142+
VolumeMounts []corev1.VolumeMount `json:"volumeMounts,omitempty"`
143+
// Compute Resources required by this container.
144+
// More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/
145+
// +optional
146+
Resources *corev1.ResourceRequirements `json:"resources,omitempty"`
147+
}
148+
149+
// The following Getter-ish functions avoid nil pointer panic
150+
func (c *LMEvalContainer) GetEnv() []corev1.EnvVar {
151+
if c == nil {
152+
return nil
153+
}
154+
return c.Env
155+
}
156+
157+
func (c *LMEvalContainer) GetVolumMounts() []corev1.VolumeMount {
158+
if c == nil {
159+
return nil
160+
}
161+
return c.VolumeMounts
162+
}
163+
164+
func (c *LMEvalContainer) GetResources() *corev1.ResourceRequirements {
165+
if c == nil {
166+
return nil
167+
}
168+
return c.Resources
169+
}
170+
171+
type LMEvalPodSpec struct {
172+
// Extra container data for the lm-eval container
173+
// +optional
174+
Container *LMEvalContainer `json:"container,omitempty"`
175+
// Specify the volumes information for the lm-eval and sidecar containers
176+
// +optional
177+
Volumes []corev1.Volume `json:"volumes,omitempty"`
178+
// Specify extra containers for the lm-eval job
179+
// FIXME: aggregate the sidecar containers into the pod
180+
// +optional
181+
SideCars []corev1.Container `json:"sideCars,omitempty"`
182+
}
183+
184+
// The following Getter-ish functions avoid nil pointer panic
185+
func (p *LMEvalPodSpec) GetContainer() *LMEvalContainer {
186+
if p == nil {
187+
return nil
188+
}
189+
return p.Container
190+
}
191+
192+
func (p *LMEvalPodSpec) GetVolumes() []corev1.Volume {
193+
if p == nil {
194+
return nil
195+
}
196+
return p.Volumes
197+
}
198+
199+
func (p *LMEvalPodSpec) GetSideCards() []corev1.Container {
200+
if p == nil {
201+
return nil
202+
}
203+
return p.SideCars
204+
}
205+
143206
// LMEvalJobSpec defines the desired state of LMEvalJob
144207
type LMEvalJobSpec struct {
145208
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
@@ -167,14 +230,12 @@ type LMEvalJobSpec struct {
167230
// model, will be saved at per-document granularity
168231
// +optional
169232
LogSamples *bool `json:"logSamples,omitempty"`
170-
// Assign secrets to the environment variables
171-
// +optional
172-
EnvSecrets []EnvSecret `json:"envSecrets,omitempty"`
173-
// Use secrets as files
174-
FileSecrets []FileSecret `json:"fileSecrets,omitempty"`
175233
// Batch size for the evaluation. This is used by the models that run and are loaded
176234
// locally and not apply for the commercial APIs.
177235
BatchSize *int `json:"batchSize,omitempty"`
236+
// Specify extra information for the lm-eval job's pod
237+
// +optional
238+
Pod *LMEvalPodSpec `json:"pod,omitempty"`
178239
}
179240

180241
// LMEvalJobStatus defines the observed state of LMEvalJob

api/lmes/v1alpha1/zz_generated.deepcopy.go

Lines changed: 67 additions & 33 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/lmes_driver/main.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,21 @@ const (
3737
OutputPath = "/opt/app-root/src/output"
3838
)
3939

40-
type taskRecipeArg []string
40+
type strArrayArg []string
4141

42-
func (t *taskRecipeArg) Set(value string) error {
42+
func (t *strArrayArg) Set(value string) error {
4343
*t = append(*t, value)
4444
return nil
4545
}
4646

47-
func (t *taskRecipeArg) String() string {
47+
func (t *strArrayArg) String() string {
4848
// supposedly, use ":" as the separator for task recipe should be safe
4949
return strings.Join(*t, ":")
5050
}
5151

5252
var (
53-
taskRecipes taskRecipeArg
53+
taskRecipes strArrayArg
54+
customCards strArrayArg
5455
copy = flag.String("copy", "", "copy this binary to specified destination path")
5556
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
5657
jobName = flag.String("job-name", "", "Job's name")
@@ -64,6 +65,7 @@ var (
6465

6566
func init() {
6667
flag.Var(&taskRecipes, "task-recipe", "task recipe")
68+
flag.Var(&customCards, "custom-card", "A JSON string represents a custom card")
6769
}
6870

6971
func main() {
@@ -105,6 +107,7 @@ func main() {
105107
DetectDevice: *detectDevice,
106108
Logger: driverLog,
107109
TaskRecipes: taskRecipes,
110+
CustomCards: customCards,
108111
Args: args,
109112
ReportInterval: *reportInterval,
110113
}

cmd/lmes_driver/main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func Test_ArgParsing(t *testing.T) {
6060
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
6161
assert.Equal(t, true, *detectDevice)
6262
assert.Equal(t, time.Second*10, *reportInterval)
63-
assert.Equal(t, taskRecipeArg{
63+
assert.Equal(t, strArrayArg{
6464
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
6565
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
6666
}, taskRecipes)

0 commit comments

Comments
 (0)