4
4
"context"
5
5
"fmt"
6
6
"os"
7
+ "os/signal"
7
8
"strconv"
8
9
"strings"
9
10
@@ -80,7 +81,7 @@ func NewRootCmd() *cobra.Command {
80
81
PersistentPostRun : func (cmd * cobra.Command , args []string ) {
81
82
ctx := cmd .Context ()
82
83
ctx .Value (spanKey ).(trace.Span ).End ()
83
- ctx .Value (systemManagerKey ).(* system.CleanupManager ).Cleanup (cmd . Context () )
84
+ ctx .Value (systemManagerKey ).(* system.CleanupManager ).Cleanup (ctx )
84
85
},
85
86
}
86
87
// ====== Start a job
@@ -138,9 +139,12 @@ Ignored if BACALHAU_API_PORT environment variable is set.`,
138
139
}
139
140
140
141
func Execute () {
141
- RootCmd := NewRootCmd ()
142
- // ANCHOR: Set global context here
143
- RootCmd .SetContext (context .Background ())
142
+ rootCmd := NewRootCmd ()
143
+
144
+ // Ensure commands are able to stop cleanly if someone presses ctrl+c
145
+ ctx , cancel := signal .NotifyContext (context .Background (), ShutdownSignals ... )
146
+ defer cancel ()
147
+ rootCmd .SetContext (ctx )
144
148
145
149
doNotTrack = false
146
150
if doNotTrackValue , foundDoNotTrack := os .LookupEnv ("DO_NOT_TRACK" ); foundDoNotTrack {
@@ -152,40 +156,36 @@ func Execute() {
152
156
153
157
viper .SetEnvPrefix ("BACALHAU" )
154
158
155
- err := viper .BindEnv ("API_HOST" )
156
- if err != nil {
157
- log .Ctx (RootCmd .Context ()).Fatal ().Msgf ("API_HOST was set, but could not bind." )
159
+ if err := viper .BindEnv ("API_HOST" ); err != nil {
160
+ log .Ctx (ctx ).Fatal ().Msgf ("API_HOST was set, but could not bind." )
158
161
}
159
162
160
- err = viper .BindEnv ("API_PORT" )
161
- if err != nil {
162
- log .Ctx (RootCmd .Context ()).Fatal ().Msgf ("API_PORT was set, but could not bind." )
163
+ if err := viper .BindEnv ("API_PORT" ); err != nil {
164
+ log .Ctx (ctx ).Fatal ().Msgf ("API_PORT was set, but could not bind." )
163
165
}
164
166
165
167
viper .AutomaticEnv ()
166
- envAPIHost := viper .Get ("API_HOST" )
167
- envAPIPort := viper .Get ("API_PORT" )
168
168
169
- if envAPIHost != nil && envAPIHost != "" {
170
- apiHost = envAPIHost .( string )
169
+ if envAPIHost := viper . GetString ( "API_HOST" ); envAPIHost != "" {
170
+ apiHost = envAPIHost
171
171
}
172
172
173
- if envAPIPort != nil && envAPIPort != "" {
173
+ if envAPIPort := viper . GetString ( "API_PORT" ); envAPIPort != "" {
174
174
var parseErr error
175
- apiPort , parseErr = strconv .Atoi (envAPIPort .( string ) )
175
+ apiPort , parseErr = strconv .Atoi (envAPIPort )
176
176
if parseErr != nil {
177
- log .Ctx (RootCmd . Context () ).Fatal ().Msgf ("could not parse API_PORT into an int. %s" , envAPIPort )
177
+ log .Ctx (ctx ).Fatal ().Msgf ("could not parse API_PORT into an int. %s" , envAPIPort )
178
178
}
179
179
}
180
180
181
181
// Use stdout, not stderr for cmd.Print output, so that
182
182
// e.g. ID=$(bacalhau run) works
183
- RootCmd .SetOut (system .Stdout )
183
+ rootCmd .SetOut (system .Stdout )
184
184
// TODO this is from fixing a deprecation warning for SetOutput. Shouldn't this be system.Stderr?
185
- RootCmd .SetErr (system .Stdout )
185
+ rootCmd .SetErr (system .Stdout )
186
186
187
- if err := RootCmd .Execute (); err != nil {
188
- Fatal (RootCmd , err .Error (), 1 )
187
+ if err := rootCmd .Execute (); err != nil {
188
+ Fatal (rootCmd , err .Error (), 1 )
189
189
}
190
190
}
191
191
@@ -197,6 +197,8 @@ var systemManagerKey = contextKey{name: "context key for storing the system mana
197
197
var spanKey = contextKey {name : "context key for storing the root span" }
198
198
199
199
func checkVersion (cmd * cobra.Command , args []string ) error {
200
+ ctx := cmd .Context ()
201
+
200
202
// corba doesn't do PersistentPreRun{,E} chaining yet
201
203
// https://github.com/spf13/cobra/issues/252
202
204
root := cmd
@@ -205,8 +207,8 @@ func checkVersion(cmd *cobra.Command, args []string) error {
205
207
root .PersistentPreRun (cmd , args )
206
208
207
209
// Check that the server version is compatible with the client version
208
- serverVersion , _ := GetAPIClient ().Version (cmd . Context () ) // Ok if this fails, version validation will skip
209
- if err := ensureValidVersion (cmd . Context () , version .Get (), serverVersion ); err != nil {
210
+ serverVersion , _ := GetAPIClient ().Version (ctx ) // Ok if this fails, version validation will skip
211
+ if err := ensureValidVersion (ctx , version .Get (), serverVersion ); err != nil {
210
212
Fatal (cmd , fmt .Sprintf ("version validation failed: %s" , err ), 1 )
211
213
return err
212
214
}
0 commit comments