124 lines
3.1 KiB
Go
124 lines
3.1 KiB
Go
package web
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/Jaculabilis/intake/core"
|
|
"github.com/Jaculabilis/intake/web/html"
|
|
)
|
|
|
|
var AuthCookieName string = "intake_auth"
|
|
var AuthDuration time.Duration = time.Hour * 24 * 7
|
|
|
|
func newSession(db core.DB) (string, error) {
|
|
bytes := make([]byte, 32)
|
|
_, err := rand.Read(bytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
session := fmt.Sprintf("%x", bytes)
|
|
expires := int(time.Now().Add(AuthDuration).Unix())
|
|
_, err = db.Exec(`
|
|
insert into sessions (id, expires)
|
|
values (?, ?)
|
|
`, session, expires)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
func checkSession(db core.DB, session string) (bool, error) {
|
|
row := db.QueryRow(`
|
|
select expires
|
|
from sessions
|
|
where id = ?
|
|
`, session)
|
|
var expires int
|
|
if err := row.Scan(&expires); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
expiration := time.Unix(int64(expires), 0).UTC()
|
|
if time.Now().After(expiration) {
|
|
return false, nil
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func renderLoginWithErrorMessage(writer http.ResponseWriter, req *http.Request, message string) {
|
|
// If an htmx interaction caused the auth error, refresh the page to get the login rendered
|
|
if req.Header.Get("HX-Request") != "" {
|
|
writer.Header()["HX-Refresh"] = []string{"true"}
|
|
writer.WriteHeader(http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
data := html.LoginData{Error: message}
|
|
if err := html.Login(writer, data); err != nil {
|
|
log.Printf("render error: %v", err)
|
|
}
|
|
}
|
|
|
|
func (env *Env) authed(handler http.HandlerFunc) http.HandlerFunc {
|
|
return func(writer http.ResponseWriter, req *http.Request) {
|
|
required, err := core.HasPassword(env.db)
|
|
if err != nil {
|
|
renderLoginWithErrorMessage(writer, req, fmt.Sprintf("error: failed to check for password: %v", err))
|
|
return
|
|
}
|
|
if required {
|
|
cookie, err := req.Cookie(AuthCookieName)
|
|
if errors.Is(err, http.ErrNoCookie) {
|
|
renderLoginWithErrorMessage(writer, req, "Your session is expired or invalid")
|
|
return
|
|
}
|
|
if valid, err := checkSession(env.db, cookie.Value); !valid || err != nil {
|
|
renderLoginWithErrorMessage(writer, req, "Your session is expired or invalid")
|
|
return
|
|
}
|
|
}
|
|
handler(writer, req)
|
|
}
|
|
}
|
|
|
|
func (env *Env) login(writer http.ResponseWriter, req *http.Request) {
|
|
if err := req.ParseForm(); err != nil {
|
|
http.Error(writer, fmt.Sprintf("error: failed to parse form: %v", err), http.StatusOK)
|
|
return
|
|
}
|
|
password := req.PostForm.Get("password")
|
|
|
|
pass, err := core.CheckPassword(env.db, password)
|
|
if err != nil {
|
|
http.Error(writer, fmt.Sprintf("error: failed to check password: %v", err), http.StatusOK)
|
|
return
|
|
}
|
|
if !pass {
|
|
http.Error(writer, "Incorrect password", http.StatusOK)
|
|
return
|
|
}
|
|
|
|
session, err := newSession(env.db)
|
|
if err != nil {
|
|
http.Error(writer, fmt.Sprintf("error: failed to start session: %v", err), http.StatusOK)
|
|
return
|
|
}
|
|
|
|
cookie := http.Cookie{
|
|
Name: AuthCookieName,
|
|
Value: session,
|
|
}
|
|
http.SetCookie(writer, &cookie)
|
|
writer.Header()["HX-Refresh"] = []string{"true"}
|
|
writer.WriteHeader(http.StatusNoContent)
|
|
}
|