Skip to content

Commit d5978ff

Browse files
authored
Validate subdomain for http tunnels (#138)
* Validate subdomain for http tunnels * Add subdomain validation tests
1 parent b50c75d commit d5978ff

File tree

5 files changed

+123
-1
lines changed

5 files changed

+123
-1
lines changed

tunnel/cmd/portr/start.go

+6
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,15 @@ func startTunnels(c *cli.Context, tunnelFromCli *config.Tunnel) error {
2727

2828
if tunnelFromCli != nil {
2929
tunnelFromCli.SetDefaults()
30+
if err := tunnelFromCli.Validate(); err != nil {
31+
return err
32+
}
3033
_c.ReplaceTunnelsFromCli(*tunnelFromCli)
3134
err = _c.Start(c.Context)
3235
} else {
36+
if err := config.Validate(); err != nil {
37+
return err
38+
}
3339
err = _c.Start(c.Context, c.Args().Slice()...)
3440
}
3541

tunnel/internal/client/config/config.go

+20
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ func (t *Tunnel) SetDefaults() {
4040
}
4141
}
4242

43+
func (t *Tunnel) Validate() error {
44+
if t.Type == constants.Http {
45+
if err := utils.ValidateSubdomain(t.Subdomain); err != nil {
46+
return err
47+
}
48+
}
49+
50+
return nil
51+
}
52+
4353
func (t *Tunnel) GetLocalAddr() string {
4454
return t.Host + ":" + fmt.Sprint(t.Port)
4555
}
@@ -84,6 +94,16 @@ func (c *Config) SetDefaults() {
8494
}
8595
}
8696

97+
func (c Config) Validate() error {
98+
for _, tunnel := range c.Tunnels {
99+
if err := tunnel.Validate(); err != nil {
100+
return err
101+
}
102+
}
103+
104+
return nil
105+
}
106+
87107
func (c Config) GetAdminAddress() string {
88108
protocol := "http"
89109
if !c.UseLocalHost {

tunnel/internal/client/ssh/ssh.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (s *SshClient) createNewConnection() (string, error) {
7777
if s.config.Debug {
7878
log.Error("Failed to create new connection", "error", reqErr)
7979
}
80-
return "", fmt.Errorf(reqErr.Message)
80+
return "", fmt.Errorf("server error: %s", reqErr.Message)
8181
}
8282
return response.ConnectionId, nil
8383
}

tunnel/internal/utils/subdomain.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package utils
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
)
7+
8+
func ValidateSubdomain(subdomain string) error {
9+
matched, err := regexp.Match(`^[a-zA-Z0-9][-a-zA-Z0-9_]{0,61}[a-zA-Z0-9]$`, []byte(subdomain))
10+
if err != nil {
11+
return fmt.Errorf("error validating subdomain: %v", err)
12+
}
13+
if !matched {
14+
return fmt.Errorf("invalid subdomain '%s'. Must not contain special characters other than '-', `_`", subdomain)
15+
}
16+
17+
return nil
18+
}
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package utils
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestValidateSubdomain(t *testing.T) {
8+
tests := []struct {
9+
name string
10+
subdomain string
11+
wantErr bool
12+
}{
13+
{
14+
name: "valid subdomain",
15+
subdomain: "test",
16+
wantErr: false,
17+
},
18+
{
19+
name: "subdomain with dash",
20+
subdomain: "test-subdomain",
21+
wantErr: false,
22+
},
23+
{
24+
name: "subdomain with underscore",
25+
subdomain: "test_subdomain",
26+
wantErr: false,
27+
},
28+
{
29+
name: "subdomain with uppercase letters",
30+
subdomain: "TestSubdomain",
31+
wantErr: false,
32+
},
33+
{
34+
name: "subdomain with leading dash",
35+
subdomain: "-test",
36+
wantErr: true,
37+
},
38+
{
39+
name: "subdomain with trailing dash",
40+
subdomain: "test-",
41+
wantErr: true,
42+
},
43+
{
44+
name: "subdomain with leading underscore",
45+
subdomain: "_test",
46+
wantErr: true,
47+
},
48+
{
49+
name: "subdomain with trailing underscore",
50+
subdomain: "test_",
51+
wantErr: true,
52+
},
53+
{
54+
name: "subdomain with special characters",
55+
subdomain: "test@subdomain",
56+
wantErr: true,
57+
},
58+
{
59+
name: "subdomain with dot",
60+
subdomain: "test.subdomain",
61+
wantErr: true,
62+
},
63+
{
64+
name: "subdomain with multiple dots",
65+
subdomain: "test.subdomain.com",
66+
wantErr: true,
67+
},
68+
}
69+
70+
for _, test := range tests {
71+
t.Run(test.name, func(t *testing.T) {
72+
err := ValidateSubdomain(test.subdomain)
73+
if (err != nil) != test.wantErr {
74+
t.Errorf("ValidateSubdomain(%q) = %v, wantErr %v", test.subdomain, err, test.wantErr)
75+
}
76+
})
77+
}
78+
}

0 commit comments

Comments
 (0)