From dcedd1b6bf1e890ff425bdf1fcd8a2e0850778b5 Mon Sep 17 00:00:00 2001
From: devonh <devon.dmytro@gmail.com>
Date: Thu, 13 Oct 2022 14:38:13 +0000
Subject: [PATCH] Federation backoff fixes and tests (#2792)

This fixes some edge cases where federation queue backoffs and
blacklisting weren't behaving as expected.
It also adds new tests for the federation queues to ensure their
behaviour continues to work correctly.
---
 federationapi/queue/destinationqueue.go |  10 +
 federationapi/queue/queue.go            |   1 +
 federationapi/queue/queue_test.go       | 422 ++++++++++++++++++++++++
 federationapi/statistics/statistics.go  |   8 +-
 go.mod                                  |   2 +-
 5 files changed, 441 insertions(+), 2 deletions(-)
 create mode 100644 federationapi/queue/queue_test.go

diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go
index 5cb8cae1f..4ae554ef3 100644
--- a/federationapi/queue/destinationqueue.go
+++ b/federationapi/queue/destinationqueue.go
@@ -75,6 +75,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
 		logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
 		return
 	}
+
 	// Create a database entry that associates the given PDU NID with
 	// this destination queue. We'll then be able to retrieve the PDU
 	// later.
@@ -108,6 +109,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
 		case oq.notify <- struct{}{}:
 		default:
 		}
+	} else {
+		oq.overflowed.Store(true)
 	}
 }
 
@@ -153,6 +156,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
 		case oq.notify <- struct{}{}:
 		default:
 		}
+	} else {
+		oq.overflowed.Store(true)
 	}
 }
 
