Skip to content

Commit 7aa5e82

Browse files
committed
refactor, Allow the "write-tests" test framework to be set within the task so it may be overwritten
1 parent 08d55eb commit 7aa5e82

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

evaluate/task/write-test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ var _ evaltask.Task = (*WriteTests)(nil)
2424
type ArgumentsWriteTest struct {
2525
// Template holds the template data to base the tests onto.
2626
Template string
27+
// TestFramework holds the test framework to use.
28+
TestFramework string
2729
}
2830

2931
// Identifier returns the write test task identifier.
@@ -71,7 +73,9 @@ func (t *WriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[evaltas
7173
ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
7274
}
7375

74-
modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, &ArgumentsWriteTest{})
76+
modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, &ArgumentsWriteTest{
77+
TestFramework: ctx.Language.TestFramework(),
78+
})
7579
problems = append(problems, ps...)
7680
if err != nil {
7781
return nil, problems, err
@@ -105,7 +109,8 @@ func (t *WriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[evaltas
105109
}
106110

107111
modelTemplateAssessmentFile, templateWithSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, &ArgumentsWriteTest{
108-
Template: string(testTemplate),
112+
Template: string(testTemplate),
113+
TestFramework: ctx.Language.TestFramework(),
109114
})
110115
problems = append(problems, ps...)
111116
if err != nil {

model/llm/llm.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ type llmWriteTestSourceFilePromptContext struct {
8383

8484
// Template holds the template data to base the tests onto.
8585
Template string
86+
// TestFramework holds the test framework to use.
87+
TestFramework string
8688
}
8789

8890
// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
8991
var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
90-
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with $testFramework := .Language.TestFramework }} with {{ $testFramework }} as a test framework{{ end }}.
92+
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with .TestFramework }} with {{ . }} as a test framework{{ end }}.
9193
The tests should produce 100 percent code coverage and must compile.
9294
The response must contain only the test code in a fenced code block and nothing else.
9395
@@ -211,7 +213,6 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
211213
if !ok {
212214
return nil, pkgerrors.Errorf("unexpected type %T", ctx.Arguments)
213215
}
214-
templateContent := arguments.Template
215216

216217
data, err := os.ReadFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath))
217218
if err != nil {
@@ -230,7 +231,8 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
230231
ImportPath: importPath,
231232
},
232233

233-
Template: templateContent,
234+
Template: arguments.Template,
235+
TestFramework: arguments.TestFramework,
234236
}).Format()
235237
if err != nil {
236238
return nil, err

0 commit comments

Comments
 (0)