diff --git a/pkg/controller/controller_test.go b/pkg/controller/controller_test.go index 60c3063..d042c69 100644 --- a/pkg/controller/controller_test.go +++ b/pkg/controller/controller_test.go @@ -99,7 +99,7 @@ func TestController_RenderPage(t *testing.T) { expectedTemplates := make(map[string]bool) expectedTemplates[p.Name+config.TemplateExt] = true expectedTemplates[p.Layout+config.TemplateExt] = true - components, err := templates.Templates.ReadDir("components") + components, err := templates.Get().ReadDir("components") require.NoError(t, err) for _, f := range components { expectedTemplates[f.Name()] = true @@ -132,7 +132,7 @@ func TestController_RenderPage(t *testing.T) { expectedTemplates := make(map[string]bool) expectedTemplates[p.Name+config.TemplateExt] = true expectedTemplates["htmx"+config.TemplateExt] = true - components, err := templates.Templates.ReadDir("components") + components, err := templates.Get().ReadDir("components") require.NoError(t, err) for _, f := range components { expectedTemplates[f.Name()] = true diff --git a/pkg/services/template_renderer.go b/pkg/services/template_renderer.go index de60a0e..d21bbb7 100644 --- a/pkg/services/template_renderer.go +++ b/pkg/services/template_renderer.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "html/template" + "io/fs" "sync" "github.com/mikestefanello/pagoda/config" @@ -103,25 +104,28 @@ func (t *TemplateRenderer) parse(build *templateBuild) (*TemplateParsed, error) parsed := template.New(build.base + config.TemplateExt). Funcs(t.funcMap) - // Parse all files provided - if len(build.files) > 0 { - for k, v := range build.files { - build.files[k] = fmt.Sprintf("%s%s", v, config.TemplateExt) - } - - parsed, err = parsed.ParseFS(templates.Templates, build.files...) - if err != nil { - return nil, err - } + // Format the requested files + for k, v := range build.files { + build.files[k] = fmt.Sprintf("%s%s", v, config.TemplateExt) } - // Parse all templates within the provided directories - for _, dir := range build.directories { - dir = fmt.Sprintf("%s/*%s", dir, config.TemplateExt) - parsed, err = parsed.ParseFS(templates.Templates, dir) - if err != nil { - return nil, err - } + // Include all files within the requested directories + for k, v := range build.directories { + build.directories[k] = fmt.Sprintf("%s/*%s", v, config.TemplateExt) + } + + // Get the templates + var tpl fs.FS + if t.config.App.Environment == config.EnvLocal { + tpl = templates.GetOS() + } else { + tpl = templates.Get() + } + + // Parse the templates + parsed, err = parsed.ParseFS(tpl, append(build.files, build.directories...)...) + if err != nil { + return nil, err } // Store the template so this process only happens once diff --git a/pkg/services/template_renderer_test.go b/pkg/services/template_renderer_test.go index 6e6805c..3c8d8f8 100644 --- a/pkg/services/template_renderer_test.go +++ b/pkg/services/template_renderer_test.go @@ -37,7 +37,7 @@ func TestTemplateRenderer(t *testing.T) { expectedTemplates := make(map[string]bool) expectedTemplates["htmx"+config.TemplateExt] = true expectedTemplates["error"+config.TemplateExt] = true - components, err := templates.Templates.ReadDir("components") + components, err := templates.Get().ReadDir("components") require.NoError(t, err) for _, f := range components { expectedTemplates[f.Name()] = true diff --git a/templates/templates.go b/templates/templates.go index 83a8320..257a286 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -2,7 +2,28 @@ package templates import ( "embed" + "io/fs" + "os" + "path" + "path/filepath" + "runtime" ) //go:embed * -var Templates embed.FS +var templates embed.FS + +// Get returns a file system containing all templates via embed.FS +func Get() embed.FS { + return templates +} + +// GetOS returns a file system containing all templates which will load the files directly from the operating system. +// This should only be used for local development in order to faciliate live reloading. +func GetOS() fs.FS { + // Gets the complete templates directory path + // This is needed in case this is called from a package outside of main, such as within tests + _, b, _, _ := runtime.Caller(0) + d := path.Join(path.Dir(b)) + p := filepath.Join(filepath.Dir(d), "templates") + return os.DirFS(p) +} diff --git a/templates/templates_test.go b/templates/templates_test.go new file mode 100644 index 0000000..44abec0 --- /dev/null +++ b/templates/templates_test.go @@ -0,0 +1,17 @@ +package templates + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGet(t *testing.T) { + _, err := Get().Open("pages/home.gohtml") + require.NoError(t, err) +} + +func TestGetOS(t *testing.T) { + _, err := GetOS().Open("pages/home.gohtml") + require.NoError(t, err) +}