Skip to content

Commit

Permalink
Merge pull request #100 from skip-mev/mergify/bp/release/v1.x.x/pr-99
Browse files Browse the repository at this point in the history
fix(core/provider): add RWMutex to task operations (backport #99)
  • Loading branch information
Zygimantass authored Feb 1, 2024
2 parents 2f025de + 34aba5e commit 1993eef
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
5 changes: 3 additions & 2 deletions core/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package provider
import (
"context"
"go.uber.org/zap"
"sync"
)

// TaskStatus defines the status of a task's underlying workload
Expand All @@ -23,8 +24,8 @@ type Task struct {
Definition TaskDefinition
Sidecars []*Task

logger *zap.Logger

logger *zap.Logger
mu sync.RWMutex
PreStart func(context.Context, *Task) error
PostStop func(context.Context, *Task) error
}
Expand Down
43 changes: 43 additions & 0 deletions core/provider/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@ import (
"context"
"errors"
"go.uber.org/zap"
"sync"
)

// CreateTask creates a task structure and sets up its underlying workload on a provider, including sidecars if there are any in the definition
func CreateTask(ctx context.Context, logger *zap.Logger, provider Provider, definition TaskDefinition) (*Task, error) {
task := &Task{
Provider: provider,
Definition: definition,
mu: sync.RWMutex{},
logger: logger.Named("task"),
}

task.mu.Lock()
defer task.mu.Unlock()

sidecarTasks := make([]*Task, 0)

for _, sidecar := range definition.Sidecars {
Expand Down Expand Up @@ -49,6 +54,9 @@ func CreateTask(ctx context.Context, logger *zap.Logger, provider Provider, defi

// Start starts the underlying task's workload including its sidecars if startSidecars is set to true
func (t *Task) Start(ctx context.Context, startSidecars bool) error {
t.mu.Lock()
defer t.mu.Unlock()

if startSidecars {
for _, sidecar := range t.Sidecars {
err := sidecar.Start(ctx, startSidecars)
Expand All @@ -75,6 +83,9 @@ func (t *Task) Start(ctx context.Context, startSidecars bool) error {

// Stop stops the underlying task's workload including its sidecars if stopSidecars is set to true
func (t *Task) Stop(ctx context.Context, stopSidecars bool) error {
t.mu.Lock()
defer t.mu.Unlock()

if stopSidecars {
for _, sidecar := range t.Sidecars {
err := sidecar.Stop(ctx, stopSidecars)
Expand Down Expand Up @@ -102,27 +113,42 @@ func (t *Task) Stop(ctx context.Context, stopSidecars bool) error {

// WriteFile writes to a file in the task's volume at a relative path
func (t *Task) WriteFile(ctx context.Context, path string, bz []byte) error {
t.mu.Lock()
defer t.mu.Unlock()

return t.Provider.WriteFile(ctx, t.ID, path, bz)
}

// ReadFile returns a file's contents in the task's volume at a relative path
func (t *Task) ReadFile(ctx context.Context, path string) ([]byte, error) {
t.mu.RLock()
defer t.mu.RUnlock()

return t.Provider.ReadFile(ctx, t.ID, path)
}

// DownloadDir downloads a directory from the task's volume at path relPath to a local path localPath
func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error {
t.mu.RLock()
defer t.mu.RUnlock()

return t.Provider.DownloadDir(ctx, t.ID, relPath, localPath)
}

// GetIP returns the task's IP
func (t *Task) GetIP(ctx context.Context) (string, error) {
t.mu.RLock()
defer t.mu.RUnlock()

return t.Provider.GetIP(ctx, t.ID)
}

// GetExternalAddress returns the external address for a specific task port in format host:port.
// Providers choose the protocol to return the port for themselves.
func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) {
t.mu.RLock()
defer t.mu.RUnlock()

return t.Provider.GetExternalAddress(ctx, t.ID, port)
}

Expand All @@ -134,19 +160,30 @@ func (t *Task) RunCommand(ctx context.Context, command []string) (string, string
}

if status == TASK_RUNNING {
t.mu.Lock()
defer t.mu.Unlock()
return t.Provider.RunCommand(ctx, t.ID, command)
}

t.mu.Lock()
defer t.mu.Unlock()

return t.Provider.RunCommandWhileStopped(ctx, t.ID, t.Definition, command)
}

// GetStatus returns the task's underlying workload's status
func (t *Task) GetStatus(ctx context.Context) (TaskStatus, error) {
t.mu.RLock()
defer t.mu.RUnlock()

return t.Provider.GetTaskStatus(ctx, t.ID)
}

// Destroy destroys the task's underlying workload, including it's sidecars if destroySidecars is set to true
func (t *Task) Destroy(ctx context.Context, destroySidecars bool) error {
t.mu.Lock()
defer t.mu.Unlock()

if destroySidecars {
for _, sidecar := range t.Sidecars {
err := sidecar.Destroy(ctx, destroySidecars)
Expand All @@ -166,10 +203,16 @@ func (t *Task) Destroy(ctx context.Context, destroySidecars bool) error {

// SetPreStart sets a task's hook function that gets called right before the task's underlying workload is about to be started
func (t *Task) SetPreStart(f func(context.Context, *Task) error) {
t.mu.Lock()
defer t.mu.Unlock()

t.PreStart = f
}

// SetPostStop sets a task's hook function that gets called right after the task's underlying workload is stopped
func (t *Task) SetPostStop(f func(context.Context, *Task) error) {
t.mu.Lock()
defer t.mu.Unlock()

t.PostStop = f
}

0 comments on commit 1993eef

Please sign in to comment.