diff --git a/.mockery.yaml b/.mockery.yaml index 0b2e6c8c63b..3d583f42561 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -17,6 +17,9 @@ packages: github.com/elastic/elastic-agent/internal/pkg/agent/application/info: interfaces: Agent: {} + github.com/elastic/elastic-agent/internal/pkg/agent/application/gateway/fleet: + interfaces: + rollbacksSource: {} github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade: interfaces: WatcherHelper: {} diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go index 25628efda76..0cdaf441f12 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade.go @@ -59,7 +59,9 @@ func (h *Upgrade) Handle(ctx context.Context, a fleetapi.Action, ack acker.Acker return nil } - var uOpts []coordinator.UpgradeOpt + uOpts := []coordinator.UpgradeOpt{ + coordinator.WithRollback(action.Data.Rollback), + } if h.tamperProtectionFn() { // Find inputs that want to receive UPGRADE action // Endpoint needs to receive a signed UPGRADE action in order to be able to uncontain itself diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go index 9291ebc3dd8..6987679c3ac 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go @@ -451,11 +451,19 @@ func TestEndpointPreUpgradeCallback(t *testing.T) { } upgradeCalledChan := make(chan struct{}) - mockCoordinator.EXPECT().Upgrade(mock.Anything, tc.upgradeAction.Data.Version, tc.upgradeAction.Data.SourceURI, mock.Anything, mock.Anything). - RunAndReturn(func(ctx context.Context, s string, s2 string, actionUpgrade *fleetapi.ActionUpgrade, opt ...coordinator.UpgradeOpt) error { - upgradeCalledChan <- struct{}{} - return tc.coordUpgradeErr - }) + if tc.shouldProxyToEndpoint { + mockCoordinator.EXPECT().Upgrade(mock.Anything, tc.upgradeAction.Data.Version, tc.upgradeAction.Data.SourceURI, mock.Anything, mock.AnythingOfType("coordinator.UpgradeOpt"), mock.AnythingOfType("coordinator.UpgradeOpt")). + RunAndReturn(func(ctx context.Context, s string, s2 string, actionUpgrade *fleetapi.ActionUpgrade, opt ...coordinator.UpgradeOpt) error { + upgradeCalledChan <- struct{}{} + return tc.coordUpgradeErr + }) + } else { + mockCoordinator.EXPECT().Upgrade(mock.Anything, tc.upgradeAction.Data.Version, tc.upgradeAction.Data.SourceURI, mock.Anything, mock.AnythingOfType("coordinator.UpgradeOpt")). + RunAndReturn(func(ctx context.Context, s string, s2 string, actionUpgrade *fleetapi.ActionUpgrade, opt ...coordinator.UpgradeOpt) error { + upgradeCalledChan <- struct{}{} + return tc.coordUpgradeErr + }) + } log, _ := logger.New("", false) u := NewUpgrade(log, mockCoordinator) diff --git a/internal/pkg/agent/application/application.go b/internal/pkg/agent/application/application.go index c141815eef8..4a35e1abc85 100644 --- a/internal/pkg/agent/application/application.go +++ b/internal/pkg/agent/application/application.go @@ -7,15 +7,13 @@ package application import ( "context" "fmt" - "os" "path/filepath" "time" "go.elastic.co/apm/v2" componentmonitoring "github.com/elastic/elastic-agent/internal/pkg/agent/application/monitoring/component" - "github.com/elastic/elastic-agent/internal/pkg/agent/install" - + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/go-ucfg" "github.com/elastic/elastic-agent-libs/logp" @@ -52,8 +50,8 @@ import ( ) type rollbacksSource interface { - Set(map[string]upgrade.TTLMarker) error - Get() (map[string]upgrade.TTLMarker, error) + Set(map[string]ttl.TTLMarker) error + Get() (map[string]ttl.TTLMarker, error) } // CfgOverrider allows for application driven overrides of configuration read from disk. @@ -135,7 +133,7 @@ func New( // monitoring is not supported in bootstrap mode https://github.com/elastic/elastic-agent/issues/1761 isMonitoringSupported := !disableMonitoring && cfg.Settings.V1MonitoringEnabled - availableRollbacksSource := upgrade.NewTTLMarkerRegistry(log, paths.Top()) + availableRollbacksSource := ttl.NewTTLMarkerRegistry(log, paths.Top()) if upgrade.IsUpgradeable() { // If we are not running in a container, check and normalize the install descriptor before we start the agent normalizeAgentInstalls(log, paths.Top(), time.Now(), initialUpdateMarker, availableRollbacksSource) @@ -255,7 +253,7 @@ func New( } // TODO: stop using global state - managed, err = newManagedConfigManager(ctx, log, agentInfo, cfg, store, runtime, fleetInitTimeout, paths.Top(), client, fleetAcker, actionAcker, retrier, stateStorage, actionQueue, upgrader) + managed, err = newManagedConfigManager(ctx, log, agentInfo, cfg, store, runtime, fleetInitTimeout, paths.Top(), client, fleetAcker, actionAcker, retrier, stateStorage, actionQueue, availableRollbacksSource, upgrader) if err != nil { return nil, nil, nil, err } @@ -331,57 +329,15 @@ func normalizeAgentInstalls(log *logger.Logger, topDir string, now time.Time, in } } - // check if we need to cleanup old agent installs - rollbacks, err := rollbackSource.Get() + absHomePath := paths.Home() + relHomePath, err := filepath.Rel(topDir, absHomePath) if err != nil { - log.Warnf("Error getting available rollbacks during startup check: %s", err) + log.Warnf("Error calculating home path %q relative to top path %q: %s", absHomePath, topDir, err) return } - - var versionedHomesToCleanup []string - for versionedHome, ttlMarker := range rollbacks { - - versionedHomeAbsPath := filepath.Join(topDir, versionedHome) - - if versionedHomeAbsPath == paths.HomeFrom(topDir) { - // skip the current install - log.Warnf("Found a TTL marker for the currently running agent at %s. Skipping cleanup...", versionedHome) - continue - } - - _, err = os.Stat(versionedHomeAbsPath) - if errors.Is(err, os.ErrNotExist) { - log.Warnf("Versioned home %s corresponding to agent TTL marker %+v is not found on disk", versionedHomeAbsPath, ttlMarker) - versionedHomesToCleanup = append(versionedHomesToCleanup, versionedHome) - continue - } - - if err != nil { - log.Warnf("error checking versioned home %s for agent install: %s", versionedHomeAbsPath, err.Error()) - continue - } - - if now.After(ttlMarker.ValidUntil) { - // the install directory exists but it's expired. Remove the files. - log.Infof("agent TTL marker %+v marks %q as expired, removing directory", ttlMarker, versionedHomeAbsPath) - if cleanupErr := install.RemoveBut(versionedHomeAbsPath, true); cleanupErr != nil { - log.Warnf("Error removing directory %q: %s", versionedHomeAbsPath, cleanupErr) - } else { - log.Infof("Directory %q was removed", versionedHomeAbsPath) - versionedHomesToCleanup = append(versionedHomesToCleanup, versionedHome) - } - } - } - - if len(versionedHomesToCleanup) > 0 { - log.Infof("removing install descriptor(s) for %v", versionedHomesToCleanup) - for _, versionedHomeToCleanup := range versionedHomesToCleanup { - delete(rollbacks, versionedHomeToCleanup) - } - err = rollbackSource.Set(rollbacks) - if err != nil { - log.Warnf("Error removing install descriptor(s): %s", err) - } + _, err = upgrade.CleanAvailableRollbacks(log, rollbackSource, topDir, relHomePath, upgrade.PreserveActiveUpgradeVersions(initialUpdateMarker, upgrade.CleanupExpiredRollbacks)) + if err != nil { + log.Warnf("Error cleaning available rollbacks: %s", err) } } diff --git a/internal/pkg/agent/application/application_test.go b/internal/pkg/agent/application/application_test.go index 0c7fb737f7f..63bbff1be73 100644 --- a/internal/pkg/agent/application/application_test.go +++ b/internal/pkg/agent/application/application_test.go @@ -21,6 +21,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/config" "github.com/elastic/elastic-agent/internal/pkg/testutils" "github.com/elastic/elastic-agent/pkg/core/logger/loggertest" @@ -337,7 +338,7 @@ func Test_normalizeInstallDescriptorAtStartup(t *testing.T) { oldAgentInstallPath := createFakeAgentInstall(t, topDir, "1.2.3", "oldversionhash", true) mockRollbackSource := newMockRollbacksSource(t) - mockRollbackSource.EXPECT().Get().Return(map[string]upgrade.TTLMarker{ + mockRollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ oldAgentInstallPath: { Version: "1.2.3", Hash: "oldversionhash", @@ -366,7 +367,7 @@ func Test_normalizeInstallDescriptorAtStartup(t *testing.T) { } // expect code to clear the rollback - mockRollbackSource.EXPECT().Set(map[string]upgrade.TTLMarker{}).Return(nil) + mockRollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{}).Return(nil) return updateMarker, mockRollbackSource }, postNormalizeAssertions: nil, @@ -379,7 +380,7 @@ func Test_normalizeInstallDescriptorAtStartup(t *testing.T) { mockRollbackSource := newMockRollbacksSource(t) nonExistingVersionedHome := filepath.Join("data", "thisdirectorydoesnotexist") - mockRollbackSource.EXPECT().Get().Return(map[string]upgrade.TTLMarker{ + mockRollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ oldAgentInstallPath: { Version: "1.2.3", Hash: "oldversionhash", @@ -392,7 +393,7 @@ func Test_normalizeInstallDescriptorAtStartup(t *testing.T) { }, }, nil) - mockRollbackSource.EXPECT().Set(map[string]upgrade.TTLMarker{ + mockRollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{ oldAgentInstallPath: { Version: "1.2.3", Hash: "oldversionhash", @@ -415,7 +416,7 @@ func Test_normalizeInstallDescriptorAtStartup(t *testing.T) { mockRollbackSource := newMockRollbacksSource(t) mockRollbackSource.EXPECT().Get().Return( - map[string]upgrade.TTLMarker{ + map[string]ttl.TTLMarker{ oldAgentInstallPath: { Version: "1.2.3", Hash: "oldver", @@ -425,7 +426,7 @@ func Test_normalizeInstallDescriptorAtStartup(t *testing.T) { nil, ) // expect removal of the existing ttlmarker - mockRollbackSource.EXPECT().Set(map[string]upgrade.TTLMarker{}).Return(nil) + mockRollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{}).Return(nil) return nil, mockRollbackSource }, postNormalizeAssertions: func(t *testing.T, topDir string, _ *upgrade.UpdateMarker) { diff --git a/internal/pkg/agent/application/coordinator/coordinator.go b/internal/pkg/agent/application/coordinator/coordinator.go index 5fb4bc31876..6825889dd5a 100644 --- a/internal/pkg/agent/application/coordinator/coordinator.go +++ b/internal/pkg/agent/application/coordinator/coordinator.go @@ -16,6 +16,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/core/backoff" "github.com/elastic/elastic-agent/internal/pkg/otel/translate" + "github.com/elastic/elastic-agent/internal/pkg/release" "go.opentelemetry.io/collector/component/componentstatus" @@ -850,6 +851,22 @@ func (c *Coordinator) Upgrade(ctx context.Context, version string, sourceURI str return c.upgradeMgr.AckAction(ctx, c.fleetAcker, action) } + if errors.Is(err, upgrade.ErrNoRollbacksAvailable) && action != nil && release.VersionWithSnapshot() == action.Data.Version { + // when manually rolling back the action store is not copied back, so it's likely that the rolled back agent + // will receive (again) the rollback action because it's using an ackToken from before the rollback action + // was received by the "upgraded" elastic-agent. + // This block here is to avoid setting an error state because the rollback requested no longer exist after + // having performed the rollback once. + // A better test would be to compare actionIDs but there's no way to persist the actionID of the rollback action + // from the upgraded agent to the rolled back agent (upgrade details is reset when the upgrade marker is deleted) + c.logger.Infow( + "Received a rollback action with the same version as current and no rollbacks available, ignoring the likely replayed action", + "action_id", action.ID()) + c.ClearOverrideState() + det.SetState(details.StateRollback) + return c.upgradeMgr.AckAction(ctx, c.fleetAcker, action) + } + c.logger.Errorw("upgrade failed", "error", logp.Error(err)) // If ErrInsufficientDiskSpace is in the error chain, we want to set the // the error to ErrInsufficientDiskSpace so that the error message is @@ -861,10 +878,17 @@ func (c *Coordinator) Upgrade(ctx context.Context, version string, sourceURI str det.Fail(err) return err } + if cb != nil { det.SetState(details.StateRestarting) c.ReExec(cb) } + + if uOpts.rollback { + // Ack the rollback action, since there's no restart callback returned, this is still run + return c.upgradeMgr.AckAction(ctx, c.fleetAcker, action) + } + return nil } diff --git a/internal/pkg/agent/application/coordinator/coordinator_unit_test.go b/internal/pkg/agent/application/coordinator/coordinator_unit_test.go index b15dbd1baef..a3cb16fb462 100644 --- a/internal/pkg/agent/application/coordinator/coordinator_unit_test.go +++ b/internal/pkg/agent/application/coordinator/coordinator_unit_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/elastic/elastic-agent-client/v7/pkg/proto" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker" "github.com/elastic/elastic-agent/internal/pkg/testutils" @@ -468,7 +469,7 @@ func TestCoordinatorReportsInvalidPolicy(t *testing.T) { }() tmpDir := t.TempDir() - upgradeMgr, err := upgrade.NewUpgrader(log, &artifact.Config{}, nil, &info.AgentInfo{}, new(upgrade.AgentWatcherHelper), upgrade.NewTTLMarkerRegistry(nil, tmpDir)) + upgradeMgr, err := upgrade.NewUpgrader(log, &artifact.Config{}, nil, &info.AgentInfo{}, new(upgrade.AgentWatcherHelper), ttl.NewTTLMarkerRegistry(nil, tmpDir)) require.NoError(t, err, "errored when creating a new upgrader") // Channels have buffer length 1, so we don't have to run on multiple diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go index 555735c7ae8..52409a506b9 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go @@ -14,6 +14,7 @@ import ( "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" agentclient "github.com/elastic/elastic-agent/pkg/control/v2/client" eaclient "github.com/elastic/elastic-agent-client/v7/pkg/client" @@ -79,6 +80,10 @@ type stateStore interface { Action() fleetapi.Action } +type rollbacksSource interface { + Get() (map[string]ttl.TTLMarker, error) +} + type FleetGateway struct { log *logger.Logger client client.Sender @@ -92,6 +97,7 @@ type FleetGateway struct { stateFetcher StateFetcher errCh chan error actionCh chan []fleetapi.Action + rollbackSource rollbacksSource } // New creates a new fleet gateway @@ -103,6 +109,7 @@ func New( stateStore stateStore, stateFetcher StateFetcher, cfg *configuration.FleetCheckin, + source rollbacksSource, ) (*FleetGateway, error) { scheduler := scheduler.NewPeriodicJitter(defaultGatewaySettings.Duration, defaultGatewaySettings.Jitter) st := defaultGatewaySettings @@ -116,6 +123,7 @@ func New( acker, stateStore, stateFetcher, + source, ) } @@ -128,18 +136,20 @@ func newFleetGatewayWithScheduler( acker acker.Acker, stateStore stateStore, stateFetcher StateFetcher, + source rollbacksSource, ) (*FleetGateway, error) { return &FleetGateway{ - log: log, - client: client, - settings: settings, - agentInfo: agentInfo, - scheduler: scheduler, - acker: acker, - stateFetcher: stateFetcher, - stateStore: stateStore, - errCh: make(chan error), - actionCh: make(chan []fleetapi.Action, 1), + log: log, + client: client, + settings: settings, + agentInfo: agentInfo, + scheduler: scheduler, + acker: acker, + stateFetcher: stateFetcher, + stateStore: stateStore, + errCh: make(chan error), + actionCh: make(chan []fleetapi.Action, 1), + rollbackSource: source, }, nil } @@ -180,6 +190,7 @@ func (f *FleetGateway) Run(ctx context.Context) error { actions := make([]fleetapi.Action, len(resp.Actions)) copy(actions, resp.Actions) if len(actions) > 0 { + f.log.Infow("received new actions from Fleet checkin", "actions", actions) f.actionCh <- actions } } @@ -393,17 +404,41 @@ func (f *FleetGateway) execute(ctx context.Context) (*fleetapi.CheckinResponse, agentPolicyID := getPolicyID(action) policyRevisionIDX := getPolicyRevisionIDX(action) + // get available rollbacks + rollbacks, err := f.rollbackSource.Get() + if err != nil { + f.log.Warnf("error getting available rollbacks: %s", err.Error()) + // this should already be nil but let's make sure that we don't include rollbacks in checkin body when encountering errors + rollbacks = nil + } + + var validRollbacks []fleetapi.CheckinRollback + if len(rollbacks) > 0 { + now := time.Now() + validRollbacks = make([]fleetapi.CheckinRollback, 0, len(rollbacks)) + for _, rollback := range rollbacks { + if rollback.ValidUntil.After(now) { + // map the `ttl.Marker` to the `fleetapi.CheckinRollback` + validRollbacks = append(validRollbacks, fleetapi.CheckinRollback{ + Version: rollback.Version, + ValidUntil: rollback.ValidUntil, + }) + } + } + } + // checkin cmd := fleetapi.NewCheckinCmd(f.agentInfo, f.client) req := &fleetapi.CheckinRequest{ - AckToken: ackToken, - Metadata: ecsMeta, - Status: agentStateToString(state.State), - Message: state.Message, - Components: components, - UpgradeDetails: state.UpgradeDetails, - AgentPolicyID: agentPolicyID, - PolicyRevisionIDX: policyRevisionIDX, + AckToken: ackToken, + Metadata: ecsMeta, + Status: agentStateToString(state.State), + Message: state.Message, + Components: components, + UpgradeDetails: state.UpgradeDetails, + AgentPolicyID: agentPolicyID, + PolicyRevisionIDX: policyRevisionIDX, + AvailableRollbacks: validRollbacks, } resp, took, err := cmd.Execute(stateCtx, req) diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go index f79d36d1039..67cb657b33e 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go @@ -22,11 +22,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/elastic/elastic-agent-libs/logp" - "github.com/open-telemetry/opentelemetry-collector-contrib/pkg/status" "go.opentelemetry.io/collector/component/componentstatus" + "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" + eaclient "github.com/elastic/elastic-agent-client/v7/pkg/client" "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" @@ -92,16 +93,10 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat stateStore := newStateStore(t, log) - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - scheduler, - noop.New(), - stateStore, - NewCheckinStateFetcher(emptyStateFetcher), - ) + mockRollbacksSrc := newMockRollbacksSource(t) + mockRollbacksSrc.EXPECT().Get().Return(nil, nil) + + gateway, err := newFleetGatewayWithScheduler(log, settings, agentInfo, client, scheduler, noop.New(), stateStore, NewCheckinStateFetcher(emptyStateFetcher), mockRollbacksSrc) require.NoError(t, err) @@ -231,16 +226,10 @@ func TestFleetGateway(t *testing.T) { log, _ := logger.New("tst", false) stateStore := newStateStore(t, log) - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - scheduler, - noop.New(), - stateStore, - NewCheckinStateFetcher(emptyStateFetcher), - ) + mockRollbacksSrc := newMockRollbacksSource(t) + mockRollbacksSrc.EXPECT().Get().Return(nil, nil) + + gateway, err := newFleetGatewayWithScheduler(log, settings, agentInfo, client, scheduler, noop.New(), stateStore, NewCheckinStateFetcher(emptyStateFetcher), mockRollbacksSrc) require.NoError(t, err) waitFn := ackSeq( @@ -280,19 +269,13 @@ func TestFleetGateway(t *testing.T) { log, _ := logger.New("tst", false) stateStore := newStateStore(t, log) - gateway, err := newFleetGatewayWithScheduler( - log, - &fleetGatewaySettings{ - Duration: d, - Backoff: &backoffSettings{Init: 1 * time.Second, Max: 30 * time.Second}, - }, - agentInfo, - client, - scheduler, - noop.New(), - stateStore, - NewCheckinStateFetcher(emptyStateFetcher), - ) + mockRollbacksSrc := newMockRollbacksSource(t) + mockRollbacksSrc.EXPECT().Get().Return(nil, nil) + + gateway, err := newFleetGatewayWithScheduler(log, &fleetGatewaySettings{ + Duration: d, + Backoff: &backoffSettings{Init: 1 * time.Second, Max: 30 * time.Second}, + }, agentInfo, client, scheduler, noop.New(), stateStore, NewCheckinStateFetcher(emptyStateFetcher), mockRollbacksSrc) require.NoError(t, err) ch2 := client.Answer(func(_ context.Context, headers http.Header, body io.Reader) (*http.Response, error) { @@ -342,16 +325,10 @@ func TestFleetGateway(t *testing.T) { } } - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - scheduler, - noop.New(), - stateStore, - NewCheckinStateFetcher(stateFetcher), - ) + mockRollbacksSrc := newMockRollbacksSource(t) + mockRollbacksSrc.EXPECT().Get().Return(nil, nil) + + gateway, err := newFleetGatewayWithScheduler(log, settings, agentInfo, client, scheduler, noop.New(), stateStore, NewCheckinStateFetcher(stateFetcher), mockRollbacksSrc) require.NoError(t, err) @@ -411,16 +388,10 @@ func TestFleetGateway(t *testing.T) { err := stateStore.Save() require.NoError(t, err) - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - scheduler, - noop.New(), - stateStore, - NewCheckinStateFetcher(emptyStateFetcher), - ) + mockRollbacksSrc := newMockRollbacksSource(t) + mockRollbacksSrc.EXPECT().Get().Return(nil, nil) + + gateway, err := newFleetGatewayWithScheduler(log, settings, agentInfo, client, scheduler, noop.New(), stateStore, NewCheckinStateFetcher(emptyStateFetcher), mockRollbacksSrc) require.NoError(t, err) waitFn := ackSeq( @@ -469,19 +440,13 @@ func TestFleetGateway(t *testing.T) { stateFetcher := NewFastCheckinStateFetcher(log, emptyStateFetcher, stateChannel) - gateway, err := newFleetGatewayWithScheduler( - log, - &fleetGatewaySettings{ - Duration: 5 * time.Second, - Backoff: &backoffSettings{Init: 10 * time.Millisecond, Max: 30 * time.Second}, - }, - agentInfo, - client, - scheduler, - noop.New(), - stateStore, - stateFetcher, - ) + mockRollbacksSrc := newMockRollbacksSource(t) + mockRollbacksSrc.EXPECT().Get().Return(nil, nil) + + gateway, err := newFleetGatewayWithScheduler(log, &fleetGatewaySettings{ + Duration: 5 * time.Second, + Backoff: &backoffSettings{Init: 10 * time.Millisecond, Max: 30 * time.Second}, + }, agentInfo, client, scheduler, noop.New(), stateStore, stateFetcher, mockRollbacksSrc) require.NoError(t, err) requestSent := make(chan struct{}, 10) @@ -1144,3 +1109,118 @@ func TestConvertToCheckingComponents(t *testing.T) { }) } } + +func TestAvailableRollbacks(t *testing.T) { + testcases := []struct { + name string + setup func(t *testing.T, rbSource *mockRollbacksSource, client *testingClient) + wantErr assert.ErrorAssertionFunc + assertCheckinResponse func(t *testing.T, resp *fleetapi.CheckinResponse) + }{ + { + name: "no available rollbacks - normal checkin", + setup: func(t *testing.T, rbSource *mockRollbacksSource, client *testingClient) { + rbSource.EXPECT().Get().Return(nil, nil) + client.Answer(func(_ context.Context, _ http.Header, body io.Reader) (*http.Response, error) { + unmarshaled := map[string]interface{}{} + err := json.NewDecoder(body).Decode(&unmarshaled) + assert.NoError(t, err, "error decoding checkin body") + assert.NotContains(t, unmarshaled, "available_rollbacks") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("{}")), + }, + nil + }) + }, + wantErr: assert.NoError, + assertCheckinResponse: nil, + }, + { + name: "valid available rollbacks - assert key and value", + setup: func(t *testing.T, rbSource *mockRollbacksSource, client *testingClient) { + + validUntil := time.Now().UTC().Add(time.Minute) + // truncate to the second to avoid different precision due to marshal/unmarshal + validUntil = validUntil.Truncate(time.Second) + + rbSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ + "data/elastic-agent-1.2.3-abcdef": { + Version: "1.2.3", + Hash: "abcdef", + ValidUntil: validUntil, + }, + }, nil) + client.Answer(func(_ context.Context, _ http.Header, body io.Reader) (*http.Response, error) { + unmarshaled := map[string]json.RawMessage{} + err := json.NewDecoder(body).Decode(&unmarshaled) + assert.NoError(t, err, "error decoding checkin body") + if assert.Contains(t, unmarshaled, "available_rollbacks") { + // verify that we got the correct data + var actual []fleetapi.CheckinRollback + err = json.Unmarshal(unmarshaled["available_rollbacks"], &actual) + require.NoError(t, err, "error decoding available rollbacks from checkin body") + + expected := []fleetapi.CheckinRollback{{ + Version: "1.2.3", + ValidUntil: validUntil, + }} + assert.Equal(t, expected, actual) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("{}")), + }, + nil + }) + }, + wantErr: assert.NoError, + assertCheckinResponse: nil, + }, + { + name: "Error getting rollbacks should not make the checkin error out, just omit available_rollbacks", + setup: func(t *testing.T, rbSource *mockRollbacksSource, client *testingClient) { + rbSource.EXPECT().Get().Return(nil, errors.New("some error getting rollbacks")) + client.Answer(func(_ context.Context, _ http.Header, body io.Reader) (*http.Response, error) { + unmarshaled := map[string]interface{}{} + err := json.NewDecoder(body).Decode(&unmarshaled) + assert.NoError(t, err, "error decoding checkin body") + assert.NotContains(t, unmarshaled, "available_rollbacks") + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("{}")), + }, + nil + }) + }, + wantErr: assert.NoError, + assertCheckinResponse: nil, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + + stepperScheduler := scheduler.NewStepper() + testClient := newTestingClient() + log, _ := logger.New("fleet_gateway", false) + + stateStore := newStateStore(t, log) + + mockRollbacksSrc := newMockRollbacksSource(t) + + mockAgentInfo := new(testAgentInfo) + + tc.setup(t, mockRollbacksSrc, testClient) + + gateway, err := newFleetGatewayWithScheduler(log, defaultGatewaySettings, mockAgentInfo, testClient, stepperScheduler, noop.New(), stateStore, NewCheckinStateFetcher(emptyStateFetcher), mockRollbacksSrc) + require.NoError(t, err, "error creating gateway") + checkinResponse, _, err := gateway.execute(t.Context()) + tc.wantErr(t, err) + if tc.assertCheckinResponse != nil { + tc.assertCheckinResponse(t, checkinResponse) + } + }) + } +} diff --git a/internal/pkg/agent/application/gateway/fleet/mocks.go b/internal/pkg/agent/application/gateway/fleet/mocks.go new file mode 100644 index 00000000000..a4c2c9f0356 --- /dev/null +++ b/internal/pkg/agent/application/gateway/fleet/mocks.go @@ -0,0 +1,97 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package fleet + +import ( + mock "github.com/stretchr/testify/mock" + + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" +) + +// newMockRollbacksSource creates a new instance of mockRollbacksSource. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockRollbacksSource(t interface { + mock.TestingT + Cleanup(func()) +}) *mockRollbacksSource { + mock := &mockRollbacksSource{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// mockRollbacksSource is an autogenerated mock type for the rollbacksSource type +type mockRollbacksSource struct { + mock.Mock +} + +type mockRollbacksSource_Expecter struct { + mock *mock.Mock +} + +func (_m *mockRollbacksSource) EXPECT() *mockRollbacksSource_Expecter { + return &mockRollbacksSource_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function for the type mockRollbacksSource +func (_mock *mockRollbacksSource) Get() (map[string]ttl.TTLMarker, error) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 map[string]ttl.TTLMarker + var r1 error + if returnFunc, ok := ret.Get(0).(func() (map[string]ttl.TTLMarker, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() map[string]ttl.TTLMarker); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]ttl.TTLMarker) + } + } + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// mockRollbacksSource_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type mockRollbacksSource_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +func (_e *mockRollbacksSource_Expecter) Get() *mockRollbacksSource_Get_Call { + return &mockRollbacksSource_Get_Call{Call: _e.mock.On("Get")} +} + +func (_c *mockRollbacksSource_Get_Call) Run(run func()) *mockRollbacksSource_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *mockRollbacksSource_Get_Call) Return(stringToTTLMarker map[string]ttl.TTLMarker, err error) *mockRollbacksSource_Get_Call { + _c.Call.Return(stringToTTLMarker, err) + return _c +} + +func (_c *mockRollbacksSource_Get_Call) RunAndReturn(run func() (map[string]ttl.TTLMarker, error)) *mockRollbacksSource_Get_Call { + _c.Call.Return(run) + return _c +} diff --git a/internal/pkg/agent/application/managed_mode.go b/internal/pkg/agent/application/managed_mode.go index 8339ff8900e..6295e042557 100644 --- a/internal/pkg/agent/application/managed_mode.go +++ b/internal/pkg/agent/application/managed_mode.go @@ -38,65 +38,51 @@ import ( const dispatchFlushInterval = time.Minute * 5 type managedConfigManager struct { - log *logger.Logger - agentInfo info.Agent - cfg *configuration.Configuration - client *remote.Client - store storage.Store - stateStore *store.StateStore - actionQueue *queue.ActionQueue - dispatcher *dispatcher.ActionDispatcher - runtime *runtime.Manager - coord *coordinator.Coordinator - fleetInitTimeout time.Duration - initialClientSetters []actions.ClientSetter - fleetAcker *fleet.Acker - actionAcker acker.Acker - retrier *retrier.Retrier + log *logger.Logger + agentInfo info.Agent + cfg *configuration.Configuration + client *remote.Client + store storage.Store + stateStore *store.StateStore + actionQueue *queue.ActionQueue + dispatcher *dispatcher.ActionDispatcher + runtime *runtime.Manager + coord *coordinator.Coordinator + fleetInitTimeout time.Duration + initialClientSetters []actions.ClientSetter + fleetAcker *fleet.Acker + actionAcker acker.Acker + retrier *retrier.Retrier + availableRollbacksSource rollbacksSource ch chan coordinator.ConfigChange errCh chan error } -func newManagedConfigManager( - ctx context.Context, - log *logger.Logger, - agentInfo info.Agent, - cfg *configuration.Configuration, - storeSaver storage.Store, - runtime *runtime.Manager, - fleetInitTimeout time.Duration, - topPath string, - client *remote.Client, - fleetAcker *fleet.Acker, - actionAcker acker.Acker, - retrier *retrier.Retrier, - stateStore *store.StateStore, - actionQueue *queue.ActionQueue, - clientSetters ...actions.ClientSetter, -) (*managedConfigManager, error) { +func newManagedConfigManager(ctx context.Context, log *logger.Logger, agentInfo info.Agent, cfg *configuration.Configuration, storeSaver storage.Store, runtime *runtime.Manager, fleetInitTimeout time.Duration, topPath string, client *remote.Client, fleetAcker *fleet.Acker, actionAcker acker.Acker, retrier *retrier.Retrier, stateStore *store.StateStore, actionQueue *queue.ActionQueue, source rollbacksSource, clientSetters ...actions.ClientSetter) (*managedConfigManager, error) { actionDispatcher, err := dispatcher.New(log, topPath, handlers.NewDefault(log), actionQueue) if err != nil { return nil, fmt.Errorf("unable to initialize action dispatcher: %w", err) } return &managedConfigManager{ - log: log, - agentInfo: agentInfo, - cfg: cfg, - client: client, - store: storeSaver, - stateStore: stateStore, - actionQueue: actionQueue, - dispatcher: actionDispatcher, - runtime: runtime, - fleetInitTimeout: fleetInitTimeout, - ch: make(chan coordinator.ConfigChange), - errCh: make(chan error), - initialClientSetters: clientSetters, - fleetAcker: fleetAcker, - actionAcker: actionAcker, - retrier: retrier, + log: log, + agentInfo: agentInfo, + cfg: cfg, + client: client, + store: storeSaver, + stateStore: stateStore, + actionQueue: actionQueue, + dispatcher: actionDispatcher, + runtime: runtime, + fleetInitTimeout: fleetInitTimeout, + ch: make(chan coordinator.ConfigChange), + errCh: make(chan error), + initialClientSetters: clientSetters, + fleetAcker: fleetAcker, + actionAcker: actionAcker, + retrier: retrier, + availableRollbacksSource: source, }, nil } @@ -185,6 +171,7 @@ func (m *managedConfigManager) Run(ctx context.Context) error { m.stateStore, stateFetcher, m.cfg.Fleet.Checkin, + m.availableRollbacksSource, ) if err != nil { return err diff --git a/internal/pkg/agent/application/mocks.go b/internal/pkg/agent/application/mocks.go index e2dd99b9e9c..c2c7703c61b 100644 --- a/internal/pkg/agent/application/mocks.go +++ b/internal/pkg/agent/application/mocks.go @@ -11,7 +11,7 @@ package application import ( mock "github.com/stretchr/testify/mock" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" ) // newMockRollbacksSource creates a new instance of mockRollbacksSource. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -42,23 +42,23 @@ func (_m *mockRollbacksSource) EXPECT() *mockRollbacksSource_Expecter { } // Get provides a mock function for the type mockRollbacksSource -func (_mock *mockRollbacksSource) Get() (map[string]upgrade.TTLMarker, error) { +func (_mock *mockRollbacksSource) Get() (map[string]ttl.TTLMarker, error) { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for Get") } - var r0 map[string]upgrade.TTLMarker + var r0 map[string]ttl.TTLMarker var r1 error - if returnFunc, ok := ret.Get(0).(func() (map[string]upgrade.TTLMarker, error)); ok { + if returnFunc, ok := ret.Get(0).(func() (map[string]ttl.TTLMarker, error)); ok { return returnFunc() } - if returnFunc, ok := ret.Get(0).(func() map[string]upgrade.TTLMarker); ok { + if returnFunc, ok := ret.Get(0).(func() map[string]ttl.TTLMarker); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(map[string]upgrade.TTLMarker) + r0 = ret.Get(0).(map[string]ttl.TTLMarker) } } if returnFunc, ok := ret.Get(1).(func() error); ok { @@ -86,18 +86,18 @@ func (_c *mockRollbacksSource_Get_Call) Run(run func()) *mockRollbacksSource_Get return _c } -func (_c *mockRollbacksSource_Get_Call) Return(stringToTTLMarker map[string]upgrade.TTLMarker, err error) *mockRollbacksSource_Get_Call { +func (_c *mockRollbacksSource_Get_Call) Return(stringToTTLMarker map[string]ttl.TTLMarker, err error) *mockRollbacksSource_Get_Call { _c.Call.Return(stringToTTLMarker, err) return _c } -func (_c *mockRollbacksSource_Get_Call) RunAndReturn(run func() (map[string]upgrade.TTLMarker, error)) *mockRollbacksSource_Get_Call { +func (_c *mockRollbacksSource_Get_Call) RunAndReturn(run func() (map[string]ttl.TTLMarker, error)) *mockRollbacksSource_Get_Call { _c.Call.Return(run) return _c } // Set provides a mock function for the type mockRollbacksSource -func (_mock *mockRollbacksSource) Set(stringToTTLMarker map[string]upgrade.TTLMarker) error { +func (_mock *mockRollbacksSource) Set(stringToTTLMarker map[string]ttl.TTLMarker) error { ret := _mock.Called(stringToTTLMarker) if len(ret) == 0 { @@ -105,7 +105,7 @@ func (_mock *mockRollbacksSource) Set(stringToTTLMarker map[string]upgrade.TTLMa } var r0 error - if returnFunc, ok := ret.Get(0).(func(map[string]upgrade.TTLMarker) error); ok { + if returnFunc, ok := ret.Get(0).(func(map[string]ttl.TTLMarker) error); ok { r0 = returnFunc(stringToTTLMarker) } else { r0 = ret.Error(0) @@ -119,16 +119,16 @@ type mockRollbacksSource_Set_Call struct { } // Set is a helper method to define mock.On call -// - stringToTTLMarker map[string]upgrade.TTLMarker +// - stringToTTLMarker map[string]ttl.TTLMarker func (_e *mockRollbacksSource_Expecter) Set(stringToTTLMarker interface{}) *mockRollbacksSource_Set_Call { return &mockRollbacksSource_Set_Call{Call: _e.mock.On("Set", stringToTTLMarker)} } -func (_c *mockRollbacksSource_Set_Call) Run(run func(stringToTTLMarker map[string]upgrade.TTLMarker)) *mockRollbacksSource_Set_Call { +func (_c *mockRollbacksSource_Set_Call) Run(run func(stringToTTLMarker map[string]ttl.TTLMarker)) *mockRollbacksSource_Set_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 map[string]upgrade.TTLMarker + var arg0 map[string]ttl.TTLMarker if args[0] != nil { - arg0 = args[0].(map[string]upgrade.TTLMarker) + arg0 = args[0].(map[string]ttl.TTLMarker) } run( arg0, @@ -142,7 +142,7 @@ func (_c *mockRollbacksSource_Set_Call) Return(err error) *mockRollbacksSource_S return _c } -func (_c *mockRollbacksSource_Set_Call) RunAndReturn(run func(stringToTTLMarker map[string]upgrade.TTLMarker) error) *mockRollbacksSource_Set_Call { +func (_c *mockRollbacksSource_Set_Call) RunAndReturn(run func(stringToTTLMarker map[string]ttl.TTLMarker) error) *mockRollbacksSource_Set_Call { _c.Call.Return(run) return _c } diff --git a/internal/pkg/agent/application/upgrade/manual_rollback.go b/internal/pkg/agent/application/upgrade/manual_rollback.go index 664ba586da1..8f9d8231e9c 100644 --- a/internal/pkg/agent/application/upgrade/manual_rollback.go +++ b/internal/pkg/agent/application/upgrade/manual_rollback.go @@ -16,6 +16,8 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/application/reexec" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" + "github.com/elastic/elastic-agent/internal/pkg/agent/install" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/internal/pkg/release" "github.com/elastic/elastic-agent/pkg/core/logger" @@ -44,12 +46,12 @@ func (u *Upgrader) rollbackToPreviousVersion(ctx context.Context, topDir string, if errors.Is(err, os.ErrNotExist) { // there is no upgrade marker (the rollback was requested after the watcher grace period had elapsed), we need // to extract available rollbacks from agent installs - watcherExecutable, versionedHomeToRollbackTo, err = rollbackUsingAgentInstalls(u.log, u.watcherHelper, u.availableRollbacksSource, topDir, now, version, u.markUpgrade) + watcherExecutable, versionedHomeToRollbackTo, err = rollbackUsingAgentInstalls(u.log, u.watcherHelper, u.availableRollbacksSource, topDir, now, version, u.markUpgrade, action) } else { // If upgrade marker is available, we need to gracefully stop any watcher process, read the available rollbacks from // the upgrade marker and then proceed with rollback updateMarkerExistsBeforeRollback = true - watcherExecutable, versionedHomeToRollbackTo, err = rollbackUsingUpgradeMarker(ctx, u.log, u.watcherHelper, topDir, now, version) + watcherExecutable, versionedHomeToRollbackTo, err = rollbackUsingUpgradeMarker(ctx, u.log, u.watcherHelper, topDir, now, version, action) } if err != nil { @@ -80,7 +82,7 @@ func (u *Upgrader) rollbackToPreviousVersion(ctx context.Context, topDir string, return nil, nil } -func rollbackUsingAgentInstalls(log *logger.Logger, watcherHelper WatcherHelper, source availableRollbacksSource, topDir string, now time.Time, rollbackVersion string, markUpgrade markUpgradeFunc) (string, string, error) { +func rollbackUsingAgentInstalls(log *logger.Logger, watcherHelper WatcherHelper, source availableRollbacksSource, topDir string, now time.Time, rollbackVersion string, markUpgrade markUpgradeFunc, action *fleetapi.ActionUpgrade) (string, string, error) { // read the available installs availableRollbacks, err := source.Get() if err != nil { @@ -88,7 +90,7 @@ func rollbackUsingAgentInstalls(log *logger.Logger, watcherHelper WatcherHelper, } // check for the version we want to rollback to var targetInstall string - var targetTTLMarker TTLMarker + var targetTTLMarker ttl.TTLMarker for versionedHome, ttlMarker := range availableRollbacks { if ttlMarker.Version == rollbackVersion && now.Before(ttlMarker.ValidUntil) { // found a valid target @@ -127,8 +129,12 @@ func rollbackUsingAgentInstalls(log *logger.Logger, watcherHelper WatcherHelper, versionedHome: targetInstall, } - upgradeDetails := details.NewDetails(release.VersionWithSnapshot(), details.StateRequested, "" /*action.ID*/) - err = markUpgrade(log, paths.DataFrom(topDir), now, curAgentInstall, prevAgentInstall, nil /*action*/, upgradeDetails, nil) + actionId := "" + if action != nil { + actionId = action.ActionID + } + upgradeDetails := details.NewDetails(release.VersionWithSnapshot(), details.StateRequested, actionId) + err = markUpgrade(log, paths.DataFrom(topDir), now, curAgentInstall, prevAgentInstall, action, upgradeDetails, nil) if err != nil { return "", "", fmt.Errorf("creating upgrade marker: %w", err) } @@ -138,7 +144,7 @@ func rollbackUsingAgentInstalls(log *logger.Logger, watcherHelper WatcherHelper, return watcherExecutable, targetInstall, nil } -func rollbackUsingUpgradeMarker(ctx context.Context, log *logger.Logger, watcherHelper WatcherHelper, topDir string, now time.Time, version string) (string, string, error) { +func rollbackUsingUpgradeMarker(ctx context.Context, log *logger.Logger, watcherHelper WatcherHelper, topDir string, now time.Time, version string, _ *fleetapi.ActionUpgrade) (string, string, error) { // read the upgrade marker updateMarker, err := LoadMarker(paths.DataFrom(topDir)) if err != nil { @@ -238,7 +244,7 @@ func extractAgentInstallsFromMarker(updateMarker *UpdateMarker) (previous agentI return previous, current, nil } -func getAvailableRollbacks(rollbackWindow time.Duration, now time.Time, currentVersion string, parsedCurrentVersion *version.ParsedSemVer, currentVersionedHome string, currentHash string) map[string]TTLMarker { +func getAvailableRollbacks(rollbackWindow time.Duration, now time.Time, currentVersion string, parsedCurrentVersion *version.ParsedSemVer, currentVersionedHome string, currentHash string) map[string]ttl.TTLMarker { if rollbackWindow == disableRollbackWindow { // if there's no rollback window it means that no rollback should survive the watcher cleanup at the end of the grace period. return nil @@ -251,8 +257,8 @@ func getAvailableRollbacks(rollbackWindow time.Duration, now time.Time, currentV // when multiple rollbacks will be supported, read the existing descriptor // at this stage we can get by with a single rollback - res := make(map[string]TTLMarker, 1) - res[currentVersionedHome] = TTLMarker{ + res := make(map[string]ttl.TTLMarker, 1) + res[currentVersionedHome] = ttl.TTLMarker{ Version: currentVersion, Hash: currentHash, ValidUntil: now.Add(rollbackWindow), @@ -260,3 +266,90 @@ func getAvailableRollbacks(rollbackWindow time.Duration, now time.Time, currentV return res } + +type RollbackCleanupFilter func(log *logger.Logger, now time.Time, versionedHome string, ttl ttl.TTLMarker) bool + +// CleanupAllRollbacks is a filter that will match all available rollbacks +func CleanupAllRollbacks(_ *logger.Logger, _ time.Time, _ string, _ ttl.TTLMarker) bool { + return true +} + +// CleanupExpiredRollbacks is a filter that will match all expired rollback targets +func CleanupExpiredRollbacks(log *logger.Logger, now time.Time, versionedHome string, ttl ttl.TTLMarker) bool { + if now.After(ttl.ValidUntil) { + // the install directory exists but it's expired. Remove the files. + log.Infof("agent TTL marker %+v marks %q as expired, removing directory", ttl, versionedHome) + return true + } + + return false +} + +// PreserveActiveUpgradeVersions is a decorator to a filter function that will preserve versions involved in an ongoing upgrade +func PreserveActiveUpgradeVersions(marker *UpdateMarker, innerFilter RollbackCleanupFilter) RollbackCleanupFilter { + return func(log *logger.Logger, now time.Time, versionedHome string, ttl ttl.TTLMarker) bool { + if marker != nil && !IsTerminalState(marker) { + // we are in the middle of an active upgrade + if marker.PrevVersionedHome == versionedHome { + // if the versionedHome matches the old versioned home, skip that cleanup. + return false + } + } + return innerFilter(log, now, versionedHome, ttl) + } +} + +// CleanAvailableRollbacks will remove the extra agent installs that can be used as manual rollback target. Invoked before triggering +// an update in order to free disk space for the new agent version or whenever a cleanup should happen. +// This function has basic protection for the current home and it will remove any available rollback for which the filter function +// returns true. +// This function will return the leftover available rollbacks that will survive the cleanup, can be used to schedule another launch +// of the cleanup in the future +func CleanAvailableRollbacks(log *logger.Logger, source availableRollbacksSource, topDir string, currentHomeRelPath string, filter RollbackCleanupFilter) (map[string]ttl.TTLMarker, error) { + rollbacks, err := source.Get() + if err != nil { + return nil, fmt.Errorf("unable to get available rollbacks: %w", err) + } + + if len(rollbacks) == 0 { + log.Debugf("No available rollbacks returned, exiting cleanup") + return nil, nil + } + + // Clean the currentHomeRel path to normalize it + currentHomeRelPath = filepath.Clean(currentHomeRelPath) + + log.Debugw("preparing to cleanup rollbacks", "rollbacks", rollbacks) + var aggregateErr error + now := time.Now().UTC() + + leftoverRollbacks := map[string]ttl.TTLMarker{} + + for versionedHome, ttlMarker := range rollbacks { + + if currentHomeRelPath == filepath.Clean(versionedHome) { + log.Warnf("skipping cleanup of available rollback located in %q as it matches the current home", versionedHome) + continue + } + + versionedHomeAbsPath := filepath.Join(topDir, versionedHome) + _, err = os.Stat(versionedHomeAbsPath) + if errors.Is(err, os.ErrNotExist) { + log.Warnf("Versioned home %s corresponding to agent TTL marker %+v is not found on disk", versionedHomeAbsPath, ttlMarker) + continue + } + + if filter(log, now, versionedHome, ttlMarker) { + if cleanupErr := install.RemoveBut(versionedHomeAbsPath, true); cleanupErr != nil { + aggregateErr = errors.Join(aggregateErr, fmt.Errorf("removing directory %q: %w", versionedHomeAbsPath, cleanupErr)) + } + } else { + leftoverRollbacks[versionedHome] = ttlMarker + } + } + err = source.Set(leftoverRollbacks) + if err != nil { + aggregateErr = errors.Join(aggregateErr, fmt.Errorf("unable to update available rollbacks on source: %w", err)) + } + return leftoverRollbacks, aggregateErr +} diff --git a/internal/pkg/agent/application/upgrade/manual_rollback_test.go b/internal/pkg/agent/application/upgrade/manual_rollback_test.go index ee2e840a7f6..7e63ae4fa2c 100644 --- a/internal/pkg/agent/application/upgrade/manual_rollback_test.go +++ b/internal/pkg/agent/application/upgrade/manual_rollback_test.go @@ -23,8 +23,10 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/agent/configuration" "github.com/elastic/elastic-agent/internal/pkg/release" + "github.com/elastic/elastic-agent/pkg/core/logger" "github.com/elastic/elastic-agent/pkg/core/logger/loggertest" "github.com/elastic/elastic-agent/pkg/version" agtversion "github.com/elastic/elastic-agent/version" @@ -321,7 +323,7 @@ func TestManualRollback(t *testing.T) { { name: "no update marker, available install for rollback with valid TTL - rollback", setup: func(t *testing.T, topDir string, agent *info.MockAgent, watcherHelper *MockWatcherHelper, rollbacksSource *mockAvailableRollbacksSource) { - rollbacksSource.EXPECT().Get().Return(map[string]TTLMarker{ + rollbacksSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ "data/elastic-agent-1.2.3-oldver": { Version: "1.2.3", Hash: "oldver", @@ -372,7 +374,7 @@ func TestManualRollback(t *testing.T) { name: "no update marker, available install for rollback with expired TTL - error", setup: func(t *testing.T, topDir string, agent *info.MockAgent, watcherHelper *MockWatcherHelper, rollbacksSource *mockAvailableRollbacksSource) { rollbacksSource.EXPECT().Get().Return( - map[string]TTLMarker{ + map[string]ttl.TTLMarker{ "data/elastic-agent-1.2.3-oldver": { Version: "1.2.3", Hash: "oldver", @@ -402,7 +404,7 @@ func TestManualRollback(t *testing.T) { name: "no update marker, no available install for the version - error", setup: func(t *testing.T, topDir string, agent *info.MockAgent, watcherHelper *MockWatcherHelper, rollbacksSource *mockAvailableRollbacksSource) { rollbacksSource.EXPECT().Get().Return( - map[string]TTLMarker{ + map[string]ttl.TTLMarker{ "data/elastic-agent-1.2.3-oldver": { Version: "1.2.3", Hash: "oldver", @@ -447,7 +449,7 @@ func TestManualRollback(t *testing.T) { name: "no update marker, invoking watcher fails - error", setup: func(t *testing.T, topDir string, agent *info.MockAgent, watcherHelper *MockWatcherHelper, rollbacksSource *mockAvailableRollbacksSource) { rollbacksSource.EXPECT().Get().Return( - map[string]TTLMarker{ + map[string]ttl.TTLMarker{ "data/elastic-agent-1.2.3-oldver": { Version: "1.2.3", Hash: "oldver", @@ -501,3 +503,321 @@ func TestManualRollback(t *testing.T) { }) } } + +func TestCleanAvailableRollbacks(t *testing.T) { + // various timestamps + now := time.Now().UTC().Truncate(time.Millisecond) + oneHourAgo := now.Add(-1 * time.Hour) + oneHourFromNow := now.Add(1 * time.Hour) + + // Convenience test agent version structs + v123Expired := testAgentVersion{ + version: "1.2.3", + hash: "expire", + } + v456Valid := testAgentVersion{ + version: "4.5.6", + hash: "valid1", + } + v789Actual := testAgentVersion{ + version: "7.8.9", + hash: "actual", + } + + type args struct { + currentHomeRelPath string + filter RollbackCleanupFilter + } + tests := []struct { + name string + setup func(t *testing.T, log *logger.Logger, topDir string, rollbackSource *mockAvailableRollbacksSource) + args args + want map[string]ttl.TTLMarker + wantErr assert.ErrorAssertionFunc + postCleanupAssertions func(t *testing.T, topDir string) + }{ + { + name: "Clear all available rollbacks regardless of ttl when using CleanupAllRollbacks", + setup: func(t *testing.T, log *logger.Logger, topDir string, rollbackSource *mockAvailableRollbacksSource) { + rollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-1.2.3-expire"): { + Version: "1.2.3", + Hash: "expire", + ValidUntil: oneHourAgo, // expired 1 hour ago + }, + filepath.Join("data", "elastic-agent-4.5.6-valid1"): { + Version: "4.5.6", + Hash: "valid1", + ValidUntil: oneHourFromNow, // still valid + }, + }, nil) + rollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{}).Return(nil) + + // setup the fake agent installations + setupAgents(t, log, topDir, setupAgentInstallations{ + installedAgents: []testAgentInstall{ + { + version: v123Expired, + useVersionInPath: true, + }, + { + version: v456Valid, + useVersionInPath: true, + }, + { + version: v789Actual, + useVersionInPath: true, + }, + }, + upgradeFrom: testAgentVersion{}, + upgradeTo: testAgentVersion{}, + currentAgent: v789Actual, + }, + false, + ) + }, + args: args{ + currentHomeRelPath: filepath.Join("data", "elastic-agent-7.8.9-actual"), + filter: CleanupAllRollbacks, + }, + want: map[string]ttl.TTLMarker{}, + wantErr: assert.NoError, + postCleanupAssertions: func(t *testing.T, topDir string) { + assert.NoDirExists(t, filepath.Join(topDir, "data", "elastic-agent-1.2.3-expire"), "expired rollback should have been removed") + assert.NoDirExists(t, filepath.Join(topDir, "data", "elastic-agent-4.5.6-valid1"), "valid rollback should have been removed") + assert.DirExists(t, filepath.Join(topDir, "data", "elastic-agent-7.8.9-actual"), "current agent install should have been preserved") + }, + }, + { + name: "Clear expired available rollbacks when using CleanupExpiredRollbacks", + setup: func(t *testing.T, log *logger.Logger, topDir string, rollbackSource *mockAvailableRollbacksSource) { + rollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-1.2.3-expire"): { + Version: "1.2.3", + Hash: "expire", + ValidUntil: oneHourAgo, // expired 1 hour ago + }, + filepath.Join("data", "elastic-agent-4.5.6-valid1"): { + Version: "4.5.6", + Hash: "valid1", + ValidUntil: oneHourFromNow, // still valid + }, + }, nil) + + rollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-4.5.6-valid1"): { + Version: "4.5.6", + Hash: "valid1", + ValidUntil: oneHourFromNow, // still valid + }, + }).Return(nil) + + // setup the fake agent installations + setupAgents(t, log, topDir, setupAgentInstallations{ + installedAgents: []testAgentInstall{ + { + version: v123Expired, + useVersionInPath: true, + }, + { + version: v456Valid, + useVersionInPath: true, + }, + { + version: v789Actual, + useVersionInPath: true, + }, + }, + upgradeFrom: testAgentVersion{}, + upgradeTo: testAgentVersion{}, + currentAgent: v789Actual, + }, + false, + ) + }, + args: args{ + currentHomeRelPath: filepath.Join("data", "elastic-agent-7.8.9-actual"), + filter: CleanupExpiredRollbacks, + }, + want: map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-4.5.6-valid1"): { + Version: "4.5.6", + Hash: "valid1", + ValidUntil: oneHourFromNow, // still valid + }, + }, + wantErr: assert.NoError, + postCleanupAssertions: func(t *testing.T, topDir string) { + assert.NoDirExists(t, filepath.Join(topDir, "data", "elastic-agent-1.2.3-expire"), "expired rollback should have been removed") + assert.DirExists(t, filepath.Join(topDir, "data", "elastic-agent-4.5.6-valid1"), "valid rollback should have not been removed") + assert.DirExists(t, filepath.Join(topDir, "data", "elastic-agent-7.8.9-actual"), "current agent install should have been preserved") + }, + }, + { + name: "Current install should be preserved when using CleanupAllRollbacks even if marked as an available rollback", + setup: func(t *testing.T, log *logger.Logger, topDir string, rollbackSource *mockAvailableRollbacksSource) { + rollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-7.8.9-actual"): { + Version: "7.8.9", + Hash: "actual", + ValidUntil: oneHourAgo, // expired 1 hour ago + }, + }, nil) + + rollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{}).Return(nil) + + // setup the fake agent installations + setupAgents(t, log, topDir, setupAgentInstallations{ + installedAgents: []testAgentInstall{ + { + version: v789Actual, + useVersionInPath: true, + }, + }, + upgradeFrom: testAgentVersion{}, + upgradeTo: testAgentVersion{}, + currentAgent: v789Actual, + }, + false, + ) + }, + args: args{ + currentHomeRelPath: filepath.Join("data", "elastic-agent-7.8.9-actual"), + filter: CleanupAllRollbacks, + }, + want: map[string]ttl.TTLMarker{}, + wantErr: assert.NoError, + postCleanupAssertions: func(t *testing.T, topDir string) { + assert.DirExists(t, filepath.Join(topDir, "data", "elastic-agent-7.8.9-actual"), "current agent install should have been preserved") + }, + }, + { + name: "Current install should be preserved when using CleanupExpiredRollbacks even if marked as an available rollback", + setup: func(t *testing.T, log *logger.Logger, topDir string, rollbackSource *mockAvailableRollbacksSource) { + rollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-7.8.9-actual"): { + Version: "7.8.9", + Hash: "actual", + ValidUntil: oneHourAgo, // expired 1 hour ago + }, + }, nil) + rollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{}).Return(nil) + + // setup the fake agent installations + setupAgents(t, log, topDir, setupAgentInstallations{ + installedAgents: []testAgentInstall{ + { + version: v789Actual, + useVersionInPath: true, + }, + }, + upgradeFrom: testAgentVersion{}, + upgradeTo: testAgentVersion{}, + currentAgent: v789Actual, + }, + false, + ) + }, + args: args{ + currentHomeRelPath: filepath.Join("data", "elastic-agent-7.8.9-actual"), + filter: CleanupExpiredRollbacks, + }, + want: map[string]ttl.TTLMarker{}, + wantErr: assert.NoError, + postCleanupAssertions: func(t *testing.T, topDir string) { + assert.DirExists(t, filepath.Join(topDir, "data", "elastic-agent-7.8.9-actual"), "current agent install should have been preserved") + }, + }, + { + name: "Preserve available rollbacks if involved in an active upgrade", + setup: func(t *testing.T, log *logger.Logger, topDir string, rollbackSource *mockAvailableRollbacksSource) { + + rollbackSource.EXPECT().Get().Return(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-1.2.3-oldver"): { + Version: "1.2.3", + Hash: "oldver", + ValidUntil: oneHourAgo, // expired 1 hour ago + }, + }, nil) + + rollbackSource.EXPECT().Set(map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-1.2.3-oldver"): { + Version: "1.2.3", + Hash: "oldver", + ValidUntil: oneHourAgo, // expired 1 hour ago + }, + }).Return(nil) + + fromVersion := testAgentVersion{ + version: "1.2.3", + hash: "oldver", + } + + // setup the fake agent installations + toVersion := testAgentVersion{ + version: "4.5.6", + hash: "newver", + } + setupAgents(t, log, topDir, setupAgentInstallations{ + installedAgents: []testAgentInstall{ + { + version: fromVersion, + useVersionInPath: true, + }, + { + version: toVersion, + useVersionInPath: true, + }, + }, + upgradeFrom: fromVersion, + upgradeTo: toVersion, + currentAgent: toVersion, + }, true) + }, + args: args{ + currentHomeRelPath: filepath.Join("data", "elastic-agent-4.5.6-newver"), + filter: PreserveActiveUpgradeVersions(&UpdateMarker{ + Version: "4.5.6", + Hash: "newver", + VersionedHome: filepath.Join("data", "elastic-agent-4.5.6-newver"), + UpdatedOn: now, + PrevVersion: "1.2.3", + PrevHash: "oldver", + PrevVersionedHome: filepath.Join("data", "elastic-agent-1.2.3-oldver"), + Acked: false, + Action: nil, + Details: nil, + RollbacksAvailable: nil, + }, + CleanupExpiredRollbacks, + ), + }, + want: map[string]ttl.TTLMarker{ + filepath.Join("data", "elastic-agent-1.2.3-oldver"): { + Version: "1.2.3", + Hash: "oldver", + ValidUntil: now.Add(-1 * time.Hour), + }, + }, + wantErr: assert.NoError, + postCleanupAssertions: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + log, _ := loggertest.New(t.Name()) + topDir := t.TempDir() + err := os.MkdirAll(filepath.Join(topDir, "data"), 0755) + require.NoError(t, err, "error creating data directory in topDir %q", topDir) + mockRollbacksSource := newMockAvailableRollbacksSource(t) + + tt.setup(t, log, topDir, mockRollbacksSource) + got, err := CleanAvailableRollbacks(log, mockRollbacksSource, topDir, tt.args.currentHomeRelPath, tt.args.filter) + tt.wantErr(t, err) + assert.Equal(t, tt.want, got) + if tt.postCleanupAssertions != nil { + tt.postCleanupAssertions(t, topDir) + } + }) + } +} diff --git a/internal/pkg/agent/application/upgrade/mocks.go b/internal/pkg/agent/application/upgrade/mocks.go index 5ef4993d084..1300327674d 100644 --- a/internal/pkg/agent/application/upgrade/mocks.go +++ b/internal/pkg/agent/application/upgrade/mocks.go @@ -16,6 +16,7 @@ import ( mock "github.com/stretchr/testify/mock" "github.com/elastic/elastic-agent/internal/pkg/agent/application/filelock" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/pkg/core/logger" ) @@ -367,23 +368,23 @@ func (_m *mockAvailableRollbacksSource) EXPECT() *mockAvailableRollbacksSource_E } // Get provides a mock function for the type mockAvailableRollbacksSource -func (_mock *mockAvailableRollbacksSource) Get() (map[string]TTLMarker, error) { +func (_mock *mockAvailableRollbacksSource) Get() (map[string]ttl.TTLMarker, error) { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for Get") } - var r0 map[string]TTLMarker + var r0 map[string]ttl.TTLMarker var r1 error - if returnFunc, ok := ret.Get(0).(func() (map[string]TTLMarker, error)); ok { + if returnFunc, ok := ret.Get(0).(func() (map[string]ttl.TTLMarker, error)); ok { return returnFunc() } - if returnFunc, ok := ret.Get(0).(func() map[string]TTLMarker); ok { + if returnFunc, ok := ret.Get(0).(func() map[string]ttl.TTLMarker); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(map[string]TTLMarker) + r0 = ret.Get(0).(map[string]ttl.TTLMarker) } } if returnFunc, ok := ret.Get(1).(func() error); ok { @@ -411,18 +412,18 @@ func (_c *mockAvailableRollbacksSource_Get_Call) Run(run func()) *mockAvailableR return _c } -func (_c *mockAvailableRollbacksSource_Get_Call) Return(stringToTTLMarker map[string]TTLMarker, err error) *mockAvailableRollbacksSource_Get_Call { +func (_c *mockAvailableRollbacksSource_Get_Call) Return(stringToTTLMarker map[string]ttl.TTLMarker, err error) *mockAvailableRollbacksSource_Get_Call { _c.Call.Return(stringToTTLMarker, err) return _c } -func (_c *mockAvailableRollbacksSource_Get_Call) RunAndReturn(run func() (map[string]TTLMarker, error)) *mockAvailableRollbacksSource_Get_Call { +func (_c *mockAvailableRollbacksSource_Get_Call) RunAndReturn(run func() (map[string]ttl.TTLMarker, error)) *mockAvailableRollbacksSource_Get_Call { _c.Call.Return(run) return _c } // Set provides a mock function for the type mockAvailableRollbacksSource -func (_mock *mockAvailableRollbacksSource) Set(stringToTTLMarker map[string]TTLMarker) error { +func (_mock *mockAvailableRollbacksSource) Set(stringToTTLMarker map[string]ttl.TTLMarker) error { ret := _mock.Called(stringToTTLMarker) if len(ret) == 0 { @@ -430,7 +431,7 @@ func (_mock *mockAvailableRollbacksSource) Set(stringToTTLMarker map[string]TTLM } var r0 error - if returnFunc, ok := ret.Get(0).(func(map[string]TTLMarker) error); ok { + if returnFunc, ok := ret.Get(0).(func(map[string]ttl.TTLMarker) error); ok { r0 = returnFunc(stringToTTLMarker) } else { r0 = ret.Error(0) @@ -444,16 +445,16 @@ type mockAvailableRollbacksSource_Set_Call struct { } // Set is a helper method to define mock.On call -// - stringToTTLMarker map[string]TTLMarker +// - stringToTTLMarker map[string]ttl.TTLMarker func (_e *mockAvailableRollbacksSource_Expecter) Set(stringToTTLMarker interface{}) *mockAvailableRollbacksSource_Set_Call { return &mockAvailableRollbacksSource_Set_Call{Call: _e.mock.On("Set", stringToTTLMarker)} } -func (_c *mockAvailableRollbacksSource_Set_Call) Run(run func(stringToTTLMarker map[string]TTLMarker)) *mockAvailableRollbacksSource_Set_Call { +func (_c *mockAvailableRollbacksSource_Set_Call) Run(run func(stringToTTLMarker map[string]ttl.TTLMarker)) *mockAvailableRollbacksSource_Set_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 map[string]TTLMarker + var arg0 map[string]ttl.TTLMarker if args[0] != nil { - arg0 = args[0].(map[string]TTLMarker) + arg0 = args[0].(map[string]ttl.TTLMarker) } run( arg0, @@ -467,7 +468,7 @@ func (_c *mockAvailableRollbacksSource_Set_Call) Return(err error) *mockAvailabl return _c } -func (_c *mockAvailableRollbacksSource_Set_Call) RunAndReturn(run func(stringToTTLMarker map[string]TTLMarker) error) *mockAvailableRollbacksSource_Set_Call { +func (_c *mockAvailableRollbacksSource_Set_Call) RunAndReturn(run func(stringToTTLMarker map[string]ttl.TTLMarker) error) *mockAvailableRollbacksSource_Set_Call { _c.Call.Return(run) return _c } diff --git a/internal/pkg/agent/application/upgrade/step_mark.go b/internal/pkg/agent/application/upgrade/step_mark.go index 11e7b208430..5cdca85c39b 100644 --- a/internal/pkg/agent/application/upgrade/step_mark.go +++ b/internal/pkg/agent/application/upgrade/step_mark.go @@ -14,6 +14,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/pkg/core/logger" @@ -22,13 +23,6 @@ import ( const markerFilename = ".update-marker" -// TTLMarker marks an elastic-agent install available for rollback -type TTLMarker struct { - Version string `json:"version" yaml:"version"` - Hash string `json:"hash" yaml:"hash"` - ValidUntil time.Time `json:"valid_until" yaml:"valid_until"` -} - // UpdateMarker is a marker holding necessary information about ongoing upgrade. type UpdateMarker struct { // Version represents the version the agent is upgraded to @@ -54,7 +48,7 @@ type UpdateMarker struct { Details *details.Details `json:"details,omitempty" yaml:"details,omitempty"` - RollbacksAvailable map[string]TTLMarker `json:"rollbacks_available,omitempty" yaml:"rollbacks_available,omitempty"` + RollbacksAvailable map[string]ttl.TTLMarker `json:"rollbacks_available,omitempty" yaml:"rollbacks_available,omitempty"` } // GetActionID returns the Fleet Action ID associated with the @@ -101,17 +95,17 @@ func convertToActionUpgrade(a *MarkerActionUpgrade) *fleetapi.ActionUpgrade { } type updateMarkerSerializer struct { - Version string `yaml:"version"` - Hash string `yaml:"hash"` - VersionedHome string `yaml:"versioned_home"` - UpdatedOn time.Time `yaml:"updated_on"` - PrevVersion string `yaml:"prev_version"` - PrevHash string `yaml:"prev_hash"` - PrevVersionedHome string `yaml:"prev_versioned_home"` - Acked bool `yaml:"acked"` - Action *MarkerActionUpgrade `yaml:"action"` - Details *details.Details `yaml:"details"` - RollbacksAvailable map[string]TTLMarker `yaml:"rollbacks_available,omitempty"` + Version string `yaml:"version"` + Hash string `yaml:"hash"` + VersionedHome string `yaml:"versioned_home"` + UpdatedOn time.Time `yaml:"updated_on"` + PrevVersion string `yaml:"prev_version"` + PrevHash string `yaml:"prev_hash"` + PrevVersionedHome string `yaml:"prev_versioned_home"` + Acked bool `yaml:"acked"` + Action *MarkerActionUpgrade `yaml:"action"` + Details *details.Details `yaml:"details"` + RollbacksAvailable map[string]ttl.TTLMarker `yaml:"rollbacks_available,omitempty"` } func newMarkerSerializer(m *UpdateMarker) *updateMarkerSerializer { @@ -141,7 +135,7 @@ type updateActiveCommitFunc func(log *logger.Logger, topDirPath, hash string, wr // markUpgrade marks update happened so we can handle grace period func markUpgradeProvider(updateActiveCommit updateActiveCommitFunc, writeFile writeFileFunc) markUpgradeFunc { - return func(log *logger.Logger, dataDirPath string, updatedOn time.Time, agent, previousAgent agentInstall, action *fleetapi.ActionUpgrade, upgradeDetails *details.Details, availableRollbacks map[string]TTLMarker) error { + return func(log *logger.Logger, dataDirPath string, updatedOn time.Time, agent, previousAgent agentInstall, action *fleetapi.ActionUpgrade, upgradeDetails *details.Details, availableRollbacks map[string]ttl.TTLMarker) error { if len(previousAgent.hash) > HashLen { previousAgent.hash = previousAgent.hash[:HashLen] @@ -269,3 +263,21 @@ func saveMarkerToPath(marker *UpdateMarker, markerFile string, shouldFsync bool) func markerFilePath(dataDirPath string) string { return filepath.Join(dataDirPath, markerFilename) } + +// IsTerminalState returns true if the state in the upgrade marker contains details and the upgrade details state is a +// terminal one: UPG_COMPLETE, UPG_ROLLBACK and UPG_FAILED +// If the upgrade marker or the upgrade marker details are nil the function will return false: as +// no state is specified, having simply a marker without details would mean that some upgrade operation is ongoing +// (probably initiated by an older agent). +func IsTerminalState(marker *UpdateMarker) bool { + if marker.Details == nil { + return false + } + + switch marker.Details.State { + case details.StateCompleted, details.StateRollback, details.StateFailed: + return true + default: + return false + } +} diff --git a/internal/pkg/agent/application/upgrade/step_mark_test.go b/internal/pkg/agent/application/upgrade/step_mark_test.go index 9daed169e58..9f0342411db 100644 --- a/internal/pkg/agent/application/upgrade/step_mark_test.go +++ b/internal/pkg/agent/application/upgrade/step_mark_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/pkg/core/logger" "github.com/elastic/elastic-agent/pkg/core/logger/loggertest" @@ -94,7 +95,7 @@ func TestMarkUpgrade(t *testing.T) { previousAgent agentInstall action *fleetapi.ActionUpgrade details *details.Details - availableRollbacks map[string]TTLMarker + availableRollbacks map[string]ttl.TTLMarker } type workingDirHook func(t *testing.T, dataDir string) @@ -200,7 +201,7 @@ func TestMarkUpgrade(t *testing.T) { }, action: nil, details: details.NewDetails("9.2.0-SNAPSHOT", details.StateReplacing, ""), - availableRollbacks: map[string]TTLMarker{ + availableRollbacks: map[string]ttl.TTLMarker{ filepath.Join("data", "elastic-agent-1.2.3-SNAPSHOT-prvagt"): { Version: "1.2.3-SNAPSHOT", ValidUntil: twentyFourHoursFromNow, @@ -228,7 +229,7 @@ func TestMarkUpgrade(t *testing.T) { ActionID: "", Metadata: details.Metadata{}, }, - RollbacksAvailable: map[string]TTLMarker{ + RollbacksAvailable: map[string]ttl.TTLMarker{ filepath.Join("data", "elastic-agent-1.2.3-SNAPSHOT-prvagt"): { Version: "1.2.3-SNAPSHOT", ValidUntil: twentyFourHoursFromNow, diff --git a/internal/pkg/agent/application/upgrade/ttl/marker.go b/internal/pkg/agent/application/upgrade/ttl/marker.go new file mode 100644 index 00000000000..5eca887393a --- /dev/null +++ b/internal/pkg/agent/application/upgrade/ttl/marker.go @@ -0,0 +1,14 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package ttl + +import "time" + +// TTLMarker marks an elastic-agent install available for rollback +type TTLMarker struct { + Version string `json:"version" yaml:"version"` + Hash string `json:"hash" yaml:"hash"` + ValidUntil time.Time `json:"valid_until" yaml:"valid_until"` +} diff --git a/internal/pkg/agent/application/upgrade/ttl_marker_source.go b/internal/pkg/agent/application/upgrade/ttl/ttl_marker_source.go similarity index 99% rename from internal/pkg/agent/application/upgrade/ttl_marker_source.go rename to internal/pkg/agent/application/upgrade/ttl/ttl_marker_source.go index 4445ff4526e..4b795266d51 100644 --- a/internal/pkg/agent/application/upgrade/ttl_marker_source.go +++ b/internal/pkg/agent/application/upgrade/ttl/ttl_marker_source.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License 2.0; // you may not use this file except in compliance with the Elastic License 2.0. -package upgrade +package ttl import ( "fmt" diff --git a/internal/pkg/agent/application/upgrade/ttl_marker_source_test.go b/internal/pkg/agent/application/upgrade/ttl/ttl_marker_source_test.go similarity index 99% rename from internal/pkg/agent/application/upgrade/ttl_marker_source_test.go rename to internal/pkg/agent/application/upgrade/ttl/ttl_marker_source_test.go index dd79961102c..fb8e89b252e 100644 --- a/internal/pkg/agent/application/upgrade/ttl_marker_source_test.go +++ b/internal/pkg/agent/application/upgrade/ttl/ttl_marker_source_test.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License 2.0; // you may not use this file except in compliance with the Elastic License 2.0. -package upgrade +package ttl import ( "bytes" diff --git a/internal/pkg/agent/application/upgrade/upgrade.go b/internal/pkg/agent/application/upgrade/upgrade.go index cbcc670aead..f00aa7e745b 100644 --- a/internal/pkg/agent/application/upgrade/upgrade.go +++ b/internal/pkg/agent/application/upgrade/upgrade.go @@ -27,6 +27,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download" upgradeErrors "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/errors" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/agent/configuration" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/agent/install" @@ -90,7 +91,7 @@ type unpackHandler interface { type copyActionStoreFunc func(log *logger.Logger, newHome string) error type copyRunDirectoryFunc func(log *logger.Logger, oldRunPath, newRunPath string) error type fileDirCopyFunc func(from, to string, opts ...filecopy.Options) error -type markUpgradeFunc func(log *logger.Logger, dataDirPath string, updatedOn time.Time, agent, previousAgent agentInstall, action *fleetapi.ActionUpgrade, upgradeDetails *details.Details, availableRollbacks map[string]TTLMarker) error +type markUpgradeFunc func(log *logger.Logger, dataDirPath string, updatedOn time.Time, agent, previousAgent agentInstall, action *fleetapi.ActionUpgrade, upgradeDetails *details.Details, availableRollbacks map[string]ttl.TTLMarker) error type changeSymlinkFunc func(log *logger.Logger, topDirPath, symlinkPath, newTarget string) error type rollbackInstallFunc func(ctx context.Context, log *logger.Logger, topDirPath, versionedHome, oldVersionedHome string, rollbackSource availableRollbacksSource) error @@ -117,8 +118,8 @@ type WatcherHelper interface { } type availableRollbacksSource interface { - Set(map[string]TTLMarker) error - Get() (map[string]TTLMarker, error) + Set(map[string]ttl.TTLMarker) error + Get() (map[string]ttl.TTLMarker, error) } // Upgrader performs an upgrade @@ -337,6 +338,16 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s u.log.Errorw("Unable to clean downloads before update", "error.message", err, "downloads.path", paths.Downloads()) } + currentVersionedHome, err := filepath.Rel(paths.Top(), paths.Home()) + if err != nil { + return nil, fmt.Errorf("calculating home path relative to top, home: %q top: %q : %w", paths.Home(), paths.Top(), err) + } + + _, err = CleanAvailableRollbacks(u.log, u.availableRollbacksSource, paths.Top(), currentVersionedHome, CleanupAllRollbacks) + if err != nil { + u.log.Warnw("Unable to clean all available rollbacks", "error.message", err) + } + det.SetState(details.StateDownloading) sourceURI = u.sourceURI(sourceURI) @@ -435,11 +446,6 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s // paths.BinaryPath properly derives the binary directory depending on the platform. The path to the binary for macOS is inside of the app bundle. newPath := paths.BinaryPath(filepath.Join(paths.Top(), hashedDir), AgentName) - currentVersionedHome, err := filepath.Rel(paths.Top(), paths.Home()) - if err != nil { - return nil, fmt.Errorf("calculating home path relative to top, home: %q top: %q : %w", paths.Home(), paths.Top(), err) - } - if err := u.changeSymlink(u.log, paths.Top(), symlinkPath, newPath); err != nil { u.log.Errorw("Rolling back: changing symlink failed", "error.message", err) rollbackErr := u.rollbackInstall(ctx, u.log, paths.Top(), hashedDir, currentVersionedHome, u.availableRollbacksSource) diff --git a/internal/pkg/agent/application/upgrade/upgrade_test.go b/internal/pkg/agent/application/upgrade/upgrade_test.go index dc91be46bc2..33bbb4d2d99 100644 --- a/internal/pkg/agent/application/upgrade/upgrade_test.go +++ b/internal/pkg/agent/application/upgrade/upgrade_test.go @@ -31,6 +31,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact" upgradeErrors "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/errors" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/ttl" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/config" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" @@ -1096,6 +1097,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkArchiveCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error if getPackageMetadata fails": { @@ -1112,6 +1114,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkArchiveCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error and cleanup downloaded archive if unpack fails before extracting": { @@ -1139,6 +1142,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkArchiveCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error and cleanup downloaded archive if unpack fails after extracting": { @@ -1171,6 +1175,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkVersionedHomeCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error and cleanup downloaded artifact and extracted archive if copyActionStore fails": { @@ -1205,6 +1210,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkVersionedHomeCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error and cleanup downloaded artifact and extracted archive if copyRunDirectory fails": { @@ -1243,6 +1249,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkVersionedHomeCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error and cleanup downloaded artifact and extracted archive if changeSymlink fails": { @@ -1286,6 +1293,7 @@ func TestUpgradeErrorHandling(t *testing.T) { checkVersionedHomeCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, "should return error and cleanup downloaded artifact and extracted archive if markUpgrade fails": { @@ -1324,7 +1332,7 @@ func TestUpgradeErrorHandling(t *testing.T) { upgrader.rollbackInstall = func(ctx context.Context, log *logger.Logger, topDirPath, versionedHome, oldVersionedHome string, source availableRollbacksSource) error { return nil } - upgrader.markUpgrade = func(log *logger.Logger, dataDirPath string, updatedOn time.Time, agent, previousAgent agentInstall, action *fleetapi.ActionUpgrade, upgradeDetails *details.Details, availableRollbacks map[string]TTLMarker) error { + upgrader.markUpgrade = func(log *logger.Logger, dataDirPath string, updatedOn time.Time, agent, previousAgent agentInstall, action *fleetapi.ActionUpgrade, upgradeDetails *details.Details, availableRollbacks map[string]ttl.TTLMarker) error { return testError } }, @@ -1332,7 +1340,8 @@ func TestUpgradeErrorHandling(t *testing.T) { checkVersionedHomeCleanup: true, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") - mockRollbackSrc.EXPECT().Set(map[string]TTLMarker(nil)).Return(nil) + mockRollbackSrc.EXPECT().Get().Return(nil, nil) + mockRollbackSrc.EXPECT().Set(map[string]ttl.TTLMarker(nil)).Return(nil) }, }, "should add disk space error to the error chain if downloadArtifact fails with disk space error": { @@ -1345,6 +1354,7 @@ func TestUpgradeErrorHandling(t *testing.T) { }, setupMocks: func(t *testing.T, mockAgentInfo *info.MockAgent, mockRollbackSrc *mockAvailableRollbacksSource, mockWatcherHelper *MockWatcherHelper) { mockAgentInfo.EXPECT().Version().Return("9.0.0") + mockRollbackSrc.EXPECT().Get().Return(nil, nil) }, }, } diff --git a/internal/pkg/agent/cmd/watch.go b/internal/pkg/agent/cmd/watch.go index 2308a7630d5..bc38847d3d2 100644 --- a/internal/pkg/agent/cmd/watch.go +++ b/internal/pkg/agent/cmd/watch.go @@ -188,7 +188,7 @@ func watchCmd(log *logp.Logger, topDir string, cfg *configuration.UpgradeWatcher log.With("marker", marker, "details", marker.Details).Info("Loaded update marker") isWithinGrace, tilGrace := gracePeriod(marker, cfg.GracePeriod) - if isTerminalState(marker) || !isWithinGrace { + if upgrade.IsTerminalState(marker) || !isWithinGrace { stateString := "" if marker.Details != nil { stateString = string(marker.Details.State) @@ -328,24 +328,6 @@ func rollback(log *logp.Logger, topDir string, client client.Client, installModi return nil } -// isTerminalState returns true if the state in the upgrade marker contains details and the upgrade details state is a -// terminal one: UPG_COMPLETE, UPG_ROLLBACK and UPG_FAILED -// If the upgrade marker or the upgrade marker details are nil the function will return false: as -// no state is specified, having simply a marker without details would mean that some upgrade operation is ongoing -// (probably initiated by an older agent). -func isTerminalState(marker *upgrade.UpdateMarker) bool { - if marker.Details == nil { - return false - } - - switch marker.Details.State { - case details.StateCompleted, details.StateRollback, details.StateFailed: - return true - default: - return false - } -} - func isWindows() bool { return runtime.GOOS == "windows" } diff --git a/internal/pkg/fleetapi/action.go b/internal/pkg/fleetapi/action.go index 08f339f4c3e..32b0026ea63 100644 --- a/internal/pkg/fleetapi/action.go +++ b/internal/pkg/fleetapi/action.go @@ -267,7 +267,8 @@ type ActionUpgradeData struct { Version string `json:"version" yaml:"version,omitempty" mapstructure:"-"` SourceURI string `json:"source_uri,omitempty" yaml:"source_uri,omitempty" mapstructure:"-"` // TODO: update fleet open api schema - Retry int `json:"retry_attempt,omitempty" yaml:"retry_attempt,omitempty" mapstructure:"-"` + Retry int `json:"retry_attempt,omitempty" yaml:"retry_attempt,omitempty" mapstructure:"-"` + Rollback bool `json:"rollback,omitempty" yaml:"rollback,omitempty" mapstructure:"-"` } func (a *ActionUpgrade) String() string { diff --git a/internal/pkg/fleetapi/checkin_cmd.go b/internal/pkg/fleetapi/checkin_cmd.go index fb204b6ad3a..b521f7698d9 100644 --- a/internal/pkg/fleetapi/checkin_cmd.go +++ b/internal/pkg/fleetapi/checkin_cmd.go @@ -39,16 +39,22 @@ type CheckinComponent struct { Units []CheckinUnit `json:"units,omitempty"` } +type CheckinRollback struct { + Version string `json:"version"` + ValidUntil time.Time `json:"valid_until"` +} + // CheckinRequest consists of multiple events reported to fleet ui. type CheckinRequest struct { - Status string `json:"status"` - AckToken string `json:"ack_token,omitempty"` - Metadata *info.ECSMeta `json:"local_metadata,omitempty"` - Message string `json:"message"` // V2 Agent message - Components []CheckinComponent `json:"components"` // V2 Agent components - UpgradeDetails *details.Details `json:"upgrade_details,omitempty"` - AgentPolicyID string `json:"agent_policy_id,omitempty"` - PolicyRevisionIDX int64 `json:"policy_revision_idx,omitempty"` + Status string `json:"status"` + AckToken string `json:"ack_token,omitempty"` + Metadata *info.ECSMeta `json:"local_metadata,omitempty"` + Message string `json:"message"` // V2 Agent message + Components []CheckinComponent `json:"components"` // V2 Agent components + UpgradeDetails *details.Details `json:"upgrade_details,omitempty"` + AgentPolicyID string `json:"agent_policy_id,omitempty"` + PolicyRevisionIDX int64 `json:"policy_revision_idx,omitempty"` + AvailableRollbacks []CheckinRollback `json:"available_rollbacks,omitempty"` } // SerializableEvent is a representation of the event to be send to the Fleet Server API via the checkin