Skip to content

Commit 5b0d787

Browse files
[flytepropeller] Watch agent metadata service dynamically (#5460)
Signed-off-by: Kevin Su <[email protected]> Signed-off-by: Kevin Su <[email protected]> Signed-off-by: Future-Outlier <[email protected]> Co-authored-by: Kevin Su <[email protected]> Co-authored-by: Kevin Su <[email protected]>
1 parent 4643e2a commit 5b0d787

File tree

9 files changed

+146
-87
lines changed

9 files changed

+146
-87
lines changed

flyteplugins/go/tasks/pluginmachinery/core/plugin.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package core
33
import (
44
"context"
55
"fmt"
6+
"sync"
7+
8+
"k8s.io/utils/strings/slices"
69
)
710

811
//go:generate mockery -all -case=underscore
@@ -55,7 +58,27 @@ type Plugin interface {
5558
Finalize(ctx context.Context, tCtx TaskExecutionContext) error
5659
}
5760

58-
// Loads and validates a plugin.
61+
type AgentService struct {
62+
mu sync.RWMutex
63+
supportedTaskTypes []TaskType
64+
CorePlugin Plugin
65+
}
66+
67+
// ContainTaskType check if agent supports this task type.
68+
func (p *AgentService) ContainTaskType(taskType TaskType) bool {
69+
p.mu.RLock()
70+
defer p.mu.RUnlock()
71+
return slices.Contains(p.supportedTaskTypes, taskType)
72+
}
73+
74+
// SetSupportedTaskType set supportTaskType in the agent service.
75+
func (p *AgentService) SetSupportedTaskType(taskTypes []TaskType) {
76+
p.mu.Lock()
77+
defer p.mu.Unlock()
78+
p.supportedTaskTypes = taskTypes
79+
}
80+
81+
// LoadPlugin Loads and validates a plugin.
5982
func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) {
6083
plugin, err := entry.LoadPlugin(ctx, iCtx)
6184
if err != nil {

flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,17 @@ func TestLoadPlugin(t *testing.T) {
9393
})
9494

9595
}
96+
97+
func TestAgentService(t *testing.T) {
98+
agentService := core.AgentService{}
99+
taskTypes := []core.TaskType{"sensor", "chatgpt"}
100+
101+
for _, taskType := range taskTypes {
102+
assert.Equal(t, false, agentService.ContainTaskType(taskType))
103+
}
104+
105+
agentService.SetSupportedTaskType(taskTypes)
106+
for _, taskType := range taskTypes {
107+
assert.Equal(t, true, agentService.ContainTaskType(taskType))
108+
}
109+
}

flyteplugins/go/tasks/plugins/webapi/agent/client.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,11 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) (
9090
return context.WithTimeout(ctx, timeout)
9191
}
9292

93-
func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
94-
agentRegistry := make(Registry)
93+
func getAgentRegistry(ctx context.Context, cs *ClientSet) Registry {
94+
newAgentRegistry := make(Registry)
9595
cfg := GetConfig()
9696
var agentDeployments []*Deployment
9797

98-
// Ensure that the old configuration is backward compatible
99-
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
100-
agent := Agent{AgentDeployment: cfg.AgentDeployments[agentDeploymentID], IsSync: false}
101-
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: &agent}
102-
}
103-
10498
if len(cfg.DefaultAgent.Endpoint) != 0 {
10599
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
106100
}
@@ -137,27 +131,36 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
137131
deprecatedSupportedTaskTypes := agent.SupportedTaskTypes
138132
for _, supportedTaskType := range deprecatedSupportedTaskTypes {
139133
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
140-
agentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
134+
newAgentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
141135
}
142136

143137
supportedTaskCategories := agent.SupportedTaskCategories
144138
for _, supportedCategory := range supportedTaskCategories {
145139
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
146-
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
140+
newAgentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
147141
}
148142
}
149-
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
150-
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
151-
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
152-
if _, ok := agentRegistry[taskType]; !ok {
153-
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
154-
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
155-
}
143+
}
144+
145+
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
146+
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
147+
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
148+
if _, ok := newAgentRegistry[taskType]; !ok {
149+
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
150+
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
156151
}
157152
}
158153
}
159-
logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
160-
setAgentRegistry(agentRegistry)
154+
155+
// Ensure that the old configuration is backward compatible
156+
for _, taskType := range cfg.SupportedTaskTypes {
157+
if _, ok := newAgentRegistry[taskType]; !ok {
158+
agent := &Agent{AgentDeployment: &cfg.DefaultAgent, IsSync: false}
159+
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
160+
}
161+
}
162+
163+
return newAgentRegistry
161164
}
162165

163166
func getAgentClientSets(ctx context.Context) *ClientSet {

flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ import (
3535
)
3636

3737
func TestEndToEnd(t *testing.T) {
38-
agentRegistry = Registry{
39-
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
40-
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
41-
}
4238
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
4339
return nil
4440
}
@@ -117,7 +113,7 @@ func TestEndToEnd(t *testing.T) {
117113
t.Run("failed to create a job", func(t *testing.T) {
118114
agentPlugin := newMockAsyncAgentPlugin()
119115
agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
120-
return Plugin{
116+
return &Plugin{
121117
metricScope: iCtx.MetricsScope(),
122118
cfg: GetConfig(),
123119
cs: &ClientSet{
@@ -259,6 +255,9 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {
259255

260256
func newMockAsyncAgentPlugin() webapi.PluginEntry {
261257
asyncAgentClient := new(agentMocks.AsyncAgentServiceClient)
258+
agentRegistry := Registry{
259+
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
260+
}
262261

263262
mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool {
264263
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}
@@ -283,20 +282,25 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry {
283282
ID: "agent-service",
284283
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark"},
285284
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
286-
return Plugin{
285+
return &Plugin{
287286
metricScope: iCtx.MetricsScope(),
288287
cfg: &cfg,
289288
cs: &ClientSet{
290289
asyncAgentClients: map[string]service.AsyncAgentServiceClient{
291290
defaultAgentEndpoint: asyncAgentClient,
292291
},
293292
},
293+
registry: agentRegistry,
294294
}, nil
295295
},
296296
}
297297
}
298298

299299
func newMockSyncAgentPlugin() webapi.PluginEntry {
300+
agentRegistry := Registry{
301+
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
302+
}
303+
300304
syncAgentClient := new(agentMocks.SyncAgentServiceClient)
301305
output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
302306
resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED, Outputs: output}
@@ -323,14 +327,15 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
323327
ID: "agent-service",
324328
SupportedTaskTypes: []core.TaskType{"openai"},
325329
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
326-
return Plugin{
330+
return &Plugin{
327331
metricScope: iCtx.MetricsScope(),
328332
cfg: &cfg,
329333
cs: &ClientSet{
330334
syncAgentClients: map[string]service.SyncAgentServiceClient{
331335
defaultAgentEndpoint: syncAgentClient,
332336
},
333337
},
338+
registry: agentRegistry,
334339
}, nil
335340
},
336341
}

0 commit comments

Comments
 (0)