diff --git a/controller/controller_test.go b/controller/controller_test.go index 7ca095d..7f1ec1f 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -6,21 +6,18 @@ import ( "net/http" "net/http/httptest" "os" - "strings" "testing" "goweb/config" "goweb/middleware" "goweb/msg" "goweb/services" + "goweb/tests" "github.com/eko/gocache/v2/store" "github.com/eko/gocache/v2/marshaler" - "github.com/gorilla/sessions" - "github.com/labstack/echo-contrib/session" - "github.com/go-playground/validator/v10" "github.com/stretchr/testify/assert" @@ -50,21 +47,8 @@ func TestMain(m *testing.M) { os.Exit(exitVal) } -func newContext(url string) (echo.Context, *httptest.ResponseRecorder) { - req := httptest.NewRequest(http.MethodGet, url, strings.NewReader("")) - rec := httptest.NewRecorder() - return c.Web.NewContext(req, rec), rec -} - -func initSesssion(t *testing.T, ctx echo.Context) { - // Simulate an HTTP request through the session middleware to initiate the session - mw := session.Middleware(sessions.NewCookieStore([]byte("secret"))) - handler := mw(echo.NotFoundHandler) - assert.Error(t, handler(ctx)) -} - func TestController_Redirect(t *testing.T) { - ctx, _ := newContext("/abc") + ctx, _ := tests.NewContext(c.Web, "/abc") ctr := NewController(c) err := ctr.Redirect(ctx, "home") require.NoError(t, err) @@ -81,8 +65,8 @@ func TestController_SetValidationErrorMessages(t *testing.T) { err := v.Struct(e) require.Error(t, err) - ctx, _ := newContext("/") - initSesssion(t, ctx) + ctx, _ := tests.NewContext(c.Web, "/") + tests.InitSession(ctx) ctr := NewController(c) ctr.SetValidationErrorMessages(ctx, err, e) @@ -93,8 +77,8 @@ func TestController_SetValidationErrorMessages(t *testing.T) { func TestController_RenderPage(t *testing.T) { setup := func() (echo.Context, *httptest.ResponseRecorder, Controller, Page) { - ctx, rec := newContext("/test/TestController_RenderPage") - initSesssion(t, ctx) + ctx, rec := tests.NewContext(c.Web, "/test/TestController_RenderPage") + tests.InitSession(ctx) ctr := NewController(c) p := NewPage(ctx) diff --git a/controller/page_test.go b/controller/page_test.go index e4a0fdb..e45bbe6 100644 --- a/controller/page_test.go +++ b/controller/page_test.go @@ -6,13 +6,14 @@ import ( "goweb/context" "goweb/msg" + "goweb/tests" echomw "github.com/labstack/echo/v4/middleware" "github.com/stretchr/testify/assert" ) func TestNewPage(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") p := NewPage(ctx) assert.Same(t, ctx, p.Context) assert.NotNil(t, p.ToURL) @@ -27,7 +28,7 @@ func TestNewPage(t *testing.T) { assert.Empty(t, p.RequestID) assert.False(t, p.Cache.Enabled) - ctx, _ = newContext("/abc?def=123") + ctx, _ = tests.NewContext(c.Web, "/abc?def=123") ctx.Set(context.AuthenticatedUserKey, 1) ctx.Set(echomw.DefaultCSRFConfig.ContextKey, "csrf") p = NewPage(ctx) @@ -39,9 +40,9 @@ func TestNewPage(t *testing.T) { } func TestPage_GetMessages(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") + tests.InitSession(ctx) p := NewPage(ctx) - initSesssion(t, ctx) // Set messages msgTests := make(map[msg.Type][]string) diff --git a/controller/pager_test.go b/controller/pager_test.go index 8f82010..f7fb89d 100644 --- a/controller/pager_test.go +++ b/controller/pager_test.go @@ -4,28 +4,30 @@ import ( "fmt" "testing" + "goweb/tests" + "github.com/stretchr/testify/assert" ) func TestNewPager(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") pgr := NewPager(ctx, 10) assert.Equal(t, 10, pgr.ItemsPerPage) assert.Equal(t, 1, pgr.Page) assert.Equal(t, 0, pgr.Items) assert.Equal(t, 0, pgr.Pages) - ctx, _ = newContext(fmt.Sprintf("/abc?%s=%d", PageQueryKey, 2)) + ctx, _ = tests.NewContext(c.Web, fmt.Sprintf("/abc?%s=%d", PageQueryKey, 2)) pgr = NewPager(ctx, 10) assert.Equal(t, 2, pgr.Page) - ctx, _ = newContext(fmt.Sprintf("/abc?%s=%d", PageQueryKey, -2)) + ctx, _ = tests.NewContext(c.Web, fmt.Sprintf("/abc?%s=%d", PageQueryKey, -2)) pgr = NewPager(ctx, 10) assert.Equal(t, 1, pgr.Page) } func TestPager_SetItems(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") pgr := NewPager(ctx, 20) pgr.SetItems(100) assert.Equal(t, 100, pgr.Items) @@ -33,7 +35,7 @@ func TestPager_SetItems(t *testing.T) { } func TestPager_IsBeginning(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") pgr := NewPager(ctx, 20) pgr.Pages = 10 assert.True(t, pgr.IsBeginning()) @@ -44,7 +46,7 @@ func TestPager_IsBeginning(t *testing.T) { } func TestPager_IsEnd(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") pgr := NewPager(ctx, 20) pgr.Pages = 10 assert.False(t, pgr.IsEnd()) @@ -55,7 +57,7 @@ func TestPager_IsEnd(t *testing.T) { } func TestPager_GetOffset(t *testing.T) { - ctx, _ := newContext("/") + ctx, _ := tests.NewContext(c.Web, "/") pgr := NewPager(ctx, 20) assert.Equal(t, 0, pgr.GetOffset()) pgr.Page = 2 diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..f5e3d73 --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,85 @@ +package middleware + +import ( + "net/http" + "testing" + + "goweb/context" + "goweb/ent" + "goweb/tests" + + "github.com/labstack/echo/v4" + + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" +) + +func TestLoadAuthenticatedUser(t *testing.T) { + ctx, _ := tests.NewContext(c.Web, "/") + tests.InitSession(ctx) + mw := LoadAuthenticatedUser(c.Auth) + + // Not authenticated + _ = tests.ExecuteMiddleware(ctx, mw) + assert.Nil(t, ctx.Get(context.AuthenticatedUserKey)) + + // Login + err := c.Auth.Login(ctx, usr.ID) + require.NoError(t, err) + + // Verify the midldeware returns the authenticated user + _ = tests.ExecuteMiddleware(ctx, mw) + require.NotNil(t, ctx.Get(context.AuthenticatedUserKey)) + ctxUsr, ok := ctx.Get(context.AuthenticatedUserKey).(*ent.User) + require.True(t, ok) + assert.Equal(t, usr.ID, ctxUsr.ID) +} + +func TestRequireAuthentication(t *testing.T) { + ctx, _ := tests.NewContext(c.Web, "/") + tests.InitSession(ctx) + + // Not logged in + err := tests.ExecuteMiddleware(ctx, RequireAuthentication()) + httpError, ok := err.(*echo.HTTPError) + require.True(t, ok) + assert.Equal(t, http.StatusUnauthorized, httpError.Code) + + // Login + err = c.Auth.Login(ctx, usr.ID) + require.NoError(t, err) + _ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth)) + + // Logged in + err = tests.ExecuteMiddleware(ctx, RequireAuthentication()) + httpError, ok = err.(*echo.HTTPError) + require.True(t, ok) + assert.NotEqual(t, http.StatusUnauthorized, httpError.Code) +} + +func TestRequireNoAuthentication(t *testing.T) { + ctx, _ := tests.NewContext(c.Web, "/") + tests.InitSession(ctx) + + // Not logged in + err := tests.ExecuteMiddleware(ctx, RequireNoAuthentication()) + httpError, ok := err.(*echo.HTTPError) + require.True(t, ok) + assert.NotEqual(t, http.StatusForbidden, httpError.Code) + + // Login + err = c.Auth.Login(ctx, usr.ID) + require.NoError(t, err) + _ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth)) + + // Logged in + err = tests.ExecuteMiddleware(ctx, RequireNoAuthentication()) + httpError, ok = err.(*echo.HTTPError) + require.True(t, ok) + assert.Equal(t, http.StatusForbidden, httpError.Code) +} + +func TestLoadValidPasswordToken(t *testing.T) { + +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..2aa60c8 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "os" + "testing" + + "goweb/config" + "goweb/ent" + "goweb/services" + "goweb/tests" +) + +var ( + c *services.Container + usr *ent.User +) + +func TestMain(m *testing.M) { + // Set the environment to test + config.SwitchEnvironment(config.EnvTest) + + // Create a new container + c = services.NewContainer() + defer func() { + if err := c.Shutdown(); err != nil { + c.Web.Logger.Fatal(err) + } + }() + + // Create a user + var err error + if usr, err = tests.CreateUser(c.ORM); err != nil { + panic(err) + } + + // Run tests + exitVal := m.Run() + os.Exit(exitVal) +} diff --git a/services/auth_test.go b/services/auth_test.go index 9032b60..806415d 100644 --- a/services/auth_test.go +++ b/services/auth_test.go @@ -9,20 +9,12 @@ import ( "goweb/ent/passwordtoken" "goweb/ent/user" - "github.com/gorilla/sessions" - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" ) func TestAuth(t *testing.T) { - // Simulate an HTTP request through the session middleware to initiate the session - mw := session.Middleware(sessions.NewCookieStore([]byte("secret"))) - handler := mw(echo.NotFoundHandler) - assert.Error(t, handler(ctx)) - assertNoAuth := func() { _, err := c.Auth.GetAuthenticatedUserID(ctx) assert.True(t, errors.Is(err, NotAuthenticatedError{})) diff --git a/services/mail.go b/services/mail.go index a907b56..69a7e81 100644 --- a/services/mail.go +++ b/services/mail.go @@ -45,20 +45,16 @@ func (c *MailClient) SendTemplate(ctx echo.Context, to, template string, data in ctx.Logger().Debugf("skipping template email sent to: %s") } - // Parse the template, if needed - if err := c.templates.Parse( + // Parse and execute template + // Uncomment the first variable when ready to use + _, err := c.templates.ParseAndExecute( "mail", template, template, []string{fmt.Sprintf("email/%s", template)}, []string{}, - ); err != nil { - return err - } - - // Execute the template - // Uncomment the first variable when ready to use - _, err := c.templates.Execute("mail", template, template, data) + data, + ) if err != nil { return err } diff --git a/services/services_test.go b/services/services_test.go index ea02a90..a67fb07 100644 --- a/services/services_test.go +++ b/services/services_test.go @@ -1,15 +1,12 @@ package services import ( - "context" - "net/http" - "net/http/httptest" "os" - "strings" "testing" "goweb/config" "goweb/ent" + "goweb/tests" "github.com/labstack/echo/v4" ) @@ -18,7 +15,6 @@ var ( c *Container ctx echo.Context usr *ent.User - rec *httptest.ResponseRecorder ) func TestMain(m *testing.M) { @@ -34,20 +30,12 @@ func TestMain(m *testing.M) { }() // Create a web context - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) - rec = httptest.NewRecorder() - ctx = c.Web.NewContext(req, rec) + ctx, _ = tests.NewContext(c.Web, "/") + tests.InitSession(ctx) // Create a test user var err error - usr, err = c.ORM.User. - Create(). - SetEmail("test@test.dev"). - SetPassword("abc"). - SetName("Test User"). - Save(context.Background()) - - if err != nil { + if usr, err = tests.CreateUser(c.ORM); err != nil { panic(err) } diff --git a/tests/tests.go b/tests/tests.go new file mode 100644 index 0000000..e512c38 --- /dev/null +++ b/tests/tests.go @@ -0,0 +1,44 @@ +package tests + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "time" + + "goweb/ent" + + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/gorilla/sessions" + "github.com/labstack/echo-contrib/session" + "github.com/labstack/echo/v4" +) + +func NewContext(e *echo.Echo, url string) (echo.Context, *httptest.ResponseRecorder) { + req := httptest.NewRequest(http.MethodGet, url, strings.NewReader("")) + rec := httptest.NewRecorder() + return e.NewContext(req, rec), rec +} + +func InitSession(ctx echo.Context) { + mw := session.Middleware(sessions.NewCookieStore([]byte("secret"))) + _ = ExecuteMiddleware(ctx, mw) +} + +func ExecuteMiddleware(ctx echo.Context, mw echo.MiddlewareFunc) error { + handler := mw(echo.NotFoundHandler) + return handler(ctx) +} + +func CreateUser(orm *ent.Client) (*ent.User, error) { + seed := fmt.Sprintf("%d-%d", time.Now().UnixMilli(), rand.IntnRange(10, 1000000)) + return orm.User. + Create(). + SetEmail(fmt.Sprintf("testuser-%s@localhost.localhost", seed)). + SetPassword("password"). + SetName(fmt.Sprintf("Test User %s", seed)). + Save(context.Background()) +}