Skip to content

Commit

Permalink
Add support for count flag in vm create
Browse files Browse the repository at this point in the history
  • Loading branch information
jdewinne authored Oct 14, 2024
1 parent f2a6981 commit 2e1728e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 32 deletions.
65 changes: 42 additions & 23 deletions cli/cmd/vm_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ The command also supports a "--wait" flag to wait for the VMs to be ready before
Example: ` # Create a single Ubuntu 20.04 VM
replicated vm create --distribution ubuntu --version 20.04
# Create 3 RHEL 9 VMs
replicated vm create --distribution ubuntu --version 20.04 --count 3
# Create 3 Ubuntu 22.04 VMs
replicated vm create --distribution ubuntu --version 22.04 --count 3
# Create 5 Ubuntu VMs with a custom instance type and disk size
replicated vm create --distribution ubuntu --version 20.04 --count 5 --instance-type r1.medium --disk 100`,
Expand Down Expand Up @@ -83,7 +83,7 @@ func (r *runners) createVM(_ *cobra.Command, args []string) error {
DryRun: r.args.createVMDryRun,
}

vm, err := r.createAndWaitForVM(opts)
vms, err := r.createAndWaitForVM(opts)
if err != nil {
if errors.Cause(err) == ErrVMWaitDurationExceeded {
defer func() {
Expand All @@ -95,7 +95,11 @@ func (r *runners) createVM(_ *cobra.Command, args []string) error {
}

if opts.DryRun {
estimatedCostMessage := fmt.Sprintf("Estimated cost: %s (if run to TTL of %s)", print.CreditsToDollarsDisplay(vm.EstimatedCost), vm.TTL)
// This should not happen, as count should be > 0
if len(vms) == 0 {
return errors.New("no vm will be created")
}
estimatedCostMessage := fmt.Sprintf("Estimated cost: %s (if run to TTL of %s)", print.CreditsToDollarsDisplay(vms[0].EstimatedCost), vms[0].TTL)
_, err := fmt.Fprintln(r.w, estimatedCostMessage)
if err != nil {
return err
Expand All @@ -104,11 +108,11 @@ func (r *runners) createVM(_ *cobra.Command, args []string) error {
return err
}

return print.VM(r.outputFormat, r.w, vm)
return print.VMs(r.outputFormat, r.w, vms, true)
}

func (r *runners) createAndWaitForVM(opts kotsclient.CreateVMOpts) (*types.VM, error) {
vm, ve, err := r.kotsAPI.CreateVM(opts)
func (r *runners) createAndWaitForVM(opts kotsclient.CreateVMOpts) ([]*types.VM, error) {
vms, ve, err := r.kotsAPI.CreateVM(opts)
if errors.Cause(err) == platformclient.ErrForbidden {
return nil, ErrCompatibilityMatrixTermsNotAccepted
} else if err != nil {
Expand All @@ -127,36 +131,51 @@ func (r *runners) createAndWaitForVM(opts kotsclient.CreateVMOpts) (*types.VM, e
}

if opts.DryRun {
return vm, nil
return vms, nil
}

// if the wait flag was provided, we poll the api until the vm is ready, or a timeout
if r.args.createVMWaitDuration > 0 {
return waitForVM(r.kotsAPI, vm.ID, r.args.createVMWaitDuration)
return waitForVMs(r.kotsAPI, vms, r.args.createVMWaitDuration)
}

return vm, nil
return vms, nil
}

func waitForVM(kotsRestClient *kotsclient.VendorV3Client, id string, duration time.Duration) (*types.VM, error) {
func waitForVMs(kotsRestClient *kotsclient.VendorV3Client, vms []*types.VM, duration time.Duration) ([]*types.VM, error) {
start := time.Now()
runningVMs := map[string]*types.VM{}
for {
vm, err := kotsRestClient.GetVM(id)
if err != nil {
return nil, errors.Wrap(err, "get vm")
}
for _, vm := range vms {
v, err := kotsRestClient.GetVM(vm.ID)
if err != nil {
return nil, errors.Wrap(err, "get vm")
}

if vm.Status == types.VMStatus(types.VMStatusRunning) {
return vm, nil
} else if vm.Status == types.VMStatus(types.VMStatusError) {
return nil, errors.New("vm failed to provision")
} else {
if time.Now().After(start.Add(duration)) {
// In case of timeout, return the vm and a WaitDurationExceeded error
return vm, ErrWaitDurationExceeded
if v.Status == types.VMStatus(types.VMStatusRunning) {
runningVMs[v.ID] = v
if len(runningVMs) == len(vms) {
return mapToSlice(runningVMs), nil
}
} else if vm.Status == types.VMStatus(types.VMStatusError) {
return nil, errors.New("vm failed to provision")
} else {
if time.Now().After(start.Add(duration)) {
// In case of timeout, return the vm and a WaitDurationExceeded error
return mapToSlice(runningVMs), ErrWaitDurationExceeded
}
}
}

time.Sleep(time.Second * 5)
}
}

// Convert map of VMs to slice of VMs
func mapToSlice(vms map[string]*types.VM) []*types.VM {
var slice []*types.VM
for _, v := range vms {
slice = append(slice, v)
}
return slice
}
20 changes: 11 additions & 9 deletions pkg/kotsclient/vm_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type CreateVMRequest struct {
}

type CreateVMResponse struct {
VM *types.VM `json:"vm"`
VMs []*types.VM `json:"vms"`
Errors []string `json:"errors"`
SupportedDistributions map[string]string `json:"supported_distributions"`
}
Expand Down Expand Up @@ -60,7 +60,7 @@ type VMValidationError struct {
SupportedDistributions []*types.VMVersion `json:"supported_distributions"`
}

func (c *VendorV3Client) CreateVM(opts CreateVMOpts) (*types.VM, *CreateVMErrorError, error) {
func (c *VendorV3Client) CreateVM(opts CreateVMOpts) ([]*types.VM, *CreateVMErrorError, error) {
req := CreateVMRequest{
Name: opts.Name,
Distribution: opts.Distribution,
Expand All @@ -78,7 +78,7 @@ func (c *VendorV3Client) CreateVM(opts CreateVMOpts) (*types.VM, *CreateVMErrorE
return c.doCreateVMRequest(req)
}

func (c *VendorV3Client) doCreateVMRequest(req CreateVMRequest) (*types.VM, *CreateVMErrorError, error) {
func (c *VendorV3Client) doCreateVMRequest(req CreateVMRequest) ([]*types.VM, *CreateVMErrorError, error) {
resp := CreateVMResponse{}
endpoint := "/v3/vm"
err := c.DoJSON("POST", endpoint, http.StatusCreated, req, &resp)
Expand All @@ -98,10 +98,10 @@ func (c *VendorV3Client) doCreateVMRequest(req CreateVMRequest) (*types.VM, *Cre
return nil, nil, err
}

return resp.VM, nil, nil
return resp.VMs, nil, nil
}

func (c *VendorV3Client) doCreateVMDryRunRequest(req CreateVMRequest) (*types.VM, *CreateVMErrorError, error) {
func (c *VendorV3Client) doCreateVMDryRunRequest(req CreateVMRequest) ([]*types.VM, *CreateVMErrorError, error) {
resp := CreateVMDryRunResponse{}
endpoint := "/v3/vm?dry-run=true"
err := c.DoJSON("POST", endpoint, http.StatusOK, req, &resp)
Expand All @@ -112,10 +112,12 @@ func (c *VendorV3Client) doCreateVMDryRunRequest(req CreateVMRequest) (*types.VM
if resp.Error.Message != "" {
return nil, &resp.Error, nil
}
cl := &types.VM{
EstimatedCost: *resp.TotalCost,
TTL: *resp.TTL,
vms := []*types.VM{
{
EstimatedCost: *resp.TotalCost,
TTL: *resp.TTL,
},
}

return cl, nil, nil
return vms, nil, nil
}

0 comments on commit 2e1728e

Please sign in to comment.