Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: replace use of setec.Watcher with setec.Updater #31

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/google/go-cmp v0.6.0
github.com/klauspost/compress v1.17.8
github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a
github.com/tailscale/setec v0.0.0-20240729215356-5eb656b60dfe
github.com/tailscale/setec v0.0.0-20240924182055-66c76d47f816
github.com/tailscale/squibble v0.0.0-20240909231413-32a80b9743f7
honnef.co/go/tools v0.5.1
modernc.org/sqlite v1.29.10
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4
github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0=
github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w=
github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU=
github.com/tailscale/setec v0.0.0-20240729215356-5eb656b60dfe h1:uKpae9D8yEuqUuEqys45NYo3xFcEsBrJBX7JWilAwGc=
github.com/tailscale/setec v0.0.0-20240729215356-5eb656b60dfe/go.mod h1:6xMcr3yo4pQchoVF7O+Az9A2D6M+9SD1Y8an+uy1ZoA=
github.com/tailscale/setec v0.0.0-20240924182055-66c76d47f816 h1:rIRp7ytaQ1sjHlBUFocC1MsFnHJD43fnGg1Rwgql0F8=
github.com/tailscale/setec v0.0.0-20240924182055-66c76d47f816/go.mod h1:nexjfRM8veJVJ5PTbqYI2YrUj/jbk3deffEHO3DH9Q4=
github.com/tailscale/squibble v0.0.0-20240909231413-32a80b9743f7 h1:nfklwaP8uNz2IbUygSKOQ1aDzzRRRLaIbPpnQWUUMGc=
github.com/tailscale/squibble v0.0.0-20240909231413-32a80b9743f7/go.mod h1:YH/J7n7jNZOq10nTxxPANv2ha/Eg47/6J5b7NnOYAhQ=
github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g=
Expand Down
6 changes: 4 additions & 2 deletions server/tailsql/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package tailsql

