Skip to content

Commit

Permalink
Merge pull request #29 from loft-sh/feat/POD-221-ec2-instance-connect…
Browse files Browse the repository at this point in the history
…-endpoint

feat: add support for ec2 instance connect endpoints
  • Loading branch information
89luca89 authored Apr 17, 2024
2 parents 8c0f655 + 646938e commit 6014878
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 30 deletions.
79 changes: 79 additions & 0 deletions cmd/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package cmd
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"time"

"github.com/loft-sh/devpod-provider-aws/pkg/aws"
"github.com/loft-sh/devpod/pkg/log"
"github.com/loft-sh/devpod/pkg/provider"
"github.com/loft-sh/devpod/pkg/ssh"
devssh "github.com/loft-sh/devpod/pkg/ssh"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -69,6 +74,53 @@ func (cmd *CommandCmd) Run(
return fmt.Errorf("instance %s doesn't exist", providerAws.Config.MachineID)
}

if providerAws.Config.UseInstanceConnectEndpoint {
instanceID := *instance.Reservations[0].Instances[0].InstanceId
endpointID := providerAws.Config.InstanceConnectEndpointID

var err error
port, err := findAvailablePort()
if err != nil {
return err
}
addr := "localhost:" + port
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
connectArgs := []string{
"ec2-instance-connect",
"open-tunnel",
"--instance-id", instanceID,
"--local-port", port,
}
if endpointID != "" {
connectArgs = append(connectArgs, "--instance-connect-endpoint-id", endpointID)
}
cmd := exec.CommandContext(cancelCtx, "aws", connectArgs...)
// open tunnel in background
if err = cmd.Start(); err != nil {
return fmt.Errorf("start tunnel: %w", err)
}
defer func() {
err = cmd.Process.Kill()
}()

timeoutCtx, cancelFn := context.WithTimeout(ctx, 30*time.Second)
defer cancelFn()
waitForPort(timeoutCtx, addr)

client, err := devssh.NewSSHClient("devpod", addr, privateKey)
if err != nil {
return err
}

err = devssh.Run(ctx, client, command, os.Stdin, os.Stdout, os.Stderr)
if err != nil {
return err
}

return err
}

// try public ip
if instance.Reservations[0].Instances[0].PublicIpAddress != nil {
ip := *instance.Reservations[0].Instances[0].PublicIpAddress
Expand Down Expand Up @@ -104,3 +156,30 @@ func (cmd *CommandCmd) Run(
providerAws.Config.MachineID,
)
}

func waitForPort(ctx context.Context, addr string) {
for {
select {
case <-ctx.Done():
return
default:
l, err := net.Listen("tcp", addr)
if err != nil {
// port is taken
return
}
_ = l.Close()
time.Sleep(1 * time.Second)
}
}

}
func findAvailablePort() (string, error) {
l, err := net.Listen("tcp", ":0")
if err != nil {
return "", err
}
defer l.Close()

return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil
}
12 changes: 11 additions & 1 deletion hack/build.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ fi
GO_BUILD_CMD="go build"
GO_BUILD_LDFLAGS="-s -w"

BUILD_VERSION="prod"
for arg in "$@"; do
if [ "$arg" == "--dev" ]; then
BUILD_VERSION="dev"
break
fi
done

echo "Building version: ${BUILD_VERSION}"

if [[ -z "${PROVIDER_BUILD_PLATFORMS}" ]]; then
PROVIDER_BUILD_PLATFORMS="linux windows darwin"
fi
Expand Down Expand Up @@ -60,4 +70,4 @@ for OS in ${PROVIDER_BUILD_PLATFORMS[@]}; do
done

# generate provider.yaml
go run -mod vendor "${PROVIDER_ROOT}/hack/provider/main.go" ${RELEASE_VERSION} > "${PROVIDER_ROOT}/release/provider.yaml"
go run -mod vendor "${PROVIDER_ROOT}/hack/provider/main.go" ${RELEASE_VERSION} ${BUILD_VERSION} ${PROVIDER_ROOT} > "${PROVIDER_ROOT}/release/provider.yaml"
23 changes: 19 additions & 4 deletions hack/provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,27 @@ var checksumMap = map[string]string{
}

func main() {
if len(os.Args) != 2 {
if len(os.Args) != 4 {
fmt.Fprintln(os.Stderr, "Expected version as argument")
os.Exit(1)

return
}

content, err := os.ReadFile("./hack/provider/provider.yaml")
releaseVersion := os.Args[1]
buildVersion := os.Args[2]
projectRoot := os.Args[3]

content, err := os.ReadFile(providerConfigPath(buildVersion))
if err != nil {
panic(err)
}

replaced := strings.Replace(string(content), "##VERSION##", os.Args[1], -1)
replaced := strings.Replace(string(content), "##VERSION##", releaseVersion, -1)

if buildVersion == "dev" {
replaced = strings.Replace(replaced, "##PROJECT_ROOT##", projectRoot, -1)
}

for k, v := range checksumMap {
checksum, err := File(k)
Expand All @@ -53,11 +61,18 @@ func File(filePath string) (string, error) {
defer file.Close()

hash := sha256.New()

_, err = io.Copy(hash, file)
if err != nil {
return "", err
}

return strings.ToLower(hex.EncodeToString(hash.Sum(nil))), nil
}

func providerConfigPath(buildVersion string) string {
if buildVersion == "prod" {
return "./hack/provider/provider.yaml"
} else {
return "./hack/provider/provider-dev.yaml"
}
}
Loading

0 comments on commit 6014878

Please sign in to comment.