Added auth to the container.

This commit is contained in:
mikestefanello 2021-12-15 09:29:43 -05:00
parent c9d50cb3d4
commit a33a76f8bc
9 changed files with 81 additions and 32 deletions

View File

@ -3,6 +3,10 @@ package auth
import ( import (
"errors" "errors"
"goweb/config"
"goweb/ent"
"goweb/ent/user"
"github.com/labstack/echo-contrib/session" "github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -14,28 +18,39 @@ const (
sessionKeyAuthenticated = "authenticated" sessionKeyAuthenticated = "authenticated"
) )
func Login(c echo.Context, userID int) error { type Client struct {
sess, err := session.Get(sessionName, c) config *config.Config
orm *ent.Client
}
func NewClient(cfg *config.Config, orm *ent.Client) *Client {
return &Client{
config: cfg,
orm: orm,
}
}
func (c *Client) Login(ctx echo.Context, userID int) error {
sess, err := session.Get(sessionName, ctx)
if err != nil { if err != nil {
return err return err
} }
sess.Values[sessionKeyUserID] = userID sess.Values[sessionKeyUserID] = userID
sess.Values[sessionKeyAuthenticated] = true sess.Values[sessionKeyAuthenticated] = true
// TODO: max age? return sess.Save(ctx.Request(), ctx.Response())
return sess.Save(c.Request(), c.Response())
} }
func Logout(c echo.Context) error { func (c *Client) Logout(ctx echo.Context) error {
sess, err := session.Get(sessionName, c) sess, err := session.Get(sessionName, ctx)
if err != nil { if err != nil {
return err return err
} }
sess.Values[sessionKeyAuthenticated] = false sess.Values[sessionKeyAuthenticated] = false
return sess.Save(c.Request(), c.Response()) return sess.Save(ctx.Request(), ctx.Response())
} }
func GetUserID(c echo.Context) (int, error) { func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) {
sess, err := session.Get(sessionName, c) sess, err := session.Get(sessionName, ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -47,7 +62,17 @@ func GetUserID(c echo.Context) (int, error) {
return 0, errors.New("user not authenticated") return 0, errors.New("user not authenticated")
} }
func HashPassword(password string) (string, error) { func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) {
if userID, err := c.GetAuthenticatedUserID(ctx); err == nil {
return c.orm.User.Query().
Where(user.ID(userID)).
First(ctx.Request().Context())
}
return nil, errors.New("user not authenticated")
}
func (c *Client) HashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return "", err return "", err
@ -55,6 +80,6 @@ func HashPassword(password string) (string, error) {
return string(hash), nil return string(hash), nil
} }
func CheckPassword(password, hash string) error { func (c *Client) CheckPassword(password, hash string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
} }

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"goweb/auth"
"goweb/mail" "goweb/mail"
"entgo.io/ent/dialect" "entgo.io/ent/dialect"
@ -27,6 +28,7 @@ type Container struct {
Database *sql.DB Database *sql.DB
ORM *ent.Client ORM *ent.Client
Mail *mail.Client Mail *mail.Client
Auth *auth.Client
} }
func NewContainer() *Container { func NewContainer() *Container {
@ -37,6 +39,7 @@ func NewContainer() *Container {
c.initDatabase() c.initDatabase()
c.initORM() c.initORM()
c.initMail() c.initMail()
c.initAuth()
return c return c
} }
@ -125,3 +128,7 @@ func (c *Container) initMail() {
panic(fmt.Sprintf("failed to create mail client: %v", err)) panic(fmt.Sprintf("failed to create mail client: %v", err))
} }
} }
func (c *Container) initAuth() {
c.Auth = auth.NewClient(c.Config, c.ORM)
}

View File

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"time" "time"
"goweb/auth" "goweb/context"
"goweb/msg" "goweb/msg"
"goweb/pager" "goweb/pager"
@ -62,7 +62,7 @@ func NewPage(c echo.Context) Page {
p.CSRF = csrf.(string) p.CSRF = csrf.(string)
} }
if _, err := auth.GetUserID(c); err == nil { if u := c.Get(context.AuthenticatedUserKey); u != nil {
p.IsAuth = true p.IsAuth = true
} }

View File

