Skip to content

Commit 2834194

Browse files
authored
server: replace use of setec.Watcher with setec.Updater (#31)
Update to the latest version of setec, which removes the Watcher type from the public API. Instead, use the Updater to manage connections, which turns out to simplify the code a bit along the way.
1 parent 232c461 commit 2834194

File tree

6 files changed

+113
-64
lines changed

6 files changed

+113
-64
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ require (
99
github.com/google/go-cmp v0.6.0
1010
github.com/klauspost/compress v1.17.8
1111
github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a
12-
github.com/tailscale/setec v0.0.0-20240729215356-5eb656b60dfe
12+
github.com/tailscale/setec v0.0.0-20240924182055-66c76d47f816
1313
github.com/tailscale/squibble v0.0.0-20240909231413-32a80b9743f7
1414
honnef.co/go/tools v0.5.1
1515
modernc.org/sqlite v1.29.10

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4
199199
github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0=
200200
github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w=
201201
github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU=
202-
github.com/tailscale/setec v0.0.0-20240729215356-5eb656b60dfe h1:uKpae9D8yEuqUuEqys45NYo3xFcEsBrJBX7JWilAwGc=
203-
github.com/tailscale/setec v0.0.0-20240729215356-5eb656b60dfe/go.mod h1:6xMcr3yo4pQchoVF7O+Az9A2D6M+9SD1Y8an+uy1ZoA=
202+
github.com/tailscale/setec v0.0.0-20240924182055-66c76d47f816 h1:rIRp7ytaQ1sjHlBUFocC1MsFnHJD43fnGg1Rwgql0F8=
203+
github.com/tailscale/setec v0.0.0-20240924182055-66c76d47f816/go.mod h1:nexjfRM8veJVJ5PTbqYI2YrUj/jbk3deffEHO3DH9Q4=
204204
github.com/tailscale/squibble v0.0.0-20240909231413-32a80b9743f7 h1:nfklwaP8uNz2IbUygSKOQ1aDzzRRRLaIbPpnQWUUMGc=
205205
github.com/tailscale/squibble v0.0.0-20240909231413-32a80b9743f7/go.mod h1:YH/J7n7jNZOq10nTxxPANv2ha/Eg47/6J5b7NnOYAhQ=
206206
github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g=

server/tailsql/internal_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package tailsql
55

66
import (
7+
"context"
78
"database/sql"
89
"os"
910
"testing"
@@ -90,13 +91,14 @@ func TestOptions(t *testing.T) {
9091

9192
// Test that we can populate options from the config.
9293
t.Run("Options", func(t *testing.T) {
93-
dbs, err := opts.openSources(nil)
94+
dbs, err := opts.openSources(context.Background(), nil)
9495
if err != nil {
9596
t.Fatalf("Options: unexpected error: %v", err)
9697
}
9798

9899
// The handles should be equinumerous and in the same order as the config.
99-
for i, h := range dbs {
100+
for i, u := range dbs {
101+
h := u.Get()
100102
if got, want := h.Source(), opts.Sources[i].Source; got != want {
101103
t.Errorf("Database %d: got src %q, want %q", i+1, got, want)
102104
}

server/tailsql/options.go

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -112,33 +112,58 @@ func (o Options) checkQuery() func(Query) (Query, error) {
112112
// openSources opens database handles to each of the sources defined by o.
113113
// Sources that require secrets will get them from store.
114114
// Precondition: All the sources of o have already been validated.
115-
func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
115+
func (o Options) openSources(ctx context.Context, store *setec.Store) ([]*setec.Updater[*dbHandle], error) {
116116
if len(o.Sources) == 0 {
117117
return nil, nil
118118
}
119119

120-
srcs := make([]*dbHandle, len(o.Sources))
120+
srcs := make([]*setec.Updater[*dbHandle], len(o.Sources))
121121
for i, spec := range o.Sources {
122122
if spec.Label == "" {
123123
spec.Label = "(unidentified database)"
124124
}
125125

126126
// Case 1: A programmatic source.
127127
if spec.DB != nil {
128-
srcs[i] = &dbHandle{
128+
srcs[i] = setec.StaticUpdater(&dbHandle{
129129
src: spec.Source,
130130
label: spec.Label,
131131
named: spec.Named,
132132
db: spec.DB,
133+
})
134+
continue
135+
}
136+
137+
// Case 2: A database managed by database/sql, with a secret from setec.
138+
if spec.Secret != "" {
139+
// We actually only maintain a single value, that is updated in-place.
140+
h := &dbHandle{src: spec.Source, label: spec.Label, named: spec.Named}
141+
u, err := setec.NewUpdater(ctx, store, spec.Secret, func(secret []byte) (*dbHandle, error) {
142+
db, err := openAndPing(spec.Driver, string(secret))
143+
if err != nil {
144+
return nil, err
145+
}
146+
o.logf()("[tailsql] opened new connection for source %q", spec.Source)
147+
h.mu.Lock()
148+
defer h.mu.Unlock()
149+
if h.db != nil {
150+
h.db.Close() // close the active handle
151+
}
152+
if up := h.checkUpdate(); up != nil {
153+
up.newDB.Close() // close a previous pending update
154+
}
155+
h.db = sqlDB{DB: db}
156+
return h, nil
157+
})
158+
if err != nil {
159+
return nil, err
133160
}
161+
srcs[i] = u
134162
continue
135163
}
136164

137-
// Case 2: A database managed by database/sql.
138-
//
139-
// Resolve the connection string.
165+
// Case 3: A database managed by database/sql, with a fixed URL.
140166
var connString string
141-
var w setec.Watcher
142167
switch {
143168
case spec.URL != "":
144169
connString = spec.URL
@@ -148,9 +173,6 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
148173
return nil, fmt.Errorf("read key file for %q: %w", spec.Source, err)
149174
}
150175
connString = strings.TrimSpace(string(data))
151-
case spec.Secret != "":
152-
w = store.Watcher(spec.Secret)
153-
connString = string(w.Get())
154176
default:
155177
panic("unexpected: no connection source is defined after validation")
156178
}
@@ -160,16 +182,13 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
160182
if err != nil {
161183
return nil, err
162184
}
163-
srcs[i] = &dbHandle{
185+
srcs[i] = setec.StaticUpdater(&dbHandle{
164186
src: spec.Source,
165187
driver: spec.Driver,
166188
label: spec.Label,
167189
named: spec.Named,
168190
db: sqlDB{DB: db},
169-
}
170-
if spec.Secret != "" {
171-
go srcs[i].handleUpdates(spec.Secret, w, o.logf())
172-
}
191+
})
173192
}
174193
return srcs, nil
175194
}
@@ -325,33 +344,6 @@ type dbHandle struct {
325344
named map[string]string
326345
}
327346

328-
// handleUpdates polls w indefinitely for updates to the connection string for
329-
// h, and reopens the database with the new string when a new value arrives.
330-
// This method should be called in a goroutine.
331-
func (h *dbHandle) handleUpdates(name string, w setec.Watcher, logf logger.Logf) {
332-
logf("[tailsql] starting updater for secret %q", name)
333-
for range w.Ready() {
334-
// N.B. Don't log the secret value itself. It's fine to log the name of
335-
// the secret and the source, those are already in the config.
336-
connString := string(w.Get())
337-
db, err := openAndPing(h.driver, connString)
338-
if err != nil {
339-
logf("WARNING: opening new database for %q: %v", h.src, err)
340-
continue
341-
}
342-
logf("[tailsql] opened new connection for source %q", h.src)
343-
h.mu.Lock()
344-
// Close the existing active handle.
345-
h.db.Close()
346-
// If there's a pending update, close it too.
347-
if up := h.checkUpdate(); up != nil {
348-
up.newDB.Close()
349-
}
350-
h.db = sqlDB{DB: db}
351-
h.mu.Unlock()
352-
}
353-
}
354-
355347
// checkUpdate returns nil if there is no pending update, otherwise it swaps
356348
// out the pending database update and returns it.
357349
func (h *dbHandle) checkUpdate() *dbUpdate {

server/tailsql/tailsql.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ import (
6868
"time"
6969
"unicode/utf8"
7070

71+
"github.com/tailscale/setec/client/setec"
7172
"tailscale.com/client/tailscale/apitype"
7273
"tailscale.com/types/logger"
7374
"tailscale.com/util/httpm"
@@ -119,7 +120,7 @@ type Server struct {
119120
logf logger.Logf
120121

121122
mu sync.Mutex
122-
dbs []*dbHandle
123+
dbs []*setec.Updater[*dbHandle]
123124
}
124125

125126
// NewServer constructs a new server with the given Options.
@@ -134,7 +135,7 @@ func NewServer(opts Options) (*Server, error) {
134135
return nil, fmt.Errorf("have %d named secrets but no secret store", len(sec))
135136
}
136137

137-
dbs, err := opts.openSources(opts.SecretStore)
138+
dbs, err := opts.openSources(context.Background(), opts.SecretStore)
138139
if err != nil {
139140
return nil, fmt.Errorf("opening sources: %w", err)
140141
}
@@ -143,14 +144,14 @@ func NewServer(opts Options) (*Server, error) {
143144
return nil, fmt.Errorf("local state: %w", err)
144145
}
145146
if state != nil && opts.LocalSource != "" {
146-
dbs = append(dbs, &dbHandle{
147+
dbs = append(dbs, setec.StaticUpdater(&dbHandle{
147148
src: opts.LocalSource,
148149
label: "tailsql local state",
149150
db: state,
150151
named: map[string]string{
151152
"schema": `select * from sqlite_schema`,
152153
},
153-
})
154+
}))
154155
}
155156

156157
if opts.Metrics != nil {
@@ -192,18 +193,18 @@ func (s *Server) SetSource(source string, db Queryable, opts *DBOptions) bool {
192193
s.mu.Lock()
193194
defer s.mu.Unlock()
194195

195-
for _, src := range s.dbs {
196-
if src.Source() == source {
196+
for _, u := range s.dbs {
197+
if src := u.Get(); src.Source() == source {
197198
src.swap(db, opts)
198199
return true
199200
}
200201
}
201-
s.dbs = append(s.dbs, &dbHandle{
202+
s.dbs = append(s.dbs, setec.StaticUpdater(&dbHandle{
202203
db: db,
203204
src: source,
204205
label: opts.label(),
205206
named: opts.namedQueries(),
206-
})
207+
}))
207208
return false
208209
}
209210

@@ -613,12 +614,15 @@ func (s *Server) getHandles() []*dbHandle {
613614
s.mu.Lock()
614615
defer s.mu.Unlock()
615616

617+
out := make([]*dbHandle, len(s.dbs))
618+
616619
// Check for pending updates.
617-
for _, h := range s.dbs {
618-
h.tryUpdate()
620+
for i, u := range s.dbs {
621+
out[i] = u.Get()
622+
out[i].tryUpdate()
619623
}
620624

621625
// It is safe to return the slice because we never remove any elements, new
622626
// data are only ever appended to the end.
623-
return s.dbs
627+
return out
624628
}

server/tailsql/tailsql_test.go

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ package tailsql_test
66
import (
77
"context"
88
"database/sql"
9+
"database/sql/driver"
910
"errors"
1011
"fmt"
1112
"html"
1213
"html/template"
1314
"io"
15+
"math/rand/v2"
1416
"net/http"
1517
"net/http/httptest"
1618
"net/url"
@@ -128,10 +130,17 @@ var testUIRules = []tailsql.UIRewriteRule{
128130
}
129131

130132
func TestSecrets(t *testing.T) {
133+
// Register a fake driver so we can probe for connection URLs.
134+
// We have to use a new name each time, because there is no way to
135+
// unregister and duplicate names trigger a panic.
136+
driver := new(fakeDriver)
137+
driverName := fmt.Sprintf("%s-driver-%d", t.Name(), rand.Int())
138+
sql.Register(driverName, driver)
139+
t.Logf("Test driver name is %q", driverName)
140+
131141
const secretName = "connection-string"
132-
url, _ := mustInitSQLite(t)
133142
db := setectest.NewDB(t, nil)
134-
db.MustPut(db.Superuser, secretName, url)
143+
db.MustPut(db.Superuser, secretName, "string 1")
135144

136145
ss := setectest.NewServer(t, db, nil)
137146
hs := httptest.NewServer(ss.Mux)
@@ -141,17 +150,23 @@ func TestSecrets(t *testing.T) {
141150
Sources: []tailsql.DBSpec{{
142151
Source: "test",
143152
Label: "Test Database",
144-
Driver: "sqlite",
153+
Driver: driverName,
145154
Secret: secretName,
146155
}},
156+
RoutePrefix: "/tsql",
147157
}
158+
159+
// Verify we found the expected secret names in the options.
148160
secrets, err := opts.CheckSources()
149161
if err != nil {
150162
t.Fatalf("Invalid sources: %v", err)
151163
}
164+
165+
tick := setectest.NewFakeTicker()
152166
st, err := setec.NewStore(context.Background(), setec.StoreConfig{
153-
Client: setec.Client{Server: hs.URL},
154-
Secrets: secrets,
167+
Client: setec.Client{Server: hs.URL},
168+
Secrets: secrets,
169+
PollTicker: tick,
155170
})
156171
if err != nil {
157172
t.Fatalf("Creating setec store: %v", err)
@@ -162,7 +177,28 @@ func TestSecrets(t *testing.T) {
162177
if err != nil {
163178
t.Fatalf("Creating tailsql server: %v", err)
164179
}
165-
ts.Close()
180+
ss.Mux.Handle("/tsql/", ts.NewMux()) // so we can call /meta below
181+
defer ts.Close()
182+
183+
// After opening the server, the database should have the initial secret
184+
// value provided on initialization.
185+
if got, want := driver.OpenedURL, "string 1"; got != want {
186+
t.Errorf("Initial URL: got %q, want %q", got, want)
187+
}
188+
189+
// Update the secret.
190+
db.MustActivate(db.Superuser, secretName, db.MustPut(db.Superuser, secretName, "string 2"))
191+
tick.Poll()
192+
193+
// Make the database fetch the latest value.
194+
if _, err := hs.Client().Get(hs.URL + "/tsql/meta"); err != nil {
195+
t.Errorf("Get tailsql meta: %v", err)
196+
}
197+
198+
// After the update, the database should have the new secret value.
199+
if got, want := driver.OpenedURL, "string 2"; got != want {
200+
t.Errorf("Updated URL: got %q, want %q", got, want)
201+
}
166202
}
167203

168204
func TestServer(t *testing.T) {
@@ -567,3 +603,18 @@ func TestRoutePrefix(t *testing.T) {
567603
}
568604
})
569605
}
606+
607+
type fakeDriver struct {
608+
OpenedURL string
609+
}
610+
611+
func (f *fakeDriver) Open(url string) (driver.Conn, error) {
612+
f.OpenedURL = url
613+
return fakeConn{}, nil
614+
}
615+
616+
// fakeConn is a fake implementation of driver.Conn to satisfy the interface,
617+
// it will panic if actually used.
618+
type fakeConn struct{ driver.Conn }
619+
620+
func (fakeConn) Close() error { return nil }

0 commit comments

Comments
 (0)