diff --git a/syz-cluster/controller/processor.go b/syz-cluster/controller/processor.go index 07851a9fa635..92e783b04232 100644 --- a/syz-cluster/controller/processor.go +++ b/syz-cluster/controller/processor.go @@ -201,10 +201,11 @@ func (sp *SeriesProcessor) stopRunningTests(ctx context.Context, sessionID strin } log.Printf("session %q is finished, but the test %q is running: marking it stopped", sessionID, test.TestName) - err = sp.sessionTestRepo.InsertOrUpdate(ctx, test, func(entity *db.SessionTest) { + err = sp.sessionTestRepo.InsertOrUpdate(ctx, test, func(entity *db.SessionTest) error { if entity.Result == api.TestRunning { entity.Result = api.TestError } + return nil }) if err != nil { return fmt.Errorf("failed to update the step %q: %w", test.TestName, err) @@ -214,19 +215,24 @@ func (sp *SeriesProcessor) stopRunningTests(ctx context.Context, sessionID strin } func (sp *SeriesProcessor) updateSessionLog(ctx context.Context, session *db.Session, log []byte) error { - return sp.sessionRepo.Update(ctx, session.ID, func(session *db.Session) error { + var logURI string + err := sp.sessionRepo.Update(ctx, session.ID, func(session *db.Session) error { if session.LogURI == "" { - path, err := sp.blobStorage.Store(bytes.NewReader(log)) - if err != nil { - return fmt.Errorf("failed to save the log: %w", err) - } - session.LogURI = path - } else { - err := sp.blobStorage.Update(session.LogURI, bytes.NewReader(log)) + var err error + session.LogURI, err = sp.blobStorage.NewURI() if err != nil { - return fmt.Errorf("failed to update the log %q: %w", session.LogURI, err) + return fmt.Errorf("failed to generate blob storage URI: %w", err) } } + logURI = session.LogURI return nil }) + if err != nil { + return err + } + err = sp.blobStorage.Write(logURI, bytes.NewReader(log)) + if err != nil { + return fmt.Errorf("failed to update the log %q: %w", session.LogURI, err) + } + return nil } diff --git a/syz-cluster/pkg/blob/gcs.go b/syz-cluster/pkg/blob/gcs.go index f6f5797d50a5..f830a2dbc8fd 100644 --- a/syz-cluster/pkg/blob/gcs.go +++ b/syz-cluster/pkg/blob/gcs.go @@ -29,54 +29,45 @@ func NewGCSClient(ctx context.Context, bucket string) (Storage, error) { }, nil } -func (gcs *gcsDriver) Store(source io.Reader) (string, error) { - object := uuid.NewString() - err := gcs.writeObject(object, source) +func (gcs *gcsDriver) NewURI() (string, error) { + key, err := uuid.NewRandom() if err != nil { return "", err } - return gcs.objectURI(object), nil + return fmt.Sprintf("gcs://%s/%s", gcs.bucket, key.String()), nil } -func (gcs *gcsDriver) Update(uri string, source io.Reader) error { - object, err := gcs.objectName(uri) +func (gcs *gcsDriver) Write(uri string, source io.Reader) error { + bucket, object, err := gcs.parseURI(uri) if err != nil { return err } - return gcs.writeObject(object, source) + w, err := gcs.client.FileWriter(fmt.Sprintf("%s/%s", bucket, object), "", "") + if err != nil { + return err + } + defer w.Close() + + _, err = io.Copy(w, source) + return err } func (gcs *gcsDriver) Read(uri string) (io.ReadCloser, error) { - object, err := gcs.objectName(uri) + bucket, object, err := gcs.parseURI(uri) if err != nil { return nil, err } - return gcs.client.FileReader(fmt.Sprintf("%s/%s", gcs.bucket, object)) + return gcs.client.FileReader(fmt.Sprintf("%s/%s", bucket, object)) } var gcsObjectRe = regexp.MustCompile(`^gcs://([\w-]+)/([\w-]+)$`) -func (gcs *gcsDriver) objectName(uri string) (string, error) { +func (gcs *gcsDriver) parseURI(uri string) (string, string, error) { match := gcsObjectRe.FindStringSubmatch(uri) if len(match) == 0 { - return "", fmt.Errorf("invalid GCS URI") + return "", "", fmt.Errorf("invalid GCS URI") } else if match[1] != gcs.bucket { - return "", fmt.Errorf("unexpected GCS bucket") - } - return match[2], nil -} - -func (gcs *gcsDriver) objectURI(object string) string { - return fmt.Sprintf("gcs://%s/%s", gcs.bucket, object) -} - -func (gcs *gcsDriver) writeObject(object string, source io.Reader) error { - w, err := gcs.client.FileWriter(fmt.Sprintf("%s/%s", gcs.bucket, object), "", "") - if err != nil { - return err + return "", "", fmt.Errorf("unexpected GCS bucket") } - defer w.Close() - - _, err = io.Copy(w, source) - return err + return gcs.bucket, match[2], nil } diff --git a/syz-cluster/pkg/blob/storage.go b/syz-cluster/pkg/blob/storage.go index be1d8f493c0b..0f4eccc48ba7 100644 --- a/syz-cluster/pkg/blob/storage.go +++ b/syz-cluster/pkg/blob/storage.go @@ -4,6 +4,7 @@ package blob import ( + "bytes" "fmt" "io" "os" @@ -16,9 +17,9 @@ import ( // Storage is not assumed to be used for partciularly large objects (e.g. GB of size), // but rather for blobs that risk overwhelming Spanner column size limits. type Storage interface { - // Store returns a URI to use later. - Store(source io.Reader) (string, error) - Update(key string, source io.Reader) error + // Returns a random URI to use later. + NewURI() (string, error) + Write(uri string, source io.Reader) error Read(uri string) (io.ReadCloser, error) } @@ -36,33 +37,20 @@ func NewLocalStorage(baseFolder string) *LocalStorage { const localStoragePrefix = "local://" -func (ls *LocalStorage) Store(source io.Reader) (string, error) { - name := uuid.NewString() - err := ls.writeFile(name, source) +func (ls *LocalStorage) NewURI() (string, error) { + key, err := uuid.NewRandom() if err != nil { return "", err } - return localStoragePrefix + name, nil + return localStoragePrefix + key.String(), nil } -func (ls *LocalStorage) Update(uri string, source io.Reader) error { - if !strings.HasPrefix(uri, localStoragePrefix) { - return fmt.Errorf("unsupported URI type") - } - return ls.writeFile(strings.TrimPrefix(uri, localStoragePrefix), source) -} - -func (ls *LocalStorage) Read(uri string) (io.ReadCloser, error) { - if !strings.HasPrefix(uri, localStoragePrefix) { - return nil, fmt.Errorf("unsupported URI type") +func (ls *LocalStorage) Write(uri string, source io.Reader) error { + path, err := ls.uriToPath(uri) + if err != nil { + return err } - // TODO: add some other URI validation checks? - path := filepath.Join(ls.baseFolder, strings.TrimPrefix(uri, localStoragePrefix)) - return os.Open(path) -} - -func (ls *LocalStorage) writeFile(name string, source io.Reader) error { - file, err := os.Create(filepath.Join(ls.baseFolder, name)) + file, err := os.Create(path) if err != nil { return err } @@ -74,6 +62,21 @@ func (ls *LocalStorage) writeFile(name string, source io.Reader) error { return nil } +func (ls *LocalStorage) Read(uri string) (io.ReadCloser, error) { + path, err := ls.uriToPath(uri) + if err != nil { + return nil, err + } + return os.Open(path) +} + +func (ls *LocalStorage) uriToPath(uri string) (string, error) { + if !strings.HasPrefix(uri, localStoragePrefix) { + return "", fmt.Errorf("unsupported URI type") + } + return filepath.Join(ls.baseFolder, strings.TrimPrefix(uri, localStoragePrefix)), nil +} + func ReadAllBytes(storage Storage, uri string) ([]byte, error) { if uri == "" { return nil, nil @@ -85,3 +88,15 @@ func ReadAllBytes(storage Storage, uri string) ([]byte, error) { defer reader.Close() return io.ReadAll(reader) } + +func StoreBytes(storage Storage, data []byte) (string, error) { + uri, err := storage.NewURI() + if err != nil { + return "", fmt.Errorf("failed to generate URI: %w", err) + } + err = storage.Write(uri, bytes.NewReader(data)) + if err != nil { + return "", err + } + return uri, nil +} diff --git a/syz-cluster/pkg/blob/storage_test.go b/syz-cluster/pkg/blob/storage_test.go index 9d2d830ce686..3057ab3bf75a 100644 --- a/syz-cluster/pkg/blob/storage_test.go +++ b/syz-cluster/pkg/blob/storage_test.go @@ -9,20 +9,23 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLocalStorage(t *testing.T) { storage := NewLocalStorage(t.TempDir()) var uris []string for i := 0; i < 2; i++ { + uri, err := storage.NewURI() + require.NoError(t, err) content := fmt.Sprintf("object #%d", i) - uri, err := storage.Store(bytes.NewReader([]byte(content))) - assert.NoError(t, err) + err = storage.Write(uri, bytes.NewReader([]byte(content))) + require.NoError(t, err) uris = append(uris, uri) } for i, uri := range uris { readBytes, err := ReadAllBytes(storage, uri) - assert.NoError(t, err) + require.NoError(t, err) assert.EqualValues(t, fmt.Sprintf("object #%d", i), readBytes) } _, err := storage.Read(localStoragePrefix + "abcdef") diff --git a/syz-cluster/pkg/db/session_test_repo.go b/syz-cluster/pkg/db/session_test_repo.go index 7043b838975a..8f8caf6c4002 100644 --- a/syz-cluster/pkg/db/session_test_repo.go +++ b/syz-cluster/pkg/db/session_test_repo.go @@ -22,7 +22,7 @@ func NewSessionTestRepository(client *spanner.Client) *SessionTestRepository { // If the beforeSave callback is specified, it will be called before saving the entity. func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *SessionTest, - beforeSave func(*SessionTest)) error { + beforeSave func(*SessionTest) error) error { _, err := repo.client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { // Check if the test already exists. @@ -41,7 +41,10 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses _, iterErr := iter.Next() if iterErr == nil { if beforeSave != nil { - beforeSave(test) + err := beforeSave(test) + if err != nil { + return err + } } m, err := spanner.UpdateStruct("SessionTests", test) if err != nil { @@ -52,7 +55,10 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses return iterErr } else { if beforeSave != nil { - beforeSave(test) + err := beforeSave(test) + if err != nil { + return err + } } m, err := spanner.InsertStruct("SessionTests", test) if err != nil { diff --git a/syz-cluster/pkg/service/finding.go b/syz-cluster/pkg/service/finding.go index bb605ba73edf..b522df4ae162 100644 --- a/syz-cluster/pkg/service/finding.go +++ b/syz-cluster/pkg/service/finding.go @@ -4,7 +4,6 @@ package service import ( - "bytes" "context" "fmt" @@ -30,13 +29,13 @@ func (s *FindingService) Save(ctx context.Context, req *api.NewFinding) error { var reportURI, logURI string var err error if len(req.Log) > 0 { - logURI, err = s.blobStorage.Store(bytes.NewReader(req.Log)) + logURI, err = blob.StoreBytes(s.blobStorage, req.Log) if err != nil { return fmt.Errorf("failed to save the log: %w", err) } } if len(req.Report) > 0 { - reportURI, err = s.blobStorage.Store(bytes.NewReader(req.Report)) + reportURI, err = blob.StoreBytes(s.blobStorage, req.Report) if err != nil { return fmt.Errorf("failed to save the report: %w", err) } diff --git a/syz-cluster/pkg/service/series.go b/syz-cluster/pkg/service/series.go index 8ab8b5cef039..735de6f5fa96 100644 --- a/syz-cluster/pkg/service/series.go +++ b/syz-cluster/pkg/service/series.go @@ -4,7 +4,6 @@ package service import ( - "bytes" "context" "errors" "fmt" @@ -66,7 +65,7 @@ func (s *SeriesService) UploadSeries(ctx context.Context, series *api.Series) (* for _, patch := range series.Patches { // In case of errors, we will waste some space, but let's ignore it for simplicity. // Patches are not super big. - uri, err := s.blobStorage.Store(bytes.NewReader(patch.Body)) + uri, err := blob.StoreBytes(s.blobStorage, patch.Body) if err != nil { return nil, fmt.Errorf("failed to upload patch body: %w", err) } diff --git a/syz-cluster/pkg/service/session.go b/syz-cluster/pkg/service/session.go index 9ad0688fdd5f..4713258b5dda 100644 --- a/syz-cluster/pkg/service/session.go +++ b/syz-cluster/pkg/service/session.go @@ -34,22 +34,30 @@ var ErrSessionNotFound = errors.New("session not found") func (s *SessionService) SkipSession(ctx context.Context, sessionID string, skip *api.SkipRequest) error { var triageLogURI string - if len(skip.TriageLog) > 0 { - var err error - triageLogURI, err = s.blobStorage.Store(bytes.NewReader(skip.TriageLog)) - if err != nil { - return fmt.Errorf("failed to save the log: %w", err) - } - } err := s.sessionRepo.Update(ctx, sessionID, func(session *db.Session) error { - session.TriageLogURI = triageLogURI + if len(skip.TriageLog) > 0 && session.TriageLogURI == "" { + var err error + session.TriageLogURI, err = s.blobStorage.NewURI() + if err != nil { + return err + } + } + triageLogURI = session.TriageLogURI session.SetSkipReason(skip.Reason) return nil }) if errors.Is(err, db.ErrEntityNotFound) { return ErrSessionNotFound + } else if err != nil { + return err + } + if triageLogURI != "" { + err = s.blobStorage.Write(triageLogURI, bytes.NewReader(skip.TriageLog)) + if err != nil { + return fmt.Errorf("failed to save the triage log: %w", err) + } } - return err + return nil } func (s *SessionService) UploadSession(ctx context.Context, req *api.NewSession) (*api.UploadSessionResp, error) { diff --git a/syz-cluster/pkg/service/sessiontest.go b/syz-cluster/pkg/service/sessiontest.go index be410fdb7a6b..412ad95278b4 100644 --- a/syz-cluster/pkg/service/sessiontest.go +++ b/syz-cluster/pkg/service/sessiontest.go @@ -39,24 +39,35 @@ func (s *SessionTestService) Save(ctx context.Context, req *api.TestResult) erro TestName: req.TestName, } } - logURI := entity.LogURI - if len(req.Log) > 0 { - logURI, err = s.uploadOrUpdate(ctx, logURI, bytes.NewReader(req.Log)) - if err != nil { - return fmt.Errorf("failed to save the log: %w", err) - } - } - return s.testRepo.InsertOrUpdate(ctx, entity, func(test *db.SessionTest) { + var logURI string + if err := s.testRepo.InsertOrUpdate(ctx, entity, func(test *db.SessionTest) error { test.Result = req.Result test.UpdatedAt = time.Now() - test.LogURI = logURI + if len(req.Log) > 0 { + var err error + test.LogURI, err = s.blobStorage.NewURI() + if err != nil { + return err + } + } + logURI = test.LogURI if req.BaseBuildID != "" { test.BaseBuildID = spanner.NullString{StringVal: req.BaseBuildID, Valid: true} } if req.PatchedBuildID != "" { test.PatchedBuildID = spanner.NullString{StringVal: req.PatchedBuildID, Valid: true} } - }) + return nil + }); err != nil { + return err + } + if logURI != "" { + err := s.blobStorage.Write(logURI, bytes.NewReader(req.Log)) + if err != nil { + return fmt.Errorf("failed to save the log: %w", err) + } + } + return nil } func (s *SessionTestService) SaveArtifacts(ctx context.Context, sessionID, testName string, reader io.Reader) error { @@ -66,27 +77,22 @@ func (s *SessionTestService) SaveArtifacts(ctx context.Context, sessionID, testN } else if entity == nil { return fmt.Errorf("the test has not been submitted yet") } - newArchiveURI, err := s.uploadOrUpdate(ctx, entity.ArtifactsArchiveURI, reader) - if err != nil { - return fmt.Errorf("failed to save the artifacts archive: %w", err) - } - return s.testRepo.InsertOrUpdate(ctx, entity, func(test *db.SessionTest) { - test.ArtifactsArchiveURI = newArchiveURI - }) -} - -func (s *SessionTestService) uploadOrUpdate(ctx context.Context, uri string, reader io.Reader) (string, error) { - if uri != "" { - err := s.blobStorage.Update(uri, reader) - if err != nil { - return "", fmt.Errorf("failed to update: %w", err) + var archiveURI string + if err := s.testRepo.InsertOrUpdate(ctx, entity, func(test *db.SessionTest) error { + if test.ArtifactsArchiveURI == "" { + var err error + test.ArtifactsArchiveURI, err = s.blobStorage.NewURI() + if err != nil { + return err + } } - return uri, nil + archiveURI = test.ArtifactsArchiveURI + return nil + }); err != nil { + return err } - // TODO: it will leak if we fail to save the entity. - uri, err := s.blobStorage.Store(reader) - if err != nil { - return "", fmt.Errorf("failed to save: %w", err) + if err := s.blobStorage.Write(archiveURI, reader); err != nil { + return fmt.Errorf("failed to upload the archive: %w", err) } - return uri, nil + return nil }