import (
"context"
"database/sql"
"os"
"testing"
Expand Down Expand Up @@ -90,13 +91,14 @@ func TestOptions(t *testing.T) {

// Test that we can populate options from the config.
t.Run("Options", func(t *testing.T) {
dbs, err := opts.openSources(nil)
dbs, err := opts.openSources(context.Background(), nil)
if err != nil {
t.Fatalf("Options: unexpected error: %v", err)
}

// The handles should be equinumerous and in the same order as the config.
for i, h := range dbs {
for i, u := range dbs {
h := u.Get()
if got, want := h.Source(), opts.Sources[i].Source; got != want {
t.Errorf("Database %d: got src %q, want %q", i+1, got, want)
}
Expand Down
76 changes: 34 additions & 42 deletions server/tailsql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,58 @@ func (o Options) checkQuery() func(Query) (Query, error) {
// openSources opens database handles to each of the sources defined by o.
// Sources that require secrets will get them from store.
// Precondition: All the sources of o have already been validated.
func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
func (o Options) openSources(ctx context.Context, store *setec.Store) ([]*setec.Updater[*dbHandle], error) {
if len(o.Sources) == 0 {
return nil, nil
}

srcs := make([]*dbHandle, len(o.Sources))
srcs := make([]*setec.Updater[*dbHandle], len(o.Sources))
for i, spec := range o.Sources {
if spec.Label == "" {
spec.Label = "(unidentified database)"
}

// Case 1: A programmatic source.
if spec.DB != nil {
srcs[i] = &dbHandle{
srcs[i] = setec.StaticUpdater(&dbHandle{
src: spec.Source,
label: spec.Label,
named: spec.Named,
db: spec.DB,
})
continue
}

// Case 2: A database managed by database/sql, with a secret from setec.
if spec.Secret != "" {
// We actually only maintain a single value, that is updated in-place.
h := &dbHandle{src: spec.Source, label: spec.Label, named: spec.Named}
u, err := setec.NewUpdater(ctx, store, spec.Secret, func(secret []byte) (*dbHandle, error) {
db, err := openAndPing(spec.Driver, string(secret))
if err != nil {
return nil, err
}
o.logf()("[tailsql] opened new connection for source %q", spec.Source)
h.mu.Lock()
defer h.mu.Unlock()
if h.db != nil {
h.db.Close() // close the active handle
}
if up := h.checkUpdate(); up != nil {
up.newDB.Close() // close a previous pending update
}
h.db = sqlDB{DB: db}
return h, nil
})
if err != nil {
return nil, err
}
srcs[i] = u
continue
}

// Case 2: A database managed by database/sql.
//
// Resolve the connection string.
// Case 3: A database managed by database/sql, with a fixed URL.
var connString string
var w setec.Watcher
switch {
case spec.URL != "":
connString = spec.URL
Expand All @@ -148,9 +173,6 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
return nil, fmt.Errorf("read key file for %q: %w", spec.Source, err)
}
connString = strings.TrimSpace(string(data))
case spec.Secret != "":
w = store.Watcher(spec.Secret)
connString = string(w.Get())
default:
panic("unexpected: no connection source is defined after validation")
}
Expand All @@ -160,16 +182,13 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
if err != nil {
return nil, err
}
srcs[i] = &dbHandle{
srcs[i] = setec.StaticUpdater(&dbHandle{
src: spec.Source,
driver: spec.Driver,
label: spec.Label,
named: spec.Named,
db: sqlDB{DB: db},
}
if spec.Secret != "" {
go srcs[i].handleUpdates(spec.Secret, w, o.logf())
}
})
}
return srcs, nil
}
Expand Down Expand Up @@ -325,33 +344,6 @@ type dbHandle struct {
named map[string]string
}

// handleUpdates polls w indefinitely for updates to the connection string for
// h, and reopens the database with the new string when a new value arrives.
// This method should be called in a goroutine.
func (h *dbHandle) handleUpdates(name string, w setec.Watcher, logf logger.Logf) {
logf("[tailsql] starting updater for secret %q", name)
for range w.Ready() {
// N.B. Don't log the secret value itself. It's fine to log the name of
// the secret and the source, those are already in the config.
connString := string(w.Get())
db, err := openAndPing(h.driver, connString)
if err != nil {
logf("WARNING: opening new database for %q: %v", h.src, err)
continue
}
logf("[tailsql] opened new connection for source %q", h.src)
h.mu.Lock()
// Close the existing active handle.
h.db.Close()
// If there's a pending update, close it too.
if up := h.checkUpdate(); up != nil {
up.newDB.Close()
}
h.db = sqlDB{DB: db}
h.mu.Unlock()
}
}

// checkUpdate returns nil if there is no pending update, otherwise it swaps
// out the pending database update and returns it.
func (h *dbHandle) checkUpdate() *dbUpdate {
Expand Down
26 changes: 15 additions & 11 deletions server/tailsql/tailsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import (
"time"
"unicode/utf8"

"github.com/tailscale/setec/client/setec"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/types/logger"
"tailscale.com/util/httpm"
Expand Down Expand Up @@ -119,7 +120,7 @@ type Server struct {
logf logger.Logf

mu sync.Mutex
dbs []*dbHandle
dbs []*setec.Updater[*dbHandle]
}

// NewServer constructs a new server with the given Options.
Expand All @@ -134,7 +135,7 @@ func NewServer(opts Options) (*Server, error) {
return nil, fmt.Errorf("have %d named secrets but no secret store", len(sec))
}

dbs, err := opts.openSources(opts.SecretStore)
dbs, err := opts.openSources(context.Background(), opts.SecretStore)
if err != nil {
return nil, fmt.Errorf("opening sources: %w", err)
}
Expand All @@ -143,14 +144,14 @@ func NewServer(opts Options) (*Server, error) {
return nil, fmt.Errorf("local state: %w", err)
}
if state != nil && opts.LocalSource != "" {
dbs = append(dbs, &dbHandle{
dbs = append(dbs, setec.StaticUpdater(&dbHandle{
src: opts.LocalSource,
label: "tailsql local state",
db: state,
named: map[string]string{
"schema": `select * from sqlite_schema`,
},
})
}))
}

if opts.Metrics != nil {
Expand Down Expand Up @@ -192,18 +193,18 @@ func (s *Server) SetSource(source string, db Queryable, opts *DBOptions) bool {
s.mu.Lock()
defer s.mu.Unlock()

for _, src := range s.dbs {
if src.Source() == source {
for _, u := range s.dbs {
if src := u.Get(); src.Source() == source {
src.swap(db, opts)
return true
}
}
s.dbs = append(s.dbs, &dbHandle{
s.dbs = append(s.dbs, setec.StaticUpdater(&dbHandle{
db: db,
src: source,
label: opts.label(),
named: opts.namedQueries(),
})
}))
return false
}

Expand Down Expand Up @@ -613,12 +614,15 @@ func (s *Server) getHandles() []*dbHandle {
s.mu.Lock()
defer s.mu.Unlock()

out := make([]*dbHandle, len(s.dbs))

// Check for pending updates.
for _, h := range s.dbs {
h.tryUpdate()
for i, u := range s.dbs {
out[i] = u.Get()
out[i].tryUpdate()
}

// It is safe to return the slice because we never remove any elements, new
// data are only ever appended to the end.
return s.dbs
return out
}
63 changes: 57 additions & 6 deletions server/tailsql/tailsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package tailsql_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"html"
"html/template"
"io"
"math/rand/v2"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -128,10 +130,17 @@ var testUIRules = []tailsql.UIRewriteRule{
}

func TestSecrets(t *testing.T) {
// Register a fake driver so we can probe for connection URLs.
// We have to use a new name each time, because there is no way to
// unregister and duplicate names trigger a panic.
driver := new(fakeDriver)
driverName := fmt.Sprintf("%s-driver-%d", t.Name(), rand.Int())
sql.Register(driverName, driver)
t.Logf("Test driver name is %q", driverName)

const secretName = "connection-string"
url, _ := mustInitSQLite(t)
db := setectest.NewDB(t, nil)
db.MustPut(db.Superuser, secretName, url)
db.MustPut(db.Superuser, secretName, "string 1")

ss := setectest.NewServer(t, db, nil)
hs := httptest.NewServer(ss.Mux)
Expand All @@ -141,17 +150,23 @@ func TestSecrets(t *testing.T) {
Sources: []tailsql.DBSpec{{
Source: "test",
Label: "Test Database",
Driver: "sqlite",
Driver: driverName,
Secret: secretName,
}},
RoutePrefix: "/tsql",
}

// Verify we found the expected secret names in the options.
secrets, err := opts.CheckSources()
if err != nil {
t.Fatalf("Invalid sources: %v", err)
}

tick := setectest.NewFakeTicker()
st, err := setec.NewStore(context.Background(), setec.StoreConfig{
Client: setec.Client{Server: hs.URL},
Secrets: secrets,
Client: setec.Client{Server: hs.URL},
Secrets: secrets,
PollTicker: tick,
})
if err != nil {
t.Fatalf("Creating setec store: %v", err)
Expand All @@ -162,7 +177,28 @@ func TestSecrets(t *testing.T) {
if err != nil {
t.Fatalf("Creating tailsql server: %v", err)
}
ts.Close()
ss.Mux.Handle("/tsql/", ts.NewMux()) // so we can call /meta below
defer ts.Close()

// After opening the server, the database should have the initial secret
// value provided on initialization.
if got, want := driver.OpenedURL, "string 1"; got != want {
t.Errorf("Initial URL: got %q, want %q", got, want)
}

// Update the secret.
db.MustActivate(db.Superuser, secretName, db.MustPut(db.Superuser, secretName, "string 2"))
tick.Poll()

// Make the database fetch the latest value.
if _, err := hs.Client().Get(hs.URL + "/tsql/meta"); err != nil {
t.Errorf("Get tailsql meta: %v", err)
}

// After the update, the database should have the new secret value.
if got, want := driver.OpenedURL, "string 2"; got != want {
t.Errorf("Updated URL: got %q, want %q", got, want)
}
}

func TestServer(t *testing.T) {
Expand Down Expand Up @@ -567,3 +603,18 @@ func TestRoutePrefix(t *testing.T) {
}
})
}

type fakeDriver struct {
OpenedURL string
}

func (f *fakeDriver) Open(url string) (driver.Conn, error) {
f.OpenedURL = url
return fakeConn{}, nil
}

// fakeConn is a fake implementation of driver.Conn to satisfy the interface,
// it will panic if actually used.
type fakeConn struct{ driver.Conn }

func (fakeConn) Close() error { return nil }