saasitone/pkg/services/db.go

121 lines
2.6 KiB
Go

package services
import (
"database/sql"
"fmt"
"log/slog"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/file"
"git.grosinger.net/tgrosinger/saasitone/config"
"git.grosinger.net/tgrosinger/saasitone/pkg/models"
"git.grosinger.net/tgrosinger/saasitone/pkg/models/sqlc"
)
type DBClient struct {
db *sql.DB
C *sqlc.Queries
User *models.DBUserClient
Post *models.DBPostClient
}
func NewDBClient(cfg *config.Config) (*DBClient, error) {
logger := slog.Default()
dbFilepath := cfg.Storage.DatabaseFile
if cfg.App.Environment == config.EnvTest {
// In memory only
dbFilepath = ":memory:"
}
logger.Info("Opening database file",
"filepath", dbFilepath)
fn := fmt.Sprintf("file:%s?_fk=1&_journal=WAL&cache=shared&_busy_timeout=5000", dbFilepath)
db, err := sql.Open("sqlite3", fn)
if err != nil {
return nil, err
}
client := DBClient{
db: db,
C: sqlc.New(db),
}
client.User = &models.DBUserClient{DB: db}
client.Post = &models.DBPostClient{DB: db}
migrationsDirPath := cfg.Storage.MigrationsDir
logger.Info("Loading schema migrations",
"filepath", migrationsDirPath)
err = client.initSchema(migrationsDirPath)
if err != nil {
return nil, err
}
return &client, nil
}
// initSchema ensures that the database is current with the migrations contained
// in db/migrations.
func (c *DBClient) initSchema(migrationsDir string) error {
driver, err := sqlite3.WithInstance(c.db, &sqlite3.Config{})
if err != nil {
return err
}
fSrc, err := (&file.File{}).Open(migrationsDir)
if err != nil {
fmt.Println("Got here 2: " + migrationsDir)
return err
}
m, err := migrate.NewWithInstance("file", fSrc, "sqlite", driver)
if err != nil {
return err
}
err = m.Up()
if err == migrate.ErrNoChange {
return nil
}
return err
}
// WithTx executes the provided callback with access to a database transaction.
// If the callback returns an error the transaction will be rolled back.
func (c *DBClient) WithTx(fn func(tx *sql.Tx) error) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
err = fn(tx)
if err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
func (c *DBClient) WithSqlcTx(fn func(*sqlc.Queries) error) error {
return c.WithTx(
func(tx *sql.Tx) error {
return fn(c.C.WithTx(tx))
},
)
}
// DB returns the underlying database object. Avoid whenever possible and use
// either sqlc (preferred) or sub-clients.
func (c *DBClient) DB() *sql.DB {
return c.db
}
func (c *DBClient) Close() error {
return c.db.Close()
}