mirror of
https://github.com/writefreely/writefreely
synced 2024-09-20 13:41:54 +00:00
ca4a576c31
This adds any OAuth login buttons to the invite signup page, stores the invite code for the flow duration, and associates the new user with it once successfully registered. It enables invite-only instances with OAuth-based registration.
253 lines
8.2 KiB
Go
253 lines
8.2 KiB
Go
package writefreely
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/writeas/impart"
|
|
"github.com/writeas/nerds/store"
|
|
"github.com/writeas/writefreely/config"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
type MockOAuthDatastoreProvider struct {
|
|
DoDB func() OAuthDatastore
|
|
DoConfig func() *config.Config
|
|
DoSessionStore func() sessions.Store
|
|
}
|
|
|
|
type MockOAuthDatastore struct {
|
|
DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
|
|
DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
|
|
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
|
|
DoCreateUser func(*config.Config, *User, string) error
|
|
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
|
|
DoGetUserByID func(int64) (*User, error)
|
|
}
|
|
|
|
var _ OAuthDatastore = &MockOAuthDatastore{}
|
|
|
|
type StringReadCloser struct {
|
|
*strings.Reader
|
|
}
|
|
|
|
func (src *StringReadCloser) Close() error {
|
|
return nil
|
|
}
|
|
|
|
type MockHTTPClient struct {
|
|
DoDo func(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|
if m.DoDo != nil {
|
|
return m.DoDo(req)
|
|
}
|
|
return &http.Response{}, nil
|
|
}
|
|
|
|
func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
|
|
if m.DoSessionStore != nil {
|
|
return m.DoSessionStore()
|
|
}
|
|
return sessions.NewCookieStore([]byte("secret-key"))
|
|
}
|
|
|
|
func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
|
|
if m.DoDB != nil {
|
|
return m.DoDB()
|
|
}
|
|
return &MockOAuthDatastore{}
|
|
}
|
|
|
|
func (m *MockOAuthDatastoreProvider) Config() *config.Config {
|
|
if m.DoConfig != nil {
|
|
return m.DoConfig()
|
|
}
|
|
cfg := config.New()
|
|
cfg.UseSQLite(true)
|
|
cfg.WriteAsOauth = config.WriteAsOauthCfg{
|
|
ClientID: "development",
|
|
ClientSecret: "development",
|
|
AuthLocation: "https://write.as/oauth/login",
|
|
TokenLocation: "https://write.as/oauth/token",
|
|
InspectLocation: "https://write.as/oauth/inspect",
|
|
}
|
|
cfg.SlackOauth = config.SlackOauthCfg{
|
|
ClientID: "development",
|
|
ClientSecret: "development",
|
|
TeamID: "development",
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
|
|
if m.DoValidateOAuthState != nil {
|
|
return m.DoValidateOAuthState(ctx, state)
|
|
}
|
|
return "", "", 0, "", nil
|
|
}
|
|
|
|
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
|
|
if m.DoGetIDForRemoteUser != nil {
|
|
return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
|
|
}
|
|
return -1, nil
|
|
}
|
|
|
|
func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error {
|
|
if m.DoCreateUser != nil {
|
|
return m.DoCreateUser(cfg, u, username)
|
|
}
|
|
u.ID = 1
|
|
return nil
|
|
}
|
|
|
|
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
|
|
if m.DoRecordRemoteUserID != nil {
|
|
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
|
|
if m.DoGetUserByID != nil {
|
|
return m.DoGetUserByID(userID)
|
|
}
|
|
user := &User{
|
|
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
|
|
if m.DoGenerateOAuthState != nil {
|
|
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
|
|
}
|
|
return store.Generate62RandomString(14), nil
|
|
}
|
|
|
|
func TestViewOauthInit(t *testing.T) {
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
app := &MockOAuthDatastoreProvider{}
|
|
h := oauthHandler{
|
|
Config: app.Config(),
|
|
DB: app.DB(),
|
|
Store: app.SessionStore(),
|
|
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
|
|
oauthClient: writeAsOauthClient{
|
|
ClientID: app.Config().WriteAsOauth.ClientID,
|
|
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
|
|
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
|
|
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
|
|
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
|
|
CallbackLocation: "http://localhost/oauth/callback",
|
|
HttpClient: nil,
|
|
},
|
|
}
|
|
req, err := http.NewRequest("GET", "/oauth/client", nil)
|
|
assert.NoError(t, err)
|
|
rr := httptest.NewRecorder()
|
|
err = h.viewOauthInit(nil, rr, req)
|
|
assert.NotNil(t, err)
|
|
httpErr, ok := err.(impart.HTTPError)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
|
|
assert.NotEmpty(t, httpErr.Message)
|
|
locURI, err := url.Parse(httpErr.Message)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "/oauth/login", locURI.Path)
|
|
assert.Equal(t, "development", locURI.Query().Get("client_id"))
|
|
assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
|
|
assert.Equal(t, "code", locURI.Query().Get("response_type"))
|
|
assert.NotEmpty(t, locURI.Query().Get("state"))
|
|
})
|
|
|
|
t.Run("state failure", func(t *testing.T) {
|
|
app := &MockOAuthDatastoreProvider{
|
|
DoDB: func() OAuthDatastore {
|
|
return &MockOAuthDatastore{
|
|
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
|
|
return "", fmt.Errorf("pretend unable to write state error")
|
|
},
|
|
}
|
|
},
|
|
}
|
|
h := oauthHandler{
|
|
Config: app.Config(),
|
|
DB: app.DB(),
|
|
Store: app.SessionStore(),
|
|
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
|
|
oauthClient: writeAsOauthClient{
|
|
ClientID: app.Config().WriteAsOauth.ClientID,
|
|
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
|
|
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
|
|
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
|
|
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
|
|
CallbackLocation: "http://localhost/oauth/callback",
|
|
HttpClient: nil,
|
|
},
|
|
}
|
|
req, err := http.NewRequest("GET", "/oauth/client", nil)
|
|
assert.NoError(t, err)
|
|
rr := httptest.NewRecorder()
|
|
err = h.viewOauthInit(nil, rr, req)
|
|
httpErr, ok := err.(impart.HTTPError)
|
|
assert.True(t, ok)
|
|
assert.NotEmpty(t, httpErr.Message)
|
|
assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
|
|
assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
|
|
})
|
|
}
|
|
|
|
func TestViewOauthCallback(t *testing.T) {
|
|
t.Run("success", func(t *testing.T) {
|
|
app := &MockOAuthDatastoreProvider{}
|
|
h := oauthHandler{
|
|
Config: app.Config(),
|
|
DB: app.DB(),
|
|
Store: app.SessionStore(),
|
|
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
|
|
oauthClient: writeAsOauthClient{
|
|
ClientID: app.Config().WriteAsOauth.ClientID,
|
|
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
|
|
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
|
|
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
|
|
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
|
|
CallbackLocation: "http://localhost/oauth/callback",
|
|
HttpClient: &MockHTTPClient{
|
|
DoDo: func(req *http.Request) (*http.Response, error) {
|
|
switch req.URL.String() {
|
|
case "https://write.as/oauth/token":
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
|
|
}, nil
|
|
case "https://write.as/oauth/inspect":
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
|
|
}, nil
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusNotFound,
|
|
}, nil
|
|
},
|
|
},
|
|
},
|
|
}
|
|
req, err := http.NewRequest("GET", "/oauth/callback", nil)
|
|
assert.NoError(t, err)
|
|
rr := httptest.NewRecorder()
|
|
err = h.viewOauthCallback(nil, rr, req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
|
|
})
|
|
}
|