From ed9413ee17ed8a05e8f59d42b2be7da509bee0b0 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Thu, 16 Dec 2021 21:27:52 -0500 Subject: [PATCH] Finished password reset workflow. Remove all password tokens upon successful reset. --- auth/auth.go | 9 ++++++ context/context.go | 1 + middleware/auth.go | 11 ++++--- middleware/entity.go | 41 ++++++++++++++++++++++++ routes/forgot_password.go | 16 +++++----- routes/login.go | 15 +++++---- routes/reset_password.go | 65 +++++++++++++++++++-------------------- routes/router.go | 5 ++- 8 files changed, 107 insertions(+), 56 deletions(-) create mode 100644 middleware/entity.go diff --git a/auth/auth.go b/auth/auth.go index 3421b53..b99580f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -152,6 +152,15 @@ func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID in return nil, InvalidTokenError{} } +func (c *Client) DeletePasswordTokens(ctx echo.Context, userID int) error { + _, err := c.orm.PasswordToken. + Delete(). + Where(passwordtoken.HasUserWith(user.ID(userID))). + Exec(ctx.Request().Context()) + + return err +} + func (c *Client) RandomToken(length int) (string, error) { b := make([]byte, length) if _, err := rand.Read(b); err != nil { diff --git a/context/context.go b/context/context.go index b1c5290..92a3b96 100644 --- a/context/context.go +++ b/context/context.go @@ -2,6 +2,7 @@ package context const ( AuthenticatedUserKey = "auth_user" + UserKey = "user" FormKey = "form" PasswordTokenKey = "password_token" ) diff --git a/middleware/auth.go b/middleware/auth.go index 58a834d..12c46e3 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "strconv" "goweb/auth" "goweb/context" @@ -35,14 +34,16 @@ func LoadAuthenticatedUser(authClient *auth.Client) echo.MiddlewareFunc { func LoadValidPasswordToken(authClient *auth.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID, err := strconv.Atoi(c.Param("user")) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, "Not found") + var usr *ent.User + + if c.Get(context.UserKey) == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Internal server error") } + usr = c.Get(context.UserKey).(*ent.User) tokenParam := c.Param("password_token") + token, err := authClient.GetValidPasswordToken(c, tokenParam, usr.ID) - token, err := authClient.GetValidPasswordToken(c, tokenParam, userID) switch err.(type) { case nil: case auth.InvalidTokenError: diff --git a/middleware/entity.go b/middleware/entity.go new file mode 100644 index 0000000..e7e4608 --- /dev/null +++ b/middleware/entity.go @@ -0,0 +1,41 @@ +package middleware + +import ( + "net/http" + "strconv" + + "goweb/context" + "goweb/ent" + "goweb/ent/user" + + "github.com/labstack/echo/v4" +) + +func LoadUser(orm *ent.Client) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + userID, err := strconv.Atoi(c.Param("user")) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "Not found") + } + + u, err := orm.User. + Query(). + Where(user.ID(userID)). + Only(c.Request().Context()) + + switch err.(type) { + case nil: + case *ent.NotFoundError: + return echo.NewHTTPError(http.StatusNotFound, "Not found") + default: + c.Logger().Error(err) + return echo.NewHTTPError(http.StatusInternalServerError, "Internal server error") + } + + c.Set(context.UserKey, u) + + return next(c) + } + } +} diff --git a/routes/forgot_password.go b/routes/forgot_password.go index a7223ce..f7a19d3 100644 --- a/routes/forgot_password.go +++ b/routes/forgot_password.go @@ -68,13 +68,12 @@ func (f *ForgotPassword) Post(c echo.Context) error { Where(user.Email(form.Email)). Only(c.Request().Context()) - if err != nil { - switch err.(type) { - case *ent.NotFoundError: - return succeed() - default: - return fail("error querying user during forgot password", err) - } + switch err.(type) { + case *ent.NotFoundError: + return succeed() + case nil: + default: + return fail("error querying user during forgot password", err) } // Generate the token @@ -85,7 +84,8 @@ func (f *ForgotPassword) Post(c echo.Context) error { c.Logger().Infof("generated password reset token for user %d", u.ID) // Email the user - err = f.Container.Mail.Send(c, u.Email, fmt.Sprintf("Go here to reset your password: %s", token)) // TODO: route + // TODO: better email + err = f.Container.Mail.Send(c, u.Email, fmt.Sprintf("Go here to reset your password: %s", c.Echo().Reverse("reset_password", u.ID, token))) if err != nil { return fail("error sending password reset email", err) } diff --git a/routes/login.go b/routes/login.go index 277a3b8..f45520f 100644 --- a/routes/login.go +++ b/routes/login.go @@ -63,14 +63,13 @@ func (l *Login) Post(c echo.Context) error { Where(user.Email(form.Email)). Only(c.Request().Context()) - if err != nil { - switch err.(type) { - case *ent.NotFoundError: - msg.Danger(c, "Invalid credentials. Please try again.") - return l.Get(c) - default: - return fail("error querying user during login", err) - } + switch err.(type) { + case *ent.NotFoundError: + msg.Danger(c, "Invalid credentials. Please try again.") + return l.Get(c) + case nil: + default: + return fail("error querying user during login", err) } // Check if the password is correct diff --git a/routes/reset_password.go b/routes/reset_password.go index cee3df5..4024da5 100644 --- a/routes/reset_password.go +++ b/routes/reset_password.go @@ -1,7 +1,10 @@ package routes import ( + "goweb/context" "goweb/controller" + "goweb/ent" + "goweb/ent/user" "goweb/msg" "github.com/labstack/echo/v4" @@ -33,13 +36,8 @@ func (r *ResetPassword) Post(c echo.Context) error { return r.Get(c) } - succeed := func() error { - msg.Success(c, "Your password has been updated.") - return r.Redirect(c, "login") - } - // Parse the form values - form := new(ResetPassword) + form := new(ResetPasswordForm) if err := c.Bind(form); err != nil { return fail("unable to parse forgot password form", err) } @@ -50,33 +48,32 @@ func (r *ResetPassword) Post(c echo.Context) error { return r.Get(c) } - // Attempt to load the user - //u, err := f.Container.ORM.User. - // Query(). - // Where(user.Email(form.Email)). - // First(c.Request().Context()) - // - //if err != nil { - // switch err.(type) { - // case *ent.NotFoundError: - // return succeed() - // default: - // return fail("error querying user during forgot password", err) - // } - //} - // - //// Generate the token - //token, _, err := f.Container.Auth.GeneratePasswordResetToken(c, u.ID) - //if err != nil { - // return fail("error generating password reset token", err) - //} - //c.Logger().Infof("generated password reset token for user %d", u.ID) - // - //// Email the user - //err = f.Container.Mail.Send(c, u.Email, fmt.Sprintf("Go here to reset your password: %s", token)) // TODO: route - //if err != nil { - // return fail("error sending password reset email", err) - //} + // Hash the new password + hash, err := r.Container.Auth.HashPassword(form.Password) + if err != nil { + return fail("unable to hash password", err) + } - return succeed() + // Get the requesting user + usr := c.Get(context.UserKey).(*ent.User) + + // Update the user + _, err = r.Container.ORM.User. + Update(). + SetPassword(hash). + Where(user.ID(usr.ID)). + Save(c.Request().Context()) + + if err != nil { + return fail("unable to update password", err) + } + + // Delete all password tokens for this user + err = r.Container.Auth.DeletePasswordTokens(c, usr.ID) + if err != nil { + return fail("unable to delete password tokens", err) + } + + msg.Success(c, "Your password has been updated.") + return r.Redirect(c, "login") } diff --git a/routes/router.go b/routes/router.go index c5b2d27..825523f 100644 --- a/routes/router.go +++ b/routes/router.go @@ -104,7 +104,10 @@ func userRoutes(c *container.Container, g *echo.Group, ctr controller.Controller noAuth.GET("/password", forgot.Get).Name = "forgot_password" noAuth.POST("/password", forgot.Post).Name = "forgot_password.post" - resetGroup := noAuth.Group("/password/reset", middleware.LoadValidPasswordToken(c.Auth)) + resetGroup := noAuth.Group("/password/reset", + middleware.LoadUser(c.ORM), + middleware.LoadValidPasswordToken(c.Auth), + ) reset := ResetPassword{Controller: ctr} resetGroup.GET("/token/:user/:password_token", reset.Get).Name = "reset_password" resetGroup.POST("/token/:user/:password_token", reset.Post).Name = "reset_password.post"