package core

import (
	"database/sql"
	"errors"
	"fmt"
	"log"
	"time"

	_ "github.com/mattn/go-sqlite3"
)

func AddSource(db DB, name string) error {
	_, err := db.Exec(`
		insert into sources (name, lastUpdated)
		values (?, 0)
	`, name)

	return err
}

func GetSources(db DB) ([]string, error) {
	rows, err := db.Query(`
		select name
		from sources
	`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var names []string
	for rows.Next() {
		var name string
		if err = rows.Scan(&name); err != nil {
			return nil, err
		}
		names = append(names, name)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}
	return names, nil
}

func SourceExists(db DB, source string) (bool, error) {
	row := db.QueryRow("select count(*) from sources where name = ?", source)
	var c int
	err := row.Scan(&c)
	return c > 0, err
}

func DeleteSource(db DB, name string) error {
	_, err := db.Exec(`
		delete from sources
		where name = ?
	`, name)

	return err
}

func GetState(db DB, source string) ([]byte, error) {
	row := db.QueryRow("select state from sources where name = ?", source)
	var state []byte
	err := row.Scan(&state)
	return state, err
}

func SetState(db DB, source string, state []byte) error {
	_, err := db.Exec("update sources set state = ? where name = ?", state, source)
	return err
}

func GetLastUpdated(db DB, source string) (time.Time, error) {
	row := db.QueryRow("select lastUpdated from sources where name = ?", source)
	var updated int
	err := row.Scan(&updated)
	return time.Unix(int64(updated), 0).UTC(), err
}

func BumpLastUpdated(db DB, source string, now time.Time) error {
	_, err := db.Exec(`
		update sources
		set lastUpdated = ?
		where name = ?
	`, now.Unix(), source)
	return err
}

func getSourceTtx(db DB, source string, env string) (int, error) {
	row := db.QueryRow(`
		select value
		from envs
		where source = ?
		and name = ?
	`, source, env)
	var ttx int
	if err := row.Scan(&ttx); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return 0, nil
		}
		return 0, err
	}
	return ttx, nil
}

func getSourceBatcher(db DB, source string) (func(createdTime time.Time) (tts int), error) {
	row := db.QueryRow(`
		select value
		from envs
		where source = ?
		and name = 'INTAKE_BATCH'
	`, source)
	var batchSpec string
	if err := row.Scan(&batchSpec); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return nil, nil
		}
		return nil, fmt.Errorf("failed to get batch spec: %v", err)
	}

	var hour, minute int
	if _, err := fmt.Sscanf(batchSpec, "%d:%d", &hour, &minute); err != nil {
		return nil, fmt.Errorf("failed to parse batch spec: %v", err)
	}

	return func(createdTime time.Time) (tts int) {
		// There is a trivial inaccuracy in using time.Now() instead of item.Created, but
		// the latter isn't populated for new items when postprocessing occurs.
		// Since postprocessing is only applied to new items post-fetch, this is close enough.
		// Ideally: createdTime := time.Unix(int64(created), 0).UTC()
		batchCutoff := time.Date(createdTime.Year(), createdTime.Month(), createdTime.Day(), hour, minute, 0, 0, time.UTC)
		if createdTime.After(batchCutoff) {
			batchCutoff = batchCutoff.AddDate(0, 0, 1)
		}
		tts = int(batchCutoff.Sub(createdTime).Seconds())
		return
	}, nil
}

func GetSourcePostProcessor(db DB, source string) (func(item Item, now time.Time) Item, error) {
	ttl, err := getSourceTtx(db, source, "INTAKE_TTL")
	if err != nil {
		return nil, err
	}
	ttd, err := getSourceTtx(db, source, "INTAKE_TTD")
	if err != nil {
		return nil, err
	}
	tts, err := getSourceTtx(db, source, "INTAKE_TTS")
	if err != nil {
		return nil, err
	}
	batchTts, err := getSourceBatcher(db, source)
	if err != nil {
		return nil, err
	}
	return func(item Item, now time.Time) Item {
		if ttl != 0 {
			item.Ttl = ttl
		}
		if ttd != 0 {
			item.Ttd = ttd
		}
		if batchTts != nil {
			item.Tts = batchTts(now)
		} else if tts != 0 {
			item.Tts = tts
		}
		return item
	}, nil
}

