diff --git a/cmd/actionExecute.go b/cmd/actionExecute.go index 565cd81..793018b 100644 --- a/cmd/actionExecute.go +++ b/cmd/actionExecute.go @@ -91,7 +91,7 @@ func actionExecute( db := openAndMigrateDb() - state, envs, argv, postProcess, err := core.GetSourceActionInputs(db, source, action) + state, envs, argv, err := core.GetSourceActionInputs(db, source, action) if err != nil { log.Fatalf("error: failed to load data for %s: %v", source, err) } @@ -109,7 +109,7 @@ func actionExecute( } } - newItem, newState, errItem, err := core.ExecuteItemAction(item, argv, envs, state, duration, postProcess) + newItem, newState, errItem, err := core.ExecuteItemAction(item, argv, envs, state, duration) if err != nil { core.AddErrorItem(db, errItem) log.Fatalf("error executing %s: %v", action, err) diff --git a/cmd/sourceFetch.go b/cmd/sourceFetch.go index 6ef6bba..e59dbf5 100644 --- a/cmd/sourceFetch.go +++ b/cmd/sourceFetch.go @@ -55,12 +55,12 @@ func sourceFetch(source string, format string, timeout string, dryRun bool) { db := openAndMigrateDb() - state, envs, argv, postProcess, err := core.GetSourceActionInputs(db, source, "fetch") + state, envs, argv, err := core.GetSourceActionInputs(db, source, "fetch") if err != nil { log.Fatalf("error: failed to load data for %s: %v", source, err) } - items, newState, errItem, err := core.Execute(source, argv, envs, state, "", duration, postProcess) + items, newState, errItem, err := core.Execute(source, argv, envs, state, "", duration) if err != nil { core.AddErrorItem(db, errItem) log.Fatalf("error: failed to execute fetch: %v", err) diff --git a/cmd/sourceTest.go b/cmd/sourceTest.go index 5bf5d01..e356e9e 100644 --- a/cmd/sourceTest.go +++ b/cmd/sourceTest.go @@ -45,7 +45,7 @@ func sourceTest(env []string, format string, timeout string, cmd []string) { log.Fatalf("error: invalid duration: %v", err) } - items, state, _, err := core.Execute("test", cmd, env, nil, "", duration, nil) + items, state, _, err := core.Execute("test", cmd, env, nil, "", duration) log.Printf("returned %d items", len(items)) log.Printf("wrote %d bytes of state", len(state)) if err != nil { diff --git a/core/cron.go b/core/cron.go index 3c4d007..6904d81 100644 --- a/core/cron.go +++ b/core/cron.go @@ -203,13 +203,13 @@ func fetchReadySources(db DB) { } log.Printf("%s: fetching", schedule.Source) - state, envs, argv, postProcess, err := GetSourceActionInputs(db, schedule.Source, "fetch") + state, envs, argv, err := GetSourceActionInputs(db, schedule.Source, "fetch") if err != nil { log.Printf("error: failed to load data for %s: %v", schedule.Source, err) continue } - items, newState, errItem, err := Execute(schedule.Source, argv, envs, state, "", DefaultTimeout, postProcess) + items, newState, errItem, err := Execute(schedule.Source, argv, envs, state, "", DefaultTimeout) if err != nil { AddErrorItem(db, errItem) log.Printf("error: failed to execute fetch: %v", err) diff --git a/core/execute.go b/core/execute.go index 7c45dc1..3b3a7a4 100644 --- a/core/execute.go +++ b/core/execute.go @@ -62,7 +62,6 @@ func Execute( state []byte, input string, timeout time.Duration, - postProcess func(item Item) Item, ) ( items []Item, newState []byte, @@ -193,9 +192,6 @@ monitor: } itemIds[item.Id] = true - if postProcess != nil { - item = postProcess(item) - } item.Active = true // These fields aren't up to item.Created = 0 // the action to set and item.Source = source // shouldn't be overrideable @@ -269,7 +265,6 @@ func ExecuteItemAction( env []string, state []byte, timeout time.Duration, - postProcess func(item Item) Item, ) ( newItem Item, newState []byte, @@ -285,7 +280,7 @@ func ExecuteItemAction( return } - res, newState, errItem, err := Execute(item.Source, argv, env, state, string(itemJson), timeout, postProcess) + res, newState, errItem, err := Execute(item.Source, argv, env, state, string(itemJson), timeout) if err != nil { err = fmt.Errorf("failed to execute action for %s/%s: %v", item.Source, item.Id, err) errItem = makeErrorItem(err, nil) diff --git a/core/execute_test.go b/core/execute_test.go index 0078d26..a8761e4 100644 --- a/core/execute_test.go +++ b/core/execute_test.go @@ -26,7 +26,7 @@ func TestExecute(t *testing.T) { } } execute := func(argv []string) ([]Item, error) { - item, _, _, err := Execute("_", argv, nil, nil, "", time.Minute, nil) + item, _, _, err := Execute("_", argv, nil, nil, "", time.Minute) return item, err } @@ -49,7 +49,7 @@ func TestExecute(t *testing.T) { }) t.Run("Timeout", func(t *testing.T) { - res, _, _, err := Execute("_", []string{"sleep", "10"}, nil, nil, "", time.Millisecond, nil) + res, _, _, err := Execute("_", []string{"sleep", "10"}, nil, nil, "", time.Millisecond) assertNotNil(t, err) assertLen(t, res, 0) }) @@ -64,7 +64,7 @@ func TestExecute(t *testing.T) { }) t.Run("ReadFromStdin", func(t *testing.T) { - res, _, _, err := Execute("_", []string{"jq", "-cR", `{id: .}`}, nil, nil, "bar", time.Minute, nil) + res, _, _, err := Execute("_", []string{"jq", "-cR", `{id: .}`}, nil, nil, "bar", time.Minute) assertNil(t, err) assertLen(t, res, 1) if res[0].Id != "bar" { @@ -73,7 +73,7 @@ func TestExecute(t *testing.T) { }) t.Run("SetEnv", func(t *testing.T) { - res, _, _, err := Execute("_", []string{"jq", "-cn", `{id: env.HELLO}`}, []string{"HELLO=baz"}, nil, "", time.Minute, nil) + res, _, _, err := Execute("_", []string{"jq", "-cn", `{id: env.HELLO}`}, []string{"HELLO=baz"}, nil, "", time.Minute) assertNil(t, err) assertLen(t, res, 1) if res[0].Id != "baz" { @@ -160,7 +160,7 @@ func TestExecute(t *testing.T) { t.Run("ReadState", func(t *testing.T) { argv := []string{"sh", "-c", `cat $STATE_PATH | jq -cR '{id: "greeting", title: .} | .title = "Hello " + .title'`} - res, _, _, err := Execute("_", argv, nil, []byte("world"), "", time.Minute, nil) + res, _, _, err := Execute("_", argv, nil, []byte("world"), "", time.Minute) assertNil(t, err) assertLen(t, res, 1) if res[0].Title != "Hello world" { @@ -170,7 +170,7 @@ func TestExecute(t *testing.T) { t.Run("WriteState", func(t *testing.T) { argv := []string{"sh", "-c", `printf "Hello world" > $STATE_PATH; jq -cn '{id: "test"}'`} - res, newState, _, err := Execute("_", argv, nil, nil, "", time.Minute, nil) + res, newState, _, err := Execute("_", argv, nil, nil, "", time.Minute) assertNil(t, err) assertLen(t, res, 1) if string(newState) != "Hello world" { @@ -178,22 +178,9 @@ func TestExecute(t *testing.T) { } }) - t.Run("PostprocessSetTtl", func(t *testing.T) { - argv := []string{"jq", "-cn", `{id: "foo"}`} - res, _, _, err := Execute("_", argv, nil, nil, "", time.Minute, func(item Item) Item { - item.Ttl = 123456 - return item - }) - assertNil(t, err) - assertLen(t, res, 1) - if res[0].Ttl != 123456 { - t.Fatalf("expected ttl to be set to 123456, got %d", res[0].Ttl) - } - }) - t.Run("ErrorItem", func(t *testing.T) { argv := []string{"sh", "-c", `echo 1>&2 Hello; jq -cn '{id: "box"}'; echo 1>&2 World; printf '{"whoops": "my bad"}'`} - _, _, errItem, err := Execute("test", argv, nil, nil, "", time.Minute, nil) + _, _, errItem, err := Execute("test", argv, nil, nil, "", time.Minute) assertNotNil(t, err) if errItem.Id == "" { t.Error("missing erritem id") diff --git a/core/source.go b/core/source.go index 31bc8e8..907373e 100644 --- a/core/source.go +++ b/core/source.go @@ -138,27 +138,21 @@ func GetSourceActionInputs( state []byte, envs []string, argv []string, - postProcess func(Item) Item, err error, ) { state, err = GetState(db, source) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to load state for %s: %v", source, err) + return nil, nil, nil, fmt.Errorf("failed to load state for %s: %v", source, err) } envs, err = GetEnvs(db, source) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to get envs for %s: %v", source, err) + return nil, nil, nil, fmt.Errorf("failed to get envs for %s: %v", source, err) } argv, err = GetArgvForAction(db, source, action) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to get %s action for %s: %v", action, source, err) - } - - postProcess, err = GetSourcePostProcessor(db, source) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to get %s post-processor: %v", source, err) + return nil, nil, nil, fmt.Errorf("failed to get %s action for %s: %v", action, source, err) } return } @@ -191,6 +185,21 @@ func updateWithFetchedItemsTx( items []Item, now time.Time, ) (int, int, error) { + envs, err := GetEnvs(db, source) + if err != nil { + return 0, 0, fmt.Errorf("failed to get envs for %s: %v", source, err) + } + + postProcess, err := GetSourcePostProcessor(db, source) + if err != nil { + return 0, 0, fmt.Errorf("failed to get post-processor for %s: %v", source, err) + } + + onCreateArgv, err := GetArgvForAction(db, source, "on_create") + if err != nil { + log.Printf("error: failed to get on_create action for %s: %v", source, err) + } + // Get all existing items existingItems, err := GetAllItemsForSource(db, source, 0, -1) if err != nil { @@ -214,6 +223,13 @@ func updateWithFetchedItemsTx( } } + // Apply post-processing to the new items + if postProcess != nil { + for i := range newItems { + newItems[i] = postProcess(newItems[i]) + } + } + // Bulk insert the new items if err = AddItems(db, newItems); err != nil { return 0, 0, err @@ -227,25 +243,14 @@ func updateWithFetchedItemsTx( return 0, 0, err } - envs, err := GetEnvs(db, source) - if err != nil { - return 0, 0, fmt.Errorf("failed to get envs for %s: %v", source, err) - } - - postProcess, err := GetSourcePostProcessor(db, source) - if err != nil { - return 0, 0, fmt.Errorf("failed to get post-processor for %s: %v", source, err) - } - // If the source has an on-create trigger, run it for each new item // On-create errors are ignored to avoid failing the fetch - onCreateArgv, err := GetArgvForAction(db, source, "on_create") - if err == nil && len(onCreateArgv) > 0 { + if len(onCreateArgv) > 0 { var updatedNewItems []Item for _, item := range newItems { var updatedItem Item var errItem Item - updatedItem, state, errItem, err = ExecuteItemAction(item, onCreateArgv, envs, state, DefaultTimeout, postProcess) + updatedItem, state, errItem, err = ExecuteItemAction(item, onCreateArgv, envs, state, DefaultTimeout) if err != nil { AddErrorItem(db, errItem) log.Printf("error: on_create failed for %s/%s: %v", item.Source, item.Id, err) diff --git a/core/source_test.go b/core/source_test.go index b39c76e..f395f23 100644 --- a/core/source_test.go +++ b/core/source_test.go @@ -147,7 +147,7 @@ func TestOnCreateAction(t *testing.T) { execute := func(argv []string) []Item { t.Helper() - items, _, _, err := Execute("test", argv, nil, nil, "", time.Minute, nil) + items, _, _, err := Execute("test", argv, nil, nil, "", time.Minute) if err != nil { t.Fatalf("unexpected error executing test fetch: %v", err) } @@ -300,7 +300,7 @@ func TestSourceState(t *testing.T) { } } -func TestSourceTtx(t *testing.T) { +func TestSourcePostProcessor(t *testing.T) { db := EphemeralDb(t) if err := AddSource(db, "s"); err != nil { t.Fatal(err) @@ -323,6 +323,34 @@ func TestSourceTtx(t *testing.T) { } } +func TestSourceUpdateAppliesPostProcess(t *testing.T) { + db := EphemeralDb(t) + if err := AddSource(db, "s"); err != nil { + t.Fatal(err) + } + if err := SetEnvs(db, "s", []string{ + "INTAKE_TTL=30", + "INTAKE_TTD=60", + "INTAKE_TTS=90", + }); err != nil { + t.Fatal(err) + } + + item := Item{Source: "s", Id: "i"} + add, del, err := UpdateWithFetchedItems(db, "s", nil, []Item{item}, time.Now()) + if add != 1 || del != 0 || err != nil { + t.Fatalf("expected 1 add, got %d and err %v", add, err) + } + + after, err := GetItem(db, "s", "i") + if err != nil { + t.Fatalf("item not added: %v", err) + } + if after.Ttl != 30 || after.Ttd != 60 || after.Tts != 90 { + t.Fatalf("Missing value after postProcess: ttl = %d, ttd = %d, tts = %d", after.Ttl, after.Ttd, after.Tts) + } +} + func TestSourceLastUpdated(t *testing.T) { db := EphemeralDb(t) if err := AddSource(db, "s"); err != nil { diff --git a/web/item.go b/web/item.go index 4f26136..887eef5 100644 --- a/web/item.go +++ b/web/item.go @@ -69,7 +69,7 @@ func (env *Env) doAction(writer http.ResponseWriter, req *http.Request) { id := req.PathValue("id") action := req.PathValue("action") - state, envs, argv, postProcess, err := core.GetSourceActionInputs(env.db, source, action) + state, envs, argv, err := core.GetSourceActionInputs(env.db, source, action) if err != nil { http.Error(writer, fmt.Sprintf("error: failed to load data for %s: %v", source, err), 500) } @@ -85,7 +85,7 @@ func (env *Env) doAction(writer http.ResponseWriter, req *http.Request) { return } - newItem, newState, errItem, err := core.ExecuteItemAction(item, argv, envs, state, core.DefaultTimeout, postProcess) + newItem, newState, errItem, err := core.ExecuteItemAction(item, argv, envs, state, core.DefaultTimeout) if err != nil { core.AddErrorItem(env.db, errItem) http.Error(writer, err.Error(), 500) diff --git a/web/source.go b/web/source.go index 8fbfc00..00c229d 100644 --- a/web/source.go +++ b/web/source.go @@ -56,13 +56,13 @@ func (env *Env) fetchSource(writer http.ResponseWriter, req *http.Request) { return } - state, envs, argv, postProcess, err := core.GetSourceActionInputs(env.db, source, "fetch") + state, envs, argv, err := core.GetSourceActionInputs(env.db, source, "fetch") if err != nil { http.Error(writer, fmt.Sprintf("error: failed to get data for %s: %v", source, err.Error()), 500) return } - items, newState, errItem, err := core.Execute(source, argv, envs, state, "", core.DefaultTimeout, postProcess) + items, newState, errItem, err := core.Execute(source, argv, envs, state, "", core.DefaultTimeout) if err != nil { core.AddErrorItem(env.db, errItem) http.Error(writer, fmt.Sprintf("error: failed to execute fetch: %v", err.Error()), 500)