diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index af2e99ed0..7d84f2494 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" @@ -525,22 +526,37 @@ func Register( userAPI userapi.UserRegisterAPI, cfg *config.ClientAPI, ) util.JSONResponse { + defer req.Body.Close() // nolint: errcheck + reqBody, err := ioutil.ReadAll(req.Body) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("Unable to read request body"), + } + } + var r registerRequest - resErr := httputil.UnmarshalJSONRequest(req, &r) - if resErr != nil { + sessionID := gjson.GetBytes(reqBody, "auth.session").String() + if sessionID == "" { + // Generate a new, random session ID + sessionID = util.RandomString(sessionIDLength) + } else if data, ok := sessions.getParams(sessionID); ok { + // Use the parameters from the session as our defaults. + // Some of these might end up being overwritten if the + // values are specified again in the request body. + r.Username = data.Username + r.Password = data.Password + r.DeviceID = data.DeviceID + r.InitialDisplayName = data.InitialDisplayName + r.InhibitLogin = data.InhibitLogin + } + if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } - // Retrieve or generate the sessionID - sessionID := r.Auth.Session - if sessionID == "" { - // Generate a new, random session ID - sessionID = util.RandomString(sessionIDLength) - } - // Don't allow numeric usernames less than MAX_INT64. if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ @@ -568,7 +584,7 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr = validateApplicationServiceUsername(r.Username); resErr != nil { + if resErr := validateApplicationServiceUsername(r.Username); resErr != nil { return *resErr } case accessTokenErr == nil: @@ -581,11 +597,11 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr = validateUsername(r.Username); resErr != nil { + if resErr := validateUsername(r.Username); resErr != nil { return *resErr } } - if resErr = validatePassword(r.Password); resErr != nil { + if resErr := validatePassword(r.Password); resErr != nil { return *resErr } @@ -835,24 +851,17 @@ func completeRegistration( } }() - if data, ok := sessions.getParams(sessionID); ok { - username = data.Username - password = data.Password - deviceID = data.DeviceID - displayName = data.InitialDisplayName - inhibitLogin = data.InhibitLogin - } if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing username"), + JSON: jsonerror.MissingArgument("Missing username"), } } // Blank passwords are only allowed by registered application services if password == "" && appserviceID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing password"), + JSON: jsonerror.MissingArgument("Missing password"), } } var accRes userapi.PerformAccountCreationResponse