feat(cache): add Cloudflare SigV4 S3 signature compatibility fix and compile locally
Some checks failed
Go CI with S3 Caching / build-and-test (push) Failing after 4s
Some checks failed
Go CI with S3 Caching / build-and-test (push) Failing after 4s
This commit is contained in:
267
go-cache-plugin-src/lib/gobuild/gobuild.go
Normal file
267
go-cache-plugin-src/lib/gobuild/gobuild.go
Normal file
@@ -0,0 +1,267 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package gobuild implements callbacks for a gocache.Server that store data
|
||||
// into an S3 bucket through a local directory.
|
||||
package gobuild
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creachadair/gocache"
|
||||
"github.com/creachadair/gocache/cachedir"
|
||||
"github.com/creachadair/taskgroup"
|
||||
"github.com/tailscale/go-cache-plugin/lib/s3util"
|
||||
)
|
||||
|
||||
// S3Cache implements callbacks for a gocache.Server using an S3 bucket for
|
||||
// backing store with a local directory for staging.
|
||||
//
|
||||
// # Remote Cache Layout
|
||||
//
|
||||
// Within the designated S3 bucket, keys are organized into two groups. Each
|
||||
// action is stored in a file named:
|
||||
//
|
||||
// [<prefix>/]action/<xx>/<action-id>
|
||||
//
|
||||
// Each output object is stored in a file named:
|
||||
//
|
||||
// [<prefix>/]output/<xx>/<object-id>
|
||||
//
|
||||
// The object and action IDs are encoded as lower-case hexadecimal strings,
|
||||
// with "<xx>" denoting the first two bytes of the ID to partition the space.
|
||||
//
|
||||
// The contents of each action file have the format:
|
||||
//
|
||||
// <output-id> <timestamp>
|
||||
//
|
||||
// where the object ID is hex encoded and the timestamp is Unix nanoseconds.
|
||||
// The object file contains just the binary data of the object.
|
||||
type S3Cache struct {
|
||||
// Local is the local cache directory where actions and objects are staged.
|
||||
// It must be non-nil. A local stage is required because the Go toolchain
|
||||
// needs direct access to read the files reported by the cache.
|
||||
// It is safe to use a tmpfs directory.
|
||||
Local *cachedir.Dir
|
||||
|
||||
// S3Client is the S3 client used to read and write cache entries to the
|
||||
// backing store. It must be non-nil.
|
||||
S3Client *s3util.Client
|
||||
|
||||
// KeyPrefix, if non-empty, is prepended to each key stored into S3, with an
|
||||
// intervening slash.
|
||||
KeyPrefix string
|
||||
|
||||
// MinUploadSize, if positive, defines a minimum object size in bytes below
|
||||
// which the cache will not write the object to S3.
|
||||
MinUploadSize int64
|
||||
|
||||
// UploadConcurrency, if positive, defines the maximum number of concurrent
|
||||
// tasks for writing cache entries to S3. If zero or negative, it uses
|
||||
// runtime.NumCPU.
|
||||
UploadConcurrency int
|
||||
|
||||
// Tracks tasks pushing cache writes to S3.
|
||||
initOnce sync.Once
|
||||
push *taskgroup.Group
|
||||
start func(taskgroup.Task)
|
||||
|
||||
getLocalHit expvar.Int // count of Get hits in the local cache
|
||||
getFaultHit expvar.Int // count of Get hits faulted in from S3
|
||||
getFaultMiss expvar.Int // count of Get faults that were misses
|
||||
putSkipSmall expvar.Int // count of "small" objects not written to S3
|
||||
putS3Found expvar.Int // count of objects not written to S3 because they were already present
|
||||
putS3Action expvar.Int // count of actions written to S3
|
||||
putS3Object expvar.Int // count of objects written to S3
|
||||
putS3Error expvar.Int // count of errors writing to S3
|
||||
}
|
||||
|
||||
func (s *S3Cache) init() {
|
||||
s.initOnce.Do(func() {
|
||||
s.push, s.start = taskgroup.New(nil).Limit(s.uploadConcurrency())
|
||||
})
|
||||
}
|
||||
|
||||
// Get implements the corresponding callback of the cache protocol.
|
||||
func (s *S3Cache) Get(ctx context.Context, actionID string) (outputID, diskPath string, _ error) {
|
||||
s.init()
|
||||
|
||||
objID, diskPath, err := s.Local.Get(ctx, actionID)
|
||||
if err == nil && objID != "" && diskPath != "" {
|
||||
s.getLocalHit.Add(1)
|
||||
return objID, diskPath, nil // cache hit, OK
|
||||
}
|
||||
|
||||
// Reaching here, either we got a cache miss or an error reading from local.
|
||||
// Try reading the action from S3.
|
||||
action, err := s.S3Client.GetData(ctx, s.actionKey(actionID))
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
s.getFaultMiss.Add(1)
|
||||
return "", "", nil // cache miss, OK
|
||||
}
|
||||
return "", "", fmt.Errorf("[s3] read action %s: %w", actionID, err)
|
||||
}
|
||||
|
||||
// We got an action hit remotely, try to update the local copy.
|
||||
outputID, mtime, err := parseAction(action)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
object, size, err := s.S3Client.Get(ctx, s.outputKey(outputID))
|
||||
if err != nil {
|
||||
// At this point we know the action exists, so if we can't read the
|
||||
// object report it as an error rather than a cache miss.
|
||||
return "", "", fmt.Errorf("[s3] read object %s: %w", outputID, err)
|
||||
}
|
||||
defer object.Close()
|
||||
s.getFaultHit.Add(1)
|
||||
|
||||
// Now we should have the body; poke it into the local cache. Preserve the
|
||||
// modification timestamp recorded with the original action.
|
||||
diskPath, err = s.Local.Put(ctx, gocache.Object{
|
||||
ActionID: actionID,
|
||||
OutputID: outputID,
|
||||
Size: size,
|
||||
Body: object,
|
||||
ModTime: mtime,
|
||||
})
|
||||
return outputID, diskPath, err
|
||||
}
|
||||
|
||||
// Put implements the corresponding callback of the cache protocol.
|
||||
func (s *S3Cache) Put(ctx context.Context, obj gocache.Object) (diskPath string, _ error) {
|
||||
s.init()
|
||||
|
||||
// Compute an etag so we can do a conditional put on the object data.
|
||||
// We do not rely on it as a secure checksum. The toolchain verifies the
|
||||
// content address against the bits we actually store.
|
||||
etr := s3util.NewETagReader(obj.Body)
|
||||
obj.Body = etr
|
||||
|
||||
diskPath, err := s.Local.Put(ctx, obj)
|
||||
if err != nil {
|
||||
return "", err // don't bother trying to forward it to the remote
|
||||
}
|
||||
if obj.Size < s.MinUploadSize {
|
||||
s.putSkipSmall.Add(1)
|
||||
return diskPath, nil // don't bother uploading this, it's too small
|
||||
}
|
||||
|
||||
// Try to push the record to S3 in the background.
|
||||
s.start(func() error {
|
||||
// Override the context with a separate timeout in case S3 is farkakte.
|
||||
sctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Stage 1: Maybe write the object. Do this before writing the action
|
||||
// record so we are less likely to get a spurious miss later.
|
||||
mtime, err := s.maybePutObject(sctx, obj.OutputID, diskPath, etr.ETag())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Stage 2: Write the action record.
|
||||
if err := s.S3Client.Put(ctx, s.actionKey(obj.ActionID),
|
||||
strings.NewReader(fmt.Sprintf("%s %d", obj.OutputID, mtime.UnixNano()))); err != nil {
|
||||
gocache.Logf(ctx, "write action %s: %v", obj.ActionID, err)
|
||||
return err
|
||||
}
|
||||
s.putS3Action.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
||||
return diskPath, nil
|
||||
}
|
||||
|
||||
// Close implements the corresponding callback of the cache protocol.
|
||||
func (s *S3Cache) Close(ctx context.Context) error {
|
||||
if s.push != nil {
|
||||
gocache.Logf(ctx, "waiting for uploads...")
|
||||
wstart := time.Now()
|
||||
s.push.Wait()
|
||||
gocache.Logf(ctx, "uploads complete (%v elapsed)", time.Since(wstart).Round(10*time.Microsecond))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMetrics implements the corresponding server callback.
|
||||
func (s *S3Cache) SetMetrics(_ context.Context, m *expvar.Map) {
|
||||
m.Set("get_local_hit", &s.getLocalHit)
|
||||
m.Set("get_fault_hit", &s.getFaultHit)
|
||||
m.Set("get_fault_miss", &s.getFaultMiss)
|
||||
m.Set("put_skip_small", &s.putSkipSmall)
|
||||
m.Set("put_s3_found", &s.putS3Found)
|
||||
m.Set("put_s3_action", &s.putS3Action)
|
||||
m.Set("put_s3_object", &s.putS3Object)
|
||||
m.Set("put_s3_error", &s.putS3Error)
|
||||
}
|
||||
|
||||
// maybePutObject writes the specified object contents to S3 if there is not
|
||||
// already a matching key with the same etag. It returns the modified time of
|
||||
// the object file, whether or not it was sent to S3.
|
||||
func (s *S3Cache) maybePutObject(ctx context.Context, outputID, diskPath, etag string) (time.Time, error) {
|
||||
f, err := os.Open(diskPath)
|
||||
if err != nil {
|
||||
gocache.Logf(ctx, "[s3] open local object %s: %v", outputID, err)
|
||||
return time.Time{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
written, err := s.S3Client.PutCond(ctx, s.outputKey(outputID), etag, f)
|
||||
if err != nil {
|
||||
s.putS3Error.Add(1)
|
||||
gocache.Logf(ctx, "[s3] put object %s: %v", outputID, err)
|
||||
return fi.ModTime(), err
|
||||
}
|
||||
if written {
|
||||
s.putS3Found.Add(1)
|
||||
return fi.ModTime(), nil // already present and matching
|
||||
}
|
||||
s.putS3Object.Add(1)
|
||||
return fi.ModTime(), nil
|
||||
}
|
||||
|
||||
// makeKey assembles a complete key from the specified parts, including the key
|
||||
// prefix if one is defined.
|
||||
func (s *S3Cache) makeKey(parts ...string) string {
|
||||
return path.Join(s.KeyPrefix, path.Join(parts...))
|
||||
}
|
||||
|
||||
func (s *S3Cache) actionKey(id string) string { return s.makeKey("action", id[:2], id) }
|
||||
func (s *S3Cache) outputKey(id string) string { return s.makeKey("output", id[:2], id) }
|
||||
|
||||
func (s *S3Cache) uploadConcurrency() int {
|
||||
if s.UploadConcurrency <= 0 {
|
||||
return runtime.NumCPU()
|
||||
}
|
||||
return s.UploadConcurrency
|
||||
}
|
||||
|
||||
func parseAction(data []byte) (outputID string, mtime time.Time, _ error) {
|
||||
fs := strings.Fields(string(data))
|
||||
if len(fs) != 2 {
|
||||
return "", time.Time{}, errors.New("invalid action record")
|
||||
}
|
||||
ts, err := strconv.ParseInt(fs[1], 10, 64)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
return fs[0], time.Unix(ts/1e9, ts%1e9), nil
|
||||
}
|
||||
323
go-cache-plugin-src/lib/modproxy/modproxy.go
Normal file
323
go-cache-plugin-src/lib/modproxy/modproxy.go
Normal file
@@ -0,0 +1,323 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package modproxy implements components of a Go module proxy that caches
|
||||
// files locally on disk, backed by objects in an S3 bucket.
|
||||
package modproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creachadair/atomicfile"
|
||||
"github.com/creachadair/taskgroup"
|
||||
"github.com/goproxy/goproxy"
|
||||
"github.com/tailscale/go-cache-plugin/lib/s3util"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
var _ goproxy.Cacher = (*S3Cacher)(nil)
|
||||
|
||||
// S3Cacher implements the [github.com/goproxy/goproxy.Cacher] interface using
|
||||
// a local disk cache backed by an S3 bucket.
|
||||
//
|
||||
// # Cache Layout
|
||||
//
|
||||
// Module cache files are stored under a SHA256 digest of the filename
|
||||
// presented to the cache, encoded as hex and partitioned by the first two
|
||||
// bytes of the digest:
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// SHA256("fizzlepug") → 160db4d719252162c87a9169e26deda33d2340770d0d540fd4c580c55008b2d6
|
||||
// <cache-dir>/module/16/160db4d719252162c87a9169e26deda33d2340770d0d540fd4c580c55008b2d6
|
||||
//
|
||||
// When files are stored in S3, the same naming convention is used, but with
|
||||
// the specified key prefix instead:
|
||||
//
|
||||
// <key-prefix>/module/16/0db4d719252162c87a9169e26deda33d2340770d0d540fd4c580c55008b2d6
|
||||
type S3Cacher struct {
|
||||
// Local is the path of a local cache directory where modules are cached.
|
||||
// It must be non-empty.
|
||||
Local string
|
||||
|
||||
// S3Client is the S3 client used to read and write cache entries to the
|
||||
// backing store. It must be non-nil.
|
||||
S3Client *s3util.Client
|
||||
|
||||
// KeyPrefix, if non-empty, is prepended to each key stored into S3, with an
|
||||
// intervening slash.
|
||||
KeyPrefix string
|
||||
|
||||
// MaxTasks, if positive, limits the number of concurrent tasks that may be
|
||||
// interacting with S3. If zero or negative, the default is
|
||||
// [runtime.NumCPU].
|
||||
MaxTasks int
|
||||
|
||||
// Logf, if non-nil, is used to write log messages. If nil, logs are
|
||||
// discarded.
|
||||
Logf func(string, ...any)
|
||||
|
||||
// LogRequests, if true, enables detailed (but noisy) debug logging of all
|
||||
// requests handled by the cache. Logs are written to Logf.
|
||||
//
|
||||
// Each result is presented in the format:
|
||||
//
|
||||
// B <op> "<name>" (<digest>)
|
||||
// E <op> "<name>", err=<error>, <time> elapsed
|
||||
//
|
||||
// Where the operations are "GET" and "PUT". The "B" line is when the
|
||||
// operation began, and "E" when it ended. When a GET operation successfully
|
||||
// faults in a result from S3, the log is:
|
||||
//
|
||||
// F GET "<name>" hit (<digest>)
|
||||
//
|
||||
// When a PUT operation finishes writing a value behind to S3, the log is:
|
||||
//
|
||||
// W PUT "<name>", err=<error>, <time> elapsed
|
||||
//
|
||||
LogRequests bool
|
||||
|
||||
// Tracks tasks interacting with S3 in the background.
|
||||
initOnce sync.Once
|
||||
tasks *taskgroup.Group
|
||||
start func(taskgroup.Task)
|
||||
sema *semaphore.Weighted
|
||||
|
||||
pathError expvar.Int // errors constructing file paths
|
||||
getRequest expvar.Int // total number of Get requests
|
||||
getLocalHit expvar.Int // get: hit in local directory
|
||||
getLocalMiss expvar.Int // get: miss in local directory
|
||||
getFaultHit expvar.Int // get: hit in S3
|
||||
getFaultMiss expvar.Int // get: miss in S3
|
||||
getLocalError expvar.Int // get: error reading the local directory
|
||||
getFaultError expvar.Int // get: error reading from S3
|
||||
getLocalBytes expvar.Int // get: total bytes fetched from the local directory
|
||||
getS3Bytes expvar.Int // get: total bytes fetched from S3
|
||||
putRequest expvar.Int // total number of Put requests
|
||||
putLocalHit expvar.Int // put: put of object already stored locally
|
||||
putLocalError expvar.Int // put: error writing the local directory
|
||||
putS3Error expvar.Int // put: error writing to S3
|
||||
putLocalBytes expvar.Int // put: total bytes written to the local directory
|
||||
putS3Bytes expvar.Int // put: total bytes written to S3
|
||||
}
|
||||
|
||||
func (c *S3Cacher) init() {
|
||||
c.initOnce.Do(func() {
|
||||
nt := c.MaxTasks
|
||||
if nt <= 0 {
|
||||
nt = runtime.NumCPU()
|
||||
}
|
||||
c.tasks, c.start = taskgroup.New(nil).Limit(nt)
|
||||
c.sema = semaphore.NewWeighted(int64(nt))
|
||||
})
|
||||
}
|
||||
|
||||
// Get implements a method of the goproxy.Cacher interface. It reports cache
|
||||
// hits out of the local directory if available, or faults in from S3.
|
||||
func (c *S3Cacher) Get(ctx context.Context, name string) (_ io.ReadCloser, oerr error) {
|
||||
c.init()
|
||||
c.getRequest.Add(1)
|
||||
start := time.Now()
|
||||
hash, path, err := c.makePath(name)
|
||||
|
||||
c.vlogf("mc B GET %q (%s)", name, hash)
|
||||
defer func() { c.vlogf("mc E GET %q, err=%v, %v elapsed", name, oerr, time.Since(start)) }()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check whether the file already exists locally.
|
||||
if rc, size, err := openReader(path); err == nil {
|
||||
c.getLocalHit.Add(1)
|
||||
c.getLocalBytes.Add(size)
|
||||
return rc, nil
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
c.getLocalMiss.Add(1)
|
||||
} else {
|
||||
c.getLocalError.Add(1)
|
||||
c.logf("get %q local: %v (treating as miss)", name, err)
|
||||
}
|
||||
|
||||
// Local cache miss, fault in from S3.
|
||||
if err := c.sema.Acquire(ctx, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.sema.Release(1)
|
||||
|
||||
obj, _, err := c.S3Client.Get(ctx, c.makeKey(hash))
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
c.getFaultMiss.Add(1)
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
c.getFaultError.Add(1)
|
||||
return nil, err
|
||||
}
|
||||
defer obj.Close()
|
||||
c.getFaultHit.Add(1)
|
||||
c.vlogf("mc F GET %q hit (%s)", name, hash)
|
||||
|
||||
if _, err := c.putLocal(ctx, name, path, obj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rc, _, err := openReader(path)
|
||||
return rc, err
|
||||
}
|
||||
|
||||
// putLocal reports whether the specified path already exists in the local
|
||||
// cache, and if not, writes data atomically into the path.
|
||||
func (c *S3Cacher) putLocal(ctx context.Context, name, path string, data io.Reader) (bool, error) {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
nw, err := atomicfile.WriteAll(path, data, 0644)
|
||||
c.putLocalBytes.Add(nw)
|
||||
if err != nil {
|
||||
c.putLocalError.Add(1)
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Put implements a method of the goproxy.Cacher interface. It stores data into
|
||||
// the local directory and then writes it back to S3 in the background.
|
||||
func (c *S3Cacher) Put(ctx context.Context, name string, data io.ReadSeeker) (oerr error) {
|
||||
c.init()
|
||||
c.putRequest.Add(1)
|
||||
start := time.Now()
|
||||
hash, path, err := c.makePath(name)
|
||||
|
||||
c.vlogf("mc B PUT %q (%s)", name, hash)
|
||||
defer func() { c.vlogf("mc E PUT %q, err=%v, %v elapsed", name, oerr, time.Since(start)) }()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok, err := c.putLocal(ctx, name, path, data); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
c.putLocalHit.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to push the object to S3 in the background.
|
||||
f, size, err := openFileSize(path)
|
||||
if err != nil {
|
||||
c.putLocalError.Add(1)
|
||||
return err
|
||||
}
|
||||
c.start(func() error {
|
||||
defer f.Close()
|
||||
start := time.Now()
|
||||
|
||||
// Override the context with a separate timeout in case S3 is farkakte.
|
||||
sctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := c.S3Client.Put(sctx, c.makeKey(hash), f); err != nil {
|
||||
c.putS3Error.Add(1)
|
||||
c.logf("[s3] put %q failed: %v", name, err)
|
||||
} else {
|
||||
c.putS3Bytes.Add(size)
|
||||
}
|
||||
c.vlogf("mc W PUT %q, err=%v %v elapsed", name, err, time.Since(start))
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close waits until all background updates are complete.
|
||||
func (c *S3Cacher) Close() error {
|
||||
c.init()
|
||||
return c.tasks.Wait()
|
||||
}
|
||||
|
||||
// Metrics returns a map of cacher metrics. The caller is responsible for
|
||||
// publishing these metrics.
|
||||
func (c *S3Cacher) Metrics() *expvar.Map {
|
||||
m := new(expvar.Map)
|
||||
m.Set("path_error", &c.pathError)
|
||||
m.Set("get_request", &c.getRequest)
|
||||
m.Set("get_local_hit", &c.getLocalHit)
|
||||
m.Set("get_local_miss", &c.getLocalMiss)
|
||||
m.Set("get_fault_hit", &c.getFaultHit)
|
||||
m.Set("get_fault_miss", &c.getFaultMiss)
|
||||
m.Set("get_local_error", &c.getLocalError)
|
||||
m.Set("get_local_bytes", &c.getLocalBytes)
|
||||
m.Set("get_s3_bytes", &c.getS3Bytes)
|
||||
m.Set("put_request", &c.putRequest)
|
||||
m.Set("put_local_hit", &c.putLocalHit)
|
||||
m.Set("put_local_error", &c.putLocalError)
|
||||
m.Set("put_s3_error", &c.putS3Error)
|
||||
m.Set("put_local_bytes", &c.putLocalBytes)
|
||||
m.Set("put_s3_bytes", &c.putS3Bytes)
|
||||
return m
|
||||
}
|
||||
|
||||
func hashName(name string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(name)))
|
||||
}
|
||||
|
||||
// makeKey assembles a complete S3 key from the specified parts, including the
|
||||
// key prefix if one is defined.
|
||||
func (c *S3Cacher) makeKey(hash string) string {
|
||||
return path.Join(c.KeyPrefix, hash[:2], hash)
|
||||
}
|
||||
|
||||
// makePath assembles a complete local cache path for the given name, creating
|
||||
// the enclosing directory if needed.
|
||||
func (c *S3Cacher) makePath(name string) (hash, path string, err error) {
|
||||
hash = hashName(name)
|
||||
path = filepath.Join(c.Local, hash[:2], hash)
|
||||
err = os.MkdirAll(filepath.Dir(path), 0755)
|
||||
if err != nil {
|
||||
c.pathError.Add(1)
|
||||
}
|
||||
return hash, path, err
|
||||
}
|
||||
|
||||
func (c *S3Cacher) logf(msg string, args ...any) {
|
||||
if c.Logf != nil {
|
||||
c.Logf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *S3Cacher) vlogf(msg string, args ...any) {
|
||||
if c.LogRequests {
|
||||
c.logf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func openReader(path string) (_ io.ReadCloser, size int64, _ error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), int64(len(data)), nil
|
||||
}
|
||||
|
||||
func openFileSize(path string) (io.ReadCloser, int64, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
return f, fi.Size(), nil
|
||||
}
|
||||
159
go-cache-plugin-src/lib/revproxy/cache.go
Normal file
159
go-cache-plugin-src/lib/revproxy/cache.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package revproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/creachadair/atomicfile"
|
||||
"github.com/creachadair/scheddle"
|
||||
"github.com/creachadair/taskgroup"
|
||||
)
|
||||
|
||||
// cacheLoadLocal reads cached headers and body from the local cache.
|
||||
func (s *Server) cacheLoadLocal(hash string) ([]byte, http.Header, error) {
|
||||
data, err := os.ReadFile(s.makePath(hash))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return parseCacheObject(data)
|
||||
}
|
||||
|
||||
// cacheStoreLocal writes the contents of body to the local cache.
|
||||
//
|
||||
// The file format is a plain-text section at the top recording a subset of the
|
||||
// response headers, followed by "\n\n", followed by the response body.
|
||||
func (s *Server) cacheStoreLocal(hash string, hdr http.Header, body []byte) error {
|
||||
path := s.makePath(hash)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return atomicfile.Tx(s.makePath(hash), 0644, func(f io.Writer) error {
|
||||
return writeCacheObject(f, hdr, body)
|
||||
})
|
||||
}
|
||||
|
||||
// cacheLoadS3 reads cached headers and body from the remote S3 cache.
|
||||
func (s *Server) cacheLoadS3(ctx context.Context, hash string) ([]byte, http.Header, error) {
|
||||
data, err := s.S3Client.GetData(ctx, s.makeKey(hash))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return parseCacheObject(data)
|
||||
}
|
||||
|
||||
// cacheStoreS3 returns a task that writes the contents of body to the remote
|
||||
// S3 cache.
|
||||
func (s *Server) cacheStoreS3(hash string, hdr http.Header, body []byte) taskgroup.Task {
|
||||
var buf bytes.Buffer
|
||||
writeCacheObject(&buf, hdr, body)
|
||||
nb := buf.Len()
|
||||
return func() error {
|
||||
sctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := s.S3Client.Put(sctx, s.makeKey(hash), &buf); err != nil {
|
||||
s.logf("[s3] put %q failed: %v", hash, err)
|
||||
s.rspPushError.Add(1)
|
||||
} else {
|
||||
s.rspPush.Add(1)
|
||||
s.rspPushBytes.Add(int64(nb))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// cacheLoadMemory reads cached headers and body from the memory cache.
|
||||
func (s *Server) cacheLoadMemory(hash string) ([]byte, http.Header, error) {
|
||||
e, ok := s.mcache.Get(hash)
|
||||
if !ok {
|
||||
return nil, nil, fs.ErrNotExist
|
||||
}
|
||||
return e.body, e.header, nil
|
||||
}
|
||||
|
||||
// cacheStoreMemory writes the contents of body to the memory cache.
|
||||
func (s *Server) cacheStoreMemory(hash string, maxAge time.Duration, hdr http.Header, body []byte) {
|
||||
s.mcache.Put(hash, memCacheEntry{
|
||||
header: trimCacheHeader(hdr),
|
||||
body: body,
|
||||
})
|
||||
s.expire.After(maxAge, scheddle.Run(func() {
|
||||
s.mcache.Remove(hash)
|
||||
}))
|
||||
}
|
||||
|
||||
var keepHeader = []string{
|
||||
"Cache-Control", "Content-Type", "Date", "Etag",
|
||||
}
|
||||
|
||||
func trimCacheHeader(h http.Header) http.Header {
|
||||
out := make(http.Header)
|
||||
for _, name := range keepHeader {
|
||||
if v := h.Get(name); v != "" {
|
||||
out.Set(name, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// parseCacheDbject parses cached object data to extract the body and headers.
|
||||
func parseCacheObject(data []byte) ([]byte, http.Header, error) {
|
||||
hdr, rest, ok := bytes.Cut(data, []byte("\n\n"))
|
||||
if !ok {
|
||||
return nil, nil, errors.New("invalid cache object: missing header")
|
||||
}
|
||||
h := make(http.Header)
|
||||
for _, line := range strings.Split(string(hdr), "\n") {
|
||||
name, value, ok := strings.Cut(line, ": ")
|
||||
if ok {
|
||||
h.Add(name, value)
|
||||
}
|
||||
}
|
||||
return rest, h, nil
|
||||
}
|
||||
|
||||
// writeCacheObject writes the specified response data into a cache object at w.
|
||||
func writeCacheObject(w io.Writer, h http.Header, body []byte) error {
|
||||
hprintf(w, h, "Content-Type", "application/octet-stream")
|
||||
hprintf(w, h, "Date", "")
|
||||
hprintf(w, h, "Etag", "")
|
||||
fmt.Fprint(w, "\n")
|
||||
_, err := w.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func hprintf(w io.Writer, h http.Header, name, fallback string) {
|
||||
if v := h.Get(name); v != "" {
|
||||
fmt.Fprintf(w, "%s: %s\n", name, v)
|
||||
} else if fallback != "" {
|
||||
fmt.Fprintf(w, "%s: %s\n", name, fallback)
|
||||
}
|
||||
}
|
||||
|
||||
// setXCacheInfo adds cache-specific headers to h.
|
||||
func setXCacheInfo(h http.Header, result, hash string) {
|
||||
h.Set("X-Cache", result)
|
||||
if hash != "" {
|
||||
h.Set("X-Cache-Id", hash[:12])
|
||||
}
|
||||
}
|
||||
|
||||
// memCacheEntry is the format of entries in the memory cache.
|
||||
type memCacheEntry struct {
|
||||
header http.Header
|
||||
body []byte
|
||||
}
|
||||
|
||||
func entrySize(e memCacheEntry) int64 { return int64(len(e.body)) }
|
||||
407
go-cache-plugin-src/lib/revproxy/revproxy.go
Normal file
407
go-cache-plugin-src/lib/revproxy/revproxy.go
Normal file
@@ -0,0 +1,407 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package revproxy implements a minimal HTTP reverse proxy that caches files
|
||||
// locally on disk, backed by objects in an S3 bucket.
|
||||
//
|
||||
// # Limitations
|
||||
//
|
||||
// By default, only objects marked "immutable" by the target server are
|
||||
// eligible to be cached. Volatile objects that specify a max-age are also
|
||||
// cached in-memory, but are not persisted on disk or in S3. If we think it's
|
||||
// worthwhile we can spend some time to add more elaborate cache pruning, but
|
||||
// for now we're doing the simpler thing.
|
||||
package revproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creachadair/mds/cache"
|
||||
"github.com/creachadair/mds/mapset"
|
||||
"github.com/creachadair/scheddle"
|
||||
"github.com/creachadair/taskgroup"
|
||||
"github.com/tailscale/go-cache-plugin/lib/s3util"
|
||||
)
|
||||
|
||||
// Server is a caching reverse proxy server that caches successful responses to
|
||||
// GET requests for certain designated domains.
|
||||
//
|
||||
// The host field of the request URL must match one of the configured targets.
|
||||
// If not, the request is rejected with HTTP 502 (Bad Gateway). Otherwise, the
|
||||
// request is forwarded. A successful response will be cached if the server's
|
||||
// Cache-Control does not include "no-store", and does include "immutable".
|
||||
//
|
||||
// In addition, a successful response that is not immutable and specifies a
|
||||
// max-age will be cached temporarily in-memory.
|
||||
//
|
||||
// # Cache Format
|
||||
//
|
||||
// A cached response is a file with a header section and the body, separated by
|
||||
// a blank line. Only a subset of response headers are saved.
|
||||
//
|
||||
// # Cache Responses
|
||||
//
|
||||
// For requests handled by the proxy, the response includes an "X-Cache" header
|
||||
// indicating how the response was obtained:
|
||||
//
|
||||
// - "hit, memory": The response was served out of the memory cache.
|
||||
// - "hit, local": The response was served out of the local cache.
|
||||
// - "hit, remote": The response was faulted in from S3.
|
||||
// - "fetch, cached": The response was forwarded to the target and cached.
|
||||
// - "fetch, uncached": The response was forwarded to the target and not cached.
|
||||
//
|
||||
// For results intersecting with the cache, it also reports a X-Cache-Id giving
|
||||
// the storage key of the cache object.
|
||||
type Server struct {
|
||||
// Targets is the list of hosts for which the proxy should forward requests.
|
||||
// Host names should be fully-qualified ("host.example.com").
|
||||
Targets []string
|
||||
|
||||
// Local is the path of a local cache directory where responses are cached.
|
||||
// It must be non-empty.
|
||||
Local string
|
||||
|
||||
// S3Client is the S3 client used to read and write cache entries to the
|
||||
// backing store. It must be non-nil
|
||||
S3Client *s3util.Client
|
||||
|
||||
// KeyPrefix, if non-empty, is prepended to each key stored into S3, with an
|
||||
// intervening slash.
|
||||
KeyPrefix string
|
||||
|
||||
// Logf, if non-nil, is used to write log messages. If nil, logs are
|
||||
// discarded.
|
||||
Logf func(string, ...any)
|
||||
|
||||
// LogRequests, if true, enables detailed (but noisy) debug logging of all
|
||||
// requests handled by the reverse proxy. Logs are written to Logf.
|
||||
//
|
||||
// Each request is presented in the format:
|
||||
//
|
||||
// B U:"<url>" H:<digest> C:<bool>
|
||||
// E H:<digest> <disposition> B:<bytes> (<time> elapsed)
|
||||
// - H:<digest> miss
|
||||
//
|
||||
// The "B" line is when the request began, and "E" when it was finished.
|
||||
// The abbreviated fields are:
|
||||
//
|
||||
// U: -- request URL
|
||||
// H: -- request URL digest (cache key)
|
||||
// C: -- whether the request is cacheable (true/false)
|
||||
// B: -- body size in bytes (for hits)
|
||||
//
|
||||
// The dispositions of a request are:
|
||||
//
|
||||
// hit mem -- cache hit in memory (volatile)
|
||||
// hit disk -- cache hit in local disk
|
||||
// hit S3 -- cache hit in S3 (faulted to disk)
|
||||
// fetch -- fetched from the origin server
|
||||
//
|
||||
// On fetches, the "RC" tag indicates whether the response is cacheable,
|
||||
// with "no" meaning it was not cached at all, "mem" meaning it was cached
|
||||
// as a short-lived volatile response in memory, and "yes" meaning it was
|
||||
// cached on disk (and S3).
|
||||
LogRequests bool
|
||||
|
||||
initOnce sync.Once
|
||||
tasks *taskgroup.Group
|
||||
start func(taskgroup.Task)
|
||||
mcache *cache.Cache[string, memCacheEntry] // short-lived mutable objects
|
||||
expire *scheddle.Queue // cache expirations
|
||||
|
||||
reqReceived expvar.Int // total requests received
|
||||
reqMemoryHit expvar.Int // hit in memory cache (volatile)
|
||||
reqLocalHit expvar.Int // hit in local cache
|
||||
reqLocalMiss expvar.Int // miss in local cache
|
||||
reqFaultHit expvar.Int // hit in remote (S3) cache
|
||||
reqFaultMiss expvar.Int // miss in remote (S3) cache
|
||||
reqForward expvar.Int // request forwarded directly to upstream
|
||||
rspSave expvar.Int // successful response saved in local cache
|
||||
rspSaveMem expvar.Int // response saved in memory cache
|
||||
rspSaveError expvar.Int // error saving to local cache
|
||||
rspSaveBytes expvar.Int // bytes written to local cache
|
||||
rspPush expvar.Int // successful response saved in S3
|
||||
rspPushError expvar.Int // error saving to S3
|
||||
rspPushBytes expvar.Int // bytes written to S3
|
||||
rspNotCached expvar.Int // response not cached anywhere
|
||||
}
|
||||
|
||||
func (s *Server) init() {
|
||||
s.initOnce.Do(func() {
|
||||
nt := runtime.NumCPU()
|
||||
s.tasks, s.start = taskgroup.New(nil).Limit(nt)
|
||||
s.mcache = cache.New(cache.LRU[string, memCacheEntry]().
|
||||
WithLimit(10 << 20).
|
||||
WithSizeFunc(entrySize),
|
||||
)
|
||||
s.expire = scheddle.NewQueue(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// Metrics returns a map of cache server metrics for s. The caller is
|
||||
// responsible to publish these metrics as desired.
|
||||
func (s *Server) Metrics() *expvar.Map {
|
||||
m := new(expvar.Map)
|
||||
m.Set("req_received", &s.reqReceived)
|
||||
m.Set("req_memory_hit", &s.reqMemoryHit)
|
||||
m.Set("req_local_hit", &s.reqLocalHit)
|
||||
m.Set("req_local_miss", &s.reqLocalMiss)
|
||||
m.Set("req_fault_hit", &s.reqFaultHit)
|
||||
m.Set("req_fault_miss", &s.reqFaultMiss)
|
||||
m.Set("req_forward", &s.reqForward)
|
||||
m.Set("rsp_save", &s.rspSave)
|
||||
m.Set("rsp_save_memory", &s.rspSaveMem)
|
||||
m.Set("rsp_save_error", &s.rspSaveError)
|
||||
m.Set("rsp_save_bytes", &s.rspSaveBytes)
|
||||
m.Set("rsp_push", &s.rspPush)
|
||||
m.Set("rsp_push_error", &s.rspPushError)
|
||||
m.Set("rsp_push_bytes", &s.rspPushBytes)
|
||||
m.Set("rsp_not_cached", &s.rspNotCached)
|
||||
return m
|
||||
}
|
||||
|
||||
// ServeHTTP implements the [http.Handler] interface for the proxy.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.init()
|
||||
s.reqReceived.Add(1)
|
||||
|
||||
// Check whether this request is to a target we are permitted to proxy for.
|
||||
if !hostMatchesTarget(r.Host, s.Targets) {
|
||||
s.logf("reject proxy request for non-target %q", r.Host)
|
||||
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
hash := hashRequestURL(r.URL)
|
||||
canCache := s.canCacheRequest(r)
|
||||
s.vlogf("rp B U:%q H:%s C:%v", r.URL, hash, canCache)
|
||||
start := time.Now()
|
||||
if canCache {
|
||||
// Check for a hit on this object in the memory cache.
|
||||
if data, hdr, err := s.cacheLoadMemory(hash); err == nil {
|
||||
s.reqMemoryHit.Add(1)
|
||||
setXCacheInfo(hdr, "hit, memory", hash)
|
||||
writeCachedResponse(w, hdr, data)
|
||||
s.vlogf("rp E H:%s hit mem B:%d (%v elapsed)", hash, len(data), time.Since(start))
|
||||
return
|
||||
}
|
||||
|
||||
// Check for a hit on this object in the local cache.
|
||||
if data, hdr, err := s.cacheLoadLocal(hash); err == nil {
|
||||
s.reqLocalHit.Add(1)
|
||||
setXCacheInfo(hdr, "hit, local", hash)
|
||||
writeCachedResponse(w, hdr, data)
|
||||
s.vlogf("rp E H:%s hit disk B:%d (%v elapsed)", hash, len(data), time.Since(start))
|
||||
return
|
||||
}
|
||||
s.reqLocalMiss.Add(1)
|
||||
|
||||
// Fault in from S3.
|
||||
if data, hdr, err := s.cacheLoadS3(r.Context(), hash); err == nil {
|
||||
s.reqFaultHit.Add(1)
|
||||
if err := s.cacheStoreLocal(hash, hdr, data); err != nil {
|
||||
s.logf("update %q local: %v", hash, err)
|
||||
}
|
||||
setXCacheInfo(hdr, "hit, remote", hash)
|
||||
writeCachedResponse(w, hdr, data)
|
||||
s.vlogf("rp E H:%s hit S3 B:%d (%v elapsed)", hash, len(data), time.Since(start))
|
||||
return
|
||||
}
|
||||
s.reqFaultMiss.Add(1)
|
||||
s.vlogf("rp - H:%s miss", hash)
|
||||
}
|
||||
|
||||
// Reaching here, the object is not already cached locally so we have to
|
||||
// talk to the backend to get it. We need to do this whether or not it is
|
||||
// cacheable. Note we handle each request with its own proxy instance, so
|
||||
// that we can handle each response in context of this request.
|
||||
s.reqForward.Add(1)
|
||||
proxy := &httputil.ReverseProxy{Rewrite: s.rewriteRequest}
|
||||
updateCache := func() {}
|
||||
if canCache {
|
||||
proxy.ModifyResponse = func(rsp *http.Response) error {
|
||||
maxAge, isVolatile := s.canMemoryCache(rsp)
|
||||
canCacheResponse := s.canCacheResponse(rsp)
|
||||
if !canCacheResponse && !isVolatile {
|
||||
// A response we cannot cache at all.
|
||||
setXCacheInfo(rsp.Header, "fetch, uncached", "")
|
||||
s.rspNotCached.Add(1)
|
||||
s.vlogf("rp E H:%s fetch RC:no (%v elapsed)", hash, time.Since(start))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read out the whole response body so we can update the cache, and
|
||||
// replace the response reader so we can copy it back to the caller.
|
||||
var buf bytes.Buffer
|
||||
rsp.Body = copyReader{
|
||||
Reader: io.TeeReader(rsp.Body, &buf),
|
||||
Closer: rsp.Body,
|
||||
}
|
||||
if !canCacheResponse && isVolatile {
|
||||
// A volatile response we can cache temporarily.
|
||||
setXCacheInfo(rsp.Header, "fetch, cached, volatile", hash)
|
||||
updateCache = func() {
|
||||
body := buf.Bytes()
|
||||
s.cacheStoreMemory(hash, maxAge, rsp.Header, body)
|
||||
s.rspSaveMem.Add(1)
|
||||
|
||||
// N.B. Don't persist on disk or in S3.
|
||||
s.vlogf("rp E H:%s fetch RC:mem B:%d (%v elapsed)", hash, len(body), time.Since(start))
|
||||
}
|
||||
} else {
|
||||
setXCacheInfo(rsp.Header, "fetch, cached", hash)
|
||||
updateCache = func() {
|
||||
body := buf.Bytes()
|
||||
if err := s.cacheStoreLocal(hash, rsp.Header, body); err != nil {
|
||||
s.rspSaveError.Add(1)
|
||||
s.logf("save %q to cache: %v", hash, err)
|
||||
|
||||
// N.B.: Don't bother trying to forward to S3 in this case.
|
||||
} else {
|
||||
s.rspSave.Add(1)
|
||||
s.rspSaveBytes.Add(int64(len(body)))
|
||||
s.start(s.cacheStoreS3(hash, rsp.Header, body))
|
||||
}
|
||||
s.vlogf("rp E H:%s fetch RC:yes B:%d (%v elapsed)", hash, len(body), time.Since(start))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
updateCache()
|
||||
}
|
||||
|
||||
// rewriteRequest rewrites the inbound request for routing to a target.
|
||||
func (s *Server) rewriteRequest(pr *httputil.ProxyRequest) {
|
||||
u, _ := url.ParseRequestURI(pr.In.RequestURI)
|
||||
u.Host = pr.In.Host
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
pr.Out.URL = u
|
||||
pr.Out.Host = u.Host
|
||||
}
|
||||
|
||||
type copyReader struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// makePath returns the local cache path for the specified request hash.
|
||||
func (s *Server) makePath(hash string) string { return filepath.Join(s.Local, hash[:2], hash) }
|
||||
|
||||
// makeKey returns the S3 object key for the specified request hash.
|
||||
func (s *Server) makeKey(hash string) string { return path.Join(s.KeyPrefix, hash[:2], hash) }
|
||||
|
||||
func (s *Server) logf(msg string, args ...any) {
|
||||
if s.Logf != nil {
|
||||
s.Logf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) vlogf(msg string, args ...any) {
|
||||
if s.LogRequests {
|
||||
s.logf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func hostMatchesTarget(host string, targets []string) bool {
|
||||
return slices.Contains(targets, host)
|
||||
}
|
||||
|
||||
// canCacheRequest reports whether r is a request whose response can be cached.
|
||||
func (s *Server) canCacheRequest(r *http.Request) bool {
|
||||
return r.Method == "GET" && !parseCacheControl(r.Header.Get("Cache-Control")).Keys.Has("no-store")
|
||||
}
|
||||
|
||||
// canCacheResponse reports whether r is a response whose body can be cached.
|
||||
func (s *Server) canCacheResponse(rsp *http.Response) bool {
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
}
|
||||
cc := parseCacheControl(rsp.Header.Get("Cache-Control"))
|
||||
if cc.Keys.Has("no-store") {
|
||||
return false
|
||||
} else if cc.Keys.Has("immutable") {
|
||||
return true
|
||||
}
|
||||
|
||||
// We treat a response that is not immutable but requires validation as
|
||||
// cacheable if its max-age is so long it doesn't matter.
|
||||
const goodLongTime = 60 * 24 * time.Hour
|
||||
return cc.Keys.Has("must-revalidate") && cc.MaxAge > goodLongTime
|
||||
}
|
||||
|
||||
type cacheControl struct {
|
||||
Keys mapset.Set[string]
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
func parseCacheControl(s string) (out cacheControl) {
|
||||
for _, v := range strings.Split(s, ",") {
|
||||
key, val, ok := strings.Cut(strings.TrimSpace(v), "=")
|
||||
if ok && key == "max-age" {
|
||||
sec, err := strconv.Atoi(val)
|
||||
if err == nil {
|
||||
out.MaxAge = time.Duration(sec) * time.Second
|
||||
}
|
||||
}
|
||||
out.Keys.Add(key)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// canMemoryCache reports whether r is a volatile response whose body can be
|
||||
// cached temporarily, and if so returns the maxmimum length of time the cache
|
||||
// entry should be valid for.
|
||||
func (s *Server) canMemoryCache(rsp *http.Response) (time.Duration, bool) {
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return 0, false
|
||||
}
|
||||
cc := parseCacheControl(rsp.Header.Get("Cache-Control"))
|
||||
if cc.Keys.Has("no-store") || cc.Keys.Has("no-cache") {
|
||||
// While no-cache doesn't mean we can't cache it, it requires
|
||||
// re-validation before reusing the response, so treat that as if it were
|
||||
// no-store.
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// We'll cache things in memory if they aren't expected to last too long.
|
||||
if cc.MaxAge > 0 && cc.MaxAge < time.Hour {
|
||||
return cc.MaxAge, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// hashRequest generates the storage digest for the specified request URL.
|
||||
func hashRequestURL(u *url.URL) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(u.String())))
|
||||
}
|
||||
|
||||
// writeCachedResponse generates an HTTP response for a cached result using the
|
||||
// provided headers and body from the cache object.
|
||||
func writeCachedResponse(w http.ResponseWriter, hdr http.Header, body []byte) {
|
||||
wh := w.Header()
|
||||
for name, vals := range hdr {
|
||||
for _, val := range vals {
|
||||
wh.Add(name, val)
|
||||
}
|
||||
}
|
||||
w.Write(body)
|
||||
}
|
||||
169
go-cache-plugin-src/lib/s3util/s3util.go
Normal file
169
go-cache-plugin-src/lib/s3util/s3util.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package s3util defines some helpful utilities for working with S3.
|
||||
package s3util
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/creachadair/mds/value"
|
||||
)
|
||||
|
||||
// IsNotExist reports whether err is an error indicating the requested resource
|
||||
// was not found, taking into account S3 and standard library types.
|
||||
func IsNotExist(err error) bool {
|
||||
var e1 *types.NotFound
|
||||
var e2 *types.NoSuchKey
|
||||
if errors.As(err, &e1) || errors.As(err, &e2) {
|
||||
return true
|
||||
}
|
||||
return errors.Is(err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
// BucketRegion reports the specified region for the given bucket using the
|
||||
// GetBucketLocation API.
|
||||
func BucketRegion(ctx context.Context, bucket string) (string, error) {
|
||||
// The default AWS region, which we use for resolving the bucket location
|
||||
// and also serves as the fallback if the API reports an empty region name.
|
||||
// The API returns "" for buckets in this region for historical reasons.
|
||||
const defaultRegion = "us-east-1"
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(defaultRegion))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cli := s3.NewFromConfig(cfg)
|
||||
loc, err := cli.GetBucketLocation(ctx, &s3.GetBucketLocationInput{Bucket: &bucket})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cmp.Or(string(loc.LocationConstraint), defaultRegion), nil
|
||||
}
|
||||
|
||||
// NewETagReader returns a new S3 ETag reader for the contents of r.
|
||||
func NewETagReader(r io.Reader) ETagReader {
|
||||
// Note: We use MD5 here because the S3 API requires it for an ETag, we do
|
||||
// not rely on it as a secure checksum.
|
||||
h := md5.New()
|
||||
return ETagReader{r: io.TeeReader(r, h), hash: h}
|
||||
}
|
||||
|
||||
// ETagReader implements the [io.Reader] interface by delegating to a nested
|
||||
// reader. The ETag method returns a correctly-formatted S3 ETag for all the
|
||||
// data that have been read so far (initially none).
|
||||
type ETagReader struct {
|
||||
r io.Reader
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
// Read satisfies [io.Reader] by delegating to the wrapped reader.
|
||||
func (e ETagReader) Read(data []byte) (int, error) { return e.r.Read(data) }
|
||||
|
||||
// ETag returns a correctly-formatted S3 etag for the contents of e that have
|
||||
// been read so far.
|
||||
func (e ETagReader) ETag() string { return fmt.Sprintf("%x", e.hash.Sum(nil)) }
|
||||
|
||||
// Client is a wrapper for an S3 client that provides basic read and write
|
||||
// facilities to a specific bucket.
|
||||
type Client struct {
|
||||
Client *s3.Client
|
||||
Bucket string
|
||||
}
|
||||
|
||||
// Put writes the specified data to S3 under the given key.
|
||||
func (c *Client) Put(ctx context.Context, key string, data io.Reader) error {
|
||||
// Attempt to find the size of the input to send as a content length.
|
||||
// If we can't do this, let the SDK figure it out.
|
||||
var sizePtr *int64
|
||||
switch t := data.(type) {
|
||||
case sizer:
|
||||
sizePtr = value.Ptr(t.Size())
|
||||
case statter:
|
||||
fi, err := t.Stat()
|
||||
if err == nil {
|
||||
sizePtr = value.Ptr(fi.Size())
|
||||
}
|
||||
case io.Seeker:
|
||||
v, err := t.Seek(0, io.SeekEnd)
|
||||
if err == nil {
|
||||
sizePtr = &v
|
||||
|
||||
// Try to seek back to the beginning. If we cannot do this, fail out
|
||||
// so we don't try to write a partial object.
|
||||
_, err = t.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[unexpected] seek failed: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err := c.Client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &c.Bucket,
|
||||
Key: &key,
|
||||
Body: data,
|
||||
ContentLength: sizePtr,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Get returns the contents of the specified key from S3. On success, the
|
||||
// returned reader contains the contents of the object, and the caller must
|
||||
// close the reader when finished.
|
||||
//
|
||||
// If the key is not found, the resulting error satisfies [fs.ErrNotExist].
|
||||
func (c *Client) Get(ctx context.Context, key string) (io.ReadCloser, int64, error) {
|
||||
rsp, err := c.Client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &c.Bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
if IsNotExist(err) {
|
||||
return nil, -1, fmt.Errorf("key %q: %w", key, fs.ErrNotExist)
|
||||
}
|
||||
return nil, -1, err
|
||||
}
|
||||
return rsp.Body, *rsp.ContentLength, nil
|
||||
}
|
||||
|
||||
// GetData returns the contents of the specified key from S3. It is a shorthand
|
||||
// for calling Get followed by io.ReadAll on the result.
|
||||
func (c *Client) GetData(ctx context.Context, key string) ([]byte, error) {
|
||||
rc, _, err := c.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rc.Close()
|
||||
return io.ReadAll(rc)
|
||||
}
|
||||
|
||||
// PutCond writes the specified data to S3 under the given key if the key does
|
||||
// not already exist, or if its content differs from the given etag.
|
||||
// The etag is an MD5 of the expected contents, encoded as lowercase hex digits.
|
||||
// On success, written reports whether the object was written.
|
||||
func (c *Client) PutCond(ctx context.Context, key, etag string, data io.Reader) (written bool, _ error) {
|
||||
if _, err := c.Client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: &c.Bucket,
|
||||
Key: &key,
|
||||
IfMatch: &etag,
|
||||
}); err == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, c.Put(ctx, key, data)
|
||||
}
|
||||
|
||||
// A sizer exports a Size method, e.g., [bytes.Reader] and similar.
|
||||
type sizer interface{ Size() int64 }
|
||||
|
||||
// A statter exports a Stat method, e.g., [os.File] and similar.
|
||||
type statter interface{ Stat() (fs.FileInfo, error) }
|
||||
42
go-cache-plugin-src/lib/s3util/s3util_test.go
Normal file
42
go-cache-plugin-src/lib/s3util/s3util_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package s3util_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tailscale/go-cache-plugin/lib/s3util"
|
||||
)
|
||||
|
||||
func TestETagReader(t *testing.T) {
|
||||
const testInput = "the once and future kitten"
|
||||
|
||||
want := md5.Sum([]byte(testInput))
|
||||
t.Logf("MD5(%q) = %x", testInput, want)
|
||||
|
||||
r := s3util.NewETagReader(strings.NewReader(testInput))
|
||||
|
||||
nr, err := io.Copy(io.Discard, r)
|
||||
if err != nil {
|
||||
t.Fatalf("Copy failed; %v", err)
|
||||
} else if nr != int64(len(testInput)) {
|
||||
t.Errorf("Copied %d bytes, want %d", nr, len(testInput))
|
||||
}
|
||||
|
||||
etag := r.ETag()
|
||||
t.Logf("Got etag %s for input %q", etag, testInput)
|
||||
|
||||
got, err := hex.DecodeString(r.ETag())
|
||||
if err != nil {
|
||||
t.Fatalf("Result is not valid hex: %s", r.ETag())
|
||||
}
|
||||
if !bytes.Equal(got, want[:]) {
|
||||
t.Errorf("Wrong result: got %x, want %x", got, want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user