Skip to content

Commit e6f8682

Browse files
author
Andreas Humenberger
committed
Let the "write-tests" prompt ask to extend the given source file in case tests should be next to the source
1 parent 5d5eb2a commit e6f8682

File tree

3 files changed

+27
-35
lines changed

3 files changed

+27
-35
lines changed

evaluate/task/test-integration/task_test.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ func TestWriteTestsRun(t *testing.T) {
109109

110110
Setup: func(t *testing.T) {
111111
var query any = bytesutil.StringTrimIndentations(`
112-
Given the following Rust code file "src/plain.rs", provide tests for this code.
112+
Given the following Rust code file "src/plain.rs", extend the code to include tests.
113113
The tests should produce 100 percent code coverage and must compile.
114-
The response must contain only the test code in a fenced code block and nothing else.
114+
The response must contain the original source code and the tests in a fenced code block and nothing else.
115115
116116
` + "```" + `rust
117117
pub fn plain() {
@@ -127,6 +127,10 @@ func TestWriteTestsRun(t *testing.T) {
127127
&provider.QueryResult{
128128
Message: bytesutil.StringTrimIndentations(`
129129
` + "```rust`" + `
130+
pub fn plain() {
131+
// This does not do anything but it gives us a line to cover.
132+
}
133+
130134
#[cfg(test)]
131135
mod tests {
132136
use super::*;
@@ -151,8 +155,8 @@ func TestWriteTestsRun(t *testing.T) {
151155
ExpectedRepositoryAssessment: map[string]map[evaltask.Identifier]metrics.Assessments{
152156
filepath.Join("src", "plain.rs"): {
153157
evaluatetask.IdentifierWriteTests: metrics.Assessments{
154-
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
155-
metrics.AssessmentKeyResponseCharacterCount: 98,
158+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 167,
159+
metrics.AssessmentKeyResponseCharacterCount: 181,
156160
metrics.AssessmentKeyCoverage: 3,
157161
metrics.AssessmentKeyFilesExecuted: 1,
158162
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
@@ -161,8 +165,8 @@ func TestWriteTestsRun(t *testing.T) {
161165
metrics.AssessmentKeyResponseWithCode: 1,
162166
},
163167
evaluatetask.IdentifierWriteTestsSymflowerFix: metrics.Assessments{
164-
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
165-
metrics.AssessmentKeyResponseCharacterCount: 98,
168+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 167,
169+
metrics.AssessmentKeyResponseCharacterCount: 181,
166170
metrics.AssessmentKeyCoverage: 3,
167171
metrics.AssessmentKeyFilesExecuted: 1,
168172
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
@@ -171,8 +175,8 @@ func TestWriteTestsRun(t *testing.T) {
171175
metrics.AssessmentKeyResponseWithCode: 1,
172176
},
173177
evaluatetask.IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
174-
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
175-
metrics.AssessmentKeyResponseCharacterCount: 98,
178+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 167,
179+
metrics.AssessmentKeyResponseCharacterCount: 181,
176180
metrics.AssessmentKeyCoverage: 3,
177181
metrics.AssessmentKeyFilesExecuted: 1,
178182
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
@@ -181,8 +185,8 @@ func TestWriteTestsRun(t *testing.T) {
181185
metrics.AssessmentKeyResponseWithCode: 1,
182186
},
183187
evaluatetask.IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
184-
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 84,
185-
metrics.AssessmentKeyResponseCharacterCount: 98,
188+
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 167,
189+
metrics.AssessmentKeyResponseCharacterCount: 181,
186190
metrics.AssessmentKeyCoverage: 3,
187191
metrics.AssessmentKeyFilesExecuted: 1,
188192
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,

model/llm/llm.go

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package llm
22

33
import (
44
"context"
5-
"errors"
65
"os"
76
"path/filepath"
87
"strings"
@@ -148,9 +147,13 @@ type llmWriteTestSourceFilePromptContext struct {
148147

149148
// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
150149
var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
151-
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" {{- with .ImportPath }} with package "{{ . }}" {{- end }}, provide {{- if .HasTestsInSource }} tests {{ else }} a test file {{ end -}} for this code{{ with .TestFramework }} with {{ . }} as a test framework{{ end }}.
150+
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" {{- with .ImportPath }} with package "{{ . }}" {{- end }}, {{- if .HasTestsInSource }} extend the code to include tests{{ else }} provide a test file for this code{{ with .TestFramework }} with {{ . }} as a test framework{{ end }}{{ end -}}.
152151
The tests should produce 100 percent code coverage and must compile.
152+
{{- if .HasTestsInSource }}
153+
The response must contain the original source code and the tests in a fenced code block and nothing else.
154+
{{- else }}
153155
The response must contain only the test code in a fenced code block and nothing else.
156+
{{- end }}
154157
155158
` + "```" + `{{ .Language.ID }}
156159
{{ .Code }}
@@ -337,7 +340,7 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
337340
filePath = filepath.Join(ctx.RepositoryPath, ctx.Language.TestFilePath(ctx.RepositoryPath, ctx.FilePath))
338341
}
339342

340-
return handleQueryResult(queryResult, filePath, ctx.HasTestsInSource)
343+
return handleQueryResult(queryResult, filePath)
341344
}
342345

343346
func (m *Model) query(logger *log.Logger, request string) (queryResult *provider.QueryResult, err error) {
@@ -422,7 +425,7 @@ func (m *Model) RepairCode(ctx model.Context) (assessment metrics.Assessments, e
422425
return nil, pkgerrors.WithStack(err)
423426
}
424427

425-
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath), false)
428+
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath))
426429
}
427430

428431
var _ model.CapabilityTranspile = (*Model)(nil)
@@ -469,7 +472,7 @@ func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, er
469472
return nil, pkgerrors.WithStack(err)
470473
}
471474

472-
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath), false)
475+
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath))
473476
}
474477

