diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 78188d1b6c..c63983b41a 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -157,6 +157,18 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) + v1unstablemux.Handle("/rooms/{roomId}/threads", + httputil.MakeAuthAPI("threads", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return Threads( + req, device, syncDB, rsAPI, vars["roomId"], + ) + })).Methods(http.MethodGet) + v3mux.Handle("/search", httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if !cfg.Fulltext.Enabled { @@ -200,4 +212,5 @@ func Setup( return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, membership, notMembership, at) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) + } diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go new file mode 100644 index 0000000000..815c6326ca --- /dev/null +++ b/syncapi/routing/threads.go @@ -0,0 +1,107 @@ +package routing + +import ( + rstypes "github.com/matrix-org/dendrite/roomserver/types" + "net/http" + "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type ThreadsResponse struct { + Chunk []synctypes.ClientEvent `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` +} + +func Threads( + req *http.Request, + device *userapi.Device, + syncDB storage.Database, + rsAPI api.SyncRoomserverAPI, + rawRoomID string) util.JSONResponse { + var err error + roomID, err := spec.NewRoomID(rawRoomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("invalid room ID"), + } + } + + limit, err := strconv.ParseUint(req.URL.Query().Get("limit"), 10, 64) + if err != nil { + limit = 50 + } + if limit > 100 { + limit = 100 + } + + var from types.StreamPosition + if f := req.URL.Query().Get("from"); f != "" { + if from, err = types.NewStreamPositionFromString(f); err != nil { + return util.ErrorResponse(err) + } + } + + include := req.URL.Query().Get("include") + + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to get snapshot for relations") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + res := &ThreadsResponse{ + Chunk: []synctypes.ClientEvent{}, + } + + var userID string + if include == "participated" { + _, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("internal server error"), + } + } + userID = device.UserID + } else { + userID = "" + } + var headeredEvents []*rstypes.HeaderedEvent + headeredEvents, _, res.NextBatch, err = snapshot.ThreadsFor( + req.Context(), roomID.String(), userID, from, limit, + ) + if err != nil { + return util.ErrorResponse(err) + } + + for _, event := range headeredEvents { + ce, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) + if err != nil { + return util.ErrorResponse(err) + } + res.Chunk = append(res.Chunk, *ce) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } +} diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 5a76e9c336..97f5b828fc 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -63,6 +63,23 @@ const selectRelationsInRangeDescSQL = "" + " AND id >= $5 AND id < $6" + " ORDER BY id DESC LIMIT $7" +const selectThreadsSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $2" + + " ORDER BY syncapi_relations.id LIMIT $3" + +const selectThreadsWithSenderSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_output_room_events.sender = $2" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $3" + + " ORDER BY syncapi_relations.id LIMIT $4" + const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" @@ -70,6 +87,8 @@ type relationsStatements struct { insertRelationStmt *sql.Stmt selectRelationsInRangeAscStmt *sql.Stmt selectRelationsInRangeDescStmt *sql.Stmt + selectThreadsStmt *sql.Stmt + selectThreadsWithSenderStmt *sql.Stmt deleteRelationStmt *sql.Stmt selectMaxRelationIDStmt *sql.Stmt } @@ -84,6 +103,8 @@ func NewPostgresRelationsTable(db *sql.DB) (tables.Relations, error) { {&s.insertRelationStmt, insertRelationSQL}, {&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL}, {&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL}, + {&s.selectThreadsStmt, selectThreadsSQL}, + {&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL}, {&s.deleteRelationStmt, deleteRelationSQL}, {&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL}, }.Prepare(db) @@ -149,6 +170,49 @@ func (s *relationsStatements) SelectRelationsInRange( return result, lastPos, rows.Err() } +func (s *relationsStatements) SelectThreads( + ctx context.Context, + txn *sql.Tx, + roomID, userID string, + from types.StreamPosition, + limit uint64, +) ([]string, types.StreamPosition, error) { + var lastPos types.StreamPosition + var stmt *sql.Stmt + var rows *sql.Rows + var err error + + if userID == "" { + stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) + rows, err = stmt.QueryContext(ctx, roomID, from, limit) + } else { + stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) + rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit) + } + if err != nil { + return nil, lastPos, err + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") + var result []string + var ( + id types.StreamPosition + eventId string + ) + + for rows.Next() { + if err = rows.Scan(&id, &eventId); err != nil { + return nil, lastPos, err + } + if id > lastPos { + lastPos = id + } + result = append(result, eventId) + } + + return result, lastPos, rows.Err() +} + func (s *relationsStatements) SelectMaxRelationID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index cd17fdc699..1c52df0eeb 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -811,3 +811,39 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } + +func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit uint64) ( + events []*rstypes.HeaderedEvent, prevBatch, nextBatch string, err error, +) { + r := types.Range{ + From: from, + } + + if r.From == 0 { + // If we're working backwards (dir=b) and there's no ?from= specified then + // we will automatically want to work backwards from the current position, + // so find out what that is. + if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil { + return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err) + } + // The result normally isn't inclusive of the event *at* the ?from= + // position, so add 1 here so that we include the most recent relation. + r.From++ + } + + // First look up any threads from the database. We add one to the limit here + // so that we can tell if we're overflowing, as we will only set the "next_batch" + // in the response if we are. + eventIDs, pos, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, from, limit+1) + + if err != nil { + return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) + } + + events, err = d.Events(ctx, eventIDs) + if err != nil { + return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) + } + + return events, prevBatch, fmt.Sprintf("%d", pos), nil +} diff --git a/syncapi/storage/sqlite3/relations_table.go b/syncapi/storage/sqlite3/relations_table.go index 7cbb5408f1..917ac2bc5e 100644 --- a/syncapi/storage/sqlite3/relations_table.go +++ b/syncapi/storage/sqlite3/relations_table.go @@ -64,11 +64,30 @@ const selectRelationsInRangeDescSQL = "" + const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" +const selectThreadsSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $2" + + " ORDER BY syncapi_relations.id LIMIT $3" + +const selectThreadsWithSenderSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_output_room_events.sender = $2" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $3" + + " ORDER BY syncapi_relations.id LIMIT $4" + type relationsStatements struct { streamIDStatements *StreamIDStatements insertRelationStmt *sql.Stmt selectRelationsInRangeAscStmt *sql.Stmt selectRelationsInRangeDescStmt *sql.Stmt + selectThreadsStmt *sql.Stmt + selectThreadsWithSenderStmt *sql.Stmt deleteRelationStmt *sql.Stmt selectMaxRelationIDStmt *sql.Stmt } @@ -85,6 +104,8 @@ func NewSqliteRelationsTable(db *sql.DB, streamID *StreamIDStatements) (tables.R {&s.insertRelationStmt, insertRelationSQL}, {&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL}, {&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL}, + {&s.selectThreadsStmt, selectThreadsSQL}, + {&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL}, {&s.deleteRelationStmt, deleteRelationSQL}, {&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL}, }.Prepare(db) @@ -154,6 +175,49 @@ func (s *relationsStatements) SelectRelationsInRange( return result, lastPos, rows.Err() } +func (s *relationsStatements) SelectThreads( + ctx context.Context, + txn *sql.Tx, + roomID, userID string, + from types.StreamPosition, + limit uint64, +) ([]string, types.StreamPosition, error) { + var lastPos types.StreamPosition + var stmt *sql.Stmt + var rows *sql.Rows + var err error + + if userID == "" { + stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) + rows, err = stmt.QueryContext(ctx, roomID, from, limit) + } else { + stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) + rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit) + } + if err != nil { + return nil, lastPos, err + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") + var result []string + var ( + id types.StreamPosition + eventId string + ) + + for rows.Next() { + if err = rows.Scan(&id, &eventId); err != nil { + return nil, lastPos, err + } + if id > lastPos { + lastPos = id + } + result = append(result, eventId) + } + + return result, lastPos, rows.Err() +} + func (s *relationsStatements) SelectMaxRelationID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 45117d6d30..14105a15ad 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -223,10 +223,10 @@ type Presence interface { } type Relations interface { - // Inserts a relation which refers from the child event ID to the event ID in the given room. + // InsertRelation Inserts a relation which refers from the child event ID to the event ID in the given room. // If the relation already exists then this function will do nothing and return no error. InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string) (err error) - // Deletes a relation which already exists as the result of an event redaction. If the relation + // DeleteRelation Deletes a relation which already exists as the result of an event redaction. If the relation // does not exist then this function will do nothing and return no error. DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error // SelectRelationsInRange will return relations grouped by relation type within the given range. @@ -235,6 +235,9 @@ type Relations interface { // will be returned, inclusive of the "to" position but excluding the "from" position. The stream // position returned is the maximum position of the returned results. SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) + // SelectThreads will find threads from a room, if userID is not empty + // then it will only include the threads that the user has participated in. + SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, from types.StreamPosition, limit uint64) ([]string, types.StreamPosition, error) // SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries // should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a // "from" or want to work forwards and don't have a "to"). diff --git a/syncapi/storage/tables/relations_test.go b/syncapi/storage/tables/relations_test.go index 46270e36dc..05fda6399d 100644 --- a/syncapi/storage/tables/relations_test.go +++ b/syncapi/storage/tables/relations_test.go @@ -24,7 +24,22 @@ func newRelationsTable(t *testing.T, dbType test.DBType) (tables.Relations, *sql t.Fatalf("failed to open db: %s", err) } + switch dbType { + case test.DBTypePostgres: + _, err = postgres.NewPostgresEventsTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + _, err = sqlite3.NewSqliteEventsTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + var tab tables.Relations + switch dbType { case test.DBTypePostgres: tab, err = postgres.NewPostgresRelationsTable(db) @@ -184,3 +199,53 @@ func TestRelationsTable(t *testing.T) { } }) } + +const threadRelType = "m.thread" + +func TestThreads(t *testing.T) { + var err error + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + firstEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{ + "body": "first message", + }) + threadReplyEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{ + "body": "thread reply", + "m.relates_to": map[string]interface{}{ + "event_id": firstEvent.EventID(), + "rel_type": threadRelType, + }, + }) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, _, close := newRelationsTable(t, dbType) + defer close() + + err = tab.InsertRelation(ctx, nil, room.ID, firstEvent.EventID(), threadReplyEvent.EventID(), "m.room.message", threadReplyEvent.EventID()) + if err != nil { + t.Fatal(err) + } + var eventIds []string + eventIds, _, err = tab.SelectThreads(ctx, nil, room.ID, "", 0, 100) + + for i, expected := range []string{ + firstEvent.EventID(), + } { + eventID := eventIds[i] + if eventID != expected { + t.Fatalf("eventID mismatch: got %s, want %s", eventID, expected) + } + } + eventIds, _, err = tab.SelectThreads(ctx, nil, room.ID, alice.ID, 0, 100) + for i, expected := range []string{ + firstEvent.EventID(), + } { + eventID := eventIds[i] + if eventID != expected { + t.Fatalf("eventID mismatch: got %s, want %s", eventID, expected) + } + } + }) +}