package skykit
import (
"cmp"
"fmt"
"html/template"
"io"
"io/fs"
"log"
"net/http"
"os"
)
type Application struct {
Database *Database
Users *Authentication
views []fs.FS
funcs template.FuncMap
engine *template.Template
controllers map[string]Handler
}
func New(views fs.FS) *Application {
db := ConnectDB()
return &Application{
Database: db,
Users: NewAuthentication(db),
views: []fs.FS{views},
funcs: template.FuncMap{},
engine: template.New("application"),
controllers: map[string]Handler{},
}
}
func Serve(views fs.FS, opts ...Option) {
app := New(views)
for _, opt := range opts {
opt(app)
}
app.Start()
}
func (app *Application) Use(name string) Handler {
return app.controllers[name]
}
func (app *Application) Protect(fn http.HandlerFunc, accessCheck AccessCheck) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if accessCheck == nil {
fn(w, r)
return
}
if !accessCheck(app, w, r) {
return
}
fn(w, r)
}
}
func (app *Application) Start() error {
return http.ListenAndServe(app.Server())
}
func (app *Application) Server() (string, http.Handler) {
funcs := template.FuncMap{
"app": func() *Application { return app },
"req": func() *http.Request { return nil },
"db": func() *Database { return app.Database },
"CurrentUser": func() *User {
user, _ := app.Users.Authenticate(nil)
return user
},
}
for name, ctrl := range app.controllers {
funcs[name] = func() Handler { return ctrl }
}
for name, fn := range app.funcs {
funcs[name] = fn
}
if app.engine == nil {
app.engine = template.New("")
}
app.engine = app.engine.Funcs(funcs)
for _, source := range app.views {
if source == nil {
continue
}
if tmpl, err := app.engine.ParseFS(source, "views/*.html"); err == nil {
app.engine = tmpl
} else {
log.Fatal("Failed to parse root views", err)
}
if tmpl, err := app.engine.ParseFS(source, "views/**/*.html"); err == nil {
app.engine = tmpl
} else {
log.Print("Failed to parse views", err)
}
}
addr := "0.0.0.0:" + cmp.Or(os.Getenv("PORT"), "5000")
return addr, http.DefaultServeMux
}
func (app *Application) Render(w io.Writer, r *http.Request, page string, data any) {
funcs := template.FuncMap{
"app": func() *Application { return app },
"req": func() *http.Request { return r },
"db": func() *Database { return app.Database },
"CurrentUser": func() *User {
user, _ := app.Users.Authenticate(r)
return user
},
}
for name, ctrl := range app.controllers {
funcs[name] = func() Handler { return ctrl.Handle(r) }
}
for name, fn := range app.funcs {
funcs[name] = fn
}
view := app.engine.Lookup(page)
if view == nil {
log.Printf("Template not found: %s", page)
if rw, ok := w.(http.ResponseWriter); ok {
http.Error(rw, fmt.Sprintf("Template not found: %s", page), http.StatusNotFound)
return
} else {
fmt.Fprintf(w, "Template not found in non-HTTP context: %s", page)
os.Exit(1)
}
}
if err := view.Funcs(funcs).Execute(w, data); err != nil {
log.Print("Error rendering: ", err)
app.engine.ExecuteTemplate(w, "error-message.html", err)
}
}