From 0cb311eb210d3b1667f2c3fa8724efc1b2f7c3fe Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Mon, 9 Jun 2025 10:45:28 +0800 Subject: [PATCH 1/5] add plugin for each connector deployment Signed-off-by: Alex Wu --- flyte-single-binary-local.yaml | 24 ++++ .../go/tasks/pluginmachinery/registry.go | 34 +++++- .../tasks/plugins/webapi/connector/client.go | 48 ++------ .../tasks/plugins/webapi/connector/config.go | 3 + .../webapi/connector/integration_test.go | 8 -- .../tasks/plugins/webapi/connector/plugin.go | 112 +++++++----------- .../plugins/webapi/connector/plugin_test.go | 79 ++++++------ .../pkg/controller/nodes/task/handler.go | 49 ++++++-- .../pkg/controller/nodes/task/handler_test.go | 2 - 9 files changed, 190 insertions(+), 169 deletions(-) diff --git a/flyte-single-binary-local.yaml b/flyte-single-binary-local.yaml index 0b3caf45f1..7d197a823c 100644 --- a/flyte-single-binary-local.yaml +++ b/flyte-single-binary-local.yaml @@ -68,6 +68,30 @@ plugins: kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} cloudwatch-enabled: false stackdriver-enabled: false + connector-service: + supportedTaskTypes: + - bigquery_query_job_task + - sensor + - chatgpt + connectors: + my-test-connector: + endpoint: "localhost:8000" + insecure: true + timeouts: + CreateTask: 5s + GetTask: 30s + DeleteTask: 30s + defaultTimeout: 30s + webApi: + caching: + resyncInterval: 60s + defaultConnector: + endpoint: "localhost:8000" + webApi: + caching: + resyncInterval: 120s + connectorForTaskTypes: + - my_test_task1: my-test-connector database: postgres: diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index 39a241a50e..3a5afe9380 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -11,14 +11,20 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) +const defaultPluginBufferSize = 5 + type taskPluginRegistry struct { m sync.Mutex k8sPlugin []k8s.PluginEntry corePlugin []core.PluginEntry + pluginRegistrationChan chan webapi.PluginEntry + registeredTaskTypes map[string]struct{} } // A singleton variable that maintains a registry of all plugins. The framework uses this to access all plugins -var pluginRegistry = &taskPluginRegistry{} +var pluginRegistry = &taskPluginRegistry{ + pluginRegistrationChan: make(chan webapi.PluginEntry, defaultPluginBufferSize), +} func PluginRegistry() TaskPluginRegistry { return pluginRegistry @@ -41,6 +47,29 @@ func (p *taskPluginRegistry) RegisterRemotePlugin(info webapi.PluginEntry) { p.m.Lock() defer p.m.Unlock() p.corePlugin = append(p.corePlugin, internalRemote.CreateRemotePlugin(info)) + p.AddRegisteredTaskType(info.ID) +} + +func (p *taskPluginRegistry) GetPluginRegistrationChan() chan webapi.PluginEntry { + return p.pluginRegistrationChan +} + +// IsTaskTypeRegistered checks if a task type is registered +func (p *taskPluginRegistry) IsTaskTypeRegistered(taskType string) bool { + p.m.Lock() + defer p.m.Unlock() + _, exists := p.registeredTaskTypes[taskType] + return exists +} + +// RegisterTaskType registers a single task type +func (p *taskPluginRegistry) AddRegisteredTaskType(taskType string) { + p.m.Lock() + defer p.m.Unlock() + if p.registeredTaskTypes == nil { + p.registeredTaskTypes = make(map[string]struct{}) + } + p.registeredTaskTypes[taskType] = struct{}{} } func CreateRemotePlugin(pluginEntry webapi.PluginEntry) core.PluginEntry { @@ -107,4 +136,7 @@ type TaskPluginRegistry interface { RegisterRemotePlugin(info webapi.PluginEntry) GetCorePlugins() []core.PluginEntry GetK8sPlugins() []k8s.PluginEntry + GetPluginRegistrationChan() chan webapi.PluginEntry + IsTaskTypeRegistered(taskType string) bool + AddRegisteredTaskType(taskType string) } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client.go b/flyteplugins/go/tasks/plugins/webapi/connector/client.go index 2bc88cb328..fae7b1ea2d 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client.go @@ -3,7 +3,6 @@ package connector import ( "context" "crypto/x509" - "strings" "golang.org/x/exp/maps" "google.golang.org/grpc" @@ -15,6 +14,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" ) @@ -91,8 +91,7 @@ func getFinalContext(ctx context.Context, operation string, connector *Deploymen return context.WithTimeout(ctx, timeout) } -func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry { - newConnectorRegistry := make(Registry) +func watchConnectors(ctx context.Context, cs *ClientSet) { cfg := GetConfig() var connectorDeployments []*Deployment @@ -127,48 +126,27 @@ func getConnectorRegistry(ctx context.Context, cs *ClientSet) Registry { logger.Errorf(finalCtx, "failed to list connector: [%v] with error: [%v]", connectorDeployment.Endpoint, err) continue } - - connectorSupportedTaskCategories := make(map[string]struct{}) + // connectorSupportedTaskCategories := make(map[string]struct{}) + // If a connector's support task type plugin was not registered yet, we should do registration for _, connector := range res.GetAgents() { deprecatedSupportedTaskTypes := connector.GetSupportedTaskTypes() for _, supportedTaskType := range deprecatedSupportedTaskTypes { - connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: connector.GetIsSync()} - newConnectorRegistry[supportedTaskType] = map[int32]*Connector{defaultTaskTypeVersion: connector} - connectorSupportedTaskCategories[supportedTaskType] = struct{}{} - } - - supportedTaskCategories := connector.GetSupportedTaskCategories() - for _, supportedCategory := range supportedTaskCategories { - connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: connector.GetIsSync()} - supportedCategoryName := supportedCategory.GetName() - newConnectorRegistry[supportedCategoryName] = map[int32]*Connector{supportedCategory.GetVersion(): connector} - connectorSupportedTaskCategories[supportedCategoryName] = struct{}{} + if ok := pluginmachinery.PluginRegistry().IsTaskTypeRegistered(supportedTaskType); !ok { + plugin := createPluginEntry(supportedTaskType, connectorDeployment, cs) + pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) + pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- plugin + } } } - logger.Infof(ctx, "ConnectorDeployment [%v] supports the following task types: [%v]", connectorDeployment.Endpoint, - strings.Join(maps.Keys(connectorSupportedTaskCategories), ", ")) } - - // If the connector doesn't implement the metadata service, we construct the registry based on the configuration - for taskType, connectorDeploymentID := range cfg.ConnectorForTaskTypes { - if connectorDeployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok { - if _, ok := newConnectorRegistry[taskType]; !ok { - connector := &Connector{ConnectorDeployment: connectorDeployment, IsSync: false} - newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector} - } - } - } - // Ensure that the old configuration is backward compatible for _, taskType := range cfg.SupportedTaskTypes { - if _, ok := newConnectorRegistry[taskType]; !ok { - connector := &Connector{ConnectorDeployment: &cfg.DefaultConnector, IsSync: false} - newConnectorRegistry[taskType] = map[int32]*Connector{defaultTaskTypeVersion: connector} + if ok := pluginmachinery.PluginRegistry().IsTaskTypeRegistered(taskType); !ok { + plugin := createPluginEntry(taskType, &cfg.DefaultConnector, cs) + pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) + pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- plugin } } - - logger.Infof(ctx, "ConnectorDeployments support the following task types: [%v]", strings.Join(maps.Keys(newConnectorRegistry), ", ")) - return newConnectorRegistry } func getConnectorClientSets(ctx context.Context) *ClientSet { diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/config.go b/flyteplugins/go/tasks/plugins/webapi/connector/config.go index 7fd9ec0ebc..82edd44ef9 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/config.go @@ -93,6 +93,9 @@ type Deployment struct { // DefaultTimeout gives the default RPC timeout if a more specific one is not defined in Timeouts; if neither DefaultTimeout nor Timeouts is defined for an operation, RPC timeout will not be enforced DefaultTimeout config.Duration `json:"defaultTimeout"` + + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` } func GetConfig() *Config { diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go index dc2c858e5d..9b065f4ddc 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/integration_test.go @@ -254,9 +254,6 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAsyncConnectorPlugin() webapi.PluginEntry { asyncAgentClient := new(agentMocks.AsyncAgentServiceClient) - connectorRegistry := Registry{ - "spark": {defaultTaskTypeVersion: {ConnectorDeployment: &Deployment{Endpoint: defaultConnectorEndpoint}, IsSync: false}}, - } mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool { expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"} @@ -289,16 +286,12 @@ func newMockAsyncConnectorPlugin() webapi.PluginEntry { defaultConnectorEndpoint: asyncAgentClient, }, }, - registry: connectorRegistry, }, nil }, } } func newMockSyncConnectorPlugin() webapi.PluginEntry { - agentRegistry := Registry{ - "openai": {defaultTaskTypeVersion: {ConnectorDeployment: &Deployment{Endpoint: defaultConnectorEndpoint}, IsSync: true}}, - } syncAgentClient := new(agentMocks.SyncAgentServiceClient) output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) @@ -334,7 +327,6 @@ func newMockSyncConnectorPlugin() webapi.PluginEntry { defaultConnectorEndpoint: syncAgentClient, }, }, - registry: agentRegistry, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go index 79931ce537..17ee09101a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go @@ -4,11 +4,10 @@ import ( "context" "encoding/gob" "fmt" - "slices" + "reflect" "sync" "time" - "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/structpb" "k8s.io/apimachinery/pkg/util/wait" @@ -26,35 +25,11 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) -const ID = "connector-service" - -type ConnectorService struct { - mu sync.RWMutex - supportedTaskTypes []string - CorePlugin core.Plugin -} - -// ContainTaskType check if connector supports this task type. -func (p *ConnectorService) ContainTaskType(taskType string) bool { - p.mu.RLock() - defer p.mu.RUnlock() - return slices.Contains(p.supportedTaskTypes, taskType) -} - -// SetSupportedTaskType set supportTaskType in the connector service. -func (p *ConnectorService) SetSupportedTaskType(taskTypes []string) { - p.mu.Lock() - defer p.mu.Unlock() - p.supportedTaskTypes = taskTypes -} - -type Registry map[string]map[int32]*Connector // map[taskTypeName][taskTypeVersion] => Connector - type Plugin struct { metricScope promutils.Scope cfg *Config cs *ClientSet - registry Registry + deployment Connector mu sync.RWMutex } @@ -80,14 +55,12 @@ type ResourceMetaWrapper struct { TaskCategory admin.TaskCategory } -func (p *Plugin) setRegistry(r Registry) { - p.mu.Lock() - defer p.mu.Unlock() - p.registry = r -} - func (p *Plugin) GetConfig() webapi.PluginConfig { - return GetConfig().WebAPI + if p.deployment.ConnectorDeployment != nil && !reflect.DeepEqual(p.deployment.ConnectorDeployment.WebAPI, webapi.PluginConfig{}) { + return p.deployment.ConnectorDeployment.WebAPI + } else { + return p.cfg.WebAPI + } } func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( @@ -130,11 +103,11 @@ func (p *Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContext outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() taskCategory := admin.TaskCategory{Name: taskTemplate.GetType(), Version: taskTemplate.GetTaskTypeVersion()} - connector, isSync := p.getFinalConnector(&taskCategory, p.cfg) + connector := p.deployment.ConnectorDeployment taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - if isSync { + if p.deployment.IsSync { finalCtx, cancel := getFinalContext(ctx, "ExecuteTaskSync", connector) defer cancel() client, err := p.getSyncConnectorClient(ctx, connector) @@ -229,7 +202,7 @@ func (p *Plugin) ExecuteTaskSync( func (p *Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - connector, _ := p.getFinalConnector(&metadata.TaskCategory, p.cfg) + connector := p.deployment.ConnectorDeployment client, err := p.getAsyncConnectorClient(ctx, connector) if err != nil { @@ -264,7 +237,7 @@ func (p *Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - connector, _ := p.getFinalConnector(&metadata.TaskCategory, p.cfg) + connector := p.deployment.ConnectorDeployment client, err := p.getAsyncConnectorClient(ctx, connector) if err != nil { @@ -366,25 +339,14 @@ func (p *Plugin) getAsyncConnectorClient(ctx context.Context, connector *Deploym return client, nil } -func (p *Plugin) watchConnectors(ctx context.Context, connectorService *ConnectorService) { +func WatchConnectors(ctx context.Context) { + cfg := GetConfig() go wait.Until(func() { childCtx, cancel := context.WithCancel(ctx) defer cancel() clientSet := getConnectorClientSets(childCtx) - connectorRegistry := getConnectorRegistry(childCtx, clientSet) - p.setRegistry(connectorRegistry) - connectorService.SetSupportedTaskType(maps.Keys(connectorRegistry)) - }, p.cfg.PollInterval.Duration, ctx.Done()) -} - -func (p *Plugin) getFinalConnector(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { - p.mu.RLock() - defer p.mu.RUnlock() - - if connector, exists := p.registry[taskCategory.GetName()][taskCategory.GetVersion()]; exists { - return connector.ConnectorDeployment, connector.IsSync - } - return &cfg.DefaultConnector, false + watchConnectors(ctx, clientSet) + }, cfg.PollInterval.Duration, ctx.Done()) } func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error { @@ -423,35 +385,43 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func newConnectorPlugin(connectorService *ConnectorService) webapi.PluginEntry { - ctx := context.Background() - gob.Register(ResourceMetaWrapper{}) - gob.Register(ResourceWrapper{}) - - clientSet := getConnectorClientSets(ctx) - connectorRegistry := getConnectorRegistry(ctx, clientSet) - supportedTaskTypes := maps.Keys(connectorRegistry) - connectorService.SetSupportedTaskType(supportedTaskTypes) - +func createPluginEntry(taskType core.TaskType, deployment *Deployment, clientSet *ClientSet) webapi.PluginEntry { plugin := &Plugin{ - metricScope: promutils.NewScope("connector_plugin"), + metricScope: promutils.NewScope(taskType), cfg: GetConfig(), cs: clientSet, - registry: connectorRegistry, + deployment: Connector{IsSync: false, ConnectorDeployment: deployment}, } - plugin.watchConnectors(ctx, connectorService) - return webapi.PluginEntry{ - ID: ID, - SupportedTaskTypes: supportedTaskTypes, + ID: taskType, + SupportedTaskTypes: []core.TaskType{taskType}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { return plugin, nil }, } } -func RegisterConnectorPlugin(connectorService *ConnectorService) { +func newConnectorPlugins(ctx context.Context) []webapi.PluginEntry { + clientSet := getConnectorClientSets(ctx) + // Get deployments from config in order to init plugins + cfg := GetConfig() + plugins := make([]webapi.PluginEntry, 0) + for taskType, deploymentID := range cfg.ConnectorForTaskTypes { + if deployment, ok := cfg.ConnectorDeployments[deploymentID]; ok { + plugins = append(plugins, createPluginEntry(taskType, deployment, clientSet)) + } + } + return plugins +} + +func RegisterConnectorPlugin() { + ctx := context.Background() gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newConnectorPlugin(connectorService)) + plugins := newConnectorPlugins(ctx) + for _, plugin := range plugins { + // Register a remote plugin to CorePlugins for task handler to read + pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) + } + WatchConnectors(ctx) } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go index 884e9069b1..cd344b9154 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin_test.go @@ -7,13 +7,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/structpb" agentMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" @@ -35,12 +33,11 @@ func TestPlugin(t *testing.T) { cfg.ConnectorDeployments = map[string]*Deployment{"spark_connector": {Endpoint: "localhost:80"}} cfg.ConnectorForTaskTypes = map[string]string{"spark": "spark_connector", "bar": "bar_connector"} - connector := &Connector{ConnectorDeployment: &Deployment{Endpoint: "localhost:80"}} - connectorRegistry := Registry{"spark": {defaultTaskTypeVersion: connector}} + // connector := &Connector{ConnectorDeployment: &Deployment{Endpoint: "localhost:80"}} + plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), - registry: connectorRegistry, } t.Run("get config", func(t *testing.T) { err := SetConfig(&cfg) @@ -61,17 +58,17 @@ func TestPlugin(t *testing.T) { assert.NotNil(t, p.PluginLoader) }) - t.Run("test getFinalConnector", func(t *testing.T) { - spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} - foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} - bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} - connectorDeployment, _ := plugin.getFinalConnector(spark, &cfg) - assert.Equal(t, connectorDeployment.Endpoint, "localhost:80") - connectorDeployment, _ = plugin.getFinalConnector(foo, &cfg) - assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) - connectorDeployment, _ = plugin.getFinalConnector(bar, &cfg) - assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) - }) + // t.Run("test getFinalConnector", func(t *testing.T) { + // spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} + // foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} + // bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} + // connectorDeployment, _ := plugin.getFinalConnector(spark, &cfg) + // assert.Equal(t, connectorDeployment.Endpoint, "localhost:80") + // connectorDeployment, _ = plugin.getFinalConnector(foo, &cfg) + // assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) + // connectorDeployment, _ = plugin.getFinalConnector(bar, &cfg) + // assert.Equal(t, connectorDeployment.Endpoint, cfg.DefaultConnector.Endpoint) + // }) t.Run("test getFinalTimeout", func(t *testing.T) { timeout := getFinalTimeout("CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) @@ -330,28 +327,28 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { return mockMetadataServiceClient } -func TestInitializeConnectorRegistry(t *testing.T) { - connectorClients := make(map[string]service.AsyncAgentServiceClient) - connectorMetadataClients := make(map[string]service.AgentMetadataServiceClient) - connectorClients[defaultConnectorEndpoint] = &agentMocks.AsyncAgentServiceClient{} - connectorMetadataClients[defaultConnectorEndpoint] = getMockMetadataServiceClient() - - cs := &ClientSet{ - asyncConnectorClients: connectorClients, - connectorMetadataClients: connectorMetadataClients, - } - - cfg := defaultConfig - cfg.ConnectorDeployments = map[string]*Deployment{"custom_connector": {Endpoint: defaultConnectorEndpoint}} - cfg.ConnectorForTaskTypes = map[string]string{"task1": "connector-deployment-1", "task2": "connector-deployment-2"} - err := SetConfig(&cfg) - assert.NoError(t, err) - - connectorRegistry := getConnectorRegistry(context.Background(), cs) - connectorRegistryKeys := maps.Keys(connectorRegistry) - expectedKeys := []string{"task1", "task2", "task3", "task_type_3", "task_type_4"} - - for _, key := range expectedKeys { - assert.Contains(t, connectorRegistryKeys, key) - } -} +// func TestInitializeConnectorRegistry(t *testing.T) { +// connectorClients := make(map[string]service.AsyncAgentServiceClient) +// connectorMetadataClients := make(map[string]service.AgentMetadataServiceClient) +// connectorClients[defaultConnectorEndpoint] = &agentMocks.AsyncAgentServiceClient{} +// connectorMetadataClients[defaultConnectorEndpoint] = getMockMetadataServiceClient() + +// cs := &ClientSet{ +// asyncConnectorClients: connectorClients, +// connectorMetadataClients: connectorMetadataClients, +// } + +// cfg := defaultConfig +// cfg.ConnectorDeployments = map[string]*Deployment{"custom_connector": {Endpoint: defaultConnectorEndpoint}} +// cfg.ConnectorForTaskTypes = map[string]string{"task1": "connector-deployment-1", "task2": "connector-deployment-2"} +// err := SetConfig(&cfg) +// assert.NoError(t, err) + +// connectorRegistry := getConnectorRegistry(context.Background(), cs) +// connectorRegistryKeys := maps.Keys(connectorRegistry) +// expectedKeys := []string{"task1", "task2", "task3", "task_type_3", "task_type_4"} + +// for _, key := range expectedKeys { +// assert.Contains(t, connectorRegistryKeys, key) +// } +// } diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index b2f9476f69..cf799f8d6c 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -251,7 +251,6 @@ type Handler struct { eventConfig *controllerConfig.EventConfig clusterID string agentService *agent.AgentService - connectorService *connector.ConnectorService } func (t *Handler) FinalizeRequired() bool { @@ -268,6 +267,40 @@ func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { return nil } +func (t *Handler) watchPlugins(ctx context.Context, sCtx interfaces.SetupContext) error { + for { + select { + case wpe := <-pluginMachinery.PluginRegistry().GetPluginRegistrationChan(): + tSCtx := t.newSetupContext(sCtx) + // Create a new base resource negotiator + resourceManagerConfig := rmConfig.GetConfig() + newResourceManagerBuilder, err := resourcemanager.GetResourceManagerBuilderByType(ctx, resourceManagerConfig.Type, t.metrics.scope) + if err != nil { + return err + } + pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(wpe.ID)) + sCtxFinal := newNameSpacedSetupCtx( + tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), wpe.ID) + logger.Infof(ctx, "Loading Plugin [%s] ENABLED", wpe.ID) + // register core plugin + for _, cpe := range pluginMachinery.PluginRegistry().GetCorePlugins() { + if cpe.ID == wpe.ID { + cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) + if err != nil { + return regErrors.Wrapf(err, "failed to load plugin - %s", wpe.ID) + } + // register the plugin to task handler local plugin registry + t.defaultPlugins[cp.GetID()] = cp + pluginMachinery.PluginRegistry().AddRegisteredTaskType(cp.GetID()) + } + } + case <-ctx.Done(): + logger.Infof(ctx, "Plugin watcher stopped due to context cancellation") + return nil + } + } +} + func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error { tSCtx := t.newSetupContext(sCtx) @@ -281,7 +314,7 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error once.Do(func() { // The agent service plugin is deprecated and will be removed in the future agent.RegisterAgentPlugin(t.agentService) - connector.RegisterConnectorPlugin(t.connectorService) + connector.RegisterConnectorPlugin() }) // Create the resource negotiator here @@ -308,10 +341,6 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return regErrors.Wrapf(err, "failed to load plugin - %s", p.ID) } - if cp.GetID() == connector.ID { - t.connectorService.CorePlugin = cp - } - // The agent service plugin is deprecated and will be removed in the future if cp.GetID() == agent.ID { t.agentService.CorePlugin = cp @@ -368,6 +397,9 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error } } + // execute watcher to monitor plugins waiting for register from connector/agent plugin + go t.watchPlugins(ctx, sCtx) + rm, err := newResourceManagerBuilder.BuildResourceManager(ctx) if err != nil { logger.Errorf(ctx, "Failed to build a resource manager") @@ -406,10 +438,6 @@ func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfi return p, nil } - if t.connectorService != nil && t.connectorService.ContainTaskType(ttype) { - return t.connectorService.CorePlugin, nil - } - // The agent service plugin is deprecated and will be removed in the future if t.agentService != nil && t.agentService.ContainTaskType(ttype) { return t.agentService.CorePlugin, nil @@ -1052,6 +1080,5 @@ func New(ctx context.Context, kubeClient executors.Client, kubeClientset kuberne eventConfig: eventConfig, clusterID: clusterID, agentService: &agent.AgentService{}, - connectorService: &connector.ConnectorService{}, }, nil } diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index 8f51627fb8..f437ca52c9 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -29,7 +29,6 @@ import ( pluginK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" pluginK8sMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/connector" eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" @@ -353,7 +352,6 @@ func Test_task_ResolvePlugin(t *testing.T) { defaultPlugin: tt.fields.defaultPlugin, pluginsForType: tt.fields.pluginsForType, agentService: &agent.AgentService{}, - connectorService: &connector.ConnectorService{}, } got, err := tk.ResolvePlugin(context.TODO(), tt.args.ttype, tt.args.executionConfig) if (err != nil) != tt.wantErr { From 2d7e090a05ccf94b831c01a2292a42b5292e0edd Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Thu, 12 Jun 2025 17:57:55 +0800 Subject: [PATCH 2/5] fix errors Signed-off-by: Alex Wu --- flyte-single-binary-local.yaml | 42 +++++++++++---- .../go/tasks/pluginmachinery/registry.go | 2 - .../tasks/plugins/webapi/connector/client.go | 2 +- .../tasks/plugins/webapi/connector/plugin.go | 51 +++++++++++++++++-- .../pkg/controller/nodes/task/handler.go | 4 ++ 5 files changed, 84 insertions(+), 17 deletions(-) diff --git a/flyte-single-binary-local.yaml b/flyte-single-binary-local.yaml index 7d197a823c..04ba39f121 100644 --- a/flyte-single-binary-local.yaml +++ b/flyte-single-binary-local.yaml @@ -43,8 +43,9 @@ tasks: - container - sidecar - K8S-ARRAY - - connector-service - echo + - agent-service + # - noop_async_agent_task default-for-task-types: - container: container - container_array: K8S-ARRAY @@ -68,13 +69,37 @@ plugins: kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} cloudwatch-enabled: false stackdriver-enabled: false - connector-service: + # connector-service: + # supportedTaskTypes: + # - bigquery_query_job_task + # - sensor + # - chatgpt + # connectors: + # my-test-connector1: + # endpoint: "localhost:8000" + # insecure: true + # timeouts: + # CreateTask: 5s + # GetTask: 30s + # DeleteTask: 30s + # defaultTimeout: 30s + # webApi: + # caching: + # resyncInterval: 60s + # defaultConnector: + # endpoint: "localhost:8000" + # webApi: + # caching: + # resyncInterval: 120s + # connectorForTaskTypes: + # - noop_async_agent_task: my-test-connector1 + agent-service: supportedTaskTypes: - bigquery_query_job_task - sensor - chatgpt - connectors: - my-test-connector: + agents: + my-test-connector1: endpoint: "localhost:8000" insecure: true timeouts: @@ -82,16 +107,13 @@ plugins: GetTask: 30s DeleteTask: 30s defaultTimeout: 30s - webApi: - caching: - resyncInterval: 60s - defaultConnector: + defaultAgent: endpoint: "localhost:8000" webApi: caching: resyncInterval: 120s - connectorForTaskTypes: - - my_test_task1: my-test-connector + agentForTaskTypes: + - noop_async_agent_task: my-test-connector1 database: postgres: diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index 3a5afe9380..46c43a16a3 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -64,8 +64,6 @@ func (p *taskPluginRegistry) IsTaskTypeRegistered(taskType string) bool { // RegisterTaskType registers a single task type func (p *taskPluginRegistry) AddRegisteredTaskType(taskType string) { - p.m.Lock() - defer p.m.Unlock() if p.registeredTaskTypes == nil { p.registeredTaskTypes = make(map[string]struct{}) } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client.go b/flyteplugins/go/tasks/plugins/webapi/connector/client.go index fae7b1ea2d..aef8590d11 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client.go @@ -95,10 +95,10 @@ func watchConnectors(ctx context.Context, cs *ClientSet) { cfg := GetConfig() var connectorDeployments []*Deployment + connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) if len(cfg.DefaultConnector.Endpoint) != 0 { connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) } - connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) for _, connectorDeployment := range connectorDeployments { client, ok := cs.connectorMetadataClients[connectorDeployment.Endpoint] if !ok { diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go index 17ee09101a..5eefab8e3a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go @@ -4,7 +4,6 @@ import ( "context" "encoding/gob" "fmt" - "reflect" "sync" "time" @@ -56,11 +55,55 @@ type ResourceMetaWrapper struct { } func (p *Plugin) GetConfig() webapi.PluginConfig { - if p.deployment.ConnectorDeployment != nil && !reflect.DeepEqual(p.deployment.ConnectorDeployment.WebAPI, webapi.PluginConfig{}) { - return p.deployment.ConnectorDeployment.WebAPI - } else { + // Return default config if deployment is nil + if p.deployment.ConnectorDeployment == nil { return p.cfg.WebAPI } + + // Create a new config object by copying deployment's config + config := p.deployment.ConnectorDeployment.WebAPI + + // 1. Check if ResourceQuotas is nil + if config.ResourceQuotas == nil { + config.ResourceQuotas = p.cfg.WebAPI.ResourceQuotas + } + + // 2. Check ReadRateLimiter values individually + if config.ReadRateLimiter.QPS == 0 { + config.ReadRateLimiter.QPS = p.cfg.WebAPI.ReadRateLimiter.QPS + } + if config.ReadRateLimiter.Burst == 0 { + config.ReadRateLimiter.Burst = p.cfg.WebAPI.ReadRateLimiter.Burst + } + + // 3. Check WriteRateLimiter values individually + if config.WriteRateLimiter.QPS == 0 { + config.WriteRateLimiter.QPS = p.cfg.WebAPI.WriteRateLimiter.QPS + } + if config.WriteRateLimiter.Burst == 0 { + config.WriteRateLimiter.Burst = p.cfg.WebAPI.WriteRateLimiter.Burst + } + + // 4. Check Caching configuration values individually + if config.Caching.ResyncInterval.Duration == time.Duration(0) { + config.Caching.ResyncInterval = p.cfg.WebAPI.Caching.ResyncInterval + } + if config.Caching.Size == 0 { + config.Caching.Size = p.cfg.WebAPI.Caching.Size + } + if config.Caching.Workers == 0 { + config.Caching.Workers = p.cfg.WebAPI.Caching.Workers + } + if config.Caching.MaxSystemFailures == 0 { + config.Caching.MaxSystemFailures = p.cfg.WebAPI.Caching.MaxSystemFailures + } + + // 5. Check if ResourceMeta is nil + if config.ResourceMeta == nil { + config.ResourceMeta = p.cfg.WebAPI.ResourceMeta + } + + return config } func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index cf799f8d6c..d5620c6c52 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -2,6 +2,7 @@ package task import ( "context" + "encoding/json" "fmt" "runtime/debug" "time" @@ -675,6 +676,9 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex ttype := nCtx.TaskReader().GetTaskType() ctx = contextutils.WithTaskType(ctx, ttype) p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + jsonBytes, _ := json.MarshalIndent(p, "", " ") + logger.Debug(ctx, "The task type is [%s]", string(ttype)) + logger.Debug(ctx, "The deployment config is [%s]", string(jsonBytes)) if err != nil { return handler.UnknownTransition, errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } From 2a6cbd9164a8656c127d41a2dcbd81eb462ac5b9 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Sat, 26 Jul 2025 14:23:12 +0800 Subject: [PATCH 3/5] add feature Signed-off-by: Alex Wu --- flyte-single-binary-local.yaml | 71 +++++----- flyteplugins/go.mod | 2 +- .../pluginmachinery/internal/webapi/cache.go | 4 +- .../pluginmachinery/internal/webapi/core.go | 4 +- .../go/tasks/pluginmachinery/registry.go | 64 ++++++--- .../tasks/plugins/webapi/connector/client.go | 77 ++++++++--- .../tasks/plugins/webapi/connector/plugin.go | 51 ++++--- .../pkg/controller/nodes/task/cache.go | 24 ++-- .../pkg/controller/nodes/task/handler.go | 129 ++++++++++++++---- .../pkg/controller/nodes/task/handler_test.go | 2 +- flytestdlib/cache/in_memory_auto_refresh.go | 2 + 11 files changed, 301 insertions(+), 129 deletions(-) diff --git a/flyte-single-binary-local.yaml b/flyte-single-binary-local.yaml index 04ba39f121..c3233da8ab 100644 --- a/flyte-single-binary-local.yaml +++ b/flyte-single-binary-local.yaml @@ -45,7 +45,6 @@ tasks: - K8S-ARRAY - echo - agent-service - # - noop_async_agent_task default-for-task-types: - container: container - container_array: K8S-ARRAY @@ -69,51 +68,61 @@ plugins: kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} cloudwatch-enabled: false stackdriver-enabled: false - # connector-service: - # supportedTaskTypes: - # - bigquery_query_job_task - # - sensor - # - chatgpt - # connectors: - # my-test-connector1: - # endpoint: "localhost:8000" - # insecure: true - # timeouts: - # CreateTask: 5s - # GetTask: 30s - # DeleteTask: 30s - # defaultTimeout: 30s - # webApi: - # caching: - # resyncInterval: 60s - # defaultConnector: - # endpoint: "localhost:8000" - # webApi: - # caching: - # resyncInterval: 120s - # connectorForTaskTypes: - # - noop_async_agent_task: my-test-connector1 - agent-service: + connector-service: supportedTaskTypes: - bigquery_query_job_task - sensor - - chatgpt - agents: + connectors: my-test-connector1: - endpoint: "localhost:8000" + endpoint: "localhost:8001" + insecure: true + timeouts: + CreateTask: 5s + GetTask: 30s + DeleteTask: 30s + defaultTimeout: 30s + webApi: + caching: + resyncInterval: 60s + my-test-connector2: + endpoint: "localhost:8002" insecure: true timeouts: CreateTask: 5s GetTask: 30s DeleteTask: 30s defaultTimeout: 30s - defaultAgent: + webApi: + caching: + resyncInterval: 120s + defaultConnector: endpoint: "localhost:8000" webApi: caching: resyncInterval: 120s - agentForTaskTypes: + connectorForTaskTypes: - noop_async_agent_task: my-test-connector1 + # agent-service: + # supportedTaskTypes: + # - bigquery_query_job_task + # - sensor + # - chatgpt + # agents: + # my-test-connector1: + # endpoint: "localhost:8000" + # insecure: true + # timeouts: + # CreateTask: 5s + # GetTask: 30s + # DeleteTask: 30s + # defaultTimeout: 30s + # defaultAgent: + # endpoint: "localhost:8000" + # webApi: + # caching: + # resyncInterval: 120s + # agentForTaskTypes: + # - noop_async_agent_task: my-test-connector1 database: postgres: diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index 43364f5f44..7c11af2505 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -15,6 +15,7 @@ require ( github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.4 + github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru v0.5.4 github.com/imdario/mergo v0.3.13 github.com/kubeflow/training-operator v1.8.0 @@ -85,7 +86,6 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/s2a-go v0.1.7 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go index 956848ed7a..e2a4ac4cdf 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go @@ -2,6 +2,7 @@ package webapi import ( "context" + "fmt" "time" "golang.org/x/time/rate" @@ -175,12 +176,13 @@ func NewResourceCache(ctx context.Context, name string, client Client, cfg webap cfg: cfg, } + scopeName := fmt.Sprintf("cache_%s", name) autoRefreshCache, err := cache.NewAutoRefreshCache(name, q.SyncResource, workqueue.NewMaxOfRateLimiter( workqueue.NewItemExponentialFailureRateLimiter(5*time.Millisecond, 1000*time.Second), &workqueue.BucketRateLimiter{Limiter: rate.NewLimiter(rate.Limit(rateCfg.QPS), rateCfg.Burst)}, ), cfg.ResyncInterval.Duration, uint(cfg.Workers), uint(cfg.Size), // #nosec G115 - scope.NewSubScope("cache")) + scope.NewSubScope(scopeName)) if err != nil { logger.Errorf(ctx, "Could not create AutoRefreshCache. Error: [%s]", err) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go index 4ba7ffea85..b986ba78d1 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go @@ -195,9 +195,9 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug } } } - + scopeName := fmt.Sprintf("cache_%s", pluginEntry.ID) resourceCache, err := NewResourceCache(ctx, pluginEntry.ID, p, p.GetConfig().Caching, - p.GetConfig().ReadRateLimiter, iCtx.MetricsScope().NewSubScope("cache")) + p.GetConfig().ReadRateLimiter, iCtx.MetricsScope().NewSubScope(scopeName)) if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index 46c43a16a3..b48fc5766e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -11,19 +11,33 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) -const defaultPluginBufferSize = 5 +const defaultPluginBufferSize = 100 + +// PluginRegistrationInfo contains information about plugin registration +type PluginRegistrationInfo struct { + Plugin webapi.PluginEntry + DeploymentID string +} + +type PluginUpdateInfo struct { + TaskType string + DeploymentID string +} type taskPluginRegistry struct { m sync.Mutex k8sPlugin []k8s.PluginEntry corePlugin []core.PluginEntry - pluginRegistrationChan chan webapi.PluginEntry - registeredTaskTypes map[string]struct{} + pluginRegistrationChan chan PluginRegistrationInfo + pluginUpdateChan chan PluginUpdateInfo + registeredPlugins map[string]map[string]struct{} } // A singleton variable that maintains a registry of all plugins. The framework uses this to access all plugins var pluginRegistry = &taskPluginRegistry{ - pluginRegistrationChan: make(chan webapi.PluginEntry, defaultPluginBufferSize), + pluginRegistrationChan: make(chan PluginRegistrationInfo, defaultPluginBufferSize), + pluginUpdateChan: make(chan PluginUpdateInfo, defaultPluginBufferSize), + registeredPlugins: make(map[string]map[string]struct{}), } func PluginRegistry() TaskPluginRegistry { @@ -47,27 +61,44 @@ func (p *taskPluginRegistry) RegisterRemotePlugin(info webapi.PluginEntry) { p.m.Lock() defer p.m.Unlock() p.corePlugin = append(p.corePlugin, internalRemote.CreateRemotePlugin(info)) - p.AddRegisteredTaskType(info.ID) } -func (p *taskPluginRegistry) GetPluginRegistrationChan() chan webapi.PluginEntry { +func (p *taskPluginRegistry) GetPluginRegistrationChan() chan PluginRegistrationInfo { return p.pluginRegistrationChan } -// IsTaskTypeRegistered checks if a task type is registered -func (p *taskPluginRegistry) IsTaskTypeRegistered(taskType string) bool { +func (p *taskPluginRegistry) GetPluginUpdateChan() chan PluginUpdateInfo { + return p.pluginUpdateChan +} + +// IsPluginForTaskTypeRegistered checks if a task type is registered +func (p *taskPluginRegistry) IsPluginForTaskTypeRegistered(taskType string, deploymentID string) bool { p.m.Lock() defer p.m.Unlock() - _, exists := p.registeredTaskTypes[taskType] + + if p.registeredPlugins == nil { + return false + } + + deploymentMap, exists := p.registeredPlugins[taskType] + if !exists { + return false + } + + _, exists = deploymentMap[deploymentID] return exists } // RegisterTaskType registers a single task type -func (p *taskPluginRegistry) AddRegisteredTaskType(taskType string) { - if p.registeredTaskTypes == nil { - p.registeredTaskTypes = make(map[string]struct{}) +func (p *taskPluginRegistry) AddRegisteredPluginForTaskType(taskType string, deploymentID string) { + p.m.Lock() + defer p.m.Unlock() + + if p.registeredPlugins[taskType] == nil { + p.registeredPlugins[taskType] = make(map[string]struct{}) } - p.registeredTaskTypes[taskType] = struct{}{} + + p.registeredPlugins[taskType][deploymentID] = struct{}{} } func CreateRemotePlugin(pluginEntry webapi.PluginEntry) core.PluginEntry { @@ -134,7 +165,8 @@ type TaskPluginRegistry interface { RegisterRemotePlugin(info webapi.PluginEntry) GetCorePlugins() []core.PluginEntry GetK8sPlugins() []k8s.PluginEntry - GetPluginRegistrationChan() chan webapi.PluginEntry - IsTaskTypeRegistered(taskType string) bool - AddRegisteredTaskType(taskType string) + GetPluginRegistrationChan() chan PluginRegistrationInfo + GetPluginUpdateChan() chan PluginUpdateInfo + IsPluginForTaskTypeRegistered(taskType string, deploymentID string) bool + AddRegisteredPluginForTaskType(taskType string, deploymentID string) } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client.go b/flyteplugins/go/tasks/plugins/webapi/connector/client.go index aef8590d11..22ee9e8144 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client.go @@ -3,8 +3,9 @@ package connector import ( "context" "crypto/x509" + "fmt" + "strings" - "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -20,6 +21,7 @@ import ( ) const defaultTaskTypeVersion = 0 +const defaultDeploymentID = "default" type Connector struct { // IsSync indicates whether this connector is a sync connector. Sync connectors are expected to return their @@ -91,15 +93,33 @@ func getFinalContext(ctx context.Context, operation string, connector *Deploymen return context.WithTimeout(ctx, timeout) } +// processTaskType handles the registration or update of a task type plugin +func processTaskType(ctx context.Context, taskName string, taskVersion int32, deploymentID string, connectorDeployment *Deployment, cs *ClientSet) string { + versionedTaskType := fmt.Sprintf("%s_%d", taskName, taskVersion) + + // Register default version if not registered + if !pluginmachinery.PluginRegistry().IsPluginForTaskTypeRegistered(versionedTaskType, deploymentID) { + registerNewPlugin(taskName, taskVersion, deploymentID, connectorDeployment, cs) + } else { + updatePlugin(versionedTaskType, deploymentID) + } + return versionedTaskType +} + func watchConnectors(ctx context.Context, cs *ClientSet) { cfg := GetConfig() - var connectorDeployments []*Deployment + connectorDeployments := make(map[string]*Deployment) - connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) + // Merge ConnectorDeployments + for key, deployment := range cfg.ConnectorDeployments { + connectorDeployments[key] = deployment + } + + // Merge DefaultConnector (if endpoint is not empty) if len(cfg.DefaultConnector.Endpoint) != 0 { - connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) + connectorDeployments[defaultDeploymentID] = &cfg.DefaultConnector } - for _, connectorDeployment := range connectorDeployments { + for deploymentID, connectorDeployment := range connectorDeployments { client, ok := cs.connectorMetadataClients[connectorDeployment.Endpoint] if !ok { logger.Warningf(ctx, "Connector client not found in the clientSet for the endpoint: %v", connectorDeployment.Endpoint) @@ -126,25 +146,41 @@ func watchConnectors(ctx context.Context, cs *ClientSet) { logger.Errorf(finalCtx, "failed to list connector: [%v] with error: [%v]", connectorDeployment.Endpoint, err) continue } - // connectorSupportedTaskCategories := make(map[string]struct{}) + // If a connector's support task type plugin was not registered yet, we should do registration + connectorSupportedTaskCategories := make(map[string]struct{}) for _, connector := range res.GetAgents() { deprecatedSupportedTaskTypes := connector.GetSupportedTaskTypes() + supportedTaskCategories := connector.GetSupportedTaskCategories() + // Process deprecated supported task types for _, supportedTaskType := range deprecatedSupportedTaskTypes { - if ok := pluginmachinery.PluginRegistry().IsTaskTypeRegistered(supportedTaskType); !ok { - plugin := createPluginEntry(supportedTaskType, connectorDeployment, cs) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) - pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- plugin - } + versionedTaskType := processTaskType(ctx, supportedTaskType, defaultTaskTypeVersion, deploymentID, connectorDeployment, cs) + connectorSupportedTaskCategories[versionedTaskType] = struct{}{} + } + // Process supported task categories + for _, supportedCategory := range supportedTaskCategories { + versionedTaskType := processTaskType(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs) + connectorSupportedTaskCategories[versionedTaskType] = struct{}{} } } + keys := make([]string, 0, len(connectorSupportedTaskCategories)) + for k := range connectorSupportedTaskCategories { + keys = append(keys, k) + } + logger.Infof(ctx, "ConnectorDeployment [%v] supports the following task types: [%v]", connectorDeployment.Endpoint, + strings.Join(keys, ", ")) + } + // always overwrite with connectorForTaskTypes config + for taskType, connectorDeploymentID := range cfg.ConnectorForTaskTypes { + if deployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok { + processTaskType(ctx, taskType, defaultTaskTypeVersion, connectorDeploymentID, deployment, cs) + } } // Ensure that the old configuration is backward compatible for _, taskType := range cfg.SupportedTaskTypes { - if ok := pluginmachinery.PluginRegistry().IsTaskTypeRegistered(taskType); !ok { - plugin := createPluginEntry(taskType, &cfg.DefaultConnector, cs) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) - pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- plugin + versionedTaskType := fmt.Sprintf("%s_%d", taskType, defaultTaskTypeVersion) + if ok := pluginmachinery.PluginRegistry().IsPluginForTaskTypeRegistered(versionedTaskType, defaultDeploymentID); !ok { + processTaskType(ctx, taskType, defaultTaskTypeVersion, defaultDeploymentID, &cfg.DefaultConnector, cs) } } } @@ -156,13 +192,18 @@ func getConnectorClientSets(ctx context.Context) *ClientSet { connectorMetadataClients: make(map[string]service.AgentMetadataServiceClient), } - var connectorDeployments []*Deployment + connectorDeployments := make(map[string]*Deployment) cfg := GetConfig() + // Merge ConnectorDeployments + for key, deployment := range cfg.ConnectorDeployments { + connectorDeployments[key] = deployment + } + + // Merge DefaultConnector (if endpoint is not empty) if len(cfg.DefaultConnector.Endpoint) != 0 { - connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) + connectorDeployments["default"] = &cfg.DefaultConnector } - connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) for _, connectorDeployment := range connectorDeployments { if _, ok := clientSet.connectorMetadataClients[connectorDeployment.Endpoint]; ok { logger.Infof(ctx, "Connector client already initialized for [%v]", connectorDeployment.Endpoint) diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go index 5eefab8e3a..63ef677648 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go @@ -382,6 +382,14 @@ func (p *Plugin) getAsyncConnectorClient(ctx context.Context, connector *Deploym return client, nil } +// UpdateDeployment updates the deployment configuration for the plugin. +// This method is thread-safe and can be called concurrently. +func (p *Plugin) UpdateDeployment(deployment Connector) { + p.mu.Lock() + defer p.mu.Unlock() + p.deployment = deployment +} + func WatchConnectors(ctx context.Context) { cfg := GetConfig() go wait.Until(func() { @@ -428,15 +436,16 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func createPluginEntry(taskType core.TaskType, deployment *Deployment, clientSet *ClientSet) webapi.PluginEntry { +func createPluginEntry(taskType core.TaskType, taskVersion int32, deployment *Deployment, clientSet *ClientSet) webapi.PluginEntry { + versionedTaskType := fmt.Sprintf("%s_%d", taskType, taskVersion) plugin := &Plugin{ - metricScope: promutils.NewScope(taskType), + metricScope: promutils.NewScope(versionedTaskType), cfg: GetConfig(), cs: clientSet, deployment: Connector{IsSync: false, ConnectorDeployment: deployment}, } return webapi.PluginEntry{ - ID: taskType, + ID: versionedTaskType, SupportedTaskTypes: []core.TaskType{taskType}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { return plugin, nil @@ -444,27 +453,33 @@ func createPluginEntry(taskType core.TaskType, deployment *Deployment, clientSet } } -func newConnectorPlugins(ctx context.Context) []webapi.PluginEntry { - clientSet := getConnectorClientSets(ctx) - // Get deployments from config in order to init plugins - cfg := GetConfig() - plugins := make([]webapi.PluginEntry, 0) - for taskType, deploymentID := range cfg.ConnectorForTaskTypes { - if deployment, ok := cfg.ConnectorDeployments[deploymentID]; ok { - plugins = append(plugins, createPluginEntry(taskType, deployment, clientSet)) - } +func updatePlugin(versionedTaskType string, deploymentID string) { + select { + case pluginmachinery.PluginRegistry().GetPluginUpdateChan() <- pluginmachinery.PluginUpdateInfo{ + TaskType: versionedTaskType, + DeploymentID: deploymentID, + }: + default: + logger.Errorf(context.Background(), "Failed to update plugin for task type %s: channel is full", versionedTaskType) + } +} + +func registerNewPlugin(taskType core.TaskType, taskTypeVersion int32, deploymentID string, deployment *Deployment, cs *ClientSet) { + plugin := createPluginEntry(taskType, taskTypeVersion, deployment, cs) + pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) + select { + case pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- pluginmachinery.PluginRegistrationInfo{ + Plugin: plugin, + DeploymentID: deploymentID, + }: + default: + logger.Errorf(context.Background(), "Failed to register plugin %s: channel is full", plugin.ID) } - return plugins } func RegisterConnectorPlugin() { ctx := context.Background() gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - plugins := newConnectorPlugins(ctx) - for _, plugin := range plugins { - // Register a remote plugin to CorePlugins for task handler to read - pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) - } WatchConnectors(ctx) } diff --git a/flytepropeller/pkg/controller/nodes/task/cache.go b/flytepropeller/pkg/controller/nodes/task/cache.go index d408a5af85..43e187f391 100644 --- a/flytepropeller/pkg/controller/nodes/task/cache.go +++ b/flytepropeller/pkg/controller/nodes/task/cache.go @@ -37,18 +37,6 @@ func (t *Handler) GetCatalogKey(ctx context.Context, nCtx interfaces.NodeExecuti func (t *Handler) IsCacheable(ctx context.Context, nCtx interfaces.NodeExecutionContext) (bool, bool, error) { // check if plugin has caching disabled ttype := nCtx.TaskReader().GetTaskType() - ctx = contextutils.WithTaskType(ctx, ttype) - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) - if err != nil { - return false, false, errors2.Wrapf(errors2.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") - } - - checkCatalog := !p.GetProperties().DisableNodeLevelCaching - if !checkCatalog { - logger.Infof(ctx, "Node level caching is disabled. Skipping catalog read.") - return false, false, nil - } - // read task template taskTemplatePath, err := ioutils.GetTaskTemplatePath(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetDataDir()) if err != nil { @@ -61,6 +49,18 @@ func (t *Handler) IsCacheable(ctx context.Context, nCtx interfaces.NodeExecution logger.Errorf(ctx, "failed to read TaskTemplate, error :%s", err.Error()) return false, false, err } + tVersion := taskTemplate.GetTaskTypeVersion() + ctx = contextutils.WithTaskType(ctx, ttype) + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) + if err != nil { + return false, false, errors2.Wrapf(errors2.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") + } + + checkCatalog := !p.GetProperties().DisableNodeLevelCaching + if !checkCatalog { + logger.Infof(ctx, "Node level caching is disabled. Skipping catalog read.") + return false, false, nil + } return taskTemplate.GetMetadata().GetDiscoverable(), taskTemplate.GetMetadata().GetDiscoverable() && taskTemplate.GetMetadata().GetCacheSerializable(), nil } diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index d5620c6c52..8f1e0218ff 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "runtime/debug" + "sync" "time" regErrors "github.com/pkg/errors" @@ -238,6 +239,7 @@ type Handler struct { catalog catalog.Client asyncCatalog catalog.AsyncClient defaultPlugins map[pluginCore.TaskType]pluginCore.Plugin + connectorDeploymentsForType map[pluginCore.TaskType]map[string]pluginCore.Plugin pluginsForType map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin taskMetricsMap map[MetricKey]*taskMetrics defaultPlugin pluginCore.Plugin @@ -252,6 +254,7 @@ type Handler struct { eventConfig *controllerConfig.EventConfig clusterID string agentService *agent.AgentService + mu sync.RWMutex } func (t *Handler) FinalizeRequired() bool { @@ -271,29 +274,66 @@ func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { func (t *Handler) watchPlugins(ctx context.Context, sCtx interfaces.SetupContext) error { for { select { - case wpe := <-pluginMachinery.PluginRegistry().GetPluginRegistrationChan(): - tSCtx := t.newSetupContext(sCtx) - // Create a new base resource negotiator - resourceManagerConfig := rmConfig.GetConfig() - newResourceManagerBuilder, err := resourcemanager.GetResourceManagerBuilderByType(ctx, resourceManagerConfig.Type, t.metrics.scope) - if err != nil { - return err + case registerInfo := <-pluginMachinery.PluginRegistry().GetPluginRegistrationChan(): + if !pluginMachinery.PluginRegistry().IsPluginForTaskTypeRegistered(registerInfo.Plugin.ID, registerInfo.DeploymentID){ + wpe := registerInfo.Plugin + tSCtx := t.newSetupContext(sCtx) + // Create a new base resource negotiator + resourceManagerConfig := rmConfig.GetConfig() + newResourceManagerBuilder, err := resourcemanager.GetResourceManagerBuilderByType(ctx, resourceManagerConfig.Type, t.metrics.scope) + if err != nil { + return err + } + pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(wpe.ID)) + sCtxFinal := newNameSpacedSetupCtx( + tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), wpe.ID) + // register core plugin + for _, cpe := range pluginMachinery.PluginRegistry().GetCorePlugins() { + if cpe.ID == wpe.ID { + cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) + if err != nil { + return regErrors.Wrapf(err, "failed to load plugin - %s", wpe.ID) + } + // register the plugin to task handler local plugin registry + t.mu.Lock() + t.defaultPlugins[cp.GetID()] = cp + if t.connectorDeploymentsForType[cp.GetID()] == nil { + t.connectorDeploymentsForType[cp.GetID()] = make(map[string]pluginCore.Plugin) + } + t.connectorDeploymentsForType[cp.GetID()][registerInfo.DeploymentID] = cp + t.mu.Unlock() + pluginMachinery.PluginRegistry().AddRegisteredPluginForTaskType(cp.GetID(), registerInfo.DeploymentID) + logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) + break + } + } + } else { + logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] already registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) } - pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(wpe.ID)) - sCtxFinal := newNameSpacedSetupCtx( - tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), wpe.ID) - logger.Infof(ctx, "Loading Plugin [%s] ENABLED", wpe.ID) - // register core plugin - for _, cpe := range pluginMachinery.PluginRegistry().GetCorePlugins() { - if cpe.ID == wpe.ID { - cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) - if err != nil { - return regErrors.Wrapf(err, "failed to load plugin - %s", wpe.ID) + case updateInfo := <-pluginMachinery.PluginRegistry().GetPluginUpdateChan(): + if pluginMachinery.PluginRegistry().IsPluginForTaskTypeRegistered(updateInfo.TaskType, updateInfo.DeploymentID) { + t.mu.RLock() + deploymentMap, exists := t.connectorDeploymentsForType[updateInfo.TaskType] + t.mu.RUnlock() + + if exists { + t.mu.RLock() + plugin, pluginExists := deploymentMap[updateInfo.DeploymentID] + t.mu.RUnlock() + + if pluginExists { + t.mu.Lock() + t.defaultPlugins[updateInfo.TaskType] = plugin + t.mu.Unlock() + logger.Infof(ctx, "The default plugin for TaskType [%s] has been updated to Deployment ID [%s]", updateInfo.TaskType, updateInfo.DeploymentID) + } else { + logger.Warningf(ctx, "Plugin for TaskType [%s] and deployment ID [%s] not found", updateInfo.TaskType, updateInfo.DeploymentID) } - // register the plugin to task handler local plugin registry - t.defaultPlugins[cp.GetID()] = cp - pluginMachinery.PluginRegistry().AddRegisteredTaskType(cp.GetID()) + } else { + logger.Warningf(ctx, "Deployment ID [%s] for TaskType [%s] not found", updateInfo.DeploymentID, updateInfo.TaskType) } + } else { + logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] not yet registered", updateInfo.TaskType, updateInfo.DeploymentID) } case <-ctx.Done(): logger.Infof(ctx, "Plugin watcher stopped due to context cancellation") @@ -312,6 +352,9 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return err } + // execute watcher to monitor plugins waiting for register from connector/agent plugin + go t.watchPlugins(ctx, sCtx) + once.Do(func() { // The agent service plugin is deprecated and will be removed in the future agent.RegisterAgentPlugin(t.agentService) @@ -398,9 +441,6 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error } } - // execute watcher to monitor plugins waiting for register from connector/agent plugin - go t.watchPlugins(ctx, sCtx) - rm, err := newResourceManagerBuilder.BuildResourceManager(ctx) if err != nil { logger.Errorf(ctx, "Failed to build a resource manager") @@ -411,12 +451,15 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return nil } -func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfig v1alpha1.ExecutionConfig) (pluginCore.Plugin, error) { +func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfig v1alpha1.ExecutionConfig, taskVersion int32) (pluginCore.Plugin, error) { // If the workflow specifies plugin overrides, check to see if any of the specified plugins for that type are // registered in this deployment of flytepropeller. if len(executionConfig.TaskPluginImpls[ttype].PluginIDs) > 0 { - if len(t.pluginsForType[ttype]) > 0 { - pluginsForType := t.pluginsForType[ttype] + t.mu.RLock() + pluginsForType, exists := t.pluginsForType[ttype] + t.mu.RUnlock() + + if exists && len(pluginsForType) > 0 { for _, pluginImplID := range executionConfig.TaskPluginImpls[ttype].PluginIDs { pluginImpl := pluginsForType[pluginImplID] if pluginImpl != nil { @@ -433,12 +476,24 @@ func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfi } } + t.mu.RLock() p, ok := t.defaultPlugins[ttype] + t.mu.RUnlock() if ok { logger.Debugf(ctx, "Plugin [%s] resolved for Handler type [%s]", p.GetID(), ttype) return p, nil } + // check if the task type is a connector/agent task type + versionedTaskType := fmt.Sprintf("%s_%d", ttype, taskVersion) + t.mu.RLock() + p, ok = t.defaultPlugins[versionedTaskType] + t.mu.RUnlock() + if ok { + logger.Debugf(ctx, "Plugin [%s] resolved for versioned task type [%s]", p.GetID(), versionedTaskType) + return p, nil + } + // The agent service plugin is deprecated and will be removed in the future if t.agentService != nil && t.agentService.ContainTaskType(ttype) { return t.agentService.CorePlugin, nil @@ -674,8 +729,13 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex defer span.End() ttype := nCtx.TaskReader().GetTaskType() + taskTemplate, err := nCtx.TaskReader().Read(ctx) + if err != nil { + return handler.UnknownTransition, errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to read task template") + } + tVersion := taskTemplate.GetTaskTypeVersion() ctx = contextutils.WithTaskType(ctx, ttype) - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) jsonBytes, _ := json.MarshalIndent(p, "", " ") logger.Debug(ctx, "The task type is [%s]", string(ttype)) logger.Debug(ctx, "The deployment config is [%s]", string(jsonBytes)) @@ -927,7 +987,12 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext } ttype := nCtx.TaskReader().GetTaskType() - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + taskTemplate, err := nCtx.TaskReader().Read(ctx) + if err != nil { + return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to read task template") + } + tVersion := taskTemplate.GetTaskTypeVersion() + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) if err != nil { return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } @@ -1022,7 +1087,12 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext func (t Handler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionContext) error { logger.Debugf(ctx, "Finalize invoked.") ttype := nCtx.TaskReader().GetTaskType() - p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + taskTemplate, err := nCtx.TaskReader().Read(ctx) + if err != nil { + return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to read task template") + } + tVersion := taskTemplate.GetTaskTypeVersion() + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig(), tVersion) if err != nil { return errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } @@ -1064,6 +1134,7 @@ func New(ctx context.Context, kubeClient executors.Client, kubeClientset kuberne return &Handler{ pluginRegistry: pluginMachinery.PluginRegistry(), defaultPlugins: make(map[pluginCore.TaskType]pluginCore.Plugin), + connectorDeploymentsForType: make(map[pluginCore.TaskType]map[string]pluginCore.Plugin), pluginsForType: make(map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin), taskMetricsMap: make(map[MetricKey]*taskMetrics), metrics: &metrics{ diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index f437ca52c9..af1ff067c1 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -353,7 +353,7 @@ func Test_task_ResolvePlugin(t *testing.T) { pluginsForType: tt.fields.pluginsForType, agentService: &agent.AgentService{}, } - got, err := tk.ResolvePlugin(context.TODO(), tt.args.ttype, tt.args.executionConfig) + got, err := tk.ResolvePlugin(context.TODO(), tt.args.ttype, tt.args.executionConfig, 0) if (err != nil) != tt.wantErr { t.Errorf("Handler.ResolvePlugin() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/flytestdlib/cache/in_memory_auto_refresh.go b/flytestdlib/cache/in_memory_auto_refresh.go index c3072f5b26..e1db4121d7 100644 --- a/flytestdlib/cache/in_memory_auto_refresh.go +++ b/flytestdlib/cache/in_memory_auto_refresh.go @@ -133,6 +133,8 @@ func NewInMemoryAutoRefresh( } metrics := newMetrics(scope) + ctx := context.Background() + logger.Debug(ctx, "the scope is %s", metrics) // #nosec G115 lruCache, err := lru.NewWithEvict(int(size), getEvictionFunction(metrics.Evictions)) if err != nil { From 1473709492f76964f644ea6138b0772f15bb96d2 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Tue, 29 Jul 2025 11:13:15 +0800 Subject: [PATCH 4/5] add connector core plugins map in plugin registry Signed-off-by: Alex Wu --- .../pluginmachinery/internal/webapi/cache.go | 4 +- .../internal/webapi/metrics.go | 31 +++++++- .../go/tasks/pluginmachinery/registry.go | 48 ++++++++++++ .../tasks/plugins/webapi/connector/client.go | 26 ++++--- .../tasks/plugins/webapi/connector/plugin.go | 8 +- .../pkg/controller/nodes/task/handler.go | 77 ++++++++++--------- flytestdlib/cache/in_memory_auto_refresh.go | 33 +++++++- 7 files changed, 169 insertions(+), 58 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go index e2a4ac4cdf..956848ed7a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/cache.go @@ -2,7 +2,6 @@ package webapi import ( "context" - "fmt" "time" "golang.org/x/time/rate" @@ -176,13 +175,12 @@ func NewResourceCache(ctx context.Context, name string, client Client, cfg webap cfg: cfg, } - scopeName := fmt.Sprintf("cache_%s", name) autoRefreshCache, err := cache.NewAutoRefreshCache(name, q.SyncResource, workqueue.NewMaxOfRateLimiter( workqueue.NewItemExponentialFailureRateLimiter(5*time.Millisecond, 1000*time.Second), &workqueue.BucketRateLimiter{Limiter: rate.NewLimiter(rate.Limit(rateCfg.QPS), rateCfg.Burst)}, ), cfg.ResyncInterval.Duration, uint(cfg.Workers), uint(cfg.Size), // #nosec G115 - scope.NewSubScope(scopeName)) + scope.NewSubScope("cache")) if err != nil { logger.Errorf(ctx, "Could not create AutoRefreshCache. Error: [%s]", err) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go index f2fa03104c..0f05c8df17 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/metrics.go @@ -1,6 +1,7 @@ package webapi import ( + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -9,6 +10,12 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils/labeled" ) +// Global metrics cache to avoid recreating metrics for the same scope +var ( + metricsCache = make(map[string]*Metrics) + metricsMutex sync.RWMutex +) + type Metrics struct { Scope promutils.Scope ResourceReleased labeled.Counter @@ -25,7 +32,26 @@ var ( ) func newMetrics(scope promutils.Scope) Metrics { - return Metrics{ + scopeName := scope.CurrentScope() + + // Check if we already have metrics for this scope + metricsMutex.RLock() + if cachedMetrics, exists := metricsCache[scopeName]; exists { + defer metricsMutex.RUnlock() + return *cachedMetrics + } + metricsMutex.RUnlock() + + // Create new metrics and store globally + metricsMutex.Lock() + defer metricsMutex.Unlock() + + // Double-check in case another goroutine created metrics while we were acquiring the lock + if cachedMetrics, exists := metricsCache[scopeName]; exists { + return *cachedMetrics + } + + newMetrics := Metrics{ Scope: scope, ResourceReleased: labeled.NewCounter("resource_release_success", "Resource allocation token released", scope, labeled.EmitUnlabeledMetric), @@ -42,4 +68,7 @@ func newMetrics(scope promutils.Scope) Metrics { FailedUnmarshalState: labeled.NewCounter("unmarshal_state_failed", "Failed to unmarshal state", scope, labeled.EmitUnlabeledMetric), } + + metricsCache[scopeName] = &newMetrics + return newMetrics } diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index b48fc5766e..9d81dca1f5 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -28,6 +28,7 @@ type taskPluginRegistry struct { m sync.Mutex k8sPlugin []k8s.PluginEntry corePlugin []core.PluginEntry + connectorCorePlugin map[string]map[string]core.PluginEntry pluginRegistrationChan chan PluginRegistrationInfo pluginUpdateChan chan PluginUpdateInfo registeredPlugins map[string]map[string]struct{} @@ -35,6 +36,8 @@ type taskPluginRegistry struct { // A singleton variable that maintains a registry of all plugins. The framework uses this to access all plugins var pluginRegistry = &taskPluginRegistry{ + corePlugin: []core.PluginEntry{}, + connectorCorePlugin: make(map[string]map[string]core.PluginEntry), pluginRegistrationChan: make(chan PluginRegistrationInfo, defaultPluginBufferSize), pluginUpdateChan: make(chan PluginUpdateInfo, defaultPluginBufferSize), registeredPlugins: make(map[string]map[string]struct{}), @@ -63,6 +66,49 @@ func (p *taskPluginRegistry) RegisterRemotePlugin(info webapi.PluginEntry) { p.corePlugin = append(p.corePlugin, internalRemote.CreateRemotePlugin(info)) } +// RegisterConnectorCorePlugin registers a core plugin for a specific connector deployment +func (p *taskPluginRegistry) RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string) { + ctx := context.Background() + if info.ID == "" { + logger.Panicf(ctx, "ID is required attribute for connector core plugin") + } + + if len(info.SupportedTaskTypes) == 0 { + logger.Panicf(ctx, "Plugin should be registered to handle at least one task type") + } + + p.m.Lock() + defer p.m.Unlock() + + if p.connectorCorePlugin == nil { + p.connectorCorePlugin = make(map[string]map[string]core.PluginEntry) + } + + if p.connectorCorePlugin[info.ID] == nil { + p.connectorCorePlugin[info.ID] = make(map[string]core.PluginEntry) + } + + p.connectorCorePlugin[info.ID][deploymentID] = internalRemote.CreateRemotePlugin(info) +} + +// GetConnectorCorePlugin returns a specific connector core plugin for a task type and deployment ID +func (p *taskPluginRegistry) GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool) { + p.m.Lock() + defer p.m.Unlock() + + if p.connectorCorePlugin == nil { + return core.PluginEntry{}, false + } + + if plugins, exists := p.connectorCorePlugin[taskType]; exists { + if plugin, exists := plugins[deploymentID]; exists { + return plugin, true + } + } + + return core.PluginEntry{}, false +} + func (p *taskPluginRegistry) GetPluginRegistrationChan() chan PluginRegistrationInfo { return p.pluginRegistrationChan } @@ -163,6 +209,8 @@ type TaskPluginRegistry interface { RegisterK8sPlugin(info k8s.PluginEntry) RegisterCorePlugin(info core.PluginEntry) RegisterRemotePlugin(info webapi.PluginEntry) + RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string) + GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool) GetCorePlugins() []core.PluginEntry GetK8sPlugins() []k8s.PluginEntry GetPluginRegistrationChan() chan PluginRegistrationInfo diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client.go b/flyteplugins/go/tasks/plugins/webapi/connector/client.go index 22ee9e8144..c0e0243d9e 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -99,7 +100,7 @@ func processTaskType(ctx context.Context, taskName string, taskVersion int32, de // Register default version if not registered if !pluginmachinery.PluginRegistry().IsPluginForTaskTypeRegistered(versionedTaskType, deploymentID) { - registerNewPlugin(taskName, taskVersion, deploymentID, connectorDeployment, cs) + registerNewPlugin(taskName, taskVersion, deploymentID, *connectorDeployment, cs) } else { updatePlugin(versionedTaskType, deploymentID) } @@ -108,18 +109,19 @@ func processTaskType(ctx context.Context, taskName string, taskVersion int32, de func watchConnectors(ctx context.Context, cs *ClientSet) { cfg := GetConfig() - connectorDeployments := make(map[string]*Deployment) + var connectorDeploymentIDs []string + var connectorDeployments []*Deployment - // Merge ConnectorDeployments - for key, deployment := range cfg.ConnectorDeployments { - connectorDeployments[key] = deployment - } - // Merge DefaultConnector (if endpoint is not empty) if len(cfg.DefaultConnector.Endpoint) != 0 { - connectorDeployments[defaultDeploymentID] = &cfg.DefaultConnector + connectorDeploymentIDs = append(connectorDeploymentIDs, defaultDeploymentID) + connectorDeployments = append(connectorDeployments, &cfg.DefaultConnector) } - for deploymentID, connectorDeployment := range connectorDeployments { + connectorDeploymentIDs = append(connectorDeploymentIDs, maps.Keys(cfg.ConnectorDeployments)...) + connectorDeployments = append(connectorDeployments, maps.Values(cfg.ConnectorDeployments)...) + + for idx, connectorDeployment := range connectorDeployments { + deploymentID := connectorDeploymentIDs[idx] client, ok := cs.connectorMetadataClients[connectorDeployment.Endpoint] if !ok { logger.Warningf(ctx, "Connector client not found in the clientSet for the endpoint: %v", connectorDeployment.Endpoint) @@ -159,8 +161,10 @@ func watchConnectors(ctx context.Context, cs *ClientSet) { } // Process supported task categories for _, supportedCategory := range supportedTaskCategories { - versionedTaskType := processTaskType(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs) - connectorSupportedTaskCategories[versionedTaskType] = struct{}{} + if supportedCategory.Version != defaultTaskTypeVersion { + versionedTaskType := processTaskType(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs) + connectorSupportedTaskCategories[versionedTaskType] = struct{}{} + } } } keys := make([]string, 0, len(connectorSupportedTaskCategories)) diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go index 63ef677648..2992267ca5 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go @@ -436,13 +436,13 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func createPluginEntry(taskType core.TaskType, taskVersion int32, deployment *Deployment, clientSet *ClientSet) webapi.PluginEntry { +func createPluginEntry(taskType core.TaskType, taskVersion int32, deployment Deployment, clientSet *ClientSet) webapi.PluginEntry { versionedTaskType := fmt.Sprintf("%s_%d", taskType, taskVersion) plugin := &Plugin{ metricScope: promutils.NewScope(versionedTaskType), cfg: GetConfig(), cs: clientSet, - deployment: Connector{IsSync: false, ConnectorDeployment: deployment}, + deployment: Connector{IsSync: false, ConnectorDeployment: &deployment}, } return webapi.PluginEntry{ ID: versionedTaskType, @@ -464,9 +464,9 @@ func updatePlugin(versionedTaskType string, deploymentID string) { } } -func registerNewPlugin(taskType core.TaskType, taskTypeVersion int32, deploymentID string, deployment *Deployment, cs *ClientSet) { +func registerNewPlugin(taskType core.TaskType, taskTypeVersion int32, deploymentID string, deployment Deployment, cs *ClientSet) { plugin := createPluginEntry(taskType, taskTypeVersion, deployment, cs) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(plugin) + pluginmachinery.PluginRegistry().RegisterConnectorCorePlugin(plugin, deploymentID) select { case pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- pluginmachinery.PluginRegistrationInfo{ Plugin: plugin, diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 8f1e0218ff..1e9e738209 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -261,6 +261,33 @@ func (t *Handler) FinalizeRequired() bool { return true } +// getConnectorPlugin retrieves a plugin from connectorDeploymentsForType map with proper locking +func (t *Handler) getConnectorPlugin(taskType pluginCore.TaskType, deploymentID string) (pluginCore.Plugin, bool) { + t.mu.RLock() + deploymentMap, exists := t.connectorDeploymentsForType[taskType] + t.mu.RUnlock() + + if !exists { + return nil, false + } + + t.mu.RLock() + plugin, pluginExists := deploymentMap[deploymentID] + t.mu.RUnlock() + + return plugin, pluginExists +} + +func (t *Handler) registerConnectorPlugin(corePlugin pluginCore.Plugin, deploymentID string) { + t.mu.Lock() + t.defaultPlugins[corePlugin.GetID()] = corePlugin + if t.connectorDeploymentsForType[corePlugin.GetID()] == nil { + t.connectorDeploymentsForType[corePlugin.GetID()] = make(map[string]pluginCore.Plugin) + } + t.connectorDeploymentsForType[corePlugin.GetID()][deploymentID] = corePlugin + t.mu.Unlock() +} + func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { if t.defaultPlugin != nil { logger.Errorf(ctx, "cannot set plugin [%s] as default as plugin [%s] is already configured as default", p.GetID(), t.defaultPlugin.GetID()) @@ -288,49 +315,29 @@ func (t *Handler) watchPlugins(ctx context.Context, sCtx interfaces.SetupContext sCtxFinal := newNameSpacedSetupCtx( tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), wpe.ID) // register core plugin - for _, cpe := range pluginMachinery.PluginRegistry().GetCorePlugins() { - if cpe.ID == wpe.ID { - cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) - if err != nil { - return regErrors.Wrapf(err, "failed to load plugin - %s", wpe.ID) - } - // register the plugin to task handler local plugin registry - t.mu.Lock() - t.defaultPlugins[cp.GetID()] = cp - if t.connectorDeploymentsForType[cp.GetID()] == nil { - t.connectorDeploymentsForType[cp.GetID()] = make(map[string]pluginCore.Plugin) - } - t.connectorDeploymentsForType[cp.GetID()][registerInfo.DeploymentID] = cp - t.mu.Unlock() - pluginMachinery.PluginRegistry().AddRegisteredPluginForTaskType(cp.GetID(), registerInfo.DeploymentID) - logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) - break + if cpe, ok := pluginMachinery.PluginRegistry().GetConnectorCorePlugin(wpe.ID, registerInfo.DeploymentID); ok { + cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) + if err != nil { + return regErrors.Wrapf(err, "failed to load plugin - %s", wpe.ID) } + // register the plugin to task handler local plugin registry + t.registerConnectorPlugin(cp, registerInfo.DeploymentID) + pluginMachinery.PluginRegistry().AddRegisteredPluginForTaskType(cp.GetID(), registerInfo.DeploymentID) + logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) } } else { logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] already registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) } case updateInfo := <-pluginMachinery.PluginRegistry().GetPluginUpdateChan(): if pluginMachinery.PluginRegistry().IsPluginForTaskTypeRegistered(updateInfo.TaskType, updateInfo.DeploymentID) { - t.mu.RLock() - deploymentMap, exists := t.connectorDeploymentsForType[updateInfo.TaskType] - t.mu.RUnlock() - - if exists { - t.mu.RLock() - plugin, pluginExists := deploymentMap[updateInfo.DeploymentID] - t.mu.RUnlock() - - if pluginExists { - t.mu.Lock() - t.defaultPlugins[updateInfo.TaskType] = plugin - t.mu.Unlock() - logger.Infof(ctx, "The default plugin for TaskType [%s] has been updated to Deployment ID [%s]", updateInfo.TaskType, updateInfo.DeploymentID) - } else { - logger.Warningf(ctx, "Plugin for TaskType [%s] and deployment ID [%s] not found", updateInfo.TaskType, updateInfo.DeploymentID) - } + plugin, pluginExists := t.getConnectorPlugin(updateInfo.TaskType, updateInfo.DeploymentID) + if pluginExists { + t.mu.Lock() + t.defaultPlugins[updateInfo.TaskType] = plugin + t.mu.Unlock() + logger.Infof(ctx, "The default plugin for TaskType [%s] has been updated to Deployment ID [%s]", updateInfo.TaskType, updateInfo.DeploymentID) } else { - logger.Warningf(ctx, "Deployment ID [%s] for TaskType [%s] not found", updateInfo.DeploymentID, updateInfo.TaskType) + logger.Warningf(ctx, "Plugin for TaskType [%s] and deployment ID [%s] not found", updateInfo.TaskType, updateInfo.DeploymentID) } } else { logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] not yet registered", updateInfo.TaskType, updateInfo.DeploymentID) diff --git a/flytestdlib/cache/in_memory_auto_refresh.go b/flytestdlib/cache/in_memory_auto_refresh.go index e1db4121d7..e96b7d53cd 100644 --- a/flytestdlib/cache/in_memory_auto_refresh.go +++ b/flytestdlib/cache/in_memory_auto_refresh.go @@ -19,6 +19,12 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) +// Global metrics cache to avoid recreating metrics for the same scope +var ( + metricsCache = make(map[string]*metrics) + metricsMutex sync.RWMutex +) + type metrics struct { SyncErrors prometheus.Counter Evictions prometheus.Counter @@ -30,7 +36,25 @@ type metrics struct { } func newMetrics(scope promutils.Scope) metrics { - return metrics{ + scopeName := scope.CurrentScope() + // Check if we already have metrics for this scope + metricsMutex.RLock() + if cachedMetrics, exists := metricsCache[scopeName]; exists { + defer metricsMutex.RUnlock() + return *cachedMetrics + } + metricsMutex.RUnlock() + + // Create new metrics and store globally + metricsMutex.Lock() + defer metricsMutex.Unlock() + + // Double-check in case another goroutine created metrics while we were acquiring the lock + if cachedMetrics, exists := metricsCache[scopeName]; exists { + return *cachedMetrics + } + + newMetrics := metrics{ SyncErrors: scope.MustNewCounter("sync_errors", "Counter for sync errors."), Evictions: scope.MustNewCounter("lru_evictions", "Counter for evictions from LRU."), SyncLatency: scope.MustNewStopWatch("latency", "Latency for sync operations.", time.Millisecond), @@ -39,6 +63,9 @@ func newMetrics(scope promutils.Scope) metrics { Size: scope.MustNewGauge("size", "Current size of the cache"), scope: scope, } + + metricsCache[scopeName] = &newMetrics + return newMetrics } func getEvictionFunction(counter prometheus.Counter) func(key interface{}, value interface{}) { @@ -133,8 +160,6 @@ func NewInMemoryAutoRefresh( } metrics := newMetrics(scope) - ctx := context.Background() - logger.Debug(ctx, "the scope is %s", metrics) // #nosec G115 lruCache, err := lru.NewWithEvict(int(size), getEvictionFunction(metrics.Evictions)) if err != nil { @@ -152,7 +177,7 @@ func NewInMemoryAutoRefresh( toDelete: newSyncSet(), syncPeriod: resyncPeriod, workqueue: workqueue.NewRateLimitingQueueWithConfig(syncRateLimiter, workqueue.RateLimitingQueueConfig{ - Name: scope.CurrentScope(), + // Name: scope.CurrentScope(), Clock: opts.clock, }), clock: opts.clock, From 1bfaa73b5b53170cd2d10d4ab1622680d6ad49f7 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Wed, 30 Jul 2025 13:16:28 +0800 Subject: [PATCH 5/5] restructure connector registration business logic Signed-off-by: Alex Wu --- .../go/tasks/pluginmachinery/registry.go | 65 ++++---------- .../tasks/plugins/webapi/connector/client.go | 29 ++---- .../tasks/plugins/webapi/connector/plugin.go | 47 +++++----- .../pkg/controller/nodes/task/handler.go | 89 ++++++++++++------- flytestdlib/cache/in_memory_auto_refresh.go | 5 -- 5 files changed, 104 insertions(+), 131 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index 9d81dca1f5..3975c5397f 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -13,14 +13,8 @@ import ( const defaultPluginBufferSize = 100 -// PluginRegistrationInfo contains information about plugin registration -type PluginRegistrationInfo struct { - Plugin webapi.PluginEntry - DeploymentID string -} - -type PluginUpdateInfo struct { - TaskType string +type PluginInfo struct { + VersionedTaskType string DeploymentID string } @@ -29,18 +23,14 @@ type taskPluginRegistry struct { k8sPlugin []k8s.PluginEntry corePlugin []core.PluginEntry connectorCorePlugin map[string]map[string]core.PluginEntry - pluginRegistrationChan chan PluginRegistrationInfo - pluginUpdateChan chan PluginUpdateInfo - registeredPlugins map[string]map[string]struct{} + pluginChan chan PluginInfo } // A singleton variable that maintains a registry of all plugins. The framework uses this to access all plugins var pluginRegistry = &taskPluginRegistry{ corePlugin: []core.PluginEntry{}, connectorCorePlugin: make(map[string]map[string]core.PluginEntry), - pluginRegistrationChan: make(chan PluginRegistrationInfo, defaultPluginBufferSize), - pluginUpdateChan: make(chan PluginUpdateInfo, defaultPluginBufferSize), - registeredPlugins: make(map[string]map[string]struct{}), + pluginChan: make(chan PluginInfo, defaultPluginBufferSize), } func PluginRegistry() TaskPluginRegistry { @@ -109,42 +99,25 @@ func (p *taskPluginRegistry) GetConnectorCorePlugin(taskType string, deploymentI return core.PluginEntry{}, false } -func (p *taskPluginRegistry) GetPluginRegistrationChan() chan PluginRegistrationInfo { - return p.pluginRegistrationChan -} - -func (p *taskPluginRegistry) GetPluginUpdateChan() chan PluginUpdateInfo { - return p.pluginUpdateChan -} - -// IsPluginForTaskTypeRegistered checks if a task type is registered -func (p *taskPluginRegistry) IsPluginForTaskTypeRegistered(taskType string, deploymentID string) bool { +func (p *taskPluginRegistry) IsConnectorCorePluginRegistered(taskType string, deploymentID string) bool { p.m.Lock() defer p.m.Unlock() - - if p.registeredPlugins == nil { + + if p.connectorCorePlugin == nil { return false } - - deploymentMap, exists := p.registeredPlugins[taskType] - if !exists { - return false + + if plugins, exists := p.connectorCorePlugin[taskType]; exists { + if _, exists := plugins[deploymentID]; exists { + return true + } } - - _, exists = deploymentMap[deploymentID] - return exists + + return false } -// RegisterTaskType registers a single task type -func (p *taskPluginRegistry) AddRegisteredPluginForTaskType(taskType string, deploymentID string) { - p.m.Lock() - defer p.m.Unlock() - - if p.registeredPlugins[taskType] == nil { - p.registeredPlugins[taskType] = make(map[string]struct{}) - } - - p.registeredPlugins[taskType][deploymentID] = struct{}{} +func (p *taskPluginRegistry) GetPluginChan() chan PluginInfo { + return p.pluginChan } func CreateRemotePlugin(pluginEntry webapi.PluginEntry) core.PluginEntry { @@ -211,10 +184,8 @@ type TaskPluginRegistry interface { RegisterRemotePlugin(info webapi.PluginEntry) RegisterConnectorCorePlugin(info webapi.PluginEntry, deploymentID string) GetConnectorCorePlugin(taskType string, deploymentID string) (core.PluginEntry, bool) + IsConnectorCorePluginRegistered(taskType string, deploymentID string) bool GetCorePlugins() []core.PluginEntry GetK8sPlugins() []k8s.PluginEntry - GetPluginRegistrationChan() chan PluginRegistrationInfo - GetPluginUpdateChan() chan PluginUpdateInfo - IsPluginForTaskTypeRegistered(taskType string, deploymentID string) bool - AddRegisteredPluginForTaskType(taskType string, deploymentID string) + GetPluginChan() chan PluginInfo } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/client.go b/flyteplugins/go/tasks/plugins/webapi/connector/client.go index c0e0243d9e..3b1d414095 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/client.go @@ -94,19 +94,6 @@ func getFinalContext(ctx context.Context, operation string, connector *Deploymen return context.WithTimeout(ctx, timeout) } -// processTaskType handles the registration or update of a task type plugin -func processTaskType(ctx context.Context, taskName string, taskVersion int32, deploymentID string, connectorDeployment *Deployment, cs *ClientSet) string { - versionedTaskType := fmt.Sprintf("%s_%d", taskName, taskVersion) - - // Register default version if not registered - if !pluginmachinery.PluginRegistry().IsPluginForTaskTypeRegistered(versionedTaskType, deploymentID) { - registerNewPlugin(taskName, taskVersion, deploymentID, *connectorDeployment, cs) - } else { - updatePlugin(versionedTaskType, deploymentID) - } - return versionedTaskType -} - func watchConnectors(ctx context.Context, cs *ClientSet) { cfg := GetConfig() var connectorDeploymentIDs []string @@ -156,35 +143,31 @@ func watchConnectors(ctx context.Context, cs *ClientSet) { supportedTaskCategories := connector.GetSupportedTaskCategories() // Process deprecated supported task types for _, supportedTaskType := range deprecatedSupportedTaskTypes { - versionedTaskType := processTaskType(ctx, supportedTaskType, defaultTaskTypeVersion, deploymentID, connectorDeployment, cs) + versionedTaskType := createOrUpdatePlugin(ctx, supportedTaskType, defaultTaskTypeVersion, deploymentID, connectorDeployment, cs) connectorSupportedTaskCategories[versionedTaskType] = struct{}{} } // Process supported task categories for _, supportedCategory := range supportedTaskCategories { if supportedCategory.Version != defaultTaskTypeVersion { - versionedTaskType := processTaskType(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs) + versionedTaskType := createOrUpdatePlugin(ctx, supportedCategory.Name, supportedCategory.Version, deploymentID, connectorDeployment, cs) connectorSupportedTaskCategories[versionedTaskType] = struct{}{} } } } - keys := make([]string, 0, len(connectorSupportedTaskCategories)) - for k := range connectorSupportedTaskCategories { - keys = append(keys, k) - } logger.Infof(ctx, "ConnectorDeployment [%v] supports the following task types: [%v]", connectorDeployment.Endpoint, - strings.Join(keys, ", ")) + strings.Join(maps.Keys(connectorSupportedTaskCategories), ", ")) } // always overwrite with connectorForTaskTypes config for taskType, connectorDeploymentID := range cfg.ConnectorForTaskTypes { if deployment, ok := cfg.ConnectorDeployments[connectorDeploymentID]; ok { - processTaskType(ctx, taskType, defaultTaskTypeVersion, connectorDeploymentID, deployment, cs) + createOrUpdatePlugin(ctx, taskType, defaultTaskTypeVersion, connectorDeploymentID, deployment, cs) } } // Ensure that the old configuration is backward compatible for _, taskType := range cfg.SupportedTaskTypes { versionedTaskType := fmt.Sprintf("%s_%d", taskType, defaultTaskTypeVersion) - if ok := pluginmachinery.PluginRegistry().IsPluginForTaskTypeRegistered(versionedTaskType, defaultDeploymentID); !ok { - processTaskType(ctx, taskType, defaultTaskTypeVersion, defaultDeploymentID, &cfg.DefaultConnector, cs) + if ok := pluginmachinery.PluginRegistry().IsConnectorCorePluginRegistered(versionedTaskType, defaultDeploymentID); !ok { + createOrUpdatePlugin(ctx, taskType, defaultTaskTypeVersion, defaultDeploymentID, &cfg.DefaultConnector, cs) } } } diff --git a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go index 2992267ca5..ee978b6137 100644 --- a/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/connector/plugin.go @@ -63,12 +63,12 @@ func (p *Plugin) GetConfig() webapi.PluginConfig { // Create a new config object by copying deployment's config config := p.deployment.ConnectorDeployment.WebAPI - // 1. Check if ResourceQuotas is nil + // Check if ResourceQuotas is nil if config.ResourceQuotas == nil { config.ResourceQuotas = p.cfg.WebAPI.ResourceQuotas } - // 2. Check ReadRateLimiter values individually + // Check ReadRateLimiter values individually if config.ReadRateLimiter.QPS == 0 { config.ReadRateLimiter.QPS = p.cfg.WebAPI.ReadRateLimiter.QPS } @@ -76,7 +76,7 @@ func (p *Plugin) GetConfig() webapi.PluginConfig { config.ReadRateLimiter.Burst = p.cfg.WebAPI.ReadRateLimiter.Burst } - // 3. Check WriteRateLimiter values individually + // Check WriteRateLimiter values individually if config.WriteRateLimiter.QPS == 0 { config.WriteRateLimiter.QPS = p.cfg.WebAPI.WriteRateLimiter.QPS } @@ -84,7 +84,7 @@ func (p *Plugin) GetConfig() webapi.PluginConfig { config.WriteRateLimiter.Burst = p.cfg.WebAPI.WriteRateLimiter.Burst } - // 4. Check Caching configuration values individually + // Check Caching configuration values individually if config.Caching.ResyncInterval.Duration == time.Duration(0) { config.Caching.ResyncInterval = p.cfg.WebAPI.Caching.ResyncInterval } @@ -98,7 +98,7 @@ func (p *Plugin) GetConfig() webapi.PluginConfig { config.Caching.MaxSystemFailures = p.cfg.WebAPI.Caching.MaxSystemFailures } - // 5. Check if ResourceMeta is nil + // Check if ResourceMeta is nil if config.ResourceMeta == nil { config.ResourceMeta = p.cfg.WebAPI.ResourceMeta } @@ -115,6 +115,7 @@ func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionC func (p *Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, webapi.Resource, error) { + logger.Debug(ctx, "create task for deployment %s", p.deployment.ConnectorDeployment.Endpoint) taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to read task template with error: %v", err) @@ -453,28 +454,26 @@ func createPluginEntry(taskType core.TaskType, taskVersion int32, deployment Dep } } -func updatePlugin(versionedTaskType string, deploymentID string) { - select { - case pluginmachinery.PluginRegistry().GetPluginUpdateChan() <- pluginmachinery.PluginUpdateInfo{ - TaskType: versionedTaskType, - DeploymentID: deploymentID, - }: - default: - logger.Errorf(context.Background(), "Failed to update plugin for task type %s: channel is full", versionedTaskType) +// createOrUpdatePlugin handles the registration or update of a task type plugin +func createOrUpdatePlugin(ctx context.Context, taskName string, taskVersion int32, deploymentID string, connectorDeployment *Deployment, cs *ClientSet) string { + versionedTaskType := fmt.Sprintf("%s_%d", taskName, taskVersion) + + // Register core plugin if not registered + if !pluginmachinery.PluginRegistry().IsConnectorCorePluginRegistered(versionedTaskType, deploymentID) { + plugin := createPluginEntry(taskName, taskVersion, *connectorDeployment, cs) + pluginmachinery.PluginRegistry().RegisterConnectorCorePlugin(plugin, deploymentID) } -} -func registerNewPlugin(taskType core.TaskType, taskTypeVersion int32, deploymentID string, deployment Deployment, cs *ClientSet) { - plugin := createPluginEntry(taskType, taskTypeVersion, deployment, cs) - pluginmachinery.PluginRegistry().RegisterConnectorCorePlugin(plugin, deploymentID) + // send message to Flyte Propeller TaskHandler to register or update plugin select { - case pluginmachinery.PluginRegistry().GetPluginRegistrationChan() <- pluginmachinery.PluginRegistrationInfo{ - Plugin: plugin, - DeploymentID: deploymentID, - }: - default: - logger.Errorf(context.Background(), "Failed to register plugin %s: channel is full", plugin.ID) - } + case pluginmachinery.PluginRegistry().GetPluginChan() <- pluginmachinery.PluginInfo{ + VersionedTaskType: versionedTaskType, + DeploymentID: deploymentID, + }: + default: + logger.Errorf(context.Background(), "Failed to create/update plugin for task type %s: channel is full", versionedTaskType) + } + return versionedTaskType } func RegisterConnectorPlugin() { diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 1e9e738209..aff7969333 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -261,6 +261,24 @@ func (t *Handler) FinalizeRequired() bool { return true } +func (t *Handler) createResourceManagerAndSetupCtx(ctx context.Context, sCtx interfaces.SetupContext, taskType string) (*nameSpacedSetupCtx, error) { + tSCtx := t.newSetupContext(sCtx) + // Create a new base resource negotiator + resourceManagerConfig := rmConfig.GetConfig() + newResourceManagerBuilder, err := resourcemanager.GetResourceManagerBuilderByType(ctx, resourceManagerConfig.Type, t.metrics.scope) + if err != nil { + return nil, err + } + return t.createNameSpacedSetupCtx(tSCtx, newResourceManagerBuilder, taskType), nil +} + +func (t *Handler) createNameSpacedSetupCtx(tSCtx *setupContext, newResourceManagerBuilder resourcemanager.Builder, taskType string) *nameSpacedSetupCtx { + pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(taskType)) + sCtxFinal := newNameSpacedSetupCtx( + tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), taskType) + return &sCtxFinal +} + // getConnectorPlugin retrieves a plugin from connectorDeploymentsForType map with proper locking func (t *Handler) getConnectorPlugin(taskType pluginCore.TaskType, deploymentID string) (pluginCore.Plugin, bool) { t.mu.RLock() @@ -288,6 +306,18 @@ func (t *Handler) registerConnectorPlugin(corePlugin pluginCore.Plugin, deployme t.mu.Unlock() } +func (t *Handler) isConnectorPluginRegistered(versionedTaskType string, deploymentID string) bool { + t.mu.RLock() + defer t.mu.RUnlock() + pluginsMap, ok := t.connectorDeploymentsForType[versionedTaskType] + if !ok { + return false + } + + _, ok = pluginsMap[deploymentID] + return ok +} + func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { if t.defaultPlugin != nil { logger.Errorf(ctx, "cannot set plugin [%s] as default as plugin [%s] is already configured as default", p.GetID(), t.defaultPlugin.GetID()) @@ -298,53 +328,50 @@ func (t *Handler) setDefault(ctx context.Context, p pluginCore.Plugin) error { return nil } -func (t *Handler) watchPlugins(ctx context.Context, sCtx interfaces.SetupContext) error { +func (t *Handler) watchPlugins(ctx context.Context, sCtx interfaces.SetupContext) { + defer func() { + if r := recover(); r != nil { + logger.Errorf(ctx, "WatchPlugins goroutine panicked: %v", r) + logger.Errorf(ctx, "Stack trace: %s", debug.Stack()) + } + }() + for { select { - case registerInfo := <-pluginMachinery.PluginRegistry().GetPluginRegistrationChan(): - if !pluginMachinery.PluginRegistry().IsPluginForTaskTypeRegistered(registerInfo.Plugin.ID, registerInfo.DeploymentID){ - wpe := registerInfo.Plugin - tSCtx := t.newSetupContext(sCtx) - // Create a new base resource negotiator - resourceManagerConfig := rmConfig.GetConfig() - newResourceManagerBuilder, err := resourcemanager.GetResourceManagerBuilderByType(ctx, resourceManagerConfig.Type, t.metrics.scope) + case info := <-pluginMachinery.PluginRegistry().GetPluginChan(): + // If plugin not registered yet, do registeration + if !t.isConnectorPluginRegistered(info.VersionedTaskType, info.DeploymentID){ + sCtxFinal, err := t.createResourceManagerAndSetupCtx(ctx, sCtx, info.VersionedTaskType) if err != nil { - return err + logger.Errorf(ctx, "Failed to create resource manager and setup context for task type [%s]: %v", info.VersionedTaskType, err) + continue } - pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(wpe.ID)) - sCtxFinal := newNameSpacedSetupCtx( - tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), wpe.ID) // register core plugin - if cpe, ok := pluginMachinery.PluginRegistry().GetConnectorCorePlugin(wpe.ID, registerInfo.DeploymentID); ok { + if cpe, ok := pluginMachinery.PluginRegistry().GetConnectorCorePlugin(info.VersionedTaskType, info.DeploymentID); ok { cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, cpe) if err != nil { - return regErrors.Wrapf(err, "failed to load plugin - %s", wpe.ID) + logger.Errorf(ctx, "Failed to load plugin of task type [%s]: %v", info.VersionedTaskType, err) + continue } // register the plugin to task handler local plugin registry - t.registerConnectorPlugin(cp, registerInfo.DeploymentID) - pluginMachinery.PluginRegistry().AddRegisteredPluginForTaskType(cp.GetID(), registerInfo.DeploymentID) - logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) + t.registerConnectorPlugin(cp, info.DeploymentID) + logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] registered", info.VersionedTaskType, info.DeploymentID) } + // If plugin already registered, update deployment } else { - logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] already registered", registerInfo.Plugin.ID, registerInfo.DeploymentID) - } - case updateInfo := <-pluginMachinery.PluginRegistry().GetPluginUpdateChan(): - if pluginMachinery.PluginRegistry().IsPluginForTaskTypeRegistered(updateInfo.TaskType, updateInfo.DeploymentID) { - plugin, pluginExists := t.getConnectorPlugin(updateInfo.TaskType, updateInfo.DeploymentID) + plugin, pluginExists := t.getConnectorPlugin(info.VersionedTaskType, info.DeploymentID) if pluginExists { t.mu.Lock() - t.defaultPlugins[updateInfo.TaskType] = plugin + t.defaultPlugins[info.VersionedTaskType] = plugin t.mu.Unlock() - logger.Infof(ctx, "The default plugin for TaskType [%s] has been updated to Deployment ID [%s]", updateInfo.TaskType, updateInfo.DeploymentID) + logger.Infof(ctx, "The default plugin for TaskType [%s] has been updated to Deployment ID [%s]", info.VersionedTaskType, info.DeploymentID) } else { - logger.Warningf(ctx, "Plugin for TaskType [%s] and deployment ID [%s] not found", updateInfo.TaskType, updateInfo.DeploymentID) + logger.Warningf(ctx, "Plugin for TaskType [%s] and deployment ID [%s] not found", info.VersionedTaskType, info.DeploymentID) } - } else { - logger.Infof(ctx, "Plugin of TaskType [%s] and deployment ID [%s] not yet registered", updateInfo.TaskType, updateInfo.DeploymentID) } case <-ctx.Done(): logger.Infof(ctx, "Plugin watcher stopped due to context cancellation") - return nil + return } } } @@ -359,7 +386,7 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return err } - // execute watcher to monitor plugins waiting for register from connector/agent plugin + // execute watcher to monitor plugins waiting for register/update from connector/agent plugin go t.watchPlugins(ctx, sCtx) once.Do(func() { @@ -382,9 +409,7 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error for _, p := range enabledPlugins { // create a new resource registrar proxy for each plugin, and pass it into the plugin's LoadPlugin() via a setup context - pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(p.ID)) - sCtxFinal := newNameSpacedSetupCtx( - tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), p.ID) + sCtxFinal := t.createNameSpacedSetupCtx(tSCtx, newResourceManagerBuilder, p.ID) logger.Infof(ctx, "Loading Plugin [%s] ENABLED", p.ID) cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, p) diff --git a/flytestdlib/cache/in_memory_auto_refresh.go b/flytestdlib/cache/in_memory_auto_refresh.go index e96b7d53cd..35501098b3 100644 --- a/flytestdlib/cache/in_memory_auto_refresh.go +++ b/flytestdlib/cache/in_memory_auto_refresh.go @@ -49,11 +49,6 @@ func newMetrics(scope promutils.Scope) metrics { metricsMutex.Lock() defer metricsMutex.Unlock() - // Double-check in case another goroutine created metrics while we were acquiring the lock - if cachedMetrics, exists := metricsCache[scopeName]; exists { - return *cachedMetrics - } - newMetrics := metrics{ SyncErrors: scope.MustNewCounter("sync_errors", "Counter for sync errors."), Evictions: scope.MustNewCounter("lru_evictions", "Counter for evictions from LRU."),