diff --git a/auth/auth.go b/auth/auth.go index 73d88a9..70b7c0f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -3,6 +3,10 @@ package auth import ( "errors" + "goweb/config" + "goweb/ent" + "goweb/ent/user" + "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" @@ -14,28 +18,39 @@ const ( sessionKeyAuthenticated = "authenticated" ) -func Login(c echo.Context, userID int) error { - sess, err := session.Get(sessionName, c) +type Client struct { + config *config.Config + orm *ent.Client +} + +func NewClient(cfg *config.Config, orm *ent.Client) *Client { + return &Client{ + config: cfg, + orm: orm, + } +} + +func (c *Client) Login(ctx echo.Context, userID int) error { + sess, err := session.Get(sessionName, ctx) if err != nil { return err } sess.Values[sessionKeyUserID] = userID sess.Values[sessionKeyAuthenticated] = true - // TODO: max age? - return sess.Save(c.Request(), c.Response()) + return sess.Save(ctx.Request(), ctx.Response()) } -func Logout(c echo.Context) error { - sess, err := session.Get(sessionName, c) +func (c *Client) Logout(ctx echo.Context) error { + sess, err := session.Get(sessionName, ctx) if err != nil { return err } sess.Values[sessionKeyAuthenticated] = false - return sess.Save(c.Request(), c.Response()) + return sess.Save(ctx.Request(), ctx.Response()) } -func GetUserID(c echo.Context) (int, error) { - sess, err := session.Get(sessionName, c) +func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) { + sess, err := session.Get(sessionName, ctx) if err != nil { return 0, err } @@ -47,7 +62,17 @@ func GetUserID(c echo.Context) (int, error) { return 0, errors.New("user not authenticated") } -func HashPassword(password string) (string, error) { +func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { + if userID, err := c.GetAuthenticatedUserID(ctx); err == nil { + return c.orm.User.Query(). + Where(user.ID(userID)). + First(ctx.Request().Context()) + } + + return nil, errors.New("user not authenticated") +} + +func (c *Client) HashPassword(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err @@ -55,6 +80,6 @@ func HashPassword(password string) (string, error) { return string(hash), nil } -func CheckPassword(password, hash string) error { +func (c *Client) CheckPassword(password, hash string) error { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) } diff --git a/container/container.go b/container/container.go index b5e1819..7f4aa68 100644 --- a/container/container.go +++ b/container/container.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" + "goweb/auth" "goweb/mail" "entgo.io/ent/dialect" @@ -27,6 +28,7 @@ type Container struct { Database *sql.DB ORM *ent.Client Mail *mail.Client + Auth *auth.Client } func NewContainer() *Container { @@ -37,6 +39,7 @@ func NewContainer() *Container { c.initDatabase() c.initORM() c.initMail() + c.initAuth() return c } @@ -125,3 +128,7 @@ func (c *Container) initMail() { panic(fmt.Sprintf("failed to create mail client: %v", err)) } } + +func (c *Container) initAuth() { + c.Auth = auth.NewClient(c.Config, c.ORM) +} diff --git a/controller/page.go b/controller/page.go index acba57f..bc948c3 100644 --- a/controller/page.go +++ b/controller/page.go @@ -5,7 +5,7 @@ import ( "net/http" "time" - "goweb/auth" + "goweb/context" "goweb/msg" "goweb/pager" @@ -62,7 +62,7 @@ func NewPage(c echo.Context) Page { p.CSRF = csrf.(string) } - if _, err := auth.GetUserID(c); err == nil { + if u := c.Get(context.AuthenticatedUserKey); u != nil { p.IsAuth = true } diff --git a/middleware/auth.go b/middleware/auth.go index 279da5a..ff02f27 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -6,25 +6,20 @@ import ( "goweb/auth" "goweb/context" "goweb/ent" - "goweb/ent/user" "github.com/labstack/echo/v4" ) -func LoadAuthenticatedUser(orm *ent.Client) echo.MiddlewareFunc { +func LoadAuthenticatedUser(authClient *auth.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if userID, err := auth.GetUserID(c); err == nil { - u, err := orm.User.Query(). - Where(user.ID(userID)). - First(c.Request().Context()) - + if user, err := authClient.GetAuthenticatedUser(c); err == nil { switch err.(type) { case *ent.NotFoundError: - c.Logger().Debug("auth user not found: %d", userID) + c.Logger().Debug("auth user not found") case nil: - c.Set(context.AuthenticatedUserKey, u) - c.Logger().Info("auth user loaded in to context: %d", userID) + c.Set(context.AuthenticatedUserKey, user) + c.Logger().Info("auth user loaded in to context: %d", user.ID) default: c.Logger().Errorf("error querying for authenticated user: %v", err) } diff --git a/routes/forgot_password.go b/routes/forgot_password.go index c57a34d..50346eb 100644 --- a/routes/forgot_password.go +++ b/routes/forgot_password.go @@ -3,6 +3,8 @@ package routes import ( "goweb/context" "goweb/controller" + "goweb/ent" + "goweb/ent/user" "goweb/msg" "github.com/labstack/echo/v4" @@ -39,6 +41,11 @@ func (f *ForgotPassword) Post(c echo.Context) error { return f.Get(c) } + succeed := func() error { + msg.Success(c, "An email containing a link to reset your password will be sent to this address if it exists in our system.") + return f.Get(c) + } + // Parse the form values form := new(ForgotPasswordForm) if err := c.Bind(form); err != nil { @@ -52,7 +59,25 @@ func (f *ForgotPassword) Post(c echo.Context) error { return f.Get(c) } + // Attempt to load the user + u, err := f.Container.ORM.User. + Query(). + Where(user.Email(form.Email)). + First(c.Request().Context()) + + if err != nil { + switch err.(type) { + case *ent.NotFoundError: + return succeed() + default: + return fail("error querying user during forgot password", err) + } + } + // TODO: generate and email a token + if u != nil { + + } return f.Redirect(c, "home") } diff --git a/routes/login.go b/routes/login.go index d4bec9f..3da625c 100644 --- a/routes/login.go +++ b/routes/login.go @@ -3,7 +3,6 @@ package routes import ( "fmt" - "goweb/auth" "goweb/context" "goweb/controller" "goweb/ent" @@ -75,14 +74,14 @@ func (l *Login) Post(c echo.Context) error { } // Check if the password is correct - err = auth.CheckPassword(form.Password, u.Password) + err = l.Container.Auth.CheckPassword(form.Password, u.Password) if err != nil { msg.Danger(c, "Invalid credentials. Please try again.") return l.Get(c) } // Log the user in - err = auth.Login(c, u.ID) + err = l.Container.Auth.Login(c, u.ID) if err != nil { return fail("unable to log in user", err) } diff --git a/routes/logout.go b/routes/logout.go index bcf7f6d..2338339 100644 --- a/routes/logout.go +++ b/routes/logout.go @@ -1,7 +1,6 @@ package routes import ( - "goweb/auth" "goweb/controller" "goweb/msg" @@ -13,7 +12,7 @@ type Logout struct { } func (l *Logout) Get(c echo.Context) error { - if err := auth.Logout(c); err == nil { + if err := l.Container.Auth.Logout(c); err == nil { msg.Success(c, "You have been logged out successfully.") } return l.Redirect(c, "home") diff --git a/routes/register.go b/routes/register.go index b8b7a16..264223b 100644 --- a/routes/register.go +++ b/routes/register.go @@ -1,7 +1,6 @@ package routes import ( - "goweb/auth" "goweb/context" "goweb/controller" "goweb/msg" @@ -57,7 +56,7 @@ func (r *Register) Post(c echo.Context) error { } // Hash the password - pwHash, err := auth.HashPassword(form.Password) + pwHash, err := r.Container.Auth.HashPassword(form.Password) if err != nil { return fail("unable to hash password", err) } @@ -76,7 +75,7 @@ func (r *Register) Post(c echo.Context) error { c.Logger().Infof("user created: %s", u.Name) - err = auth.Login(c, u.ID) + err = r.Container.Auth.Login(c, u.ID) if err != nil { c.Logger().Errorf("unable to log in: %v", err) msg.Info(c, "Your account has been created.") diff --git a/routes/router.go b/routes/router.go index 617913c..ff85605 100644 --- a/routes/router.go +++ b/routes/router.go @@ -57,7 +57,7 @@ func BuildRouter(c *container.Container) { echomw.CSRFWithConfig(echomw.CSRFConfig{ TokenLookup: "form:csrf", }), - middleware.LoadAuthenticatedUser(c.ORM), + middleware.LoadAuthenticatedUser(c.Auth), ) // Base controller