Connor McCutcheon
/ Skykit
database.go
go
package skykit
import (
	"cmp"
	"database/sql"
	"fmt"
	"log"
	"os"
	"path/filepath"
	"reflect"
	"strings"
	"time"
	"github.com/pkg/errors"
	"github.com/tursodatabase/go-libsql"
)
type Database struct {
	*sql.DB
	Ents      []Entity
	Cols      map[string]*Collection[Entity]
	connector *libsql.Connector
}
// Sync forces a sync with the remote database (for embedded replicas)
func (db *Database) Sync() error {
	if db.connector != nil {
		_, err := db.connector.Sync()
		return err
	}
	return nil
}
func ConnectDB() *Database {
	var (
		name  = os.Getenv("DB_NAME")
		url   = os.Getenv("DB_URL")
		token = os.Getenv("DB_TOKEN")
	)
	if name == "" || url == "" || token == "" {
		db, err := sql.Open("libsql", ":memory:")
		if err != nil {
			log.Fatal("Failed to open in-memory database:", err)
		}
		db.SetMaxOpenConns(1)
		return &Database{DB: db, Ents: []Entity{}, Cols: map[string]*Collection[Entity]{}}
	}
	var (
		home, err = os.UserHomeDir()
		path      = filepath.Join(home, name)
	)
	if err != nil {
		log.Fatal("Failed to get user home dir:", err)
	}
	connector, err := libsql.NewEmbeddedReplicaConnector(path, url,
		libsql.WithSyncInterval(time.Second*30),
		libsql.WithAuthToken(token))
	if err != nil {
		log.Fatal("Failed to replicate to remote db:", err)
	}
	if info, err := os.Stat(path); err != nil || time.Since(info.ModTime()) > time.Hour {
		log.Println("Syncing database (missing or stale)...")
		if _, err := connector.Sync(); err != nil {
			log.Fatal("Failed to sync to remote db:", err)
		}
	} else {
		log.Println("Using cached database, syncing in background...")
		go connector.Sync()
	}
	return &Database{DB: sql.OpenDB(connector), Ents: []Entity{}, Cols: map[string]*Collection[Entity]{}, connector: connector}
}
func (db *Database) Model() Model {
	return Model{DB: db}
}
func (db *Database) NewModel(id string) Model {
	return Model{DB: db, ID: id, CreatedAt: time.Now(), UpdatedAt: time.Now()}
}
func (db *Database) Query(query string, args ...any) *Iter {
	return &Iter{Conn: db.DB, Text: query, Args: args}
}
func (db *Database) Register(table string, ent Entity) error {
	if err := db.Query(`
		CREATE TABLE IF NOT EXISTS ` + table + ` (
			ID        TEXT PRIMARY KEY,
			CreatedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
			UpdatedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP
		);
	`).Exec(); err != nil {
		return errors.Wrap(err, "failed to create table")
	}
	kind := reflect.ValueOf(ent).Kind()
	if kind != reflect.Ptr && kind != reflect.Struct {
		return errors.New("expected struct, got " + kind.String())
	}
	fields, types, defaults := db.Fields(ent)
	for i, field := range fields {
		db.Query(fmt.Sprintf(`
			ALTER TABLE %s ADD COLUMN %s %s DEFAULT %v
		`, table, field, types[i], defaults[i])).Exec()
	}
	db.Ents = append(db.Ents, ent)
	return nil
}
func (db *Database) Fields(ent Entity) (fields []string, types []string, defaults []string) {
	value := reflect.ValueOf(ent)
	if value.Kind() == reflect.Ptr {
		value = value.Elem()
	}
	if value.Kind() != reflect.Struct {
		return
	}
	type_ := value.Type()
	for i := range type_.NumField() {
		field := type_.Field(i)
		kind := field.Type.Kind()
		// Skip anonymous fields, pointers, interfaces, and functions
		if field.Anonymous || kind == reflect.Ptr || kind == reflect.Interface ||
			kind == reflect.Func {
			continue
		}
		// Special handling for time.Time struct
		isTimeField := field.Type.PkgPath() == "time" && field.Type.Name() == "Time"
		// Skip other structs that aren't time.Time
		if kind == reflect.Struct && !isTimeField {
			continue
		}
		fields = append(fields, field.Name)
		// Handle time.Time as TIMESTAMP
		if isTimeField {
			types = append(types, "TIMESTAMP")
			defaults = append(defaults, cmp.Or(field.Tag.Get("default"), "NULL"))
		} else {
			switch kind {
			case reflect.String:
				types = append(types, "TEXT")
				defaults = append(defaults, cmp.Or(field.Tag.Get("default"), "''"))
			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
				reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
				types = append(types, "INTEGER")
				defaults = append(defaults, cmp.Or(field.Tag.Get("default"), "0"))
			case reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
				types = append(types, "REAL")
				defaults = append(defaults, cmp.Or(field.Tag.Get("default"), "0"))
			case reflect.Bool:
				types = append(types, "BOOLEAN")
				defaults = append(defaults, cmp.Or(field.Tag.Get("default"), "FALSE"))
			default:
				types = append(types, "ANY")
				defaults = append(defaults, cmp.Or(field.Tag.Get("default"), "NULL"))
			}
		}
	}
	return
}
func (db *Database) Reflect(ent Entity) (fields []string, values []any, addrs []any) {
	value := reflect.ValueOf(ent)
	if value.Kind() == reflect.Ptr {
		value = value.Elem()
	}
	fields, _, _ = db.Fields(ent)
	fields = append([]string{"ID", "CreatedAt", "UpdatedAt"}, fields...)
	for _, field := range fields {
		if !value.FieldByName(field).IsValid() {
			continue
		}
		values = append(values, value.FieldByName(field).Interface())
		addrs = append(addrs, value.FieldByName(field).Addr().Interface())
	}
	return
}
func (db *Database) qualified(table string, fields []string) (res []string) {
	res = []string{}
	for _, field := range fields {
		res = append(res, fmt.Sprintf("%s.%s", table, field))
	}
	return
}
func (db *Database) entID(ent Entity) (id string) {
	value := reflect.ValueOf(ent)
	if value.Kind() == reflect.Ptr {
		value = value.Elem()
	}
	if value.Kind() != reflect.Struct {
		return
	}
	// Use FieldByName which finds fields in embedded structs too
	idField := value.FieldByName("ID")
	if idField.IsValid() {
		return idField.String()
	}
	return
}
func (db *Database) Insert(table string, ent Entity) error {
	fields, values, addrs := db.Reflect(ent)
	places := make([]string, len(fields))
	for i := range fields {
		places[i] = "?"
	}
	return db.Query(fmt.Sprintf(`
		INSERT INTO %[1]s (%[2]s)
		VALUES (%[3]s)
		RETURNING %[2]s
	`, table,
		strings.Join(fields, ", "),
		strings.Join(places, ", ")),
		values...).Scan(addrs...)
}
func (db *Database) Get(table, id string, ent Entity) error {
	fields, _, addrs := db.Reflect(ent)
	fields = db.qualified(table, fields)
	places := make([]string, len(fields))
	for i := range fields {
		places[i] = "?"
	}
	return db.Query(fmt.Sprintf(`
		SELECT %s
		FROM %s
		WHERE ID = ?
	`, strings.Join(fields, ", "), table), id).Scan(addrs...)
}
func (db *Database) Update(table string, ent Entity) error {
	var (
		entityID  any
		updatedAt any
		fields, values, addrs = db.Reflect(ent)
	)
	sets := make([]string, len(fields)+1)
	for i, field := range fields {
		switch field {
		case "ID":
			entityID = values[i]
		case "UpdatedAt":
			updatedAt = addrs[i]
		}
		sets[i] = fmt.Sprintf("%s = ?", field)
	}
	sets[len(fields)] = "UpdatedAt = CURRENT_TIMESTAMP"
	return db.Query(fmt.Sprintf(`
		UPDATE %s
		SET %s
		WHERE ID = ?
		RETURNING UpdatedAt
	`, table, strings.Join(sets, ", ")),
		append(values, entityID)...).Scan(updatedAt)
}
func (db *Database) Delete(table string, ent Entity) error {
	return db.Query(fmt.Sprintf(`
		DELETE FROM %s
		WHERE ID = ?
	`, table), db.entID(ent)).Exec()
}
// Index creates an index on the specified table and columns
// Example: db.Index("users", "email") creates idx_users_email
func (db *Database) Index(table string, columns ...string) error {
	if len(columns) == 0 {
		return nil
	}
	// Auto-generate index name from table and columns
	indexName := fmt.Sprintf("idx_%s_%s", table, strings.Join(columns, "_"))
	indexName = strings.ToLower(strings.ReplaceAll(indexName, " ", "_"))
	return db.Query(fmt.Sprintf(
		"CREATE INDEX IF NOT EXISTS %s ON %s(%s)",
		indexName, table, strings.Join(columns, ", "),
	)).Exec()
}
// UniqueIndex creates a unique index on the specified table and columns
// Example: db.UniqueIndex("users", "email") creates uniq_users_email
func (db *Database) UniqueIndex(table string, columns ...string) error {
	if len(columns) == 0 {
		return nil
	}
	// Auto-generate index name from table and columns
	indexName := fmt.Sprintf("uniq_%s_%s", table, strings.Join(columns, "_"))
	indexName = strings.ToLower(strings.ReplaceAll(indexName, " ", "_"))
	return db.Query(fmt.Sprintf(
		"CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s(%s)",
		indexName, table, strings.Join(columns, ", "),
	)).Exec()
}
func Cursor[E Entity](db *Database, ent E, table, query string, args ...any) *cursor[E] {
	typeOf := reflect.TypeOf(ent)
	return &cursor[E]{db, typeOf, ent, table, query, args}
}
type cursor[E Entity] struct {
	db     *Database
	typeOf reflect.Type
	entity E
	table  string
	query  string
	args   []any
}
func (c *cursor[E]) Iter(visit func(func(Entity) error) error) error {
	fields, _, _ := c.db.Reflect(c.entity)
	fields = c.db.qualified(c.table, fields)
	err := c.db.Query(
		fmt.Sprintf(`SELECT %s FROM %s %s`,
			strings.Join(fields, ", "), c.table, c.query,
		), c.args...).
		All(func(scan ScanFunc) error {
			return visit(func(ent Entity) error {
				_, _, attrs := c.db.Reflect(ent)
				return scan(attrs...)
			})
		})
	if errors.Is(err, sql.ErrNoRows) {
		return nil
	}
	return err
}
func (c *cursor[E]) One() (E, error) {
	ent := reflect.New(c.typeOf.Elem()).Interface().(E)
	ent.GetModel().SetDB(c.db)
	fields, _, attrs := c.db.Reflect(ent)
	return ent, c.db.Query(
		fmt.Sprintf(`SELECT %s FROM %s %s`,
			strings.Join(fields, ", "), c.table, c.query,
		), c.args...).
		Scan(attrs...)
}
No comments yet.