2024-07-11 21:31:54 -07:00
|
|
|
package services
|
2024-07-11 21:09:15 -07:00
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"log/slog"
|
|
|
|
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
|
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
|
|
|
"github.com/golang-migrate/migrate/v4/source/file"
|
|
|
|
|
|
|
|
"git.grosinger.net/tgrosinger/saasitone/config"
|
2024-07-19 20:44:09 -07:00
|
|
|
"git.grosinger.net/tgrosinger/saasitone/pkg/models"
|
2024-07-11 21:09:15 -07:00
|
|
|
"git.grosinger.net/tgrosinger/saasitone/pkg/models/sqlc"
|
|
|
|
)
|
|
|
|
|
2024-07-11 21:31:54 -07:00
|
|
|
type DBClient struct {
|
2024-07-11 21:09:15 -07:00
|
|
|
db *sql.DB
|
|
|
|
C *sqlc.Queries
|
|
|
|
|
2024-07-19 20:44:09 -07:00
|
|
|
User *models.DBUserClient
|
|
|
|
Post *models.DBPostClient
|
2024-07-11 21:09:15 -07:00
|
|
|
}
|
|
|
|
|
2024-07-11 21:31:54 -07:00
|
|
|
func NewDBClient(cfg *config.Config) (*DBClient, error) {
|
2024-07-11 21:09:15 -07:00
|
|
|
logger := slog.Default()
|
|
|
|
|
|
|
|
dbFilepath := cfg.Storage.DatabaseFile
|
|
|
|
if cfg.App.Environment == config.EnvTest {
|
|
|
|
// In memory only
|
|
|
|
dbFilepath = ":memory:"
|
|
|
|
}
|
|
|
|
|
|
|
|
logger.Info("Opening database file",
|
|
|
|
"filepath", dbFilepath)
|
|
|
|
|
|
|
|
fn := fmt.Sprintf("file:%s?_fk=1&_journal=WAL&cache=shared&_busy_timeout=5000", dbFilepath)
|
|
|
|
db, err := sql.Open("sqlite3", fn)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2024-07-11 21:31:54 -07:00
|
|
|
client := DBClient{
|
2024-07-11 21:09:15 -07:00
|
|
|
db: db,
|
|
|
|
C: sqlc.New(db),
|
|
|
|
}
|
2024-07-19 20:44:09 -07:00
|
|
|
client.User = &models.DBUserClient{DB: db}
|
|
|
|
client.Post = &models.DBPostClient{DB: db}
|
2024-07-11 21:09:15 -07:00
|
|
|
|
|
|
|
migrationsDirPath := cfg.Storage.MigrationsDir
|
|
|
|
logger.Info("Loading schema migrations",
|
|
|
|
"filepath", migrationsDirPath)
|
|
|
|
err = client.initSchema(migrationsDirPath)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &client, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// initSchema ensures that the database is current with the migrations contained
|
|
|
|
// in db/migrations.
|
2024-07-11 21:31:54 -07:00
|
|
|
func (c *DBClient) initSchema(migrationsDir string) error {
|
2024-07-11 21:09:15 -07:00
|
|
|
driver, err := sqlite3.WithInstance(c.db, &sqlite3.Config{})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
fSrc, err := (&file.File{}).Open(migrationsDir)
|
|
|
|
if err != nil {
|
|
|
|
fmt.Println("Got here 2: " + migrationsDir)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
m, err := migrate.NewWithInstance("file", fSrc, "sqlite", driver)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
err = m.Up()
|
|
|
|
if err == migrate.ErrNoChange {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// WithTx executes the provided callback with access to a database transaction.
|
|
|
|
// If the callback returns an error the transaction will be rolled back.
|
2024-07-11 21:31:54 -07:00
|
|
|
func (c *DBClient) WithTx(fn func(tx *sql.Tx) error) error {
|
2024-07-11 21:09:15 -07:00
|
|
|
tx, err := c.db.Begin()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
err = fn(tx)
|
|
|
|
if err != nil {
|
|
|
|
_ = tx.Rollback()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return tx.Commit()
|
|
|
|
}
|
|
|
|
|
2024-07-11 21:31:54 -07:00
|
|
|
func (c *DBClient) WithSqlcTx(fn func(*sqlc.Queries) error) error {
|
2024-07-11 21:09:15 -07:00
|
|
|
return c.WithTx(
|
|
|
|
func(tx *sql.Tx) error {
|
|
|
|
return fn(c.C.WithTx(tx))
|
|
|
|
},
|
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
// DB returns the underlying database object. Avoid whenever possible and use
|
|
|
|
// either sqlc (preferred) or sub-clients.
|
2024-07-11 21:31:54 -07:00
|
|
|
func (c *DBClient) DB() *sql.DB {
|
2024-07-11 21:09:15 -07:00
|
|
|
return c.db
|
|
|
|
}
|
|
|
|
|
2024-07-11 21:31:54 -07:00
|
|
|
func (c *DBClient) Close() error {
|
2024-07-11 21:09:15 -07:00
|
|
|
return c.db.Close()
|
|
|
|
}
|