package services import ( "database/sql" "fmt" "log/slog" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/source/file" "git.grosinger.net/tgrosinger/saasitone/config" "git.grosinger.net/tgrosinger/saasitone/pkg/models" "git.grosinger.net/tgrosinger/saasitone/pkg/models/sqlc" ) type DBClient struct { db *sql.DB C *sqlc.Queries User *models.DBUserClient Post *models.DBPostClient } func NewDBClient(cfg *config.Config) (*DBClient, error) { logger := slog.Default() dbFilepath := cfg.Storage.DatabaseFile if cfg.App.Environment == config.EnvTest { // In memory only dbFilepath = ":memory:" } logger.Info("Opening database file", "filepath", dbFilepath) fn := fmt.Sprintf("file:%s?_fk=1&_journal=WAL&cache=shared&_busy_timeout=5000", dbFilepath) db, err := sql.Open("sqlite3", fn) if err != nil { return nil, err } client := DBClient{ db: db, C: sqlc.New(db), } client.User = &models.DBUserClient{DB: db} client.Post = &models.DBPostClient{DB: db} migrationsDirPath := cfg.Storage.MigrationsDir logger.Info("Loading schema migrations", "filepath", migrationsDirPath) err = client.initSchema(migrationsDirPath) if err != nil { return nil, err } return &client, nil } // initSchema ensures that the database is current with the migrations contained // in db/migrations. func (c *DBClient) initSchema(migrationsDir string) error { driver, err := sqlite3.WithInstance(c.db, &sqlite3.Config{}) if err != nil { return err } fSrc, err := (&file.File{}).Open(migrationsDir) if err != nil { fmt.Println("Got here 2: " + migrationsDir) return err } m, err := migrate.NewWithInstance("file", fSrc, "sqlite", driver) if err != nil { return err } err = m.Up() if err == migrate.ErrNoChange { return nil } return err } // WithTx executes the provided callback with access to a database transaction. // If the callback returns an error the transaction will be rolled back. func (c *DBClient) WithTx(fn func(tx *sql.Tx) error) error { tx, err := c.db.Begin() if err != nil { return err } err = fn(tx) if err != nil { _ = tx.Rollback() return err } return tx.Commit() } func (c *DBClient) WithSqlcTx(fn func(*sqlc.Queries) error) error { return c.WithTx( func(tx *sql.Tx) error { return fn(c.C.WithTx(tx)) }, ) } // DB returns the underlying database object. Avoid whenever possible and use // either sqlc (preferred) or sub-clients. func (c *DBClient) DB() *sql.DB { return c.db } func (c *DBClient) Close() error { return c.db.Close() }