Skip to content

Commit cbb00e0

Browse files
authored
Add —config to allow the user to specify the config (#2291)
* Add —config to allow the user to specify the config * —config allows the user to tell cog which file to look at * This is useful when they have multiple configs for different environments * Fix lint * Inject configFilename into migrator * Change —config to be -f to conform to docker --------- Signed-off-by: Will Sackfield <[email protected]>
1 parent 497a8d6 commit cbb00e0

14 files changed

+46
-36
lines changed

pkg/cli/build.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var buildStrip bool
2929
var buildPrecompile bool
3030
var buildFast bool
3131
var buildLocalImage bool
32+
var configFilename string
3233

3334
const useCogBaseImageFlagKey = "use-cog-base-image"
3435

@@ -53,6 +54,7 @@ func newBuildCommand() *cobra.Command {
5354
addPrecompileFlag(cmd)
5455
addFastFlag(cmd)
5556
addLocalImage(cmd)
57+
addConfigFlag(cmd)
5658
cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'")
5759
return cmd
5860
}
@@ -68,7 +70,7 @@ func buildCommand(cmd *cobra.Command, args []string) error {
6870
logClient := coglog.NewClient(client)
6971
logCtx := logClient.StartBuild(buildFast, buildLocalImage)
7072

71-
cfg, projectDir, err := config.GetConfig()
73+
cfg, projectDir, err := config.GetConfig(configFilename)
7274
if err != nil {
7375
logClient.EndBuild(ctx, err, logCtx)
7476
return err
@@ -172,6 +174,11 @@ func addLocalImage(cmd *cobra.Command) {
172174
_ = cmd.Flags().MarkHidden(localImage)
173175
}
174176

177+
func addConfigFlag(cmd *cobra.Command) {
178+
const configFlag = "f"
179+
cmd.Flags().StringVar(&configFilename, configFlag, "cog.yaml", "The name of the config file.")
180+
}
181+
175182
func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error {
176183
flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"}
177184
var flagsSet []string

pkg/cli/debug.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"github.com/replicate/cog/pkg/config"
99
"github.com/replicate/cog/pkg/docker"
1010
"github.com/replicate/cog/pkg/dockerfile"
11-
"github.com/replicate/cog/pkg/global"
1211
"github.com/replicate/cog/pkg/util/console"
1312
)
1413

@@ -18,7 +17,7 @@ func newDebugCommand() *cobra.Command {
1817
cmd := &cobra.Command{
1918
Use: "debug",
2019
Hidden: true,
21-
Short: "Generate a Dockerfile from " + global.ConfigFilename,
20+
Short: "Generate a Dockerfile from cog",
2221
RunE: cmdDockerfile,
2322
}
2423

@@ -29,6 +28,7 @@ func newDebugCommand() *cobra.Command {
2928
addBuildTimestampFlag(cmd)
3029
addFastFlag(cmd)
3130
addLocalImage(cmd)
31+
addConfigFlag(cmd)
3232
cmd.Flags().StringVarP(&imageName, "image-name", "", "", "The image name to use for the generated Dockerfile")
3333

3434
return cmd
@@ -37,7 +37,7 @@ func newDebugCommand() *cobra.Command {
3737
func cmdDockerfile(cmd *cobra.Command, args []string) error {
3838
ctx := cmd.Context()
3939

40-
cfg, projectDir, err := config.GetConfig()
40+
cfg, projectDir, err := config.GetConfig(configFilename)
4141
if err != nil {
4242
return err
4343
}

pkg/cli/migrate.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ This will attempt to migrate your cog project to be compatible with fast boots.`
2121
}
2222

2323
addYesFlag(cmd)
24+
addConfigFlag(cmd)
2425

2526
return cmd
2627
}
@@ -31,7 +32,7 @@ func cmdMigrate(cmd *cobra.Command, args []string) error {
3132
if err != nil {
3233
return err
3334
}
34-
err = migrator.Migrate(ctx)
35+
err = migrator.Migrate(ctx, configFilename)
3536
if err != nil {
3637
return err
3738
}

pkg/cli/predict.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ the prediction on that.`,
5858
addSetupTimeoutFlag(cmd)
5959
addFastFlag(cmd)
6060
addLocalImage(cmd)
61+
addConfigFlag(cmd)
6162

6263
cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i [email protected]")
6364
cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path")
@@ -78,7 +79,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
7879
if len(args) == 0 {
7980
// Build image
8081

81-
cfg, projectDir, err := config.GetConfig()
82+
cfg, projectDir, err := config.GetConfig(configFilename)
8283
if err != nil {
8384
return err
8485
}

pkg/cli/push.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ func newPushCommand() *cobra.Command {
3939
addPrecompileFlag(cmd)
4040
addFastFlag(cmd)
4141
addLocalImage(cmd)
42+
addConfigFlag(cmd)
4243

4344
return cmd
4445
}
@@ -54,7 +55,7 @@ func push(cmd *cobra.Command, args []string) error {
5455
logClient := coglog.NewClient(client)
5556
logCtx := logClient.StartPush(buildFast, buildLocalImage)
5657

57-
cfg, projectDir, err := config.GetConfig()
58+
cfg, projectDir, err := config.GetConfig(configFilename)
5859
if err != nil {
5960
logClient.EndPush(ctx, err, logCtx)
6061
return err

pkg/cli/run.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func newRunCommand() *cobra.Command {
3838
addGpusFlag(cmd)
3939
addFastFlag(cmd)
4040
addLocalImage(cmd)
41+
addConfigFlag(cmd)
4142

4243
flags := cmd.Flags()
4344
// Flags after first argument are considered args and passed to command
@@ -54,7 +55,7 @@ func newRunCommand() *cobra.Command {
5455
func run(cmd *cobra.Command, args []string) error {
5556
ctx := cmd.Context()
5657

57-
cfg, projectDir, err := config.GetConfig()
58+
cfg, projectDir, err := config.GetConfig(configFilename)
5859
if err != nil {
5960
return err
6061
}

pkg/cli/serve.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.`
3434
addUseCogBaseImageFlag(cmd)
3535
addGpusFlag(cmd)
3636
addFastFlag(cmd)
37+
addConfigFlag(cmd)
3738

3839
cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen")
3940

@@ -43,7 +44,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.`
4344
func cmdServe(cmd *cobra.Command, arg []string) error {
4445
ctx := cmd.Context()
4546

46-
cfg, projectDir, err := config.GetConfig()
47+
cfg, projectDir, err := config.GetConfig(configFilename)
4748
if err != nil {
4849
return err
4950
}

pkg/cli/train.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Otherwise, it will build the model in the current directory and train it.`,
4444
addGpusFlag(cmd)
4545
addUseCogBaseImageFlag(cmd)
4646
addFastFlag(cmd)
47+
addConfigFlag(cmd)
4748

4849
cmd.Flags().StringArrayVarP(&trainInputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i [email protected]")
4950
cmd.Flags().StringArrayVarP(&trainEnvFlags, "env", "e", []string{}, "Environment variables, in the form name=value")
@@ -61,7 +62,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
6162
volumes := []docker.Volume{}
6263
gpus := gpusFlag
6364

64-
cfg, projectDir, err := config.GetConfig()
65+
cfg, projectDir, err := config.GetConfig(configFilename)
6566
if err != nil {
6667
return err
6768
}

pkg/config/load.go

+13-13
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,30 @@ import (
77
"path/filepath"
88

99
"github.com/replicate/cog/pkg/errors"
10-
"github.com/replicate/cog/pkg/global"
1110
"github.com/replicate/cog/pkg/util/files"
1211
)
1312

1413
const maxSearchDepth = 100
1514

1615
// Returns the project's root directory, or the directory specified by the --project-dir flag
17-
func GetProjectDir() (string, error) {
16+
func GetProjectDir(configFilename string) (string, error) {
1817
cwd, err := os.Getwd()
1918
if err != nil {
2019
return "", err
2120
}
22-
return findProjectRootDir(cwd)
21+
return findProjectRootDir(cwd, configFilename)
2322
}
2423

2524
// Loads and instantiates a Config object
2625
// customDir can be specified to override the default - current working directory
27-
func GetConfig() (*Config, string, error) {
26+
func GetConfig(configFilename string) (*Config, string, error) {
2827
// Find the root project directory
29-
rootDir, err := GetProjectDir()
28+
rootDir, err := GetProjectDir(configFilename)
29+
3030
if err != nil {
3131
return nil, "", err
3232
}
33-
configPath := path.Join(rootDir, global.ConfigFilename)
33+
configPath := path.Join(rootDir, configFilename)
3434

3535
// Then try to load the config file from there
3636
config, err := loadConfigFromFile(configPath)
@@ -51,7 +51,7 @@ func loadConfigFromFile(file string) (*Config, error) {
5151
}
5252

5353
if !exists {
54-
return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, filepath.Dir(file))
54+
return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", filepath.Base(file), filepath.Dir(file))
5555
}
5656

5757
contents, err := os.ReadFile(file)
@@ -69,30 +69,30 @@ func loadConfigFromFile(file string) (*Config, error) {
6969
}
7070

7171
// Given a directory, find the cog config file in that directory
72-
func findConfigPathInDirectory(dir string) (configPath string, err error) {
73-
filePath := path.Join(dir, global.ConfigFilename)
72+
func findConfigPathInDirectory(dir string, configFilename string) (configPath string, err error) {
73+
filePath := path.Join(dir, configFilename)
7474
exists, err := files.Exists(filePath)
7575
if err != nil {
7676
return "", fmt.Errorf("Failed to scan directory %s for %s: %s", dir, filePath, err)
7777
} else if exists {
7878
return filePath, nil
7979
}
8080

81-
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", global.ConfigFilename, dir))
81+
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", configFilename, dir))
8282
}
8383

8484
// Walk up the directory tree to find the root of the project.
8585
// The project root is defined as the directory housing a `cog.yaml` file.
86-
func findProjectRootDir(startDir string) (string, error) {
86+
func findProjectRootDir(startDir string, configFilename string) (string, error) {
8787
dir := startDir
8888
for i := 0; i < maxSearchDepth; i++ {
89-
switch _, err := findConfigPathInDirectory(dir); {
89+
switch _, err := findConfigPathInDirectory(dir, configFilename); {
9090
case err != nil && !errors.IsConfigNotFound(err):
9191
return "", err
9292
case err == nil:
9393
return dir, nil
9494
case dir == "." || dir == "/":
95-
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", global.ConfigFilename, startDir))
95+
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", configFilename, startDir))
9696
}
9797

9898
dir = filepath.Dir(dir)

pkg/config/load_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestFindProjectRootDirShouldFindParentDir(t *testing.T) {
2828
err = os.MkdirAll(subdir, 0o700)
2929
require.NoError(t, err)
3030

31-
foundDir, err := findProjectRootDir(subdir)
31+
foundDir, err := findProjectRootDir(subdir, "cog.yaml")
3232
require.NoError(t, err)
3333
require.Equal(t, foundDir, projectDir)
3434
}
@@ -40,6 +40,6 @@ func TestFindProjectRootDirShouldReturnErrIfNoConfig(t *testing.T) {
4040
err := os.MkdirAll(subdir, 0o700)
4141
require.NoError(t, err)
4242

43-
_, err = findProjectRootDir(subdir)
43+
_, err = findProjectRootDir(subdir, "cog.yaml")
4444
require.Error(t, err)
4545
}

pkg/global/global.go

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ var (
66
BuildTime = "none"
77
Debug = false
88
ProfilingEnabled = false
9-
ConfigFilename = "cog.yaml"
109
ReplicateRegistryHost = "r8.im"
1110
ReplicateWebsiteHost = "replicate.com"
1211
LabelNamespace = "run.cog."

pkg/migrate/migrator.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ package migrate
33
import "context"
44

55
type Migrator interface {
6-
Migrate(ctx context.Context) error
6+
Migrate(ctx context.Context, configFilename string) error
77
}

pkg/migrate/migrator_v1_v1fast.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616

1717
"github.com/replicate/cog/pkg/config"
1818
"github.com/replicate/cog/pkg/dockerfile"
19-
"github.com/replicate/cog/pkg/global"
2019
"github.com/replicate/cog/pkg/requirements"
2120
"github.com/replicate/cog/pkg/util"
2221
"github.com/replicate/cog/pkg/util/console"
@@ -40,8 +39,8 @@ func NewMigratorV1ToV1Fast(interactive bool) *MigratorV1ToV1Fast {
4039
}
4140
}
4241

43-
func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error {
44-
cfg, projectDir, err := config.GetConfig()
42+
func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context, configFilename string) error {
43+
cfg, projectDir, err := config.GetConfig(configFilename)
4544
if err != nil {
4645
return err
4746
}
@@ -57,7 +56,7 @@ func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error {
5756
if err != nil {
5857
return err
5958
}
60-
err = g.flushConfig(cfg, projectDir)
59+
err = g.flushConfig(cfg, projectDir, configFilename)
6160
return err
6261
}
6362

@@ -167,7 +166,7 @@ func (g *MigratorV1ToV1Fast) checkPythonCode(ctx context.Context, cfg *config.Co
167166
return nil
168167
}
169168

170-
func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error {
169+
func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configFilename string) error {
171170
if cfg.Build == nil {
172171
cfg.Build = config.DefaultConfig().Build
173172
}
@@ -182,7 +181,7 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error {
182181
}
183182
configStr := string(data)
184183

185-
configFilepath := filepath.Join(dir, global.ConfigFilename)
184+
configFilepath := filepath.Join(dir, configFilename)
186185
file, err := os.Open(configFilepath)
187186
if err != nil {
188187
return err

pkg/migrate/migrator_v1_v1fast_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88

99
"github.com/stretchr/testify/require"
1010

11-
"github.com/replicate/cog/pkg/global"
1211
"github.com/replicate/cog/pkg/requirements"
1312
)
1413

@@ -25,7 +24,7 @@ func TestMigrate(t *testing.T) {
2524
require.NoError(t, err)
2625

2726
// Write our test configs/code
28-
configFilepath := filepath.Join(dir, global.ConfigFilename)
27+
configFilepath := filepath.Join(dir, "cog.yaml")
2928
file, err := os.Create(configFilepath)
3029
require.NoError(t, err)
3130
_, err = file.WriteString(`build:
@@ -56,7 +55,7 @@ class Predictor(BasePredictor):
5655

5756
// Perform the migration
5857
migrator := NewMigratorV1ToV1Fast(false)
59-
err = migrator.Migrate(t.Context())
58+
err = migrator.Migrate(t.Context(), "cog.yaml")
6059
require.NoError(t, err)
6160

6261
// Check config output

0 commit comments

Comments
 (0)