From 570fa7c3598118ded6df7ced0a5326f54e7a43e2 Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Sat, 20 Aug 2022 22:47:19 +0200 Subject: [PATCH] [bugfix] Fix potential dereference of accounts on own instance (#757) * add GetAccountByUsernameDomain * simplify search * add escape to not deref accounts on own domain * check if local + we have account by ap uri --- internal/cache/account.go | 15 +++ internal/cache/account_test.go | 4 + internal/db/account.go | 3 + internal/db/bundb/account.go | 20 ++++ internal/db/bundb/account_test.go | 12 +++ internal/federation/dereferencing/account.go | 96 ++++++++++------- .../federation/dereferencing/account_test.go | 102 +++++++++++++++++- internal/processing/search.go | 83 ++++++-------- 8 files changed, 243 insertions(+), 92 deletions(-) diff --git a/internal/cache/account.go b/internal/cache/account.go index ac67b5d07..1f958ebb8 100644 --- a/internal/cache/account.go +++ b/internal/cache/account.go @@ -37,6 +37,7 @@ func NewAccountCache() *AccountCache { RegisterLookups: func(lm *cache.LookupMap[string, string]) { lm.RegisterLookup("uri") lm.RegisterLookup("url") + lm.RegisterLookup("usernamedomain") }, AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { @@ -46,6 +47,7 @@ func NewAccountCache() *AccountCache { if url := acc.URL; url != "" { lm.Set("url", url, acc.ID) } + lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID) }, DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { @@ -55,6 +57,7 @@ func NewAccountCache() *AccountCache { if url := acc.URL; url != "" { lm.Delete("url", url) } + lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain)) }, }) c.cache.SetTTL(time.Minute*5, false) @@ -77,6 +80,10 @@ func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) { return c.cache.GetBy("uri", uri) } +func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) { + return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain)) +} + // Put places a account in the cache, ensuring that the object place is a copy for thread-safety func (c *AccountCache) Put(account *gtsmodel.Account) { if account == nil || account.ID == "" { @@ -135,3 +142,11 @@ func copyAccount(account *gtsmodel.Account) *gtsmodel.Account { SuspensionOrigin: account.SuspensionOrigin, } } + +func usernameDomainKey(username string, domain string) string { + u := "@" + username + if domain != "" { + return u + "@" + domain + } + return u +} diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go index ff882cc3d..a6d3c6b7d 100644 --- a/internal/cache/account_test.go +++ b/internal/cache/account_test.go @@ -69,6 +69,10 @@ func (suite *AccountCacheTestSuite) TestAccountCache() { if account.URL != "" && !ok && !accountIs(account, check) { suite.Fail("Failed to fetch expected account with URL: %s", account.URL) } + check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain) + if !ok && !accountIs(account, check) { + suite.Fail("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain) + } } } diff --git a/internal/db/account.go b/internal/db/account.go index 79e7c01a5..155bd666c 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -36,6 +36,9 @@ type Account interface { // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) + // GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong. + GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error) + // UpdateAccount updates one account by ID. UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 201de6f02..95c3d80d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. ) } +func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByUsernameDomain(username, domain) + }, + func(account *gtsmodel.Account) error { + q := a.newAccountQ(account).Where("account.username = ?", username) + + if domain != "" { + q = q.Where("account.domain = ?", domain) + } else { + q = q.Where("account.domain IS NULL") + } + + return q.Scan(ctx) + }, + ) +} + func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { // Attempt to fetch cached account account, cached := cacheGet() diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 59b51386d..3c19e84d9 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { suite.NotEmpty(account.HeaderMediaAttachment.URL) } +func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { + testAccount1 := suite.testAccounts["local_account_1"] + account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain) + suite.NoError(err) + suite.NotNil(account1) + + testAccount2 := suite.testAccounts["remote_account_1"] + account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain) + suite.NoError(err) + suite.NotNil(account2) +} + func (suite *AccountTestSuite) TestUpdateAccount() { testAccount := suite.testAccounts["local_account_1"] diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index a0e2b87ae..cbb9466ff 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -32,6 +32,7 @@ import ( "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" @@ -78,7 +79,10 @@ type GetRemoteAccountParams struct { // GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account, // puts or updates it in the database (if necessary), and returns it to a caller. -func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (remoteAccount *gtsmodel.Account, err error) { +// +// If a local account is passed into this function for whatever reason (hey, it happens!), then it +// will be returned from the database without making any remote calls. +func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (foundAccount *gtsmodel.Account, err error) { /* In this function we want to retrieve a gtsmodel representation of a remote account, with its proper accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth. @@ -99,23 +103,40 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar from that. */ - // first check if we can retrieve the account locally just with what we've been given + skipResolve := params.SkipResolve + + // this first step checks if we have the + // account in the database somewhere already switch { case params.RemoteAccountID != nil: - // try with uri - if a, dbErr := d.db.GetAccountByURI(ctx, params.RemoteAccountID.String()); dbErr == nil { - remoteAccount = a + uri := params.RemoteAccountID + host := uri.Host + if host == config.GetHost() || host == config.GetAccountDomain() { + // this is actually a local account, + // make sure we don't try to resolve + skipResolve = true + } + + if a, dbErr := d.db.GetAccountByURI(ctx, uri.String()); dbErr == nil { + foundAccount = a } else if dbErr != db.ErrNoEntries { - err = fmt.Errorf("GetRemoteAccount: database error looking for account %s: %s", params.RemoteAccountID, err) + err = fmt.Errorf("GetRemoteAccount: database error looking for account with uri %s: %s", uri, err) + } + case params.RemoteAccountUsername != "" && (params.RemoteAccountHost == "" || params.RemoteAccountHost == config.GetHost() || params.RemoteAccountHost == config.GetAccountDomain()): + // either no domain is provided or this seems + // to be a local account, so don't resolve + skipResolve = true + + if a, dbErr := d.db.GetLocalAccountByUsername(ctx, params.RemoteAccountUsername); dbErr == nil { + foundAccount = a + } else if dbErr != db.ErrNoEntries { + err = fmt.Errorf("GetRemoteAccount: database error looking for local account with username %s: %s", params.RemoteAccountUsername, err) } case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "": - // try with username/host - a := >smodel.Account{} - where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: params.RemoteAccountHost}} - if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil { - remoteAccount = a + if a, dbErr := d.db.GetAccountByUsernameDomain(ctx, params.RemoteAccountUsername, params.RemoteAccountHost); dbErr == nil { + foundAccount = a } else if dbErr != db.ErrNoEntries { - err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err) + err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and domain %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err) } default: err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account") @@ -125,10 +146,11 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar return } - if params.SkipResolve { - // if we can't resolve, return already since there's nothing more we can do - if remoteAccount == nil { - err = errors.New("GetRemoteAccount: error retrieving account with skipResolve set true") + if skipResolve { + // if we can't resolve, return already + // since there's nothing more we can do + if foundAccount == nil { + err = errors.New("GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") } return } @@ -141,8 +163,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar // ... but we still need the username so we can do a finger for the accountDomain // check if we had the account stored already and got it earlier - if remoteAccount != nil { - params.RemoteAccountUsername = remoteAccount.Username + if foundAccount != nil { + params.RemoteAccountUsername = foundAccount.Username } else { // if we didn't already have it, we have dereference it from remote and just... accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID) @@ -167,8 +189,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar // already about what the account domain might be; this var will be overwritten later if necessary var accountDomain string switch { - case remoteAccount != nil: - accountDomain = remoteAccount.Domain + case foundAccount != nil: + accountDomain = foundAccount.Domain case params.RemoteAccountID != nil: accountDomain = params.RemoteAccountID.Host default: @@ -178,7 +200,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar // to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't // fingered the remote account for at least 2 days; don't finger instance accounts var fingered time.Time - if remoteAccount == nil || (remoteAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(remoteAccount)) { + if foundAccount == nil || (foundAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(foundAccount)) { accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost) if err != nil { err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err) @@ -187,14 +209,14 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar fingered = time.Now() } - if !fingered.IsZero() && remoteAccount == nil { + if !fingered.IsZero() && foundAccount == nil { // if we just fingered and now have a discovered account domain but still no account, // we should do a final lookup in the database with the discovered username + accountDomain // to make absolutely sure we don't already have this account a := >smodel.Account{} where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}} if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil { - remoteAccount = a + foundAccount = a } else if dbErr != db.ErrNoEntries { err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err) return @@ -203,7 +225,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar // we may also have some extra information already, like the account we had in the db, or the // accountable representation that we dereferenced from remote - if remoteAccount == nil { + if foundAccount == nil { // we still don't have the account, so deference it if we didn't earlier if accountable == nil { accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID) @@ -214,7 +236,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar } // then convert - remoteAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false) + foundAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false) if err != nil { err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err) return @@ -227,18 +249,18 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err) return } - remoteAccount.ID = ulid + foundAccount.ID = ulid - _, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking) + _, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking) if err != nil { err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err) return } - remoteAccount.LastWebfingeredAt = fingered - remoteAccount.UpdatedAt = time.Now() + foundAccount.LastWebfingeredAt = fingered + foundAccount.UpdatedAt = time.Now() - err = d.db.Put(ctx, remoteAccount) + err = d.db.Put(ctx, foundAccount) if err != nil { err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err) return @@ -248,9 +270,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar } // we had the account already, but now we know the account domain, so update it if it's different - if !strings.EqualFold(remoteAccount.Domain, accountDomain) { - remoteAccount.Domain = accountDomain - remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount) + if !strings.EqualFold(foundAccount.Domain, accountDomain) { + foundAccount.Domain = accountDomain + foundAccount, err = d.db.UpdateAccount(ctx, foundAccount) if err != nil { err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err) return @@ -260,7 +282,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar // make sure the account fields are populated before returning: // the caller might want to block until everything is loaded var fieldsChanged bool - fieldsChanged, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking) + fieldsChanged, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking) if err != nil { return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err) } @@ -268,12 +290,12 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar var fingeredChanged bool if !fingered.IsZero() { fingeredChanged = true - remoteAccount.LastWebfingeredAt = fingered + foundAccount.LastWebfingeredAt = fingered } if fieldsChanged || fingeredChanged { - remoteAccount.UpdatedAt = time.Now() - remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount) + foundAccount.UpdatedAt = time.Now() + foundAccount, err = d.db.UpdateAccount(ctx, foundAccount) if err != nil { return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err) } diff --git a/internal/federation/dereferencing/account_test.go b/internal/federation/dereferencing/account_test.go index 72092951b..77ebb7cac 100644 --- a/internal/federation/dereferencing/account_test.go +++ b/internal/federation/dereferencing/account_test.go @@ -21,9 +21,11 @@ package dereferencing_test import ( "context" "testing" + "time" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -42,11 +44,11 @@ func (suite *AccountTestSuite) TestDereferenceGroup() { }) suite.NoError(err) suite.NotNil(group) - suite.NotNil(group) // group values should be set suite.Equal("https://unknown-instance.com/groups/some_group", group.URI) suite.Equal("https://unknown-instance.com/@some_group", group.URL) + suite.WithinDuration(time.Now(), group.LastWebfingeredAt, 5*time.Second) // group should be in the database dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI) @@ -65,11 +67,11 @@ func (suite *AccountTestSuite) TestDereferenceService() { }) suite.NoError(err) suite.NotNil(service) - suite.NotNil(service) // service values should be set suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI) suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL) + suite.WithinDuration(time.Now(), service.LastWebfingeredAt, 5*time.Second) // service should be in the database dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI) @@ -79,6 +81,102 @@ func (suite *AccountTestSuite) TestDereferenceService() { suite.Equal("example.org", dbService.Domain) } +/* + We shouldn't try webfingering or making http calls to dereference local accounts + that might be passed into GetRemoteAccount for whatever reason, so these tests are + here to make sure that such cases are (basically) short-circuit evaluated and given + back as-is without trying to make any calls to one's own instance. +*/ + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURL() { + fetchingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["local_account_2"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountID: testrig.URLMustParse(targetAccount.URI), + }) + suite.NoError(err) + suite.NotNil(fetchedAccount) + suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsername() { + fetchingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["local_account_2"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountUsername: targetAccount.Username, + }) + suite.NoError(err) + suite.NotNil(fetchedAccount) + suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomain() { + fetchingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["local_account_2"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountUsername: targetAccount.Username, + RemoteAccountHost: config.GetHost(), + }) + suite.NoError(err) + suite.NotNil(fetchedAccount) + suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomainAndURL() { + fetchingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["local_account_2"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountID: testrig.URLMustParse(targetAccount.URI), + RemoteAccountUsername: targetAccount.Username, + RemoteAccountHost: config.GetHost(), + }) + suite.NoError(err) + suite.NotNil(fetchedAccount) + suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername() { + fetchingAccount := suite.testAccounts["local_account_1"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountUsername: "thisaccountdoesnotexist", + }) + suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") + suite.Nil(fetchedAccount) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDomain() { + fetchingAccount := suite.testAccounts["local_account_1"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountUsername: "thisaccountdoesnotexist", + RemoteAccountHost: "localhost:8080", + }) + suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") + suite.Nil(fetchedAccount) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() { + fetchingAccount := suite.testAccounts["local_account_1"] + + fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ + RequestingUsername: fetchingAccount.Username, + RemoteAccountID: testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"), + }) + suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") + suite.Nil(fetchedAccount) +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/internal/processing/search.go b/internal/processing/search.go index d25bee2ae..b766b4ba2 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -39,7 +39,6 @@ import ( func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) { l := log.WithFields(kv.Fields{ - {"query", search.Query}, }...) @@ -62,7 +61,7 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a /* SEARCH BY MENTION - check if the query is something like @whatever_username@example.org -- this means it's a remote account + check if the query is something like @whatever_username@example.org -- this means it's likely a remote account */ maybeNamestring := query if maybeNamestring[0] != '@' { @@ -135,7 +134,6 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) { l := log.WithFields(kv.Fields{ - {"uri", uri.String()}, {"resolve", resolve}, }...) @@ -161,67 +159,46 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u } func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) { - if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { - return maybeAccount, nil - } else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil { + // it might be a web url like http://example.org/@user instead + // of an AP uri like http://example.org/users/user, check first + if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil { return maybeAccount, nil } - if resolve { - // we don't have it locally so try and dereference it - account, err := p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ - RequestingUsername: authed.Account.Username, - RemoteAccountID: uri, - }) - if err != nil { - return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err) + if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() { + // this is a local account; if we don't have it now then + // we should just bail instead of trying to get it remote + if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { + return maybeAccount, nil } - return account, nil + return nil, nil } - return nil, nil + + // we don't have it yet, try to find it remotely + return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ + RequestingUsername: authed.Account.Username, + RemoteAccountID: uri, + Blocking: true, + SkipResolve: !resolve, + }) } func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) { - maybeAcct := >smodel.Account{} - var err error - // if it's a local account we can skip a whole bunch of stuff if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" { - maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username) - if err != nil && err != db.ErrNoEntries { - return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err) + maybeAcct, err := p.db.GetLocalAccountByUsername(ctx, username) + if err == nil || err == db.ErrNoEntries { + return maybeAcct, nil } - return maybeAcct, nil + return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err) } - // it's not a local account so first we'll check if it's in the database already... - where := []db.Where{ - {Key: "username", Value: username, CaseInsensitive: true}, - {Key: "domain", Value: domain, CaseInsensitive: true}, - } - err = p.db.GetWhere(ctx, where, maybeAcct) - if err == nil { - // we've got it stored locally already! - return maybeAcct, nil - } - - if err != db.ErrNoEntries { - // if it's not errNoEntries there's been a real database error so bail at this point - return nil, fmt.Errorf("searchAccountByMention: database error: %s", err) - } - - // we got a db.ErrNoEntries, so we just don't have the account locally stored -- check if we can dereference it - if resolve { - maybeAcct, err = p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ - RequestingUsername: authed.Account.Username, - RemoteAccountUsername: username, - RemoteAccountHost: domain, - }) - if err != nil { - return nil, fmt.Errorf("searchAccountByMention: error getting remote account: %s", err) - } - return maybeAcct, nil - } - - return nil, nil + // we don't have it yet, try to find it remotely + return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ + RequestingUsername: authed.Account.Username, + RemoteAccountUsername: username, + RemoteAccountHost: domain, + Blocking: true, + SkipResolve: !resolve, + }) }