@ -6,25 +6,20 @@ import (
"goweb/auth" "goweb/auth"
"goweb/context" "goweb/context"
"goweb/ent" "goweb/ent"
"goweb/ent/user"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
func LoadAuthenticatedUser(orm *ent.Client) echo.MiddlewareFunc { func LoadAuthenticatedUser(authClient *auth.Client) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
if userID, err := auth.GetUserID(c); err == nil { if user, err := authClient.GetAuthenticatedUser(c); err == nil {
u, err := orm.User.Query().
Where(user.ID(userID)).
First(c.Request().Context())
switch err.(type) { switch err.(type) {
case *ent.NotFoundError: case *ent.NotFoundError:
c.Logger().Debug("auth user not found: %d", userID) c.Logger().Debug("auth user not found")
case nil: case nil:
c.Set(context.AuthenticatedUserKey, u) c.Set(context.AuthenticatedUserKey, user)
c.Logger().Info("auth user loaded in to context: %d", userID) c.Logger().Info("auth user loaded in to context: %d", user.ID)
default: default:
c.Logger().Errorf("error querying for authenticated user: %v", err) c.Logger().Errorf("error querying for authenticated user: %v", err)
} }

View File

@ -3,6 +3,8 @@ package routes
import ( import (
"goweb/context" "goweb/context"
"goweb/controller" "goweb/controller"
"goweb/ent"
"goweb/ent/user"
"goweb/msg" "goweb/msg"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -39,6 +41,11 @@ func (f *ForgotPassword) Post(c echo.Context) error {
return f.Get(c) return f.Get(c)
} }
succeed := func() error {
msg.Success(c, "An email containing a link to reset your password will be sent to this address if it exists in our system.")
return f.Get(c)
}
// Parse the form values // Parse the form values
form := new(ForgotPasswordForm) form := new(ForgotPasswordForm)
if err := c.Bind(form); err != nil { if err := c.Bind(form); err != nil {
@ -52,7 +59,25 @@ func (f *ForgotPassword) Post(c echo.Context) error {
return f.Get(c) return f.Get(c)
} }
// Attempt to load the user
u, err := f.Container.ORM.User.
Query().
Where(user.Email(form.Email)).
First(c.Request().Context())
if err != nil {
switch err.(type) {
case *ent.NotFoundError:
return succeed()
default:
return fail("error querying user during forgot password", err)
}
}
// TODO: generate and email a token // TODO: generate and email a token
if u != nil {
}
return f.Redirect(c, "home") return f.Redirect(c, "home")
} }

View File

@ -3,7 +3,6 @@ package routes
import ( import (
"fmt" "fmt"
"goweb/auth"
"goweb/context" "goweb/context"
"goweb/controller" "goweb/controller"
"goweb/ent" "goweb/ent"
@ -75,14 +74,14 @@ func (l *Login) Post(c echo.Context) error {
} }
// Check if the password is correct // Check if the password is correct
err = auth.CheckPassword(form.Password, u.Password) err = l.Container.Auth.CheckPassword(form.Password, u.Password)
if err != nil { if err != nil {
msg.Danger(c, "Invalid credentials. Please try again.") msg.Danger(c, "Invalid credentials. Please try again.")
return l.Get(c) return l.Get(c)
} }
// Log the user in // Log the user in
err = auth.Login(c, u.ID) err = l.Container.Auth.Login(c, u.ID)
if err != nil { if err != nil {
return fail("unable to log in user", err) return fail("unable to log in user", err)
} }

View File

@ -1,7 +1,6 @@
package routes package routes
import ( import (
"goweb/auth"
"goweb/controller" "goweb/controller"
"goweb/msg" "goweb/msg"
@ -13,7 +12,7 @@ type Logout struct {
} }
func (l *Logout) Get(c echo.Context) error { func (l *Logout) Get(c echo.Context) error {
if err := auth.Logout(c); err == nil { if err := l.Container.Auth.Logout(c); err == nil {
msg.Success(c, "You have been logged out successfully.") msg.Success(c, "You have been logged out successfully.")
} }
return l.Redirect(c, "home") return l.Redirect(c, "home")

View File

@ -1,7 +1,6 @@
package routes package routes
import ( import (
"goweb/auth"
"goweb/context" "goweb/context"
"goweb/controller" "goweb/controller"
"goweb/msg" "goweb/msg"
@ -57,7 +56,7 @@ func (r *Register) Post(c echo.Context) error {
} }
// Hash the password // Hash the password
pwHash, err := auth.HashPassword(form.Password) pwHash, err := r.Container.Auth.HashPassword(form.Password)
if err != nil { if err != nil {
return fail("unable to hash password", err) return fail("unable to hash password", err)
} }
@ -76,7 +75,7 @@ func (r *Register) Post(c echo.Context) error {
c.Logger().Infof("user created: %s", u.Name) c.Logger().Infof("user created: %s", u.Name)
err = auth.Login(c, u.ID) err = r.Container.Auth.Login(c, u.ID)
if err != nil { if err != nil {
c.Logger().Errorf("unable to log in: %v", err) c.Logger().Errorf("unable to log in: %v", err)
msg.Info(c, "Your account has been created.") msg.Info(c, "Your account has been created.")

View File

@ -57,7 +57,7 @@ func BuildRouter(c *container.Container) {
echomw.CSRFWithConfig(echomw.CSRFConfig{ echomw.CSRFWithConfig(echomw.CSRFConfig{
TokenLookup: "form:csrf", TokenLookup: "form:csrf",
}), }),
middleware.LoadAuthenticatedUser(c.ORM), middleware.LoadAuthenticatedUser(c.Auth),
) )
// Base controller // Base controller