mirror of
https://github.com/writefreely/writefreely
synced 2024-11-10 11:24:13 +00:00
Pass OAuth requests through new OAuth handler
This gives us our standard logging and passes around errors with impart.HTTPError. Ref T705
This commit is contained in:
parent
4266154749
commit
af23e28d05
5 changed files with 72 additions and 40 deletions
2
go.mod
2
go.mod
|
@ -38,7 +38,7 @@ require (
|
|||
github.com/writeas/go-strip-markdown v2.0.1+incompatible
|
||||
github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2
|
||||
github.com/writeas/httpsig v1.0.0
|
||||
github.com/writeas/impart v1.1.0
|
||||
github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d
|
||||
github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219
|
||||
github.com/writeas/nerds v1.0.0
|
||||
github.com/writeas/saturday v1.7.1
|
||||
|
|
2
go.sum
2
go.sum
|
@ -122,6 +122,8 @@ github.com/writeas/httpsig v1.0.0 h1:peIAoIA3DmlP8IG8tMNZqI4YD1uEnWBmkcC9OFPjt3A
|
|||
github.com/writeas/httpsig v1.0.0/go.mod h1:7ClMGSrSVXJbmiLa17bZ1LrG1oibGZmUMlh3402flPY=
|
||||
github.com/writeas/impart v1.1.0 h1:nPnoO211VscNkp/gnzir5UwCDEvdHThL5uELU60NFSE=
|
||||
github.com/writeas/impart v1.1.0/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y=
|
||||
github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d h1:PK7DOj3JE6MGf647esPrKzXEHFjGWX2hl22uX79ixaE=
|
||||
github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y=
|
||||
github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 h1:baEp0631C8sT2r/hqwypIw2snCFZa6h7U6TojoLHu/c=
|
||||
github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219/go.mod h1:NyM35ayknT7lzO6O/1JpfgGyv+0W9Z9q7aE0J8bXxfQ=
|
||||
github.com/writeas/nerds v1.0.0 h1:ZzRcCN+Sr3MWID7o/x1cr1ZbLvdpej9Y1/Ho+JKlqxo=
|
||||
|
|
50
handle.go
50
handle.go
|
@ -549,6 +549,37 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handler) OAuth(f handlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleOAuthError(w, r, func() error {
|
||||
// TODO: return correct "success" status
|
||||
status := 200
|
||||
start := time.Now()
|
||||
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Error("%s:\n%s", e, debug.Stack())
|
||||
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."})
|
||||
status = 500
|
||||
}
|
||||
|
||||
log.Info(h.app.ReqLog(r, status, time.Since(start)))
|
||||
}()
|
||||
|
||||
err := f(h.app.App(), w, r)
|
||||
if err != nil {
|
||||
if err, ok := err.(impart.HTTPError); ok {
|
||||
status = err.Status
|
||||
} else {
|
||||
status = 500
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleError(w, r, func() error {
|
||||
|
@ -779,6 +810,25 @@ func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error)
|
|||
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
|
||||
}
|
||||
|
||||
func (h *Handler) handleOAuthError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err, ok := err.(impart.HTTPError); ok {
|
||||
if err.Status >= 300 && err.Status < 400 {
|
||||
sendRedirect(w, err.Status, err.Message)
|
||||
return
|
||||
}
|
||||
|
||||
impart.WriteOAuthError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
impart.WriteOAuthError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."})
|
||||
return
|
||||
}
|
||||
|
||||
func correctPageFromLoginAttempt(r *http.Request) string {
|
||||
to := r.FormValue("to")
|
||||
if to == "" {
|
||||
|
|
54
oauth.go
54
oauth.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/guregu/null/zero"
|
||||
"github.com/writeas/impart"
|
||||
"github.com/writeas/nerds/store"
|
||||
"github.com/writeas/web-core/auth"
|
||||
"github.com/writeas/web-core/log"
|
||||
|
@ -96,16 +97,15 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation
|
|||
}
|
||||
|
||||
// app *App, w http.ResponseWriter, r *http.Request
|
||||
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
|
||||
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
|
||||
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
|
||||
}
|
||||
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
|
||||
return impart.HTTPError{http.StatusTemporaryRedirect, location}
|
||||
}
|
||||
|
||||
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
code := r.FormValue("code")
|
||||
|
@ -113,28 +113,24 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
err := h.DB.ValidateOAuthState(ctx, state)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
|
||||
tokenResponse, err := h.exchangeOauthCode(ctx, code)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
|
||||
// Now that we have the access token, let's use it real quick to make sur
|
||||
// it really really works.
|
||||
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
|
||||
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
|
||||
fmt.Println("local user id", localUserID)
|
||||
|
@ -148,8 +144,7 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request)
|
|||
hashedPass, err := auth.HashPass([]byte(randPass))
|
||||
if err != nil {
|
||||
log.ErrorLog.Println(err)
|
||||
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"}
|
||||
}
|
||||
newUser := &User{
|
||||
Username: tokenInfo.Username,
|
||||
|
@ -161,30 +156,28 @@ func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
|
||||
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
|
||||
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := h.DB.GetUserForAuthByID(localUserID)
|
||||
if err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
if err = loginOrFail(h.Store, w, r, user); err != nil {
|
||||
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
|
||||
|
@ -265,16 +258,3 @@ func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, u
|
|||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
return nil
|
||||
}
|
||||
|
||||
// failOAuthRequest is an HTTP handler helper that formats returned error
|
||||
// messages.
|
||||
func failOAuthRequest(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
err := json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": message,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,8 +87,8 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
|
|||
Store: apper.App().SessionStore(),
|
||||
}
|
||||
|
||||
write.HandleFunc("/oauth/write.as", oauthHandler.viewOauthInit).Methods("GET")
|
||||
write.HandleFunc("/oauth/callback", oauthHandler.viewOauthCallback).Methods("GET")
|
||||
write.HandleFunc("/oauth/write.as", handler.OAuth(oauthHandler.viewOauthInit)).Methods("GET")
|
||||
write.HandleFunc("/oauth/callback", handler.OAuth(oauthHandler.viewOauthCallback)).Methods("GET")
|
||||
|
||||
// Handle logged in user sections
|
||||
me := write.PathPrefix("/me").Subrouter()
|
||||
|
|
Loading…
Reference in a new issue