mirror of
https://github.com/writefreely/writefreely
synced 2024-11-24 09:33:11 +00:00
Merging T705-oauth into T710-oauth-slack. T705,T710
This commit is contained in:
parent
4266154749
commit
13121cb266
7 changed files with 144 additions and 93 deletions
|
@ -56,6 +56,25 @@ type (
|
|||
Port int `ini:"port"`
|
||||
}
|
||||
|
||||
OAuthCfg struct {
|
||||
Enabled bool `ini:"enable"`
|
||||
|
||||
// write.as
|
||||
WriteAsProviderAuthLocation string `ini:"wa_auth_location"`
|
||||
WriteAsProviderTokenLocation string `ini:"wa_token_location"`
|
||||
WriteAsProviderInspectLocation string `ini:"wa_inspect_location"`
|
||||
WriteAsClientCallbackLocation string `ini:"wa_callback_location"`
|
||||
WriteAsClientID string `ini:"wa_client_id"`
|
||||
WriteAsClientSecret string `ini:"wa_client_secret"`
|
||||
WriteAsAuthLocation string
|
||||
|
||||
// slack
|
||||
SlackClientID string `ini:"slack_client_id"`
|
||||
SlackClientSecret string `ini:"slack_client_secret"`
|
||||
SlackTeamID string `init:"slack_team_id"`
|
||||
SlackAuthLocation string
|
||||
}
|
||||
|
||||
// AppCfg holds values that affect how the application functions
|
||||
AppCfg struct {
|
||||
SiteName string `ini:"site_name"`
|
||||
|
@ -92,17 +111,10 @@ type (
|
|||
LocalTimeline bool `ini:"local_timeline"`
|
||||
UserInvites string `ini:"user_invites"`
|
||||
|
||||
// OAuth
|
||||
EnableOAuth bool `ini:"enable_oauth"`
|
||||
OAuthProviderAuthLocation string `ini:"oauth_auth_location"`
|
||||
OAuthProviderTokenLocation string `ini:"oauth_token_location"`
|
||||
OAuthProviderInspectLocation string `ini:"oauth_inspect_location"`
|
||||
OAuthClientCallbackLocation string `ini:"oauth_callback_location"`
|
||||
OAuthClientID string `ini:"oauth_client_id"`
|
||||
OAuthClientSecret string `ini:"oauth_client_secret"`
|
||||
|
||||
// Defaults
|
||||
DefaultVisibility string `ini:"default_visibility"`
|
||||
|
||||
OAuth OAuthCfg `ini:"oauth"`
|
||||
}
|
||||
|
||||
// Config holds the complete configuration for running a writefreely instance
|
||||
|
|
18
database.go
18
database.go
|
@ -125,10 +125,10 @@ type writestore interface {
|
|||
GetUserLastPostTime(id int64) (*time.Time, error)
|
||||
GetCollectionLastPostTime(id int64) (*time.Time, error)
|
||||
|
||||
GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error)
|
||||
RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error
|
||||
ValidateOAuthState(ctx context.Context, state string) error
|
||||
GenerateOAuthState(ctx context.Context) (string, error)
|
||||
GetIDForRemoteUser(context.Context, int64) (int64, error)
|
||||
RecordRemoteUserID(context.Context, int64, int64) error
|
||||
ValidateOAuthState(context.Context, string, string, string) error
|
||||
GenerateOAuthState(context.Context, string, string) (string, error)
|
||||
|
||||
DatabaseInitialized() bool
|
||||
}
|
||||
|
@ -138,6 +138,8 @@ type datastore struct {
|
|||
driverName string
|
||||
}
|
||||
|
||||
var _ writestore = &datastore{}
|
||||
|
||||
func (db *datastore) now() string {
|
||||
if db.driverName == driverSQLite {
|
||||
return "strftime('%Y-%m-%d %H:%M:%S','now')"
|
||||
|
@ -2459,17 +2461,17 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
|
|||
return &t, nil
|
||||
}
|
||||
|
||||
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
|
||||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
|
||||
state := store.Generate62RandomString(24)
|
||||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state)
|
||||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to record oauth client state: %w", err)
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
|
||||
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
|
||||
func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error {
|
||||
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ func TestOAuthDatastore(t *testing.T) {
|
|||
driverName: "",
|
||||
}
|
||||
|
||||
state, err := ds.GenerateOAuthState(ctx)
|
||||
state, err := ds.GenerateOAuthState(ctx, "", "")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, state, 24)
|
||||
|
||||
|
|
1
go.mod
1
go.mod
|
@ -19,6 +19,7 @@ require (
|
|||
github.com/guregu/null v3.4.0+incompatible
|
||||
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2
|
||||
github.com/jtolds/gls v4.2.1+incompatible // indirect
|
||||
github.com/kr/pretty v0.1.0
|
||||
github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec
|
||||
github.com/lunixbochs/vtclean v1.0.0 // indirect
|
||||
github.com/manifoldco/promptui v0.3.2
|
||||
|
|
1
go.sum
1
go.sum
|
@ -64,6 +64,7 @@ github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpR
|
|||
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU=
|
||||
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
|
|
181
oauth.go
181
oauth.go
|
@ -2,14 +2,17 @@ package writefreely
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/guregu/null/zero"
|
||||
"github.com/writeas/nerds/store"
|
||||
"github.com/writeas/web-core/auth"
|
||||
"github.com/writeas/web-core/log"
|
||||
"github.com/writeas/writefreely/config"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -55,11 +58,12 @@ type OAuthDatastoreProvider interface {
|
|||
// OAuthDatastore provides a minimal interface of data store methods used in
|
||||
// oauth functionality.
|
||||
type OAuthDatastore interface {
|
||||
GenerateOAuthState(context.Context) (string, error)
|
||||
ValidateOAuthState(context.Context, string) error
|
||||
GetIDForRemoteUser(context.Context, int64) (int64, error)
|
||||
CreateUser(*config.Config, *User, string) error
|
||||
RecordRemoteUserID(context.Context, int64, int64) error
|
||||
ValidateOAuthState(context.Context, string, string, string) error
|
||||
GenerateOAuthState(context.Context, string, string) (string, error)
|
||||
|
||||
CreateUser(*config.Config, *User, string) error
|
||||
GetUserForAuthByID(int64) (*User, error)
|
||||
}
|
||||
|
||||
|
@ -75,8 +79,8 @@ type oauthHandler struct {
|
|||
}
|
||||
|
||||
// buildAuthURL returns a URL used to initiate authentication.
|
||||
func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
|
||||
state, err := db.GenerateOAuthState(ctx)
|
||||
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) {
|
||||
state, err := db.GenerateOAuthState(ctx, provider, clientID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -95,9 +99,8 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation
|
|||
return u.String(), nil
|
||||
}
|
||||
|
||||
// app *App, w http.ResponseWriter, r *http.Request
|
||||
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
|
||||
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
|
||||
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) {
|
||||
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
|
||||
return
|
||||
|
@ -105,94 +108,128 @@ func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
code := r.FormValue("code")
|
||||
state := r.FormValue("state")
|
||||
|
||||
err := h.DB.ValidateOAuthState(ctx, state)
|
||||
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) {
|
||||
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
tokenResponse, err := h.exchangeOauthCode(ctx, code)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Now that we have the access token, let's use it real quick to make sur
|
||||
// it really really works.
|
||||
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("local user id", localUserID)
|
||||
|
||||
if localUserID == -1 {
|
||||
// We don't have, nor do we want, the password from the origin, so we
|
||||
//create a random string. If the user needs to set a password, they
|
||||
//can do so through the settings page or through the password reset
|
||||
//flow.
|
||||
randPass := store.Generate62RandomString(14)
|
||||
hashedPass, err := auth.HashPass([]byte(randPass))
|
||||
if err != nil {
|
||||
log.ErrorLog.Println(err)
|
||||
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
|
||||
return
|
||||
func (h oauthHandler) configureRoutes(r *mux.Router) {
|
||||
if h.Config.App.OAuth.Enabled {
|
||||
if h.Config.App.OAuth.WriteAsClientID != "" {
|
||||
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID)
|
||||
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash)
|
||||
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET")
|
||||
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET")
|
||||
}
|
||||
newUser := &User{
|
||||
Username: tokenInfo.Username,
|
||||
HashedPass: hashedPass,
|
||||
HasPass: true,
|
||||
Email: zero.NewString("", tokenInfo.Email != ""),
|
||||
Created: time.Now().Truncate(time.Second).UTC(),
|
||||
if h.Config.App.OAuth.SlackClientID != "" {
|
||||
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
|
||||
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
|
||||
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
|
||||
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
|
||||
}
|
||||
}
|
||||
|
||||
err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
|
||||
}
|
||||
|
||||
func oauthProviderHash(provider, clientID string) string {
|
||||
hasher := fnv.New32()
|
||||
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID)))
|
||||
}
|
||||
|
||||
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
code := r.FormValue("code")
|
||||
state := r.FormValue("state")
|
||||
|
||||
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
|
||||
tokenResponse, err := h.exchangeOauthCode(ctx, code)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
|
||||
// Now that we have the access token, let's use it real quick to make sur
|
||||
// it really really works.
|
||||
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("local user id", localUserID)
|
||||
|
||||
if localUserID == -1 {
|
||||
// We don't have, nor do we want, the password from the origin, so we
|
||||
//create a random string. If the user needs to set a password, they
|
||||
//can do so through the settings page or through the password reset
|
||||
//flow.
|
||||
randPass := store.Generate62RandomString(14)
|
||||
hashedPass, err := auth.HashPass([]byte(randPass))
|
||||
if err != nil {
|
||||
log.ErrorLog.Println(err)
|
||||
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
|
||||
return
|
||||
}
|
||||
newUser := &User{
|
||||
Username: tokenInfo.Username,
|
||||
HashedPass: hashedPass,
|
||||
HasPass: true,
|
||||
Email: zero.NewString("", tokenInfo.Email != ""),
|
||||
Created: time.Now().Truncate(time.Second).UTC(),
|
||||
}
|
||||
|
||||
err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.DB.GetUserForAuthByID(localUserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err = loginOrFail(h.Store, w, r, user); err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.DB.GetUserForAuthByID(localUserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err = loginOrFail(h.Store, w, r, user); err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
|
||||
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation)
|
||||
form.Add("code", code)
|
||||
req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
|
||||
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -200,7 +237,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
|
|||
req.Header.Set("User-Agent", "writefreely")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
|
||||
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret)
|
||||
|
||||
resp, err := h.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
|
@ -224,7 +261,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
|
|||
}
|
||||
|
||||
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
|
||||
req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
|
||||
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -86,9 +86,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
|
|||
DB: apper.App().DB(),
|
||||
Store: apper.App().SessionStore(),
|
||||
}
|
||||
|
||||
write.HandleFunc("/oauth/write.as", oauthHandler.viewOauthInit).Methods("GET")
|
||||
write.HandleFunc("/oauth/callback", oauthHandler.viewOauthCallback).Methods("GET")
|
||||
oauthHandler.configureRoutes(write)
|
||||
|
||||
// Handle logged in user sections
|
||||
me := write.PathPrefix("/me").Subrouter()
|
||||
|
|
Loading…
Reference in a new issue