Skip to content

Commit 6014878

Browse files
authored
Merge pull request #29 from loft-sh/feat/POD-221-ec2-instance-connect-endpoint
feat: add support for ec2 instance connect endpoints
2 parents 8c0f655 + 646938e commit 6014878

File tree

7 files changed

+447
-30
lines changed

7 files changed

+447
-30
lines changed

cmd/command.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@ package cmd
33
import (
44
"context"
55
"fmt"
6+
"net"
67
"os"
8+
"os/exec"
9+
"strconv"
10+
"time"
711

812
"github.com/loft-sh/devpod-provider-aws/pkg/aws"
913
"github.com/loft-sh/devpod/pkg/log"
1014
"github.com/loft-sh/devpod/pkg/provider"
1115
"github.com/loft-sh/devpod/pkg/ssh"
16+
devssh "github.com/loft-sh/devpod/pkg/ssh"
1217
"github.com/spf13/cobra"
1318
)
1419

@@ -69,6 +74,53 @@ func (cmd *CommandCmd) Run(
6974
return fmt.Errorf("instance %s doesn't exist", providerAws.Config.MachineID)
7075
}
7176

77+
if providerAws.Config.UseInstanceConnectEndpoint {
78+
instanceID := *instance.Reservations[0].Instances[0].InstanceId
79+
endpointID := providerAws.Config.InstanceConnectEndpointID
80+
81+
var err error
82+
port, err := findAvailablePort()
83+
if err != nil {
84+
return err
85+
}
86+
addr := "localhost:" + port
87+
cancelCtx, cancel := context.WithCancel(ctx)
88+
defer cancel()
89+
connectArgs := []string{
90+
"ec2-instance-connect",
91+
"open-tunnel",
92+
"--instance-id", instanceID,
93+
"--local-port", port,
94+
}
95+
if endpointID != "" {
96+
connectArgs = append(connectArgs, "--instance-connect-endpoint-id", endpointID)
97+
}
98+
cmd := exec.CommandContext(cancelCtx, "aws", connectArgs...)
99+
// open tunnel in background
100+
if err = cmd.Start(); err != nil {
101+
return fmt.Errorf("start tunnel: %w", err)
102+
}
103+
defer func() {
104+
err = cmd.Process.Kill()
105+
}()
106+
107+
timeoutCtx, cancelFn := context.WithTimeout(ctx, 30*time.Second)
108+
defer cancelFn()
109+
waitForPort(timeoutCtx, addr)
110+
111+
client, err := devssh.NewSSHClient("devpod", addr, privateKey)
112+
if err != nil {
113+
return err
114+
}
115+
116+
err = devssh.Run(ctx, client, command, os.Stdin, os.Stdout, os.Stderr)
117+
if err != nil {
118+
return err
119+
}
120+
121+
return err
122+
}
123+
72124
// try public ip
73125
if instance.Reservations[0].Instances[0].PublicIpAddress != nil {
74126
ip := *instance.Reservations[0].Instances[0].PublicIpAddress
@@ -104,3 +156,30 @@ func (cmd *CommandCmd) Run(
104156
providerAws.Config.MachineID,
105157
)
106158
}
159+
160+
func waitForPort(ctx context.Context, addr string) {
161+
for {
162+
select {
163+
case <-ctx.Done():
164+
return
165+
default:
166+
l, err := net.Listen("tcp", addr)
167+
if err != nil {
168+
// port is taken
169+
return
170+
}
171+
_ = l.Close()
172+
time.Sleep(1 * time.Second)
173+
}
174+
}
175+
176+
}
177+
func findAvailablePort() (string, error) {
178+
l, err := net.Listen("tcp", ":0")
179+
if err != nil {
180+
return "", err
181+
}
182+
defer l.Close()
183+
184+
return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil
185+
}

hack/build.sh

100644100755
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ fi
2222
GO_BUILD_CMD="go build"
2323
GO_BUILD_LDFLAGS="-s -w"
2424

25+
BUILD_VERSION="prod"
26+
for arg in "$@"; do
27+
if [ "$arg" == "--dev" ]; then
28+
BUILD_VERSION="dev"
29+
break
30+
fi
31+
done
32+
33+
echo "Building version: ${BUILD_VERSION}"
34+
2535
if [[ -z "${PROVIDER_BUILD_PLATFORMS}" ]]; then
2636
PROVIDER_BUILD_PLATFORMS="linux windows darwin"
2737
fi
@@ -60,4 +70,4 @@ for OS in ${PROVIDER_BUILD_PLATFORMS[@]}; do
6070
done
6171

6272
# generate provider.yaml
63-
go run -mod vendor "${PROVIDER_ROOT}/hack/provider/main.go" ${RELEASE_VERSION} > "${PROVIDER_ROOT}/release/provider.yaml"
73+
go run -mod vendor "${PROVIDER_ROOT}/hack/provider/main.go" ${RELEASE_VERSION} ${BUILD_VERSION} ${PROVIDER_ROOT} > "${PROVIDER_ROOT}/release/provider.yaml"

hack/provider/main.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,27 @@ var checksumMap = map[string]string{
1818
}
1919

2020
func main() {
21-
if len(os.Args) != 2 {
21+
if len(os.Args) != 4 {
2222
fmt.Fprintln(os.Stderr, "Expected version as argument")
2323
os.Exit(1)
2424

2525
return
2626
}
2727

28-
content, err := os.ReadFile("./hack/provider/provider.yaml")
28+
releaseVersion := os.Args[1]
29+
buildVersion := os.Args[2]
30+
projectRoot := os.Args[3]
31+
32+
content, err := os.ReadFile(providerConfigPath(buildVersion))
2933
if err != nil {
3034
panic(err)
3135
}
3236

33-
replaced := strings.Replace(string(content), "##VERSION##", os.Args[1], -1)
37+
replaced := strings.Replace(string(content), "##VERSION##", releaseVersion, -1)
38+
39+
if buildVersion == "dev" {
40+
replaced = strings.Replace(replaced, "##PROJECT_ROOT##", projectRoot, -1)
41+
}
3442

3543
for k, v := range checksumMap {
3644
checksum, err := File(k)
@@ -53,11 +61,18 @@ func File(filePath string) (string, error) {
5361
defer file.Close()
5462

5563
hash := sha256.New()
56-
5764
_, err = io.Copy(hash, file)
5865
if err != nil {
5966
return "", err
6067
}
6168

6269
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))), nil
6370
}
71+
72+
func providerConfigPath(buildVersion string) string {
73+
if buildVersion == "prod" {
74+
return "./hack/provider/provider.yaml"
75+
} else {
76+
return "./hack/provider/provider-dev.yaml"
77+
}
78+
}

0 commit comments

Comments
 (0)