@@ -335,6 +340,11 @@ func (oq *destinationQueue) backgroundSend() {
 			// We failed to send the transaction. Mark it as a failure.
 			oq.statistics.Failure()
 
+			// Queue up another attempt since the transaction failed.
+			select {
+			case oq.notify <- struct{}{}:
+			default:
+			}
 		} else if transaction {
 			// If we successfully sent the transaction then clear out
 			// the pending events and EDUs, and wipe our transaction ID.
diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go
index 8245aa5bd..5d352eca6 100644
--- a/federationapi/queue/queue.go
+++ b/federationapi/queue/queue.go
@@ -332,6 +332,7 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
 	if oqs.disabled {
 		return
 	}
+	oqs.statistics.ForServer(srv).RemoveBlacklist()
 	if queue := oqs.getQueue(srv); queue != nil {
 		queue.wakeQueueIfNeeded()
 	}
diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go
new file mode 100644
index 000000000..8e4a675f4
--- /dev/null
+++ b/federationapi/queue/queue_test.go
@@ -0,0 +1,422 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package queue
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"sync"
+	"testing"
+	"time"
+
+	"go.uber.org/atomic"
+	"gotest.tools/v3/poll"
+
+	"github.com/matrix-org/dendrite/federationapi/api"
+	"github.com/matrix-org/dendrite/federationapi/statistics"
+	"github.com/matrix-org/dendrite/federationapi/storage"
+	"github.com/matrix-org/dendrite/federationapi/storage/shared"
+	rsapi "github.com/matrix-org/dendrite/roomserver/api"
+	"github.com/matrix-org/dendrite/setup/process"
+	"github.com/matrix-org/dendrite/test"
+	"github.com/matrix-org/gomatrixserverlib"
+	"github.com/pkg/errors"
+	"github.com/stretchr/testify/assert"
+)
+
+var dbMutex sync.Mutex
+
+type fakeDatabase struct {
+	storage.Database
+	pendingPDUServers  map[gomatrixserverlib.ServerName]struct{}
+	pendingEDUServers  map[gomatrixserverlib.ServerName]struct{}
+	blacklistedServers map[gomatrixserverlib.ServerName]struct{}
+	pendingPDUs        map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent
+	pendingEDUs        map[*shared.Receipt]*gomatrixserverlib.EDU
+	associatedPDUs     map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
+	associatedEDUs     map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
+}
+
+func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	var event gomatrixserverlib.HeaderedEvent
+	if err := json.Unmarshal([]byte(js), &event); err == nil {
+		receipt := &shared.Receipt{}
+		d.pendingPDUs[receipt] = &event
+		return receipt, nil
+	}
+
+	var edu gomatrixserverlib.EDU
+	if err := json.Unmarshal([]byte(js), &edu); err == nil {
+		receipt := &shared.Receipt{}
+		d.pendingEDUs[receipt] = &edu
+		return receipt, nil
+	}
+
+	return nil, errors.New("Failed to determine type of json to store")
+}
+
+func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent)
+	if receipts, ok := d.associatedPDUs[serverName]; ok {
+		for receipt := range receipts {
+			if event, ok := d.pendingPDUs[receipt]; ok {
+				pdus[receipt] = event
+			}
+		}
+	}
+	return pdus, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU)
+	if receipts, ok := d.associatedEDUs[serverName]; ok {
+		for receipt := range receipts {
+			if event, ok := d.pendingEDUs[receipt]; ok {
+				edus[receipt] = event
+			}
+		}
+	}
+	return edus, nil
+}
+
+func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	if _, ok := d.pendingPDUs[receipt]; ok {
+		if _, ok := d.associatedPDUs[serverName]; !ok {
+			d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{})
+		}
+		d.associatedPDUs[serverName][receipt] = struct{}{}
+		return nil
+	} else {
+		return errors.New("PDU doesn't exist")
+	}
+}
+
+func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	if _, ok := d.pendingEDUs[receipt]; ok {
+		if _, ok := d.associatedEDUs[serverName]; !ok {
+			d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{})
+		}
+		d.associatedEDUs[serverName][receipt] = struct{}{}
+		return nil
+	} else {
+		return errors.New("EDU doesn't exist")
+	}
+}
+
+func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	if pdus, ok := d.associatedPDUs[serverName]; ok {
+		for _, receipt := range receipts {
+			delete(pdus, receipt)
+		}
+	}
+
+	return nil
+}
+
+func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	if edus, ok := d.associatedEDUs[serverName]; ok {
+		for _, receipt := range receipts {
+			delete(edus, receipt)
+		}
+	}
+
+	return nil
+}
+
+func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	var count int64
+	if pdus, ok := d.associatedPDUs[serverName]; ok {
+		count = int64(len(pdus))
+	}
+	return count, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	var count int64
+	if edus, ok := d.associatedEDUs[serverName]; ok {
+		count = int64(len(edus))
+	}
+	return count, nil
+}
+
+func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	servers := []gomatrixserverlib.ServerName{}
+	for server := range d.pendingPDUServers {
+		servers = append(servers, server)
+	}
+	return servers, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	servers := []gomatrixserverlib.ServerName{}
+	for server := range d.pendingEDUServers {
+		servers = append(servers, server)
+	}
+	return servers, nil
+}
+
+func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	d.blacklistedServers[serverName] = struct{}{}
+	return nil
+}
+
+func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	delete(d.blacklistedServers, serverName)
+	return nil
+}
+
+func (d *fakeDatabase) RemoveAllServersFromBlacklist() error {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
+	return nil
+}
+
+func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
+	dbMutex.Lock()
+	defer dbMutex.Unlock()
+
+	isBlacklisted := false
+	if _, ok := d.blacklistedServers[serverName]; ok {
+		isBlacklisted = true
+	}
+
+	return isBlacklisted, nil
+}
+
+type stubFederationRoomServerAPI struct {
+	rsapi.FederationRoomserverAPI
+}
+
+func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error {
+	res.Banned = false
+	return nil
+}
+
+type stubFederationClient struct {
+	api.FederationClient
+	shouldTxSucceed bool
+	txCount         atomic.Uint32
+}
+
+func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
+	var result error
+	if !f.shouldTxSucceed {
+		result = fmt.Errorf("transaction failed")
+	}
+
+	f.txCount.Add(1)
+	return gomatrixserverlib.RespSend{}, result
+}
+
+func createDatabase() storage.Database {
+	return &fakeDatabase{
+		pendingPDUServers:  make(map[gomatrixserverlib.ServerName]struct{}),
+		pendingEDUServers:  make(map[gomatrixserverlib.ServerName]struct{}),
+		blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
+		pendingPDUs:        make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent),
+		pendingEDUs:        make(map[*shared.Receipt]*gomatrixserverlib.EDU),
+		associatedPDUs:     make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
+	}
+}
+
+func mustCreateEvent(t *testing.T) *gomatrixserverlib.HeaderedEvent {
+	t.Helper()
+	content := `{"type":"m.room.message"}`
+	ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10)
+	if err != nil {
+		t.Fatalf("failed to create event: %v", err)
+	}
+	return ev.Headered(gomatrixserverlib.RoomVersionV10)
+}
+
+func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool) (storage.Database, *stubFederationClient, *OutgoingQueues) {
+	db := createDatabase()
+
+	fc := &stubFederationClient{
+		shouldTxSucceed: shouldTxSucceed,
+		txCount:         *atomic.NewUint32(0),
+	}
+	rs := &stubFederationRoomServerAPI{}
+	stats := &statistics.Statistics{
+		DB:                     db,
+		FailuresUntilBlacklist: failuresUntilBlacklist,
+	}
+	signingInfo := &SigningInfo{
+		KeyID:      "ed25519:auto",
+		PrivateKey: test.PrivateKeyA,
+		ServerName: "localhost",
+	}
+	queues := NewOutgoingQueues(db, process.NewProcessContext(), false, "localhost", fc, rs, stats, signingInfo)
+
+	return db, fc, queues
+}
+
+func TestSendTransactionOnSuccessRemovedFromDB(t *testing.T) {
+	ctx := context.Background()
+	failuresUntilBlacklist := uint32(16)
+	destination := gomatrixserverlib.ServerName("remotehost")
+	db, fc, queues := testSetup(failuresUntilBlacklist, true)
+
+	ev := mustCreateEvent(t)
+	err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+	assert.NoError(t, err)
+
+	check := func(log poll.LogT) poll.Result {
+		if fc.txCount.Load() >= 1 {
+			data, err := db.GetPendingPDUs(ctx, destination, 100)
+			assert.NoError(t, err)
+			if len(data) == 0 {
+				return poll.Success()
+			}
+			return poll.Continue("waiting for event to be removed from database")
+		}
+		return poll.Continue("waiting for more send attempts before checking database")
+	}
+	poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendTransactionOnFailStoredInDB(t *testing.T) {
+	ctx := context.Background()
+	failuresUntilBlacklist := uint32(16)
+	destination := gomatrixserverlib.ServerName("remotehost")
+	db, fc, queues := testSetup(failuresUntilBlacklist, false)
+
+	ev := mustCreateEvent(t)
+	err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+	assert.NoError(t, err)
+
+	check := func(log poll.LogT) poll.Result {
+		// Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+		if fc.txCount.Load() >= 2 {
+			data, err := db.GetPendingPDUs(ctx, destination, 100)
+			assert.NoError(t, err)
+			if len(data) == 1 {
+				return poll.Success()
+			}
+			return poll.Continue("waiting for event to be added to database")
+		}
+		return poll.Continue("waiting for more send attempts before checking database")
+	}
+	poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendTransactionMultipleFailuresBlacklisted(t *testing.T) {
+	ctx := context.Background()
+	failuresUntilBlacklist := uint32(2)
+	destination := gomatrixserverlib.ServerName("remotehost")
+	db, fc, queues := testSetup(failuresUntilBlacklist, false)
+
+	ev := mustCreateEvent(t)
+	err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+	assert.NoError(t, err)
+
+	check := func(log poll.LogT) poll.Result {
+		if fc.txCount.Load() >= failuresUntilBlacklist {
+			data, err := db.GetPendingPDUs(ctx, destination, 100)
+			assert.NoError(t, err)
+			if len(data) == 1 {
+				if val, _ := db.IsServerBlacklisted(destination); val {
+					return poll.Success()
+				}
+				return poll.Continue("waiting for server to be blacklisted")
+			}
+			return poll.Continue("waiting for event to be added to database")
+		}
+		return poll.Continue("waiting for more send attempts before checking database")
+	}
+	poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestRetryServerSendsSuccessfully(t *testing.T) {
+	ctx := context.Background()
+	failuresUntilBlacklist := uint32(1)
+	destination := gomatrixserverlib.ServerName("remotehost")
+	db, fc, queues := testSetup(failuresUntilBlacklist, false)
+
+	ev := mustCreateEvent(t)
+	err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+	assert.NoError(t, err)
+
+	checkBlacklisted := func(log poll.LogT) poll.Result {
+		if fc.txCount.Load() >= failuresUntilBlacklist {
+			data, err := db.GetPendingPDUs(ctx, destination, 100)
+			assert.NoError(t, err)
+			if len(data) == 1 {
+				if val, _ := db.IsServerBlacklisted(destination); val {
+					return poll.Success()
+				}
+				return poll.Continue("waiting for server to be blacklisted")
+			}
+			return poll.Continue("waiting for event to be added to database")
+		}
+		return poll.Continue("waiting for more send attempts before checking database")
+	}
+	poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+	fc.shouldTxSucceed = true
+	db.RemoveServerFromBlacklist(destination)
+	queues.RetryServer(destination)
+	checkRetry := func(log poll.LogT) poll.Result {
+		data, err := db.GetPendingPDUs(ctx, destination, 100)
+		assert.NoError(t, err)
+		if len(data) == 0 {
+			return poll.Success()
+		}
+		return poll.Continue("waiting for event to be removed from database")
+	}
+	poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go
index db6d5c735..61a965791 100644
--- a/federationapi/statistics/statistics.go
+++ b/federationapi/statistics/statistics.go
@@ -95,8 +95,8 @@ func (s *ServerStatistics) cancel() {
 // we will unblacklist it.
 func (s *ServerStatistics) Success() {
 	s.cancel()
-	s.successCounter.Inc()
 	s.backoffCount.Store(0)
+	s.successCounter.Inc()
 	if s.statistics.DB != nil {
 		if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
 			logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@@ -174,6 +174,12 @@ func (s *ServerStatistics) Blacklisted() bool {
 	return s.blacklisted.Load()
 }
 
+// RemoveBlacklist removes the blacklisted status from the server.
+func (s *ServerStatistics) RemoveBlacklist() {
+	s.cancel()
+	s.backoffCount.Store(0)
+}
+
 // SuccessCount returns the number of successful requests. This is
 // usually useful in constructing transaction IDs.
 func (s *ServerStatistics) SuccessCount() uint32 {
diff --git a/go.mod b/go.mod
index eefad89e6..eeae9608f 100644
--- a/go.mod
+++ b/go.mod
@@ -50,6 +50,7 @@ require (
 	golang.org/x/term v0.0.0-20220919170432-7a66f970e087
 	gopkg.in/h2non/bimg.v1 v1.1.9
 	gopkg.in/yaml.v2 v2.4.0
+	gotest.tools/v3 v3.0.3
 	nhooyr.io/websocket v1.8.7
 )
 
@@ -128,7 +129,6 @@ require (
 	gopkg.in/macaroon.v2 v2.1.0 // indirect
 	gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
-	gotest.tools/v3 v3.0.3 // indirect
 )
 
 go 1.18