diff --git a/internal/cmd/new.go b/internal/cmd/new.go index d8b2d3b48cc..2fa8f76424d 100644 --- a/internal/cmd/new.go +++ b/internal/cmd/new.go @@ -2,6 +2,8 @@ package cmd import ( "fmt" + "io" + "strings" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -23,12 +25,12 @@ func (c *newScriptCmd) flagSet() *pflag.FlagSet { flags := pflag.NewFlagSet("", pflag.ContinueOnError) flags.SortFlags = false flags.BoolVarP(&c.overwriteFiles, "force", "f", false, "overwrite existing files") - flags.StringVar(&c.templateType, "template", "minimal", "template type (choices: minimal, protocol, browser)") + flags.StringVar(&c.templateType, "template", "minimal", "template type (choices: minimal, protocol, browser) or relative/absolute path to a custom template file") //nolint:lll flags.StringVar(&c.projectID, "project-id", "", "specify the Grafana Cloud project ID for the test") return flags } -func (c *newScriptCmd) run(_ *cobra.Command, args []string) error { +func (c *newScriptCmd) run(_ *cobra.Command, args []string) (err error) { target := defaultNewScriptName if len(args) > 0 { target = args[0] @@ -42,27 +44,8 @@ func (c *newScriptCmd) run(_ *cobra.Command, args []string) error { return fmt.Errorf("%s already exists. Use the `--force` flag to overwrite it", target) } - fd, err := c.gs.FS.Create(target) - if err != nil { - return err - } - - var closeErr error - defer func() { - if cerr := fd.Close(); cerr != nil { - if _, err := fmt.Fprintf(c.gs.Stderr, "error closing file: %v\n", cerr); err != nil { - closeErr = fmt.Errorf("error writing error message to stderr: %w", err) - } else { - closeErr = cerr - } - } - }() - - if closeErr != nil { - return closeErr - } - - tm, err := templates.NewTemplateManager() + // Initialize template manager and validate template before creating any files + tm, err := templates.NewTemplateManager(c.gs.FS) if err != nil { return fmt.Errorf("error initializing template manager: %w", err) } @@ -72,12 +55,36 @@ func (c *newScriptCmd) run(_ *cobra.Command, args []string) error { return fmt.Errorf("error retrieving template: %w", err) } + // Prepare template arguments argsStruct := templates.TemplateArgs{ ScriptName: target, ProjectID: c.projectID, } - if err := templates.ExecuteTemplate(fd, tmpl, argsStruct); err != nil { + // First render the template to a buffer to validate it + var buf strings.Builder + if err := templates.ExecuteTemplate(&buf, tmpl, argsStruct); err != nil { + return fmt.Errorf("failed to execute template %s: %w", c.templateType, err) + } + + // Only create the file after template rendering succeeds + fd, err := c.gs.FS.Create(target) + if err != nil { + return err + } + + defer func() { + if cerr := fd.Close(); cerr != nil { + if _, werr := fmt.Fprintf(c.gs.Stderr, "error closing file: %v\n", cerr); werr != nil { + err = fmt.Errorf("error writing error message to stderr: %w", werr) + } else { + err = cerr + } + } + }() + + // Write the rendered content to the file + if _, err := io.WriteString(fd, buf.String()); err != nil { return err } diff --git a/internal/cmd/new_test.go b/internal/cmd/new_test.go index 8a669037e98..c7f718afcc1 100644 --- a/internal/cmd/new_test.go +++ b/internal/cmd/new_test.go @@ -1,6 +1,7 @@ package cmd import ( + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -99,11 +100,15 @@ func TestNewScriptCmd_InvalidTemplateType(t *testing.T) { ts := tests.NewGlobalTestState(t) ts.CmdArgs = []string{"k6", "new", "--template", "invalid-template"} - ts.ExpectedExitCode = -1 newRootCommand(ts.GlobalState).execute() assert.Contains(t, ts.Stderr.String(), "invalid template type") + + // Verify that no script file was created + exists, err := fsext.Exists(ts.FS, defaultNewScriptName) + require.NoError(t, err) + assert.False(t, exists, "script file should not exist") } func TestNewScriptCmd_ProjectID(t *testing.T) { @@ -119,3 +124,101 @@ func TestNewScriptCmd_ProjectID(t *testing.T) { assert.Contains(t, string(data), "projectID: 1422") } + +func TestNewScriptCmd_LocalTemplate(t *testing.T) { + t.Parallel() + + ts := tests.NewGlobalTestState(t) + + // Create template file in test temp directory + templatePath := filepath.Join(t.TempDir(), "template.js") + templateContent := `export default function() { + console.log("Hello, world!"); +}` + require.NoError(t, fsext.WriteFile(ts.FS, templatePath, []byte(templateContent), 0o600)) + + ts.CmdArgs = []string{"k6", "new", "--template", templatePath} + + newRootCommand(ts.GlobalState).execute() + + data, err := fsext.ReadFile(ts.FS, defaultNewScriptName) + require.NoError(t, err) + + assert.Equal(t, templateContent, string(data), "generated file should match the template content") +} + +func TestNewScriptCmd_LocalTemplateWith_ProjectID(t *testing.T) { + t.Parallel() + + ts := tests.NewGlobalTestState(t) + + // Create template file in test temp directory + templatePath := filepath.Join(t.TempDir(), "template.js") + templateContent := `export default function() { + // Template with {{ .ProjectID }} project ID + console.log("Hello from project {{ .ProjectID }}"); +}` + require.NoError(t, fsext.WriteFile(ts.FS, templatePath, []byte(templateContent), 0o600)) + + ts.CmdArgs = []string{"k6", "new", "--template", templatePath, "--project-id", "9876"} + + newRootCommand(ts.GlobalState).execute() + + data, err := fsext.ReadFile(ts.FS, defaultNewScriptName) + require.NoError(t, err) + + expectedContent := `export default function() { + // Template with 9876 project ID + console.log("Hello from project 9876"); +}` + assert.Equal(t, expectedContent, string(data), "generated file should have project ID interpolated") +} + +func TestNewScriptCmd_LocalTemplate_NonExistentFile(t *testing.T) { + t.Parallel() + + ts := tests.NewGlobalTestState(t) + ts.ExpectedExitCode = -1 + + // Use a path that we know doesn't exist in the temp directory + nonExistentPath := filepath.Join(t.TempDir(), "nonexistent.js") + + ts.CmdArgs = []string{"k6", "new", "--template", nonExistentPath} + ts.ExpectedExitCode = -1 + + newRootCommand(ts.GlobalState).execute() + + assert.Contains(t, ts.Stderr.String(), "failed to read template file") + + // Verify that no script file was created + exists, err := fsext.Exists(ts.FS, defaultNewScriptName) + require.NoError(t, err) + assert.False(t, exists, "script file should not exist") +} + +func TestNewScriptCmd_LocalTemplate_SyntaxError(t *testing.T) { + t.Parallel() + + ts := tests.NewGlobalTestState(t) + ts.ExpectedExitCode = -1 + + // Create template file with invalid content in test temp directory + templatePath := filepath.Join(t.TempDir(), "template.js") + invalidTemplateContent := `export default function() { + // Invalid template with {{ .InvalidField }} field + console.log("This will cause an error"); +}` + require.NoError(t, fsext.WriteFile(ts.FS, templatePath, []byte(invalidTemplateContent), 0o600)) + + ts.CmdArgs = []string{"k6", "new", "--template", templatePath, "--project-id", "9876"} + ts.ExpectedExitCode = -1 + + newRootCommand(ts.GlobalState).execute() + + assert.Contains(t, ts.Stderr.String(), "failed to execute template") + + // Verify that no script file was created + exists, err := fsext.Exists(ts.FS, defaultNewScriptName) + require.NoError(t, err) + assert.False(t, exists, "script file should not exist") +} diff --git a/internal/cmd/templates/templates.go b/internal/cmd/templates/templates.go index 55fc22b54ce..1d2783933ea 100644 --- a/internal/cmd/templates/templates.go +++ b/internal/cmd/templates/templates.go @@ -5,7 +5,11 @@ import ( _ "embed" "fmt" "io" + "path/filepath" + "strings" "text/template" + + "go.k6.io/k6/lib/fsext" ) //go:embed minimal.js @@ -18,6 +22,7 @@ var protocolTemplateContent string var browserTemplateContent string // Constants for template types +// Template names should not contain path separators to not to be confused with file paths const ( MinimalTemplate = "minimal" ProtocolTemplate = "protocol" @@ -29,10 +34,11 @@ type TemplateManager struct { minimalTemplate *template.Template protocolTemplate *template.Template browserTemplate *template.Template + fs fsext.Fs } // NewTemplateManager initializes a new TemplateManager with parsed templates -func NewTemplateManager() (*TemplateManager, error) { +func NewTemplateManager(fs fsext.Fs) (*TemplateManager, error) { minimalTmpl, err := template.New(MinimalTemplate).Parse(minimalTemplateContent) if err != nil { return nil, fmt.Errorf("failed to parse minimal template: %w", err) @@ -52,21 +58,56 @@ func NewTemplateManager() (*TemplateManager, error) { minimalTemplate: minimalTmpl, protocolTemplate: protocolTmpl, browserTemplate: browserTmpl, + fs: fs, }, nil } // GetTemplate selects the appropriate template based on the type -func (tm *TemplateManager) GetTemplate(templateType string) (*template.Template, error) { - switch templateType { +func (tm *TemplateManager) GetTemplate(tpl string) (*template.Template, error) { + // First check built-in templates + switch tpl { case MinimalTemplate: return tm.minimalTemplate, nil case ProtocolTemplate: return tm.protocolTemplate, nil case BrowserTemplate: return tm.browserTemplate, nil - default: - return nil, fmt.Errorf("invalid template type: %s", templateType) } + + // Then check if it's a file path + if isFilePath(tpl) { + tplPath, err := filepath.Abs(tpl) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for template %s: %w", tpl, err) + } + + // Read the template content using the provided filesystem + content, err := fsext.ReadFile(tm.fs, tplPath) + if err != nil { + return nil, fmt.Errorf("failed to read template file %s: %w", tpl, err) + } + + tmpl, err := template.New(filepath.Base(tplPath)).Parse(string(content)) + if err != nil { + return nil, fmt.Errorf("failed to parse template file %s: %w", tpl, err) + } + + return tmpl, nil + } + + // Check if there's a file with this name in current directory + exists, err := fsext.Exists(tm.fs, fsext.JoinFilePath(".", tpl)) + if err == nil && exists { + return nil, fmt.Errorf("invalid template type %q, did you mean ./%s?", tpl, tpl) + } + + return nil, fmt.Errorf("invalid template type %q", tpl) +} + +// isFilePath checks if the given string looks like a file path by detecting path separators +// We assume that built-in template names don't contain path separators +func isFilePath(path string) bool { + return strings.ContainsRune(path, filepath.Separator) || strings.ContainsRune(path, '/') } // TemplateArgs represents arguments passed to templates