[bugfix] avoid v. long notification clear query (#3007)

This commit is contained in:
tobi 2024-06-14 12:14:55 +02:00 committed by GitHub
parent b789fe2bc7
commit db803617db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 57 additions and 72 deletions

View file

@ -22,7 +22,6 @@ import (
"errors" "errors"
"slices" "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -108,6 +107,11 @@ func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string
notifs, err := n.state.Caches.GTS.Notification.LoadIDs("ID", notifs, err := n.state.Caches.GTS.Notification.LoadIDs("ID",
ids, ids,
func(uncached []string) ([]*gtsmodel.Notification, error) { func(uncached []string) ([]*gtsmodel.Notification, error) {
// Skip query if everything was cached.
if len(uncached) == 0 {
return nil, nil
}
// Preallocate expected length of uncached notifications. // Preallocate expected length of uncached notifications.
notifs := make([]*gtsmodel.Notification, 0, len(uncached)) notifs := make([]*gtsmodel.Notification, 0, len(uncached))
@ -282,26 +286,18 @@ func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.No
} }
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error {
defer n.state.Caches.GTS.Notification.Invalidate("ID", id) // Delete notif from DB.
if _, err := n.db.
// Load notif into cache before attempting a delete, NewDelete().
// as we need it cached in order to trigger the invalidate Table("notifications").
// callback. This in turn invalidates others. Where("? = ?", bun.Ident("id"), id).
_, err := n.GetNotificationByID(gtscontext.SetBarebones(ctx), id) Exec(ctx); err != nil {
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
err = nil
}
return err return err
} }
// Finally delete notif from DB. // Invalidate deleted notification by ID.
_, err = n.db.NewDelete(). n.state.Caches.GTS.Notification.Invalidate("ID", id)
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). return nil
Where("? = ?", bun.Ident("notification.id"), id).
Exec(ctx)
return err
} }
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error { func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error {
@ -309,11 +305,8 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
} }
var notifIDs []string
q := n.db. q := n.db.
NewSelect(). NewDelete().
Column("id").
Table("notifications") Table("notifications")
if len(types) > 0 { if len(types) > 0 {
@ -328,61 +321,33 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
q = q.Where("? = ?", bun.Ident("origin_account_id"), originAccountID) q = q.Where("? = ?", bun.Ident("origin_account_id"), originAccountID)
} }
if _, err := q.Exec(ctx, &notifIDs); err != nil { var notifIDs []string
q = q.Returning("?", bun.Ident("id"))
// Delete from DB.
if _, err := q.
Exec(ctx, &notifIDs); err != nil {
return err return err
} }
// Invalidate all cached notifications by IDs on return. // Invalidate all deleted notifications by IDs.
defer n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs) n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs)
return nil
// Load all notif into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range notifIDs {
_, err := n.GetNotificationByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
}
// Finally delete all from DB.
_, err := n.db.NewDelete().
Table("notifications").
Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)).
Exec(ctx)
return err
} }
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error { func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error {
var notifIDs []string var notifIDs []string
q := n.db. if _, err := n.db.
NewSelect(). NewDelete().
Column("id").
Table("notifications"). Table("notifications").
Where("? = ?", bun.Ident("status_id"), statusID) Where("? = ?", bun.Ident("status_id"), statusID).
Returning("?", bun.Ident("id")).
if _, err := q.Exec(ctx, &notifIDs); err != nil { Exec(ctx, &notifIDs); err != nil {
return err return err
} }
// Invalidate all cached notifications by IDs on return. // Invalidate all deleted notifications by IDs.
defer n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs) n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs)
return nil
// Load all notif into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range notifIDs {
_, err := n.GetNotificationByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
}
// Finally delete all from DB.
_, err := n.db.NewDelete().
Table("notifications").
Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)).
Exec(ctx)
return err
} }

View file

@ -73,7 +73,7 @@ func (suite *NotificationTestSuite) spamNotifs() {
Read: util.Ptr(false), Read: util.Ptr(false),
} }
if err := suite.db.Put(context.Background(), notif); err != nil { if err := suite.db.PutNotification(context.Background(), notif); err != nil {
panic(err) panic(err)
} }
} }
@ -133,9 +133,8 @@ func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
suite.spamNotifs() suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
// Test getting notifs first.
notifications, err := suite.db.GetAccountNotifications( notifications, err := suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()), gtscontext.SetBarebones(context.Background()),
testAccount.ID, testAccount.ID,
@ -145,8 +144,29 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
20, 20,
nil, nil,
) )
suite.NoError(err) if err != nil {
suite.Nil(notifications) suite.FailNow(err.Error())
}
suite.Len(notifications, 20)
// Now delete.
if err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, ""); err != nil {
suite.FailNow(err.Error())
}
// Now try getting again.
notifications, err = suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()),
testAccount.ID,
id.Highest,
id.Lowest,
"",
20,
nil,
)
if err != nil {
suite.FailNow(err.Error())
}
suite.Empty(notifications) suite.Empty(notifications)
} }