diff --git a/firewall/request_logger.go b/firewall/request_logger.go index dad96339b..7c7abb043 100644 --- a/firewall/request_logger.go +++ b/firewall/request_logger.go @@ -194,6 +194,7 @@ func (r *RequestLogger) addNewAction(ri *RequestInfo, } action := &firewalldb.Action{ + SessionID: sessionID, RPCMethod: ri.URI, AttemptedAt: time.Now(), State: firewalldb.ActionStateInit, @@ -222,7 +223,7 @@ func (r *RequestLogger) addNewAction(ri *RequestInfo, } } - id, err := r.actionsDB.AddAction(sessionID, action) + id, err := r.actionsDB.AddAction(action) if err != nil { return err } diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 72e141556..57e43e0d6 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -1,55 +1,10 @@ package firewalldb import ( - "bytes" "context" - "encoding/binary" - "errors" - "fmt" - "io" "time" "github.com/lightninglabs/lightning-terminal/session" - "github.com/lightningnetwork/lnd/tlv" - "go.etcd.io/bbolt" -) - -const ( - typeActorName tlv.Type = 1 - typeFeature tlv.Type = 2 - typeTrigger tlv.Type = 3 - typeIntent tlv.Type = 4 - typeStructuredJsonData tlv.Type = 5 - typeRPCMethod tlv.Type = 6 - typeRPCParamsJson tlv.Type = 7 - typeAttemptedAt tlv.Type = 8 - typeState tlv.Type = 9 - typeErrorReason tlv.Type = 10 - - typeLocatorSessionID tlv.Type = 1 - typeLocatorActionID tlv.Type = 2 -) - -/* - The Actions are stored in the following structure in the KV db: - - actions-bucket -> actions -> -> -> serialised action - - -> actions-index -> -> {sessionID:action-index} -*/ - -var ( - // actionsBucketKey is the key that will be used for the main Actions - // bucket. - actionsBucketKey = []byte("actions-bucket") - - // actionsKey is the key used for the sub-bucket containing the - // session actions. - actionsKey = []byte("actions") - - // actionsIndex is the key used for the sub-bucket containing a map - // from monotonically increasing IDs to action locators. - actionsIndex = []byte("actions-index") ) // ActionState represents the state of an action. @@ -116,154 +71,6 @@ type Action struct { ErrorReason string } -// AddAction serialises and adds an Action to the DB under the given sessionID. -func (db *BoltDB) AddAction(sessionID session.ID, action *Action) (uint64, - error) { - - var buf bytes.Buffer - if err := SerializeAction(&buf, action); err != nil { - return 0, err - } - - var id uint64 - err := db.DB.Update(func(tx *bbolt.Tx) error { - mainActionsBucket, err := getBucket(tx, actionsBucketKey) - if err != nil { - return err - } - - actionsBucket := mainActionsBucket.Bucket(actionsKey) - if actionsBucket == nil { - return ErrNoSuchKeyFound - } - - sessBucket, err := actionsBucket.CreateBucketIfNotExists( - sessionID[:], - ) - if err != nil { - return err - } - - nextActionIndex, err := sessBucket.NextSequence() - if err != nil { - return err - } - id = nextActionIndex - - var actionIndex [8]byte - byteOrder.PutUint64(actionIndex[:], nextActionIndex) - err = sessBucket.Put(actionIndex[:], buf.Bytes()) - if err != nil { - return err - } - - actionsIndexBucket := mainActionsBucket.Bucket(actionsIndex) - if actionsIndexBucket == nil { - return ErrNoSuchKeyFound - } - - nextSeq, err := actionsIndexBucket.NextSequence() - if err != nil { - return err - } - - locator := ActionLocator{ - SessionID: sessionID, - ActionID: nextActionIndex, - } - - var buf bytes.Buffer - err = serializeActionLocator(&buf, &locator) - if err != nil { - return err - } - - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextSeq) - return actionsIndexBucket.Put(seqNoBytes[:], buf.Bytes()) - }) - if err != nil { - return 0, err - } - - return id, nil -} - -func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error { - var buf bytes.Buffer - if err := SerializeAction(&buf, a); err != nil { - return err - } - - mainActionsBucket, err := getBucket(tx, actionsBucketKey) - if err != nil { - return err - } - - actionsBucket := mainActionsBucket.Bucket(actionsKey) - if actionsBucket == nil { - return ErrNoSuchKeyFound - } - - sessBucket := actionsBucket.Bucket(al.SessionID[:]) - if sessBucket == nil { - return fmt.Errorf("session bucket for session ID %x does not "+ - "exist", al.SessionID) - } - - var id [8]byte - binary.BigEndian.PutUint64(id[:], al.ActionID) - - return sessBucket.Put(id[:], buf.Bytes()) -} - -func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) { - sessBucket := actionsBkt.Bucket(al.SessionID[:]) - if sessBucket == nil { - return nil, fmt.Errorf("session bucket for session ID "+ - "%x does not exist", al.SessionID) - } - - var id [8]byte - binary.BigEndian.PutUint64(id[:], al.ActionID) - - actionBytes := sessBucket.Get(id[:]) - return DeserializeAction(bytes.NewReader(actionBytes), al.SessionID) -} - -// SetActionState finds the action specified by the ActionLocator and sets its -// state to the given state. -func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState, - errorReason string) error { - - if errorReason != "" && state != ActionStateError { - return fmt.Errorf("error reason should only be set for " + - "ActionStateError") - } - - return db.DB.Update(func(tx *bbolt.Tx) error { - mainActionsBucket, err := getBucket(tx, actionsBucketKey) - if err != nil { - return err - } - - actionsBucket := mainActionsBucket.Bucket(actionsKey) - if actionsBucket == nil { - return ErrNoSuchKeyFound - } - - action, err := getAction(actionsBucket, al) - if err != nil { - return err - } - - action.State = state - action.ErrorReason = errorReason - - return putAction(tx, al, action) - }) -} - // ListActionsQuery can be used to tweak the query to ListActions and // ListSessionActions. type ListActionsQuery struct { @@ -284,277 +91,10 @@ type ListActionsQuery struct { CountAll bool } -// ListActionsFilterFn defines a function that can be used to determine if an -// action should be included in a set of results or not. The reversed parameter -// indicates if the actions are being traversed in reverse order or not. -// The first return boolean indicates if the action should be included or not -// and the second one indicates if the iteration should be stopped or not. -type ListActionsFilterFn func(a *Action, reversed bool) (bool, bool) - -// ListActions returns a list of Actions that pass the filterFn requirements. -// The indexOffset and maxNum params can be used to control the number of -// actions returned. The return values are the list of actions, the last index -// and the total count (iff query.CountTotal is set). -func (db *BoltDB) ListActions(filterFn ListActionsFilterFn, - query *ListActionsQuery) ([]*Action, uint64, uint64, error) { - - var ( - actions []*Action - totalCount uint64 - lastIndex uint64 - ) - err := db.View(func(tx *bbolt.Tx) error { - mainActionsBucket, err := getBucket(tx, actionsBucketKey) - if err != nil { - return err - } - - actionsBucket := mainActionsBucket.Bucket(actionsKey) - if actionsBucket == nil { - return ErrNoSuchKeyFound - } - - actionsIndexBucket := mainActionsBucket.Bucket(actionsIndex) - if actionsIndexBucket == nil { - return ErrNoSuchKeyFound - } - - readAction := func(index, locatorBytes []byte) (*Action, - error) { - - locator, err := deserializeActionLocator( - bytes.NewReader(locatorBytes), - ) - if err != nil { - return nil, err - } - - return getAction(actionsBucket, locator) - } - - actions, lastIndex, totalCount, err = paginateActions( - query, actionsIndexBucket.Cursor(), readAction, - filterFn, - ) - return err - }) - if err != nil { - return nil, 0, 0, err - } - - return actions, lastIndex, totalCount, nil -} - -// ListSessionActions returns a list of the given session's Actions that pass -// the filterFn requirements. -func (db *BoltDB) ListSessionActions(sessionID session.ID, - filterFn ListActionsFilterFn, query *ListActionsQuery) ([]*Action, - uint64, uint64, error) { - - var ( - actions []*Action - totalCount uint64 - lastIndex uint64 - ) - err := db.View(func(tx *bbolt.Tx) error { - mainActionsBucket, err := getBucket(tx, actionsBucketKey) - if err != nil { - return err - } - - actionsBucket := mainActionsBucket.Bucket(actionsKey) - if actionsBucket == nil { - return ErrNoSuchKeyFound - } - - sessionsBucket := actionsBucket.Bucket(sessionID[:]) - if sessionsBucket == nil { - return nil - } - - readAction := func(_, v []byte) (*Action, error) { - return DeserializeAction(bytes.NewReader(v), sessionID) - } - - actions, lastIndex, totalCount, err = paginateActions( - query, sessionsBucket.Cursor(), readAction, filterFn, - ) - - return err - }) - if err != nil { - return nil, 0, 0, err - } - - return actions, lastIndex, totalCount, nil -} - -// ListGroupActions returns a list of the given session group's Actions that -// pass the filterFn requirements. -// -// TODO: update to allow for pagination. -func (db *BoltDB) ListGroupActions(ctx context.Context, groupID session.ID, - filterFn ListActionsFilterFn) ([]*Action, error) { - - if filterFn == nil { - filterFn = func(a *Action, reversed bool) (bool, bool) { - return true, true - } - } - - sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID) - if err != nil { - return nil, err - } - - var ( - actions []*Action - errDone = errors.New("done iterating") - ) - err = db.View(func(tx *bbolt.Tx) error { - mainActionsBucket, err := getBucket(tx, actionsBucketKey) - if err != nil { - return err - } - - actionsBucket := mainActionsBucket.Bucket(actionsKey) - if actionsBucket == nil { - return ErrNoSuchKeyFound - } - - // Iterate over each session ID in this group. - for _, sessionID := range sessionIDs { - sessionsBucket := actionsBucket.Bucket(sessionID[:]) - if sessionsBucket == nil { - return nil - } - - err = sessionsBucket.ForEach(func(_, v []byte) error { - action, err := DeserializeAction( - bytes.NewReader(v), sessionID, - ) - if err != nil { - return err - } - - include, cont := filterFn(action, false) - if include { - actions = append(actions, action) - } - - if !cont { - return errDone - } - - return nil - }) - if err != nil { - return err - } - } - - return nil - }) - if err != nil && !errors.Is(err, errDone) { - return nil, err - } - - return actions, nil -} - -// SerializeAction binary serializes the given action to the writer using the -// tlv format. -func SerializeAction(w io.Writer, action *Action) error { - if action == nil { - return fmt.Errorf("action cannot be nil") - } - - var ( - actor = []byte(action.ActorName) - feature = []byte(action.FeatureName) - trigger = []byte(action.Trigger) - intent = []byte(action.Intent) - data = []byte(action.StructuredJsonData) - rpcMethod = []byte(action.RPCMethod) - params = action.RPCParamsJson - attemptedAt = uint64(action.AttemptedAt.Unix()) - state = uint8(action.State) - errorReason = []byte(action.ErrorReason) - ) - - tlvRecords := []tlv.Record{ - tlv.MakePrimitiveRecord(typeActorName, &actor), - tlv.MakePrimitiveRecord(typeFeature, &feature), - tlv.MakePrimitiveRecord(typeTrigger, &trigger), - tlv.MakePrimitiveRecord(typeIntent, &intent), - tlv.MakePrimitiveRecord(typeStructuredJsonData, &data), - tlv.MakePrimitiveRecord(typeRPCMethod, &rpcMethod), - tlv.MakePrimitiveRecord(typeRPCParamsJson, ¶ms), - tlv.MakePrimitiveRecord(typeAttemptedAt, &attemptedAt), - tlv.MakePrimitiveRecord(typeState, &state), - tlv.MakePrimitiveRecord(typeErrorReason, &errorReason), - } - - tlvStream, err := tlv.NewStream(tlvRecords...) - if err != nil { - return err - } - - return tlvStream.Encode(w) -} - -// DeserializeAction deserializes an action from the given reader, expecting -// the data to be encoded in the tlv format. -func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) { - var ( - action = Action{} - actor, featureName []byte - trigger, intent, data []byte - rpcMethod, params []byte - attemptedAt uint64 - state uint8 - errorReason []byte - ) - tlvStream, err := tlv.NewStream( - tlv.MakePrimitiveRecord(typeActorName, &actor), - tlv.MakePrimitiveRecord(typeFeature, &featureName), - tlv.MakePrimitiveRecord(typeTrigger, &trigger), - tlv.MakePrimitiveRecord(typeIntent, &intent), - tlv.MakePrimitiveRecord(typeStructuredJsonData, &data), - tlv.MakePrimitiveRecord(typeRPCMethod, &rpcMethod), - tlv.MakePrimitiveRecord(typeRPCParamsJson, ¶ms), - tlv.MakePrimitiveRecord(typeAttemptedAt, &attemptedAt), - tlv.MakePrimitiveRecord(typeState, &state), - tlv.MakePrimitiveRecord(typeErrorReason, &errorReason), - ) - if err != nil { - return nil, err - } - - _, err = tlvStream.DecodeWithParsedTypes(r) - if err != nil { - return nil, err - } - - action.SessionID = sessionID - action.ActorName = string(actor) - action.FeatureName = string(featureName) - action.Trigger = string(trigger) - action.Intent = string(intent) - action.StructuredJsonData = string(data) - action.RPCMethod = string(rpcMethod) - action.RPCParamsJson = params - action.AttemptedAt = time.Unix(int64(attemptedAt), 0) - action.State = ActionState(state) - action.ErrorReason = string(errorReason) - - return &action, nil -} - // ActionsWriteDB is an abstraction over the Actions DB that will allow a // caller to add new actions as well as change the values of an existing action. type ActionsWriteDB interface { - AddAction(sessionID session.ID, action *Action) (uint64, error) + AddAction(action *Action) (uint64, error) SetActionState(al *ActionLocator, state ActionState, errReason string) error } @@ -569,10 +109,10 @@ type RuleAction struct { PerformedAt time.Time } -// ActionsDB represents a DB backend that contains Action entries that can +// ActionsListDB represents a DB backend that contains Action entries that can // be queried. It allows us to abstract away the details of the data storage // method. -type ActionsDB interface { +type ActionsListDB interface { // ListActions returns a list of past Action items. ListActions(ctx context.Context) ([]*RuleAction, error) } @@ -580,8 +120,8 @@ type ActionsDB interface { // ActionsReadDB is an abstraction gives a caller access to either a group // specific or group and feature specific rules.ActionDB. type ActionsReadDB interface { - GroupActionsDB() ActionsDB - GroupFeatureActionsDB() ActionsDB + GroupActionsDB() ActionsListDB + GroupFeatureActionsDB() ActionsListDB } // ActionReadDBGetter represents a function that can be used to construct @@ -610,25 +150,25 @@ type allActionsReadDB struct { var _ ActionsReadDB = (*allActionsReadDB)(nil) -// GroupActionsDB returns a rules.ActionsDB that will give the caller access +// GroupActionsDB returns a rules.ActionsListDB that will give the caller access // to all of a groups Actions. -func (a *allActionsReadDB) GroupActionsDB() ActionsDB { +func (a *allActionsReadDB) GroupActionsDB() ActionsListDB { return &groupActionsReadDB{a} } -// GroupFeatureActionsDB returns a rules.ActionsDB that will give the caller +// GroupFeatureActionsDB returns a rules.ActionsListDB that will give the caller // access to only a specific features Actions in a specific group. -func (a *allActionsReadDB) GroupFeatureActionsDB() ActionsDB { +func (a *allActionsReadDB) GroupFeatureActionsDB() ActionsListDB { return &groupFeatureActionsReadDB{a} } -// groupActionsReadDB is an implementation of the rules.ActionsDB that will +// groupActionsReadDB is an implementation of the rules.ActionsListDB that will // provide read access to all the Actions of a particular group. type groupActionsReadDB struct { *allActionsReadDB } -var _ ActionsDB = (*groupActionsReadDB)(nil) +var _ ActionsListDB = (*groupActionsReadDB)(nil) // ListActions will return all the Actions for a particular group. func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction, @@ -651,14 +191,14 @@ func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction, return actions, nil } -// groupFeatureActionsReadDB is an implementation of the rules.ActionsDB that +// groupFeatureActionsReadDB is an implementation of the rules.ActionsListDB that // will provide read access to all the Actions of a feature within a particular // group. type groupFeatureActionsReadDB struct { *allActionsReadDB } -var _ ActionsDB = (*groupFeatureActionsReadDB)(nil) +var _ ActionsListDB = (*groupFeatureActionsReadDB)(nil) // ListActions will return all the Actions for a particular group that were // executed by a particular feature. @@ -695,59 +235,3 @@ type ActionLocator struct { SessionID session.ID ActionID uint64 } - -// serializeActionLocator binary serializes the given ActionLocator to the -// writer using the tlv format. -func serializeActionLocator(w io.Writer, al *ActionLocator) error { - if al == nil { - return fmt.Errorf("action locator cannot be nil") - } - - var ( - sessionID = al.SessionID[:] - actionID = al.ActionID - ) - - tlvRecords := []tlv.Record{ - tlv.MakePrimitiveRecord(typeLocatorSessionID, &sessionID), - tlv.MakePrimitiveRecord(typeLocatorActionID, &actionID), - } - - tlvStream, err := tlv.NewStream(tlvRecords...) - if err != nil { - return err - } - - return tlvStream.Encode(w) -} - -// deserializeActionLocator deserializes an ActionLocator from the given reader, -// expecting the data to be encoded in the tlv format. -func deserializeActionLocator(r io.Reader) (*ActionLocator, error) { - var ( - sessionID []byte - actionID uint64 - ) - tlvStream, err := tlv.NewStream( - tlv.MakePrimitiveRecord(typeLocatorSessionID, &sessionID), - tlv.MakePrimitiveRecord(typeLocatorActionID, &actionID), - ) - if err != nil { - return nil, err - } - - _, err = tlvStream.DecodeWithParsedTypes(r) - if err != nil { - return nil, err - } - - id, err := session.IDFromBytes(sessionID) - if err != nil { - return nil, err - } - - return &ActionLocator{ - SessionID: id, - ActionID: actionID, - }, nil -} diff --git a/firewalldb/actions_kvdb.go b/firewalldb/actions_kvdb.go new file mode 100644 index 000000000..d92f95543 --- /dev/null +++ b/firewalldb/actions_kvdb.go @@ -0,0 +1,522 @@ +package firewalldb + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/tlv" + "go.etcd.io/bbolt" +) + +const ( + typeActorName tlv.Type = 1 + typeFeature tlv.Type = 2 + typeTrigger tlv.Type = 3 + typeIntent tlv.Type = 4 + typeStructuredJsonData tlv.Type = 5 + typeRPCMethod tlv.Type = 6 + typeRPCParamsJson tlv.Type = 7 + typeAttemptedAt tlv.Type = 8 + typeState tlv.Type = 9 + typeErrorReason tlv.Type = 10 + + typeLocatorSessionID tlv.Type = 1 + typeLocatorActionID tlv.Type = 2 +) + +/* + The Actions are stored in the following structure in the KV db: + + actions-bucket -> actions -> -> -> serialised action + + -> actions-index -> -> {sessionID:action-index} +*/ + +var ( + // actionsBucketKey is the key that will be used for the main Actions + // bucket. + actionsBucketKey = []byte("actions-bucket") + + // actionsKey is the key used for the sub-bucket containing the + // session actions. + actionsKey = []byte("actions") + + // actionsIndex is the key used for the sub-bucket containing a map + // from monotonically increasing IDs to action locators. + actionsIndex = []byte("actions-index") +) + +// AddAction serialises and adds an Action to the DB under the given sessionID. +func (db *BoltDB) AddAction(action *Action) (uint64, error) { + var buf bytes.Buffer + if err := SerializeAction(&buf, action); err != nil { + return 0, err + } + + var id uint64 + err := db.DB.Update(func(tx *bbolt.Tx) error { + mainActionsBucket, err := getBucket(tx, actionsBucketKey) + if err != nil { + return err + } + + actionsBucket := mainActionsBucket.Bucket(actionsKey) + if actionsBucket == nil { + return ErrNoSuchKeyFound + } + + sessBucket, err := actionsBucket.CreateBucketIfNotExists( + action.SessionID[:], + ) + if err != nil { + return err + } + + nextActionIndex, err := sessBucket.NextSequence() + if err != nil { + return err + } + id = nextActionIndex + + var actionIndex [8]byte + byteOrder.PutUint64(actionIndex[:], nextActionIndex) + err = sessBucket.Put(actionIndex[:], buf.Bytes()) + if err != nil { + return err + } + + actionsIndexBucket := mainActionsBucket.Bucket(actionsIndex) + if actionsIndexBucket == nil { + return ErrNoSuchKeyFound + } + + nextSeq, err := actionsIndexBucket.NextSequence() + if err != nil { + return err + } + + locator := ActionLocator{ + SessionID: action.SessionID, + ActionID: nextActionIndex, + } + + var buf bytes.Buffer + err = serializeActionLocator(&buf, &locator) + if err != nil { + return err + } + + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextSeq) + return actionsIndexBucket.Put(seqNoBytes[:], buf.Bytes()) + }) + if err != nil { + return 0, err + } + + return id, nil +} + +func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error { + var buf bytes.Buffer + if err := SerializeAction(&buf, a); err != nil { + return err + } + + mainActionsBucket, err := getBucket(tx, actionsBucketKey) + if err != nil { + return err + } + + actionsBucket := mainActionsBucket.Bucket(actionsKey) + if actionsBucket == nil { + return ErrNoSuchKeyFound + } + + sessBucket := actionsBucket.Bucket(al.SessionID[:]) + if sessBucket == nil { + return fmt.Errorf("session bucket for session ID %x does not "+ + "exist", al.SessionID) + } + + var id [8]byte + binary.BigEndian.PutUint64(id[:], al.ActionID) + + return sessBucket.Put(id[:], buf.Bytes()) +} + +func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) { + sessBucket := actionsBkt.Bucket(al.SessionID[:]) + if sessBucket == nil { + return nil, fmt.Errorf("session bucket for session ID "+ + "%x does not exist", al.SessionID) + } + + var id [8]byte + binary.BigEndian.PutUint64(id[:], al.ActionID) + + actionBytes := sessBucket.Get(id[:]) + return DeserializeAction(bytes.NewReader(actionBytes), al.SessionID) +} + +// SetActionState finds the action specified by the ActionLocator and sets its +// state to the given state. +func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState, + errorReason string) error { + + if errorReason != "" && state != ActionStateError { + return fmt.Errorf("error reason should only be set for " + + "ActionStateError") + } + + return db.DB.Update(func(tx *bbolt.Tx) error { + mainActionsBucket, err := getBucket(tx, actionsBucketKey) + if err != nil { + return err + } + + actionsBucket := mainActionsBucket.Bucket(actionsKey) + if actionsBucket == nil { + return ErrNoSuchKeyFound + } + + action, err := getAction(actionsBucket, al) + if err != nil { + return err + } + + action.State = state + action.ErrorReason = errorReason + + return putAction(tx, al, action) + }) +} + +// ListActionsFilterFn defines a function that can be used to determine if an +// action should be included in a set of results or not. The reversed parameter +// indicates if the actions are being traversed in reverse order or not. +// The first return boolean indicates if the action should be included or not +// and the second one indicates if the iteration should be stopped or not. +type ListActionsFilterFn func(a *Action, reversed bool) (bool, bool) + +// ListActions returns a list of Actions that pass the filterFn requirements. +// The indexOffset and maxNum params can be used to control the number of +// actions returned. The return values are the list of actions, the last index +// and the total count (iff query.CountTotal is set). +func (db *BoltDB) ListActions(filterFn ListActionsFilterFn, + query *ListActionsQuery) ([]*Action, uint64, uint64, error) { + + var ( + actions []*Action + totalCount uint64 + lastIndex uint64 + ) + err := db.View(func(tx *bbolt.Tx) error { + mainActionsBucket, err := getBucket(tx, actionsBucketKey) + if err != nil { + return err + } + + actionsBucket := mainActionsBucket.Bucket(actionsKey) + if actionsBucket == nil { + return ErrNoSuchKeyFound + } + + actionsIndexBucket := mainActionsBucket.Bucket(actionsIndex) + if actionsIndexBucket == nil { + return ErrNoSuchKeyFound + } + + readAction := func(index, locatorBytes []byte) (*Action, + error) { + + locator, err := deserializeActionLocator( + bytes.NewReader(locatorBytes), + ) + if err != nil { + return nil, err + } + + return getAction(actionsBucket, locator) + } + + actions, lastIndex, totalCount, err = paginateActions( + query, actionsIndexBucket.Cursor(), readAction, + filterFn, + ) + return err + }) + if err != nil { + return nil, 0, 0, err + } + + return actions, lastIndex, totalCount, nil +} + +// ListSessionActions returns a list of the given session's Actions that pass +// the filterFn requirements. +func (db *BoltDB) ListSessionActions(sessionID session.ID, + filterFn ListActionsFilterFn, query *ListActionsQuery) ([]*Action, + uint64, uint64, error) { + + var ( + actions []*Action + totalCount uint64 + lastIndex uint64 + ) + err := db.View(func(tx *bbolt.Tx) error { + mainActionsBucket, err := getBucket(tx, actionsBucketKey) + if err != nil { + return err + } + + actionsBucket := mainActionsBucket.Bucket(actionsKey) + if actionsBucket == nil { + return ErrNoSuchKeyFound + } + + sessionsBucket := actionsBucket.Bucket(sessionID[:]) + if sessionsBucket == nil { + return nil + } + + readAction := func(_, v []byte) (*Action, error) { + return DeserializeAction(bytes.NewReader(v), sessionID) + } + + actions, lastIndex, totalCount, err = paginateActions( + query, sessionsBucket.Cursor(), readAction, filterFn, + ) + + return err + }) + if err != nil { + return nil, 0, 0, err + } + + return actions, lastIndex, totalCount, nil +} + +// ListGroupActions returns a list of the given session group's Actions that +// pass the filterFn requirements. +// +// TODO: update to allow for pagination. +func (db *BoltDB) ListGroupActions(ctx context.Context, groupID session.ID, + filterFn ListActionsFilterFn) ([]*Action, error) { + + if filterFn == nil { + filterFn = func(a *Action, reversed bool) (bool, bool) { + return true, true + } + } + + sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID) + if err != nil { + return nil, err + } + + var ( + actions []*Action + errDone = errors.New("done iterating") + ) + err = db.View(func(tx *bbolt.Tx) error { + mainActionsBucket, err := getBucket(tx, actionsBucketKey) + if err != nil { + return err + } + + actionsBucket := mainActionsBucket.Bucket(actionsKey) + if actionsBucket == nil { + return ErrNoSuchKeyFound + } + + // Iterate over each session ID in this group. + for _, sessionID := range sessionIDs { + sessionsBucket := actionsBucket.Bucket(sessionID[:]) + if sessionsBucket == nil { + return nil + } + + err = sessionsBucket.ForEach(func(_, v []byte) error { + action, err := DeserializeAction( + bytes.NewReader(v), sessionID, + ) + if err != nil { + return err + } + + include, cont := filterFn(action, false) + if include { + actions = append(actions, action) + } + + if !cont { + return errDone + } + + return nil + }) + if err != nil { + return err + } + } + + return nil + }) + if err != nil && !errors.Is(err, errDone) { + return nil, err + } + + return actions, nil +} + +// SerializeAction binary serializes the given action to the writer using the +// tlv format. +func SerializeAction(w io.Writer, action *Action) error { + if action == nil { + return fmt.Errorf("action cannot be nil") + } + + var ( + actor = []byte(action.ActorName) + feature = []byte(action.FeatureName) + trigger = []byte(action.Trigger) + intent = []byte(action.Intent) + data = []byte(action.StructuredJsonData) + rpcMethod = []byte(action.RPCMethod) + params = action.RPCParamsJson + attemptedAt = uint64(action.AttemptedAt.Unix()) + state = uint8(action.State) + errorReason = []byte(action.ErrorReason) + ) + + tlvRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(typeActorName, &actor), + tlv.MakePrimitiveRecord(typeFeature, &feature), + tlv.MakePrimitiveRecord(typeTrigger, &trigger), + tlv.MakePrimitiveRecord(typeIntent, &intent), + tlv.MakePrimitiveRecord(typeStructuredJsonData, &data), + tlv.MakePrimitiveRecord(typeRPCMethod, &rpcMethod), + tlv.MakePrimitiveRecord(typeRPCParamsJson, ¶ms), + tlv.MakePrimitiveRecord(typeAttemptedAt, &attemptedAt), + tlv.MakePrimitiveRecord(typeState, &state), + tlv.MakePrimitiveRecord(typeErrorReason, &errorReason), + } + + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// DeserializeAction deserializes an action from the given reader, expecting +// the data to be encoded in the tlv format. +func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) { + var ( + action = Action{} + actor, featureName []byte + trigger, intent, data []byte + rpcMethod, params []byte + attemptedAt uint64 + state uint8 + errorReason []byte + ) + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(typeActorName, &actor), + tlv.MakePrimitiveRecord(typeFeature, &featureName), + tlv.MakePrimitiveRecord(typeTrigger, &trigger), + tlv.MakePrimitiveRecord(typeIntent, &intent), + tlv.MakePrimitiveRecord(typeStructuredJsonData, &data), + tlv.MakePrimitiveRecord(typeRPCMethod, &rpcMethod), + tlv.MakePrimitiveRecord(typeRPCParamsJson, ¶ms), + tlv.MakePrimitiveRecord(typeAttemptedAt, &attemptedAt), + tlv.MakePrimitiveRecord(typeState, &state), + tlv.MakePrimitiveRecord(typeErrorReason, &errorReason), + ) + if err != nil { + return nil, err + } + + _, err = tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + action.SessionID = sessionID + action.ActorName = string(actor) + action.FeatureName = string(featureName) + action.Trigger = string(trigger) + action.Intent = string(intent) + action.StructuredJsonData = string(data) + action.RPCMethod = string(rpcMethod) + action.RPCParamsJson = params + action.AttemptedAt = time.Unix(int64(attemptedAt), 0) + action.State = ActionState(state) + action.ErrorReason = string(errorReason) + + return &action, nil +} + +// serializeActionLocator binary serializes the given ActionLocator to the +// writer using the tlv format. +func serializeActionLocator(w io.Writer, al *ActionLocator) error { + if al == nil { + return fmt.Errorf("action locator cannot be nil") + } + + var ( + sessionID = al.SessionID[:] + actionID = al.ActionID + ) + + tlvRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(typeLocatorSessionID, &sessionID), + tlv.MakePrimitiveRecord(typeLocatorActionID, &actionID), + } + + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// deserializeActionLocator deserializes an ActionLocator from the given reader, +// expecting the data to be encoded in the tlv format. +func deserializeActionLocator(r io.Reader) (*ActionLocator, error) { + var ( + sessionID []byte + actionID uint64 + ) + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(typeLocatorSessionID, &sessionID), + tlv.MakePrimitiveRecord(typeLocatorActionID, &actionID), + ) + if err != nil { + return nil, err + } + + _, err = tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + id, err := session.IDFromBytes(sessionID) + if err != nil { + return nil, err + } + + return &ActionLocator{ + SessionID: id, + ActionID: actionID, + }, nil +} diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index 63a77ec45..da5dff147 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -39,7 +39,7 @@ var ( } ) -// TestActionStorage tests that the ActionsDB CRUD logic. +// TestActionStorage tests that the ActionsListDB CRUD logic. func TestActionStorage(t *testing.T) { tmpDir := t.TempDir() @@ -67,11 +67,11 @@ func TestActionStorage(t *testing.T) { require.NoError(t, err) require.Len(t, actions, 0) - id, err := db.AddAction(sessionID1, action1) + id, err := db.AddAction(action1) require.NoError(t, err) require.Equal(t, uint64(1), id) - id, err = db.AddAction(sessionID2, action2) + id, err = db.AddAction(action2) require.NoError(t, err) require.Equal(t, uint64(1), id) @@ -104,7 +104,7 @@ func TestActionStorage(t *testing.T) { action2.State = ActionStateDone require.Equal(t, action2, actions[0]) - id, err = db.AddAction(sessionID1, action1) + id, err = db.AddAction(action1) require.NoError(t, err) require.Equal(t, uint64(2), id) @@ -176,7 +176,7 @@ func TestListActions(t *testing.T) { State: ActionStateDone, } - _, err := db.AddAction(sessionID, action) + _, err := db.AddAction(action) require.NoError(t, err) } @@ -365,7 +365,7 @@ func TestListGroupActions(t *testing.T) { require.Empty(t, al) // Add an action under session 1. - _, err = db.AddAction(sessionID1, action1) + _, err = db.AddAction(action1) require.NoError(t, err) // There should now be one action in the group. @@ -375,7 +375,7 @@ func TestListGroupActions(t *testing.T) { require.Equal(t, sessionID1, al[0].SessionID) // Add an action under session 2. - _, err = db.AddAction(sessionID2, action2) + _, err = db.AddAction(action2) require.NoError(t, err) // There should now be actions in the group. diff --git a/rules/config.go b/rules/config.go index a109c351d..2bd92f014 100644 --- a/rules/config.go +++ b/rules/config.go @@ -16,7 +16,7 @@ type Config interface { // GetActionsDB can be used by rules to list any past actions that were // made for the specific session or feature. - GetActionsDB() firewalldb.ActionsDB + GetActionsDB() firewalldb.ActionsListDB // GetMethodPerms returns a map that contains URIs and the permissions // required to use them. @@ -48,7 +48,7 @@ type ConfigImpl struct { // ActionsDB can be used by rules to list any past actions that were // made for the specific session or feature. - ActionsDB firewalldb.ActionsDB + ActionsDB firewalldb.ActionsListDB // MethodPerms is a function that can be used to fetch the permissions // required for a URI. @@ -76,7 +76,7 @@ func (c *ConfigImpl) GetStores() firewalldb.KVStores { } // GetActionsDB returns the list of past actions. -func (c *ConfigImpl) GetActionsDB() firewalldb.ActionsDB { +func (c *ConfigImpl) GetActionsDB() firewalldb.ActionsListDB { return c.ActionsDB } diff --git a/rules/rate_limit.go b/rules/rate_limit.go index f324721a0..df2302bff 100644 --- a/rules/rate_limit.go +++ b/rules/rate_limit.go @@ -87,7 +87,7 @@ func (r *RateLimitMgr) EmptyValue() Values { // rateLimitConfig is the config required by RateLimitMgr. It can be derived // from the main rules Config struct. type rateLimitConfig interface { - GetActionsDB() firewalldb.ActionsDB + GetActionsDB() firewalldb.ActionsListDB GetMethodPerms() func(string) ([]bakery.Op, bool) } diff --git a/rules/rate_limit_test.go b/rules/rate_limit_test.go index 257232b65..1f291d291 100644 --- a/rules/rate_limit_test.go +++ b/rules/rate_limit_test.go @@ -216,7 +216,7 @@ type mockRateLimitCfg struct { var _ rateLimitConfig = (*mockRateLimitCfg)(nil) -func (m *mockRateLimitCfg) GetActionsDB() firewalldb.ActionsDB { +func (m *mockRateLimitCfg) GetActionsDB() firewalldb.ActionsListDB { return m.db } @@ -233,7 +233,7 @@ type mockActionsDB struct { actions []*firewalldb.RuleAction } -var _ firewalldb.ActionsDB = (*mockActionsDB)(nil) +var _ firewalldb.ActionsListDB = (*mockActionsDB)(nil) func (m *mockActionsDB) addAction(uri string, timestamp time.Time) { m.actions = append(m.actions, &firewalldb.RuleAction{