From faeb0b4ba0149eb49512283734997c502d3e4a24 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Wed, 31 Jul 2024 13:55:58 +0800 Subject: [PATCH 01/10] half done --- syncapi/routing/threads.go | 77 +++++++++++++++++++++ syncapi/storage/postgres/relations_table.go | 72 ++++++++++++++++++- syncapi/storage/shared/storage_sync.go | 63 +++++++++++++++++ syncapi/storage/tables/interface.go | 7 +- 4 files changed, 216 insertions(+), 3 deletions(-) create mode 100644 syncapi/routing/threads.go diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go new file mode 100644 index 0000000000..6d0e2518a7 --- /dev/null +++ b/syncapi/routing/threads.go @@ -0,0 +1,77 @@ +package routing + +import ( + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "net/http" + "strconv" +) + +type ThreadsResponse struct { + Chunk []synctypes.ClientEvent `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` +} + +func Threads( + req *http.Request, + device userapi.Device, + syncDB storage.Database, + rsAPI api.SyncRoomserverAPI, + rawRoomID string) util.JSONResponse { + var err error + roomID, err := spec.NewRoomID(rawRoomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("invalid room ID"), + } + } + + limit, err := strconv.ParseUint(req.URL.Query().Get("limit"), 10, 64) + if err != nil { + limit = 50 + } + if limit > 100 { + limit = 100 + } + + from := req.URL.Query().Get("from") + include := req.URL.Query().Get("include") + + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to get snapshot for relations") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + res := &ThreadsResponse{ + Chunk: []synctypes.ClientEvent{}, + } + + if include == "participated" { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("internal server error"), + } + } + var events []types.StreamEvent + events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( + req.Context(), roomID.String(), "", relType, eventType, from, to, dir == "b", limit, + ) + } +} diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 5a76e9c336..861eb3d3f0 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -17,9 +17,10 @@ package postgres import ( "context" "database/sql" - + "encoding/json" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + types2 "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -63,6 +64,23 @@ const selectRelationsInRangeDescSQL = "" + " AND id >= $5 AND id < $6" + " ORDER BY id DESC LIMIT $7" +const selectThreadsSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $2 AND syncapi_relations.id < $3" + + " ORDER BY syncapi_relations.id LIMIT $4" + +const selectThreadsWithSenderSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_output_room_events.sender = $2" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $3 AND syncapi_relations.id < $4" + + " ORDER BY syncapi_relations.id LIMIT $5" + const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" @@ -70,6 +88,8 @@ type relationsStatements struct { insertRelationStmt *sql.Stmt selectRelationsInRangeAscStmt *sql.Stmt selectRelationsInRangeDescStmt *sql.Stmt + selectThreadsStmt *sql.Stmt + selectThreadsWithSenderStmt *sql.Stmt deleteRelationStmt *sql.Stmt selectMaxRelationIDStmt *sql.Stmt } @@ -84,6 +104,8 @@ func NewPostgresRelationsTable(db *sql.DB) (tables.Relations, error) { {&s.insertRelationStmt, insertRelationSQL}, {&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL}, {&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL}, + {&s.selectThreadsStmt, selectThreadsSQL}, + {&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL}, {&s.deleteRelationStmt, deleteRelationSQL}, {&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL}, }.Prepare(db) @@ -149,6 +171,54 @@ func (s *relationsStatements) SelectRelationsInRange( return result, lastPos, rows.Err() } +func (s *relationsStatements) SelectThreads( + ctx context.Context, + txn *sql.Tx, + roomID, userID string, + r types.Range, + limit int, +) ([]map[string]any, types.StreamPosition, error) { + var lastPos types.StreamPosition + var stmt *sql.Stmt + var rows *sql.Rows + var err error + + if userID == "" { + stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) + rows, err = stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) + } else { + stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) + rows, err = stmt.QueryContext(ctx, roomID, userID, r.Low(), r.High(), limit) + } + if err != nil { + return nil, lastPos, err + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") + var result []map[string]any + var ( + id types.StreamPosition + childEventID string + sender string + eventId string + headeredEventJson string + ) + + for rows.Next() { + if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson); err != nil { + return nil, lastPos, err + } + if id > lastPos { + lastPos = id + } + var event types2.HeaderedEvent + json.Unmarshal(event) + result = append(result) + } + + return result, lastPos, rows.Err() +} + func (s *relationsStatements) SelectMaxRelationID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index cd17fdc699..0c70025e33 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -811,3 +811,66 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } + +func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit int) ( + events []types.StreamEvent, prevBatch, nextBatch string, err error, +) { + r := types.Range{ + From: from, + } + + if r.From == 0 { + // If we're working backwards (dir=b) and there's no ?from= specified then + // we will automatically want to work backwards from the current position, + // so find out what that is. + if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil { + return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err) + } + // The result normally isn't inclusive of the event *at* the ?from= + // position, so add 1 here so that we include the most recent relation. + r.From++ + } + + // First look up any relations from the database. We add one to the limit here + // so that we can tell if we're overflowing, as we will only set the "next_batch" + // in the response if we are. + relations, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, limit+1) + if err != nil { + return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) + } + + // If we specified a relation type then just get those results, otherwise collate + // them from all of the returned relation types. + entries := []types.RelationEntry{} + for _, e := range relations { + entries = append(entries, e...) + } + + // If there were no entries returned, there were no relations, so stop at this point. + if len(entries) == 0 { + return nil, "", "", nil + } + + // Otherwise, let's try and work out what sensible prev_batch and next_batch values + // could be. We've requested an extra event by adding one to the limit already so + // that we can determine whether or not to provide a "next_batch", so trim off that + // event off the end if needs be. + if len(entries) > limit { + entries = entries[:len(entries)-1] + nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position) + } + // TODO: set prevBatch? doesn't seem to affect the tests... + + // Extract all of the event IDs from the relation entries so that we can pull the + // events out of the database. Then go and fetch the events. + eventIDs := make([]string, 0, len(entries)) + for _, entry := range entries { + eventIDs = append(eventIDs, entry.EventID) + } + events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true) + if err != nil { + return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) + } + + return events, prevBatch, nextBatch, nil +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 45117d6d30..adf46a70fa 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -223,10 +223,10 @@ type Presence interface { } type Relations interface { - // Inserts a relation which refers from the child event ID to the event ID in the given room. + // InsertRelation Inserts a relation which refers from the child event ID to the event ID in the given room. // If the relation already exists then this function will do nothing and return no error. InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string) (err error) - // Deletes a relation which already exists as the result of an event redaction. If the relation + // DeleteRelation Deletes a relation which already exists as the result of an event redaction. If the relation // does not exist then this function will do nothing and return no error. DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error // SelectRelationsInRange will return relations grouped by relation type within the given range. @@ -235,6 +235,9 @@ type Relations interface { // will be returned, inclusive of the "to" position but excluding the "from" position. The stream // position returned is the maximum position of the returned results. SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) + // SelectThreads this will find some threads from a room + // if userID is not empty then it will only include the threads that the user has participated + SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) // SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries // should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a // "from" or want to work forwards and don't have a "to"). From 9934f9543b6d7a57a7569a10346b91223824e862 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 15:24:12 +0800 Subject: [PATCH 02/10] closer to complete --- syncapi/routing/threads.go | 42 +++++++++++++++++---- syncapi/storage/postgres/relations_table.go | 34 ++++++++--------- syncapi/storage/shared/storage_sync.go | 39 +++---------------- syncapi/storage/tables/interface.go | 2 +- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go index 6d0e2518a7..175ece1977 100644 --- a/syncapi/routing/threads.go +++ b/syncapi/routing/threads.go @@ -1,6 +1,10 @@ package routing import ( + rstypes "github.com/matrix-org/dendrite/roomserver/types" + "net/http" + "strconv" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" @@ -10,8 +14,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "net/http" - "strconv" ) type ThreadsResponse struct { @@ -42,7 +44,13 @@ func Threads( limit = 100 } - from := req.URL.Query().Get("from") + var from types.StreamPosition + if f := req.URL.Query().Get("from"); f != "" { + if from, err = types.NewStreamPositionFromString(f); err != nil { + return util.ErrorResponse(err) + } + } + include := req.URL.Query().Get("include") snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) @@ -60,8 +68,9 @@ func Threads( Chunk: []synctypes.ClientEvent{}, } + var userID string if include == "participated" { - userID, err := spec.NewUserID(device.UserID, true) + _, err := spec.NewUserID(device.UserID, true) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid") return util.JSONResponse{ @@ -69,9 +78,26 @@ func Threads( JSON: spec.Unknown("internal server error"), } } - var events []types.StreamEvent - events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( - req.Context(), roomID.String(), "", relType, eventType, from, to, dir == "b", limit, - ) + userID = device.UserID + } else { + userID = "" + } + var headeredEvents []*rstypes.HeaderedEvent + headeredEvents, _, res.NextBatch, err = snapshot.ThreadsFor( + req.Context(), roomID.String(), userID, from, limit, + ) + + for _, event := range headeredEvents { + ce, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) + if err != nil { + return util.ErrorResponse(err) + } + res.Chunk = append(res.Chunk, *ce) + } + + return util.JSONResponse{ + JSON: res, } } diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 861eb3d3f0..0607c9c265 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -17,10 +17,9 @@ package postgres import ( "context" "database/sql" - "encoding/json" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - types2 "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -65,21 +64,21 @@ const selectRelationsInRangeDescSQL = "" + " ORDER BY id DESC LIMIT $7" const selectThreadsSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_relations.rel_type = 'm.thread'" + - " AND syncapi_relations.id >= $2 AND syncapi_relations.id < $3" + - " ORDER BY syncapi_relations.id LIMIT $4" + " AND syncapi_relations.id >= $2 AND" + + " ORDER BY syncapi_relations.id LIMIT $3" const selectThreadsWithSenderSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_output_room_events.sender = $2" + " AND syncapi_relations.rel_type = 'm.thread'" + - " AND syncapi_relations.id >= $3 AND syncapi_relations.id < $4" + - " ORDER BY syncapi_relations.id LIMIT $5" + " AND syncapi_relations.id >= $3" + + " ORDER BY syncapi_relations.id LIMIT $4" const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" @@ -175,9 +174,9 @@ func (s *relationsStatements) SelectThreads( ctx context.Context, txn *sql.Tx, roomID, userID string, - r types.Range, - limit int, -) ([]map[string]any, types.StreamPosition, error) { + from types.StreamPosition, + limit uint64, +) ([]string, types.StreamPosition, error) { var lastPos types.StreamPosition var stmt *sql.Stmt var rows *sql.Rows @@ -185,35 +184,34 @@ func (s *relationsStatements) SelectThreads( if userID == "" { stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) - rows, err = stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) + rows, err = stmt.QueryContext(ctx, roomID, from, limit) } else { stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) - rows, err = stmt.QueryContext(ctx, roomID, userID, r.Low(), r.High(), limit) + rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit) } if err != nil { return nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") - var result []map[string]any + var result []string var ( id types.StreamPosition childEventID string sender string eventId string headeredEventJson string + eventType string ) for rows.Next() { - if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson); err != nil { + if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson, &eventType); err != nil { return nil, lastPos, err } if id > lastPos { lastPos = id } - var event types2.HeaderedEvent - json.Unmarshal(event) - result = append(result) + result = append(result, eventId) } return result, lastPos, rows.Err() diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 0c70025e33..a7d783d6f2 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -812,8 +812,8 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } -func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit int) ( - events []types.StreamEvent, prevBatch, nextBatch string, err error, +func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit uint64) ( + events []*rstypes.HeaderedEvent, prevBatch, nextBatch string, err error, ) { r := types.Range{ From: from, @@ -831,43 +831,16 @@ func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID str r.From++ } - // First look up any relations from the database. We add one to the limit here + // First look up any threads from the database. We add one to the limit here // so that we can tell if we're overflowing, as we will only set the "next_batch" // in the response if we are. - relations, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, limit+1) + eventIDs, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, from, limit+1) + if err != nil { return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) } - // If we specified a relation type then just get those results, otherwise collate - // them from all of the returned relation types. - entries := []types.RelationEntry{} - for _, e := range relations { - entries = append(entries, e...) - } - - // If there were no entries returned, there were no relations, so stop at this point. - if len(entries) == 0 { - return nil, "", "", nil - } - - // Otherwise, let's try and work out what sensible prev_batch and next_batch values - // could be. We've requested an extra event by adding one to the limit already so - // that we can determine whether or not to provide a "next_batch", so trim off that - // event off the end if needs be. - if len(entries) > limit { - entries = entries[:len(entries)-1] - nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position) - } - // TODO: set prevBatch? doesn't seem to affect the tests... - - // Extract all of the event IDs from the relation entries so that we can pull the - // events out of the database. Then go and fetch the events. - eventIDs := make([]string, 0, len(entries)) - for _, entry := range entries { - eventIDs = append(eventIDs, entry.EventID) - } - events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true) + events, err = d.Events(ctx, eventIDs) if err != nil { return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index adf46a70fa..5470349a0b 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -237,7 +237,7 @@ type Relations interface { SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) // SelectThreads this will find some threads from a room // if userID is not empty then it will only include the threads that the user has participated - SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) + SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, from types.StreamPosition, limit uint64) ([]string, types.StreamPosition, error) // SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries // should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a // "from" or want to work forwards and don't have a "to"). From 8e9dd3f1f8fbecc6454e06a439d554f1bb542cff Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 15:30:58 +0800 Subject: [PATCH 03/10] cool now it's done I guess --- syncapi/storage/postgres/relations_table.go | 14 ++--- syncapi/storage/sqlite3/relations_table.go | 64 +++++++++++++++++++++ 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 0607c9c265..2b710a3a90 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -64,7 +64,7 @@ const selectRelationsInRangeDescSQL = "" + " ORDER BY id DESC LIMIT $7" const selectThreadsSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_relations.rel_type = 'm.thread'" + @@ -72,7 +72,7 @@ const selectThreadsSQL = "" + " ORDER BY syncapi_relations.id LIMIT $3" const selectThreadsWithSenderSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_output_room_events.sender = $2" + @@ -196,16 +196,12 @@ func (s *relationsStatements) SelectThreads( defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") var result []string var ( - id types.StreamPosition - childEventID string - sender string - eventId string - headeredEventJson string - eventType string + id types.StreamPosition + eventId string ) for rows.Next() { - if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson, &eventType); err != nil { + if err = rows.Scan(&id, &eventId); err != nil { return nil, lastPos, err } if id > lastPos { diff --git a/syncapi/storage/sqlite3/relations_table.go b/syncapi/storage/sqlite3/relations_table.go index 7cbb5408f1..512178c5bf 100644 --- a/syncapi/storage/sqlite3/relations_table.go +++ b/syncapi/storage/sqlite3/relations_table.go @@ -64,11 +64,30 @@ const selectRelationsInRangeDescSQL = "" + const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" +const selectThreadsSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $2 AND" + + " ORDER BY syncapi_relations.id LIMIT $3" + +const selectThreadsWithSenderSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_output_room_events.sender = $2" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $3" + + " ORDER BY syncapi_relations.id LIMIT $4" + type relationsStatements struct { streamIDStatements *StreamIDStatements insertRelationStmt *sql.Stmt selectRelationsInRangeAscStmt *sql.Stmt selectRelationsInRangeDescStmt *sql.Stmt + selectThreadsStmt *sql.Stmt + selectThreadsWithSenderStmt *sql.Stmt deleteRelationStmt *sql.Stmt selectMaxRelationIDStmt *sql.Stmt } @@ -85,6 +104,8 @@ func NewSqliteRelationsTable(db *sql.DB, streamID *StreamIDStatements) (tables.R {&s.insertRelationStmt, insertRelationSQL}, {&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL}, {&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL}, + {&s.selectThreadsStmt, selectThreadsSQL}, + {&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL}, {&s.deleteRelationStmt, deleteRelationSQL}, {&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL}, }.Prepare(db) @@ -154,6 +175,49 @@ func (s *relationsStatements) SelectRelationsInRange( return result, lastPos, rows.Err() } +func (s *relationsStatements) SelectThreads( + ctx context.Context, + txn *sql.Tx, + roomID, userID string, + from types.StreamPosition, + limit uint64, +) ([]string, types.StreamPosition, error) { + var lastPos types.StreamPosition + var stmt *sql.Stmt + var rows *sql.Rows + var err error + + if userID == "" { + stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) + rows, err = stmt.QueryContext(ctx, roomID, from, limit) + } else { + stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) + rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit) + } + if err != nil { + return nil, lastPos, err + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") + var result []string + var ( + id types.StreamPosition + eventId string + ) + + for rows.Next() { + if err = rows.Scan(&id, &eventId); err != nil { + return nil, lastPos, err + } + if id > lastPos { + lastPos = id + } + result = append(result, eventId) + } + + return result, lastPos, rows.Err() +} + func (s *relationsStatements) SelectMaxRelationID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { From a6f5918715ba50de5bcaac8c319258a63953dc00 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 15:35:47 +0800 Subject: [PATCH 04/10] oo there is a extra and --- syncapi/storage/postgres/relations_table.go | 2 +- syncapi/storage/sqlite3/relations_table.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 2b710a3a90..97f5b828fc 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -68,7 +68,7 @@ const selectThreadsSQL = "" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_relations.rel_type = 'm.thread'" + - " AND syncapi_relations.id >= $2 AND" + + " AND syncapi_relations.id >= $2" + " ORDER BY syncapi_relations.id LIMIT $3" const selectThreadsWithSenderSQL = "" + diff --git a/syncapi/storage/sqlite3/relations_table.go b/syncapi/storage/sqlite3/relations_table.go index 512178c5bf..917ac2bc5e 100644 --- a/syncapi/storage/sqlite3/relations_table.go +++ b/syncapi/storage/sqlite3/relations_table.go @@ -69,7 +69,7 @@ const selectThreadsSQL = "" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_relations.rel_type = 'm.thread'" + - " AND syncapi_relations.id >= $2 AND" + + " AND syncapi_relations.id >= $2" + " ORDER BY syncapi_relations.id LIMIT $3" const selectThreadsWithSenderSQL = "" + From 0f80b26b4e81184d6e917c370ab346d51c479907 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 15:46:48 +0800 Subject: [PATCH 05/10] add the route --- syncapi/routing/routing.go | 11 +++++++++++ syncapi/routing/threads.go | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 78188d1b6c..bf1d9e9780 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -157,6 +157,16 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) + v1unstablemux.Handle("/rooms/{roomId}/threads", httputil.MakeAuthAPI("threads", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return Threads( + req, device, syncDB, rsAPI, vars["roomId"], + ) + })) v3mux.Handle("/search", httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if !cfg.Fulltext.Enabled { @@ -200,4 +210,5 @@ func Setup( return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, membership, notMembership, at) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) + } diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go index 175ece1977..76e50a55be 100644 --- a/syncapi/routing/threads.go +++ b/syncapi/routing/threads.go @@ -23,7 +23,7 @@ type ThreadsResponse struct { func Threads( req *http.Request, - device userapi.Device, + device *userapi.Device, syncDB storage.Database, rsAPI api.SyncRoomserverAPI, rawRoomID string) util.JSONResponse { From 06fa53cbbb5e466662df4f4aab5c70960e9f7c0a Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 19:11:25 +0800 Subject: [PATCH 06/10] set methods to get only --- syncapi/routing/routing.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index bf1d9e9780..e98e05239a 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -166,7 +166,8 @@ func Setup( return Threads( req, device, syncDB, rsAPI, vars["roomId"], ) - })) + })).Methods(http.MethodGet) + v3mux.Handle("/search", httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if !cfg.Fulltext.Enabled { From 133c77d93b04d5f3a806c89334bcae6e6a992436 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 19:15:20 +0800 Subject: [PATCH 07/10] make it respond with status ok --- syncapi/routing/routing.go | 19 ++++++++++--------- syncapi/routing/threads.go | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index e98e05239a..c63983b41a 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -157,16 +157,17 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) - v1unstablemux.Handle("/rooms/{roomId}/threads", httputil.MakeAuthAPI("threads", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } + v1unstablemux.Handle("/rooms/{roomId}/threads", + httputil.MakeAuthAPI("threads", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } - return Threads( - req, device, syncDB, rsAPI, vars["roomId"], - ) - })).Methods(http.MethodGet) + return Threads( + req, device, syncDB, rsAPI, vars["roomId"], + ) + })).Methods(http.MethodGet) v3mux.Handle("/search", httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go index 76e50a55be..2865874a89 100644 --- a/syncapi/routing/threads.go +++ b/syncapi/routing/threads.go @@ -98,6 +98,7 @@ func Threads( } return util.JSONResponse{ + Code: http.StatusOK, JSON: res, } } From 1f44a77037e762b19fa8098a96376c61081d018e Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 19:25:26 +0800 Subject: [PATCH 08/10] update --- syncapi/routing/threads.go | 3 +++ syncapi/storage/shared/storage_sync.go | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go index 2865874a89..815c6326ca 100644 --- a/syncapi/routing/threads.go +++ b/syncapi/routing/threads.go @@ -86,6 +86,9 @@ func Threads( headeredEvents, _, res.NextBatch, err = snapshot.ThreadsFor( req.Context(), roomID.String(), userID, from, limit, ) + if err != nil { + return util.ErrorResponse(err) + } for _, event := range headeredEvents { ce, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index a7d783d6f2..1c52df0eeb 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -834,7 +834,7 @@ func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID str // First look up any threads from the database. We add one to the limit here // so that we can tell if we're overflowing, as we will only set the "next_batch" // in the response if we are. - eventIDs, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, from, limit+1) + eventIDs, pos, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, from, limit+1) if err != nil { return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) @@ -845,5 +845,5 @@ func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID str return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) } - return events, prevBatch, nextBatch, nil + return events, prevBatch, fmt.Sprintf("%d", pos), nil } From bf6a960ccbf397097e473b8191f476a04e225beb Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Thu, 15 Aug 2024 14:02:26 +0800 Subject: [PATCH 09/10] add test for threads --- syncapi/storage/tables/relations_test.go | 65 ++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/syncapi/storage/tables/relations_test.go b/syncapi/storage/tables/relations_test.go index 46270e36dc..05fda6399d 100644 --- a/syncapi/storage/tables/relations_test.go +++ b/syncapi/storage/tables/relations_test.go @@ -24,7 +24,22 @@ func newRelationsTable(t *testing.T, dbType test.DBType) (tables.Relations, *sql t.Fatalf("failed to open db: %s", err) } + switch dbType { + case test.DBTypePostgres: + _, err = postgres.NewPostgresEventsTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + _, err = sqlite3.NewSqliteEventsTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + var tab tables.Relations + switch dbType { case test.DBTypePostgres: tab, err = postgres.NewPostgresRelationsTable(db) @@ -184,3 +199,53 @@ func TestRelationsTable(t *testing.T) { } }) } + +const threadRelType = "m.thread" + +func TestThreads(t *testing.T) { + var err error + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + firstEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{ + "body": "first message", + }) + threadReplyEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{ + "body": "thread reply", + "m.relates_to": map[string]interface{}{ + "event_id": firstEvent.EventID(), + "rel_type": threadRelType, + }, + }) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, _, close := newRelationsTable(t, dbType) + defer close() + + err = tab.InsertRelation(ctx, nil, room.ID, firstEvent.EventID(), threadReplyEvent.EventID(), "m.room.message", threadReplyEvent.EventID()) + if err != nil { + t.Fatal(err) + } + var eventIds []string + eventIds, _, err = tab.SelectThreads(ctx, nil, room.ID, "", 0, 100) + + for i, expected := range []string{ + firstEvent.EventID(), + } { + eventID := eventIds[i] + if eventID != expected { + t.Fatalf("eventID mismatch: got %s, want %s", eventID, expected) + } + } + eventIds, _, err = tab.SelectThreads(ctx, nil, room.ID, alice.ID, 0, 100) + for i, expected := range []string{ + firstEvent.EventID(), + } { + eventID := eventIds[i] + if eventID != expected { + t.Fatalf("eventID mismatch: got %s, want %s", eventID, expected) + } + } + }) +} From 1b4fc3728f99a996ece7f82c5eafadb97cb1c870 Mon Sep 17 00:00:00 2001 From: Lukas <77849373+qwqtoday@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:29:04 +0800 Subject: [PATCH 10/10] make the explainitation better Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> --- syncapi/storage/tables/interface.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 5470349a0b..14105a15ad 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -235,8 +235,8 @@ type Relations interface { // will be returned, inclusive of the "to" position but excluding the "from" position. The stream // position returned is the maximum position of the returned results. SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) - // SelectThreads this will find some threads from a room - // if userID is not empty then it will only include the threads that the user has participated + // SelectThreads will find threads from a room, if userID is not empty + // then it will only include the threads that the user has participated in. SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, from types.StreamPosition, limit uint64) ([]string, types.StreamPosition, error) // SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries // should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a