func GetSourceActionInputs(
	db DB,
	source string,
	action string,
) (
	state []byte,
	envs []string,
	argv []string,
	err error,
) {
	state, err = GetState(db, source)
	if err != nil {
		return nil, nil, nil, fmt.Errorf("failed to load state for %s: %v", source, err)
	}

	envs, err = GetEnvs(db, source)
	if err != nil {
		return nil, nil, nil, fmt.Errorf("failed to get envs for %s: %v", source, err)
	}

	argv, err = GetArgvForAction(db, source, action)
	if err != nil {
		return nil, nil, nil, fmt.Errorf("failed to get %s action for %s: %v", action, source, err)
	}
	return
}

// Given the results of a fetch, add new items, update existing items, and delete expired items.
//
// Returns the number of new and deleted items on success.
func UpdateWithFetchedItems(
	db DB,
	source string,
	state []byte,
	items []Item,
	now time.Time,
) (int, int, error) {
	var new int
	var del int
	var err error
	err = db.Transact(func(tx DB) error {
		new, del, err = updateWithFetchedItemsTx(tx, source, state, items, now)
		return err
	})
	return new, del, err
}

// Implementation logic for [UpdateWithFetchedItems], which executes this inside a transaction.
func updateWithFetchedItemsTx(
	db DB,
	source string,
	state []byte,
	items []Item,
	now time.Time,
) (int, int, error) {
	envs, err := GetEnvs(db, source)
	if err != nil {
		return 0, 0, fmt.Errorf("failed to get envs for %s: %v", source, err)
	}

	postProcess, err := GetSourcePostProcessor(db, source)
	if err != nil {
		return 0, 0, fmt.Errorf("failed to get post-processor for %s: %v", source, err)
	}

	onCreateArgv, err := GetArgvForAction(db, source, "on_create")
	if err != nil {
		log.Printf("error: failed to get on_create action for %s: %v", source, err)
	}

	// Get all existing items
	existingItems, err := GetAllItemsForSource(db, source, 0, -1)
	if err != nil {
		return 0, 0, err
	}
	existingIds := map[string]bool{}
	existingItemsById := map[string]*Item{}
	for _, item := range existingItems {
		existingIds[item.Id] = true
		existingItemsById[item.Id] = &item
	}

	// Split the fetch into adds and updates
	var newItems []Item
	var updatedItems []Item
	for _, item := range items {
		if existingIds[item.Id] {
			updatedItems = append(updatedItems, item)
		} else {
			newItems = append(newItems, item)
		}
	}

	// Apply post-processing to the new items
	if postProcess != nil {
		for i := range newItems {
			newItems[i] = postProcess(newItems[i], now)
		}
	}

	// Bulk insert the new items
	if err = AddItems(db, newItems); err != nil {
		return 0, 0, err
	}

	// Bulk update the existing items
	for i := range updatedItems {
		BackfillItem(&updatedItems[i], existingItemsById[updatedItems[i].Id])
	}
	if err = UpdateItems(db, updatedItems); err != nil {
		return 0, 0, err
	}

	// If the source has an on-create trigger, run it for each new item
	// On-create errors are ignored to avoid failing the fetch
	if len(onCreateArgv) > 0 {
		var updatedNewItems []Item
		for _, item := range newItems {
			var updatedItem Item
			var errItem Item
			updatedItem, state, errItem, err = ExecuteItemAction(item, onCreateArgv, envs, state, DefaultTimeout)
			if err != nil {
				AddErrorItem(db, errItem)
				log.Printf("error: on_create failed for %s/%s: %v", item.Source, item.Id, err)
			}
			updatedNewItems = append(updatedNewItems, updatedItem)
		}
		UpdateItems(db, updatedNewItems)
	}

	// Get the list of expired items
	fetchedIds := map[string]bool{}
	for _, item := range items {
		fetchedIds[item.Id] = true
	}
	expiredIds := map[string]bool{}
	for id := range existingIds {
		expiredIds[id] = !fetchedIds[id]
	}

	// Check expired items for deletion
	idsToDelete := map[string]bool{}
	for _, item := range existingItems {
		if expiredIds[item.Id] && item.Deletable(now) {
			idsToDelete[item.Id] = true
		}
	}

	// Delete each item to be deleted
	for id := range idsToDelete {
		if _, err = DeleteItem(db, source, id); err != nil {
			return 0, 0, err
		}
	}

	if err = SetState(db, source, state); err != nil {
		return 0, 0, err
	}

	if err = BumpLastUpdated(db, source, now); err != nil {
		return 0, 0, err
	}

	return len(newItems), len(idsToDelete), nil
}