From 12f442439a64a4b76e8b9c7e93c0aa5de74d213c Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 15 Nov 2024 20:09:32 +0300 Subject: [PATCH] [management] Refactor group to use store methods (#2867) * Refactor setup key handling to use store methods Signed-off-by: bcmmbaga * add lock to get account groups Signed-off-by: bcmmbaga * add check for regular user Signed-off-by: bcmmbaga * get only required groups for auto-group validation Signed-off-by: bcmmbaga * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga * refactor account peers update Signed-off-by: bcmmbaga * Refactor groups to use store methods Signed-off-by: bcmmbaga * refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga * Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga * Run groups ops in transaction Signed-off-by: bcmmbaga * fix missing group removed from setup key activity Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix sonar Signed-off-by: bcmmbaga * Change setup key log level to debug for missing group Signed-off-by: bcmmbaga * Retrieve modified peers once for group events Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga * Add account locking and merge group deletion methods Signed-off-by: bcmmbaga * Fix tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 29 +- management/server/account_test.go | 12 +- management/server/dns.go | 2 +- management/server/group.go | 490 ++++++++++-------- management/server/group/group.go | 27 + management/server/group/group_test.go | 90 ++++ management/server/group_test.go | 2 +- management/server/integrated_validator.go | 27 +- management/server/mock_server/account_mock.go | 9 - management/server/nameserver.go | 6 +- management/server/peer.go | 41 +- management/server/peer_test.go | 2 +- management/server/policy.go | 4 +- management/server/posture_checks.go | 2 +- management/server/route.go | 6 +- management/server/route_test.go | 2 +- management/server/setupkey.go | 6 +- management/server/sql_store.go | 120 ++++- management/server/sql_store_test.go | 285 +++++++++- management/server/status/error.go | 11 +- management/server/store.go | 8 +- management/server/user.go | 8 +- 22 files changed, 866 insertions(+), 323 deletions(-) create mode 100644 management/server/group/group_test.go diff --git a/management/server/account.go b/management/server/account.go index 4afadb4e9ae..1bd8a99a9aa 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -110,7 +110,6 @@ type AccountManager interface { SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) @@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2101,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2114,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { - return status.NewGetAccountError(err) + return err } - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) + if err != nil { + return err + } + + if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } } @@ -2401,12 +2405,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - return - } - am.updateAccountPeers(ctx, updatedAccount) + am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index fdf004a3b8a..97e0d45f016 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - group := group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - } + }) + + require.NoError(t, err, "failed to save group") policy := Policy{ Enabled: true, @@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil { t.Errorf("delete group: %v", err) return } @@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) @@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) diff --git a/management/server/dns.go b/management/server/dns.go index 256b8b12512..4551be5ab92 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/group.go b/management/server/group.go index 7b4f079487d..a36213f0493 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -33,8 +33,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() } return nil @@ -45,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -54,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -73,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var eventsToStore []func() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } - for _, newGroup := range newGroups { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err - } - } + var eventsToStore []func() + var groupsToSave []*nbgroup.Group + var updateAccountPeers bool - // Avoid duplicate groups only for the API issued groups. - // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err } - newGroup.ID = xid.New().String() - } + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) - for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup - - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) - eventsToStore = append(eventsToStore, events...) - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err + } - newGroupIDs := make([]string, 0, len(newGroups)) - for _, newGroup := range newGroups { - newGroupIDs = append(newGroupIDs, newGroup.ID) - } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + }) + if err != nil { return err } - if areGroupChangesAffectPeers(account, newGroupIDs) { - am.updateAccountPeers(ctx, account) - } - for _, storeEvent := range eventsToStore { storeEvent() } + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -155,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + + for _, peerID := range addedPeers { + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -206,42 +210,10 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) - if err != nil { - return err - } - - group, ok := account.Groups[groupID] - if !ok { - return nil - } - - allGroup, err := account.GetGroupAll() - if err != nil { - return err - } - - if allGroup.ID == groupID { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(account, group, userId); err != nil { - return err - } - delete(account.Groups, groupID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) - - return nil + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -250,58 +222,56 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use // // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var allErrors error - - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { - continue - } - - if err := validateDeleteGroup(account, group, userId); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) - continue - } - - delete(account.Groups, groupID) - deletedGroups = append(deletedGroups, group) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + if user.IsRegularUser() { + return status.NewAdminPermissionError() } - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) - } + var allErrors error + var groupIDsToDelete []string + var deletedGroups []*nbgroup.Group - return allErrors -} + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) + if err != nil { + allErrors = errors.Join(allErrors, err) + continue + } -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + allErrors = errors.Join(allErrors, err) + continue + } + + groupIDsToDelete = append(groupIDsToDelete, groupID) + deletedGroups = append(deletedGroups, group) + } - account, err := am.Store.GetAccount(ctx, accountID) + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + }) if err != nil { - return nil, err + return err } - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } - return groups, nil + return allErrors } // GroupAddPeer appends peer to the group @@ -309,34 +279,37 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } - add := true - for _, itemID := range group.Peers { - if itemID == peerID { - add = false - break + if updated := group.AddPeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - } - if add { - group.Peers = append(group.Peers, peerID) - } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + }) + if err != nil { return err } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil @@ -347,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + var group *nbgroup.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemovePeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + }) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } - account.Network.IncSerial() - for i, itemID := range group.Peers { - if itemID == peerID { - group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { + return nil +} + +// validateNewGroup validates the new group for existence and required fields. +func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if err != nil { + if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { return err } } + + // Prevent duplicate groups for API-issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } } return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { - return status.Errorf(status.NotFound, "user not found") + executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { +func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + for _, r := range routes { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { return true, r } } + return false, nil } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { +func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + for _, policy := range policies { for _, rule := range policy.Rules { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { @@ -442,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { +func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + for _, dns := range nameServerGroups { for _, g := range dns.Groups { if g == groupID { @@ -450,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou } } } + return false, nil } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { +func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + for _, setupKey := range setupKeys { if slices.Contains(setupKey.AutoGroups, groupID) { return true, setupKey @@ -464,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { +func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { + users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + for _, user := range users { if slices.Contains(user.AutoGroups, groupID) { return true, user @@ -473,31 +537,41 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if group, exists := account.Groups[groupID]; exists && group.HasPeers() { - return true - } +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err } - return false -} -func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { - if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { - return true + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil } - if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { - return true + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { + return true, nil } - if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { - return true + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { + return true, nil } - if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { - return true + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { + return true, nil } } + return false, nil +} + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if group, exists := account.Groups[groupID]; exists && group.HasPeers() { + return true + } + } return false } diff --git a/management/server/group/group.go b/management/server/group/group.go index e98e5ecc4b5..24c60d3ceef 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool { func (g *Group) IsGroupAll() bool { return g.Name == "All" } + +// AddPeer adds peerID to Peers if not present, returning true if added. +func (g *Group) AddPeer(peerID string) bool { + if peerID == "" { + return false + } + + for _, itemID := range g.Peers { + if itemID == peerID { + return false + } + } + + g.Peers = append(g.Peers, peerID) + return true +} + +// RemovePeer removes peerID from Peers if present, returning true if removed. +func (g *Group) RemovePeer(peerID string) bool { + for i, itemID := range g.Peers { + if itemID == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/group/group_test.go new file mode 100644 index 00000000000..cb002f8d9e1 --- /dev/null +++ b/management/server/group/group_test.go @@ -0,0 +1,90 @@ +package group + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddPeer(t *testing.T) { + t.Run("add new peer to empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to non-empty slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add duplicate peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer1" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("add empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} + +func TestRemovePeer(t *testing.T) { + t.Run("remove existing peer from slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2", "peer3"}} + peerID := "peer2" + assert.True(t, group.RemovePeer(peerID)) + assert.NotContains(t, group.Peers, peerID) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + }) + + t.Run("remove peer from nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Nil(t, group.Peers) + }) + + t.Run("remove non-existent peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from single-item slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1"}} + peerID := "peer1" + assert.True(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + assert.NotContains(t, group.Peers, peerID) + }) + + t.Run("remove empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e81927..59094a23e92 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c2b..0c70b702a01 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { - if len(groups) == 0 { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) - if err != nil { - return false, err - } - for _, group := range groups { - var found bool - for _, accountGroup := range accountsGroups { - if accountGroup.ID == group { - found = true - break + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } } - if !found { - return false, nil - } + return nil + }) + if err != nil { + return false, err } return true, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d7139bb2a5f..aa6a47b152e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5ebd263dcc2..957008714e5 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, newNSGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun } if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, nsGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) diff --git a/management/server/peer.go b/management/server/peer.go index 87b5b0e4ea7..1405dead85a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -271,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return peer, nil @@ -335,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) + if err != nil { + return err + } err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { @@ -348,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -555,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -598,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) } - groupsToAdd = append(groupsToAdd, allGroup.ID) - if areGroupChangesAffectPeers(account, groupsToAdd) { - am.updateAccountPeers(ctx, account) + + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + + if newGroupsAffectsPeers { + am.updateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -666,7 +674,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } } @@ -685,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } validPeersMap, err := am.GetValidatedPeers(account) @@ -816,7 +824,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -979,7 +987,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { start := time.Now() defer func() { if am.metrics != nil { @@ -987,6 +995,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) @@ -1033,12 +1046,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { peerGroupIDs = append(peerGroupIDs, group.ID) } } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 78885ea1b72..4e2dcb2c313 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account) + manager.updateAccountPeers(ctx, account.Id) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index 43a925f8850..8a5733f011c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 2dccd8f590c..096cff3f5c9 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route.go b/management/server/route.go index 1cf00b37c46..dcf2cb0d32c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route_test.go b/management/server/route_test.go index 4893e19b9f3..5c848f68c7b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(context.Background(), account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 960532bf9f1..cae0dfecbde 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -453,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran modifiedGroups := slices.Concat(addedGroups, removedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) if err != nil { - log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil } for _, g := range removedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } @@ -473,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran for _, g := range addedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9dd3e778d43..0ebda6440c1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) } - return status.NewGetUserFromStoreError() } user.LastLogin = lastLogin @@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group not found for account") + return status.NewGroupNotFoundError(groupID) } return status.Errorf(status.Internal, "issue finding group: %s", result.Error) @@ -1079,10 +1079,45 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil +} + +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1103,7 +1138,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, + db: tx, + storeEngine: s.storeEngine, } } @@ -1155,12 +1191,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewGroupNotFoundError(groupID) + } + log.WithContext(ctx).Errorf("failed to get group from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently @@ -1172,12 +1218,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren query = query.Order("json_array_length(peers) DESC") } - result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupName) } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get group by name from store") } return &group, nil } @@ -1185,7 +1232,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") @@ -1203,8 +1250,37 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store") + } + return nil +} + +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") } + + if result.RowsAffected == 0 { + return status.NewGroupNotFoundError(groupID) + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") + } + return nil } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 3f3b2a453d4..114da1ee6f6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal("failed to save group") return err } - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID) if err != nil { t.Fatal("failed to get group") return err @@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } @@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) { } accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, err) - require.Equal(t, "All", group.Name) + require.True(t, group.IsGroupAll()) } func Test_DeleteSetupKeySuccessfully(t *testing.T) { @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index f1f3f16e675..8b6d0077b4b 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( @@ -126,11 +125,6 @@ func NewAdminPermissionError() error { return Errorf(PermissionDenied, "admin role required to perform this action") } -// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context -func NewStoreContextCanceledError(duration time.Duration) error { - return Errorf(Internal, "store access: context canceled after %v", duration) -} - // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") @@ -140,3 +134,8 @@ func NewInvalidKeyIDError() error { func NewGetAccountError(err error) error { return Errorf(Internal, "error getting account: %s", err) } + +// NewGroupNotFoundError creates a new Error with NotFound type for a missing group +func NewGroupNotFoundError(groupID string) error { + return Errorf(NotFound, "group: %s not found", groupID) +} diff --git a/management/server/store.go b/management/server/store.go index c7c066629e3..71b0d457b4c 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,7 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -76,6 +76,8 @@ type Store interface { GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -90,6 +92,8 @@ type Store interface { AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -108,7 +112,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string diff --git a/management/server/user.go b/management/server/user.go index 5e0d9d034e3..74062112af6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -835,7 +835,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -1132,7 +1132,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil } @@ -1240,7 +1240,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta {