Skip to content

syz-cluster: refactor blob storage #6018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions syz-cluster/controller/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,11 @@ func (sp *SeriesProcessor) stopRunningTests(ctx context.Context, sessionID strin
}
log.Printf("session %q is finished, but the test %q is running: marking it stopped",
sessionID, test.TestName)
err = sp.sessionTestRepo.InsertOrUpdate(ctx, test, func(entity *db.SessionTest) {
err = sp.sessionTestRepo.InsertOrUpdate(ctx, test, func(entity *db.SessionTest) error {
if entity.Result == api.TestRunning {
entity.Result = api.TestError
}
return nil
})
if err != nil {
return fmt.Errorf("failed to update the step %q: %w", test.TestName, err)
Expand All @@ -214,19 +215,24 @@ func (sp *SeriesProcessor) stopRunningTests(ctx context.Context, sessionID strin
}

func (sp *SeriesProcessor) updateSessionLog(ctx context.Context, session *db.Session, log []byte) error {
return sp.sessionRepo.Update(ctx, session.ID, func(session *db.Session) error {
var logURI string
err := sp.sessionRepo.Update(ctx, session.ID, func(session *db.Session) error {
if session.LogURI == "" {
path, err := sp.blobStorage.Store(bytes.NewReader(log))
if err != nil {
return fmt.Errorf("failed to save the log: %w", err)
}
session.LogURI = path
} else {
err := sp.blobStorage.Update(session.LogURI, bytes.NewReader(log))
var err error
session.LogURI, err = sp.blobStorage.NewURI()
if err != nil {
return fmt.Errorf("failed to update the log %q: %w", session.LogURI, err)
return fmt.Errorf("failed to generate blob storage URI: %w", err)
}
}
logURI = session.LogURI
return nil
})
if err != nil {
return err
}
err = sp.blobStorage.Write(logURI, bytes.NewReader(log))
if err != nil {
return fmt.Errorf("failed to update the log %q: %w", session.LogURI, err)
}
return nil
}
47 changes: 19 additions & 28 deletions syz-cluster/pkg/blob/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,54 +29,45 @@ func NewGCSClient(ctx context.Context, bucket string) (Storage, error) {
}, nil
}

func (gcs *gcsDriver) Store(source io.Reader) (string, error) {
object := uuid.NewString()
err := gcs.writeObject(object, source)
func (gcs *gcsDriver) NewURI() (string, error) {
key, err := uuid.NewRandom()
if err != nil {
return "", err
}
return gcs.objectURI(object), nil
return fmt.Sprintf("gcs://%s/%s", gcs.bucket, key.String()), nil
}

func (gcs *gcsDriver) Update(uri string, source io.Reader) error {
object, err := gcs.objectName(uri)
func (gcs *gcsDriver) Write(uri string, source io.Reader) error {
bucket, object, err := gcs.parseURI(uri)
if err != nil {
return err
}
return gcs.writeObject(object, source)
w, err := gcs.client.FileWriter(fmt.Sprintf("%s/%s", bucket, object), "", "")
if err != nil {
return err
}
defer w.Close()

_, err = io.Copy(w, source)
return err
}

func (gcs *gcsDriver) Read(uri string) (io.ReadCloser, error) {
object, err := gcs.objectName(uri)
bucket, object, err := gcs.parseURI(uri)
if err != nil {
return nil, err
}
return gcs.client.FileReader(fmt.Sprintf("%s/%s", gcs.bucket, object))
return gcs.client.FileReader(fmt.Sprintf("%s/%s", bucket, object))
}

var gcsObjectRe = regexp.MustCompile(`^gcs://([\w-]+)/([\w-]+)$`)

func (gcs *gcsDriver) objectName(uri string) (string, error) {
func (gcs *gcsDriver) parseURI(uri string) (string, string, error) {
match := gcsObjectRe.FindStringSubmatch(uri)
if len(match) == 0 {
return "", fmt.Errorf("invalid GCS URI")
return "", "", fmt.Errorf("invalid GCS URI")
} else if match[1] != gcs.bucket {
return "", fmt.Errorf("unexpected GCS bucket")
}
return match[2], nil
}

func (gcs *gcsDriver) objectURI(object string) string {
return fmt.Sprintf("gcs://%s/%s", gcs.bucket, object)
}

func (gcs *gcsDriver) writeObject(object string, source io.Reader) error {
w, err := gcs.client.FileWriter(fmt.Sprintf("%s/%s", gcs.bucket, object), "", "")
if err != nil {
return err
return "", "", fmt.Errorf("unexpected GCS bucket")
}
defer w.Close()

_, err = io.Copy(w, source)
return err
return gcs.bucket, match[2], nil
}
63 changes: 39 additions & 24 deletions syz-cluster/pkg/blob/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package blob

import (
"bytes"
"fmt"
"io"
"os"
Expand All @@ -16,9 +17,9 @@ import (
// Storage is not assumed to be used for partciularly large objects (e.g. GB of size),
// but rather for blobs that risk overwhelming Spanner column size limits.
type Storage interface {
// Store returns a URI to use later.
Store(source io.Reader) (string, error)
Update(key string, source io.Reader) error
// Returns a random URI to use later.
NewURI() (string, error)
Write(uri string, source io.Reader) error
Read(uri string) (io.ReadCloser, error)
}

Expand All @@ -36,33 +37,20 @@ func NewLocalStorage(baseFolder string) *LocalStorage {

const localStoragePrefix = "local://"

func (ls *LocalStorage) Store(source io.Reader) (string, error) {
name := uuid.NewString()
err := ls.writeFile(name, source)
func (ls *LocalStorage) NewURI() (string, error) {
key, err := uuid.NewRandom()
if err != nil {
return "", err
}
return localStoragePrefix + name, nil
return localStoragePrefix + key.String(), nil
}

func (ls *LocalStorage) Update(uri string, source io.Reader) error {
if !strings.HasPrefix(uri, localStoragePrefix) {
return fmt.Errorf("unsupported URI type")
}
return ls.writeFile(strings.TrimPrefix(uri, localStoragePrefix), source)
}

func (ls *LocalStorage) Read(uri string) (io.ReadCloser, error) {
if !strings.HasPrefix(uri, localStoragePrefix) {
return nil, fmt.Errorf("unsupported URI type")
func (ls *LocalStorage) Write(uri string, source io.Reader) error {
path, err := ls.uriToPath(uri)
if err != nil {
return err
}
// TODO: add some other URI validation checks?
path := filepath.Join(ls.baseFolder, strings.TrimPrefix(uri, localStoragePrefix))
return os.Open(path)
}

func (ls *LocalStorage) writeFile(name string, source io.Reader) error {
file, err := os.Create(filepath.Join(ls.baseFolder, name))
file, err := os.Create(path)
if err != nil {
return err
}
Expand All @@ -74,6 +62,21 @@ func (ls *LocalStorage) writeFile(name string, source io.Reader) error {
return nil
}

func (ls *LocalStorage) Read(uri string) (io.ReadCloser, error) {
path, err := ls.uriToPath(uri)
if err != nil {
return nil, err
}
return os.Open(path)
}

func (ls *LocalStorage) uriToPath(uri string) (string, error) {
if !strings.HasPrefix(uri, localStoragePrefix) {
return "", fmt.Errorf("unsupported URI type")
}
return filepath.Join(ls.baseFolder, strings.TrimPrefix(uri, localStoragePrefix)), nil
}

func ReadAllBytes(storage Storage, uri string) ([]byte, error) {
if uri == "" {
return nil, nil
Expand All @@ -85,3 +88,15 @@ func ReadAllBytes(storage Storage, uri string) ([]byte, error) {
defer reader.Close()
return io.ReadAll(reader)
}

func StoreBytes(storage Storage, data []byte) (string, error) {
uri, err := storage.NewURI()
if err != nil {
return "", fmt.Errorf("failed to generate URI: %w", err)
}
err = storage.Write(uri, bytes.NewReader(data))
if err != nil {
return "", err
}
return uri, nil
}
9 changes: 6 additions & 3 deletions syz-cluster/pkg/blob/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestLocalStorage(t *testing.T) {
storage := NewLocalStorage(t.TempDir())
var uris []string
for i := 0; i < 2; i++ {
uri, err := storage.NewURI()
require.NoError(t, err)
content := fmt.Sprintf("object #%d", i)
uri, err := storage.Store(bytes.NewReader([]byte(content)))
assert.NoError(t, err)
err = storage.Write(uri, bytes.NewReader([]byte(content)))
require.NoError(t, err)
uris = append(uris, uri)
}
for i, uri := range uris {
readBytes, err := ReadAllBytes(storage, uri)
assert.NoError(t, err)
require.NoError(t, err)
assert.EqualValues(t, fmt.Sprintf("object #%d", i), readBytes)
}
_, err := storage.Read(localStoragePrefix + "abcdef")
Expand Down
12 changes: 9 additions & 3 deletions syz-cluster/pkg/db/session_test_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func NewSessionTestRepository(client *spanner.Client) *SessionTestRepository {

// If the beforeSave callback is specified, it will be called before saving the entity.
func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *SessionTest,
beforeSave func(*SessionTest)) error {
beforeSave func(*SessionTest) error) error {
_, err := repo.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
// Check if the test already exists.
Expand All @@ -41,7 +41,10 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses
_, iterErr := iter.Next()
if iterErr == nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordering this if->else if->else you can merge 2 beforeSave calls and shift left some code

Copy link
Collaborator

@tarasmadan tarasmadan May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if iterErr != iterator.Done {
	return iterErr
}
if beforeSave != nil {
	err := beforeSave(test)
	if err != nil {
		return err
	}
}
...
if iterErr == nil {
	...
}

if beforeSave != nil {
beforeSave(test)
err := beforeSave(test)
if err != nil {
return err
}
}
m, err := spanner.UpdateStruct("SessionTests", test)
if err != nil {
Expand All @@ -52,7 +55,10 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses
return iterErr
} else {
if beforeSave != nil {
beforeSave(test)
err := beforeSave(test)
if err != nil {
return err
}
}
m, err := spanner.InsertStruct("SessionTests", test)
if err != nil {
Expand Down
5 changes: 2 additions & 3 deletions syz-cluster/pkg/service/finding.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package service

import (
"bytes"
"context"
"fmt"

Expand All @@ -30,13 +29,13 @@ func (s *FindingService) Save(ctx context.Context, req *api.NewFinding) error {
var reportURI, logURI string
var err error
if len(req.Log) > 0 {
logURI, err = s.blobStorage.Store(bytes.NewReader(req.Log))
logURI, err = blob.StoreBytes(s.blobStorage, req.Log)
if err != nil {
return fmt.Errorf("failed to save the log: %w", err)
}
}
if len(req.Report) > 0 {
reportURI, err = s.blobStorage.Store(bytes.NewReader(req.Report))
reportURI, err = blob.StoreBytes(s.blobStorage, req.Report)
if err != nil {
return fmt.Errorf("failed to save the report: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions syz-cluster/pkg/service/series.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package service

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -66,7 +65,7 @@ func (s *SeriesService) UploadSeries(ctx context.Context, series *api.Series) (*
for _, patch := range series.Patches {
// In case of errors, we will waste some space, but let's ignore it for simplicity.
// Patches are not super big.
uri, err := s.blobStorage.Store(bytes.NewReader(patch.Body))
uri, err := blob.StoreBytes(s.blobStorage, patch.Body)
if err != nil {
return nil, fmt.Errorf("failed to upload patch body: %w", err)
}
Expand Down
26 changes: 17 additions & 9 deletions syz-cluster/pkg/service/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,30 @@ var ErrSessionNotFound = errors.New("session not found")

func (s *SessionService) SkipSession(ctx context.Context, sessionID string, skip *api.SkipRequest) error {
var triageLogURI string
if len(skip.TriageLog) > 0 {
var err error
triageLogURI, err = s.blobStorage.Store(bytes.NewReader(skip.TriageLog))
if err != nil {
return fmt.Errorf("failed to save the log: %w", err)
}
}
err := s.sessionRepo.Update(ctx, sessionID, func(session *db.Session) error {
session.TriageLogURI = triageLogURI
if len(skip.TriageLog) > 0 && session.TriageLogURI == "" {
var err error
session.TriageLogURI, err = s.blobStorage.NewURI()
if err != nil {
return err
}
}
triageLogURI = session.TriageLogURI
session.SetSkipReason(skip.Reason)
return nil
})
if errors.Is(err, db.ErrEntityNotFound) {
return ErrSessionNotFound
} else if err != nil {
return err
}
if triageLogURI != "" {
err = s.blobStorage.Write(triageLogURI, bytes.NewReader(skip.TriageLog))
if err != nil {
return fmt.Errorf("failed to save the triage log: %w", err)
}
}
return err
return nil
}

func (s *SessionService) UploadSession(ctx context.Context, req *api.NewSession) (*api.UploadSessionResp, error) {
Expand Down
Loading