Skip to content

Commit

Permalink
Include file name and import path for LLM prompt for better results o…
Browse files Browse the repository at this point in the history
…f LLMs
  • Loading branch information
zimmski committed Apr 3, 2024
1 parent ca66931 commit 62ab016
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 39 deletions.
56 changes: 36 additions & 20 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"text/template"

pkgerrors "github.com/pkg/errors"
"github.com/zimmski/osutil/bytesutil"

"github.com/symflower/eval-symflower-codegen-testing/model"
Expand All @@ -31,27 +32,39 @@ func NewLLMModel(provider provider.QueryProvider, modelIdentifier string) model.
}
}

// llmGenerateTestForFilePrompt is the prompt used to query LLMs for test generation.
var llmGenerateTestForFilePrompt = bytesutil.StringTrimIndentations(`
Given the following Go code file, provide a test file for this code.
The tests should produce 100 percent code coverage and must compile.
The response must contain only the test code and nothing else.
`)

// llmGenerateTestForFilePromptContext is the context for for the prompt template.
// llmGenerateTestForFilePromptContext is the context for template for generating an LLM test generation prompt.
type llmGenerateTestForFilePromptContext struct {
Prompt string
Code string
// Code holds the source code of the file.
Code string
// FilePath holds the file path of the file.
FilePath string
// ImportPath holds the import path of the file.
ImportPath string
}

// llmGenerateTestForFilePromptData is the template for generating an LLM test generation prompt.
var llmGenerateTestForFilePromptData = template.Must(template.New("templateGenerateTestPrompt").Parse(bytesutil.StringTrimIndentations(`
{{ .Prompt }}
// llmGenerateTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm-generate-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
Given the following Go code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code.
The tests should produce 100 percent code coverage and must compile.
The response must contain only the test code and nothing else.
` + "```" + `
{{ .Code }}
` + "```" + `
`)))

// llmGenerateTestForFilePrompt returns the prompt for generating an LLM test generation.
func llmGenerateTestForFilePrompt(data *llmGenerateTestForFilePromptContext) (message string, err error) {
data.Code = strings.TrimSpace(data.Code)

var b strings.Builder
if err := llmGenerateTestForFilePromptTemplate.Execute(&b, data); err != nil {
return "", pkgerrors.WithStack(err)
}

return b.String(), nil
}

var _ model.Model = (*llm)(nil)

// ID returns the unique ID of this model.
Expand All @@ -67,19 +80,22 @@ func (m *llm) GenerateTestsForFile(repositoryPath string, filePath string) (err
}
fileContent := strings.TrimSpace(string(data))

var promptBuilder strings.Builder
if err = llmGenerateTestForFilePromptData.Execute(&promptBuilder, llmGenerateTestForFilePromptContext{
Prompt: llmGenerateTestForFilePrompt,
Code: fileContent,
}); err != nil {
importPath := filepath.Join(filepath.Base(repositoryPath), filepath.Dir(filePath))

message, err := llmGenerateTestForFilePrompt(&llmGenerateTestForFilePromptContext{
Code: fileContent,
FilePath: filePath,
ImportPath: importPath,
})
if err != nil {
return err
}

response, err := m.provider.Query(context.Background(), m.model, promptBuilder.String())
response, err := m.provider.Query(context.Background(), m.model, message)
if err != nil {
return err
}
log.Printf("Model %q responded to query %q with: %q", m.ID(), promptBuilder.String(), response)
log.Printf("Model %q responded to query %q with: %q", m.ID(), message, response)

testContent := prompt.ParseResponse(response)

Expand Down
46 changes: 27 additions & 19 deletions model/llm/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,49 +30,57 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
tempDir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(tempDir, tc.SourceFilePath), []byte(bytesutil.StringTrimIndentations(tc.SourceFileContent)), 0644))
temporaryPath := t.TempDir()
temporaryPath = filepath.Join(temporaryPath, "native")
require.NoError(t, os.Mkdir(temporaryPath, 0755))

require.NoError(t, os.WriteFile(filepath.Join(temporaryPath, tc.SourceFilePath), []byte(bytesutil.StringTrimIndentations(tc.SourceFileContent)), 0644))

mock := &providertesting.MockQueryProvider{}
tc.SetupMock(mock)
llm := NewLLMModel(mock, tc.ModelID)

assert.NoError(t, llm.GenerateTestsForFile(tempDir, tc.SourceFilePath))
assert.NoError(t, llm.GenerateTestsForFile(temporaryPath, tc.SourceFilePath))

actualTestFileContent, err := os.ReadFile(filepath.Join(tempDir, tc.ExpectedTestFilePath))
actualTestFileContent, err := os.ReadFile(filepath.Join(temporaryPath, tc.ExpectedTestFilePath))
assert.NoError(t, err)

assert.Equal(t, strings.TrimSpace(bytesutil.StringTrimIndentations(tc.ExpectedTestFileContent)), string(actualTestFileContent))
})
}

sourceFileContent := `
package native
func main() {}
`
sourceFilePath := "simple.go"
promptMessage, err := llmGenerateTestForFilePrompt(&llmGenerateTestForFilePromptContext{
Code: bytesutil.StringTrimIndentations(sourceFileContent),
FilePath: sourceFilePath,
ImportPath: "native",
})
require.NoError(t, err)
validate(t, &testCase{
Name: "Simple",

SetupMock: func(mockedProvider *providertesting.MockQueryProvider) {
mockedProvider.On("Query", mock.Anything, "model-id",
bytesutil.StringTrimIndentations(`
Given the following Go code file, provide a test file for this code.
The tests should produce 100 percent code coverage and must compile.
The response must contain only the test code and nothing else.
`+"```"+`
func main() {}
`+"```"+`
`)).Return(bytesutil.StringTrimIndentations(`
mockedProvider.On("Query", mock.Anything, "model-id", promptMessage).Return(bytesutil.StringTrimIndentations(`
`+"```"+`
package native
func TestMain() {}
`+"```"+`
`), nil)
},

SourceFileContent: `
func main() {}
`,
SourceFilePath: "simple.go",
ModelID: "model-id",
SourceFileContent: sourceFileContent,
SourceFilePath: sourceFilePath,
ModelID: "model-id",

ExpectedTestFileContent: `
package native
func TestMain() {}
`,
ExpectedTestFilePath: "simple_test.go",
Expand Down

0 comments on commit 62ab016

Please sign in to comment.