475478
var _ model.CapabilityMigrate = (*Model)(nil)
@@ -509,10 +512,10 @@ func (m *Model) Migrate(ctx model.Context) (assessment metrics.Assessments, err
509512
return nil, pkgerrors.WithStack(err)
510513
}
511514

512-
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath), false)
515+
return handleQueryResult(queryResult, filepath.Join(ctx.RepositoryPath, ctx.FilePath))
513516
}
514517

515-
func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute string, appendFile bool) (assessment metrics.Assessments, err error) {
518+
func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute string) (assessment metrics.Assessments, err error) {
516519
assessment, sourceFileContent, err := prompt.ParseResponse(queryResult.Message)
517520
if err != nil {
518521
return nil, pkgerrors.WithStack(err)
@@ -536,22 +539,7 @@ func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute strin
536539
return nil, pkgerrors.WithStack(err)
537540
}
538541

539-
flags := os.O_WRONLY | os.O_CREATE
540-
if appendFile {
541-
flags = flags | os.O_APPEND
542-
} else {
543-
flags = flags | os.O_TRUNC
544-
}
545-
file, err := os.OpenFile(filePathAbsolute, flags, 0644)
546-
if err != nil {
547-
return nil, pkgerrors.WithStack(err)
548-
}
549-
defer func() {
550-
if closeErr := file.Close(); closeErr != nil {
551-
err = errors.Join(err, pkgerrors.WithStack(closeErr))
552-
}
553-
}()
554-
if _, err := file.WriteString(sourceFileContent); err != nil {
542+
if err := os.WriteFile(filePathAbsolute, []byte(sourceFileContent), 0644); err != nil {
555543
return nil, pkgerrors.WithStack(err)
556544
}
557545

model/llm/llm_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,9 +708,9 @@ func TestFormatPromptContext(t *testing.T) {
708708
},
709709

710710
ExpectedMessage: bytesutil.StringTrimIndentations(`
711-
Given the following Rust code file "path/to/main.rs", provide tests for this code.
711+
Given the following Rust code file "path/to/main.rs", extend the code to include tests.
712712
The tests should produce 100 percent code coverage and must compile.
713-
The response must contain only the test code in a fenced code block and nothing else.
713+
The response must contain the original source code and the tests in a fenced code block and nothing else.
714714
715715
` + "```" + `rust
716716
fn main() {

0 commit comments

Comments
 (0)