diff --git a/pkg/provider/handlers/function_list.go b/pkg/provider/handlers/function_list.go index 2ffe26fe..406b9385 100644 --- a/pkg/provider/handlers/function_list.go +++ b/pkg/provider/handlers/function_list.go @@ -4,6 +4,7 @@ import ( "context" "errors" "log" + "strings" "time" "github.com/containerd/containerd" @@ -51,7 +52,10 @@ func ListFunctions(client *containerd.Client, namespace string) (map[string]*Fun name := c.ID() f, err := GetFunction(client, name, namespace) if err != nil { - log.Printf("skipping %s, error: %s", name, err) + if !strings.Contains(err.Error(), "unable to get IP address for container") { + log.Printf("List functions, skipping: %s, error: %s", name, err) + } + } else { functions[name] = &f } diff --git a/pkg/service/service.go b/pkg/service/service.go index a6ad9713..9025a184 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -6,6 +6,7 @@ import ( "log" "os" "path/filepath" + "strings" "sync" "time" @@ -45,10 +46,24 @@ func Remove(ctx context.Context, client *containerd.Client, name string) error { log.Printf("Status of %s is: %s\n", name, status.Status) } - log.Printf("Need to kill task: %s\n", name) - if err = killTask(ctx, t); err != nil { + var gracePeriod = time.Second * 30 + spec, err := t.Spec(ctx) + if err == nil { + for _, p := range spec.Process.Env { + k, v, ok := strings.Cut(p, "=") + if ok && k == "grace_period" { + periodVal, err := time.ParseDuration(v) + if err == nil { + gracePeriod = periodVal + } + } + } + } + + if err = killTask(ctx, t, gracePeriod); err != nil { return fmt.Errorf("error killing task %s, %s, %w", container.ID(), name, err) } + } if err := container.Delete(ctx, containerd.WithSnapshotCleanup); err != nil { @@ -66,14 +81,13 @@ func Remove(ctx context.Context, client *containerd.Client, name string) error { } // Adapted from Stellar - https://github.com/stellar -func killTask(ctx context.Context, task containerd.Task) error { - - killTimeout := 30 * time.Second +func killTask(ctx context.Context, task containerd.Task, gracePeriod time.Duration) error { wg := &sync.WaitGroup{} wg.Add(1) var err error + waited := false go func() { defer wg.Done() if task != nil { @@ -89,18 +103,35 @@ func killTask(ctx context.Context, task containerd.Task) error { select { case <-wait: - task.Delete(ctx) + waited = true return - case <-time.After(killTimeout): + case <-time.After(gracePeriod): + log.Printf("Sending SIGKILL to: %s after: %s", task.ID(), gracePeriod.Round(time.Second).String()) if err := task.Kill(ctx, unix.SIGKILL, containerd.WithKillAll); err != nil { - log.Printf("error force killing container task: %s", err) + log.Printf("error sending SIGKILL to task: %s", err) } + return } } }() wg.Wait() + if task != nil { + if !waited { + wait, err := task.Wait(ctx) + if err != nil { + log.Printf("error waiting on task after kill: %s", err) + } + + <-wait + } + + if _, err := task.Delete(ctx); err != nil { + return err + } + } + return err }