diff --git a/client/build_test.go b/client/build_test.go index 1cd04222f..463088ec0 100644 --- a/client/build_test.go +++ b/client/build_test.go @@ -40,6 +40,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tonistiigi/fsutil" "golang.org/x/crypto/ssh/agent" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -54,6 +55,7 @@ func TestClientGatewayIntegration(t *testing.T) { testClientGatewayContainerExecPipe, testClientGatewayContainerExecPipeRelease, testClientGatewayContainerExecPipeSignalKill, + testClientGatewayContainerExecStdinCloseRace, testClientGatewayContainerCancelOnRelease, testClientGatewayContainerPID1Fail, testClientGatewayContainerPID1Exit, @@ -548,6 +550,137 @@ func testClientGatewayContainerExecPipeWithCleanup(t *testing.T, sb integration. checkAllReleasable(t, c, sb, true) } +// testClientGatewayContainerExecStdinCloseRace starts a gateway container and +// runs many concurrent execs that copy stdin into files. It verifies that every +// exec receives stdin before EOF and uses the timeout only to catch stuck execs. +func testClientGatewayContainerExecStdinCloseRace(t *testing.T, sb integration.Sandbox) { + requiresLinux(t) + + ctx, cancel := context.WithTimeoutCause(sb.Context(), 30*time.Second, errors.WithStack(context.DeadlineExceeded)) + defer cancel() + + c, err := New(ctx, sb.Address()) + require.NoError(t, err) + defer c.Close() + + product := "buildkit_test" + + b := func(ctx context.Context, c client.Client) (_ *client.Result, retErr error) { + st := llb.Image("busybox:latest") + + def, err := st.Marshal(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal state") + } + + r, err := c.Solve(ctx, client.SolveRequest{ + Definition: def.ToPB(), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to solve") + } + + ctr, err := c.NewContainer(ctx, client.NewContainerRequest{ + Mounts: []client.Mount{{ + Dest: "/", + MountType: pb.MountType_BIND, + Ref: r.Ref, + }}, + }) + if err != nil { + return nil, err + } + defer func() { + if err := ctr.Release(context.WithoutCancel(ctx)); retErr == nil && err != nil { + retErr = errors.WithStack(err) + } + }() + + pid1StdinR, pid1StdinW := io.Pipe() + pid1, err := ctr.Start(ctx, client.StartRequest{ + Args: []string{"cat"}, + Stdin: pid1StdinR, + }) + if err != nil { + return nil, err + } + defer pid1StdinR.Close() + defer pid1StdinW.Close() + + const iterations = 600 + expected := bytes.NewBuffer(nil) + eg, execCtx := errgroup.WithContext(ctx) + eg.SetLimit(16) + for i := range iterations { + msg := []byte("hello-" + strconv.Itoa(i)) + expected.Write(msg) + expected.WriteByte('\n') + + eg.Go(func() error { + stdout := bytes.NewBuffer(nil) + stderr := bytes.NewBuffer(nil) + pid, err := ctr.Start(execCtx, client.StartRequest{ + Args: []string{"/bin/sh", "-c", "cat > /tmp/msg.$MSG_ID"}, + Env: []string{"MSG_ID=" + strconv.Itoa(i)}, + Stdin: io.NopCloser(bytes.NewReader(msg)), + Stdout: &iohelper.NopWriteCloser{Writer: stdout}, + Stderr: &iohelper.NopWriteCloser{Writer: stderr}, + }) + if err != nil { + return errors.Wrapf(err, "exec %d", i) + } + if err := pid.Wait(); err != nil { + return errors.Wrapf(err, "exec %d", i) + } + if stdout.Len() != 0 { + return errors.Errorf("exec %d stdout: %q", i, stdout.String()) + } + if stderr.Len() != 0 { + return errors.Errorf("exec %d stderr: %q", i, stderr.String()) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, err + } + + stdout := bytes.NewBuffer(nil) + stderr := bytes.NewBuffer(nil) + pid, err := ctr.Start(ctx, client.StartRequest{ + Args: []string{"/bin/sh", "-c", "i=0; while [ $i -lt $MSG_COUNT ]; do cat /tmp/msg.$i; echo; i=$((i+1)); done"}, + Env: []string{"MSG_COUNT=" + strconv.Itoa(iterations)}, + Stdout: &iohelper.NopWriteCloser{Writer: stdout}, + Stderr: &iohelper.NopWriteCloser{Writer: stderr}, + }) + if err != nil { + return nil, err + } + if err := pid.Wait(); err != nil { + return nil, err + } + if stdout.String() != expected.String() { + return nil, errors.Errorf("unexpected stdout: %q", stdout.String()) + } + if stderr.Len() != 0 { + return nil, errors.Errorf("stderr: %q", stderr.String()) + } + + if err := pid1StdinW.Close(); err != nil { + return nil, errors.WithStack(err) + } + if err := pid1.Wait(); err != nil { + return nil, err + } + + return &client.Result{}, nil + } + + _, err = c.Build(ctx, SolveOpt{}, product, b, nil) + require.NoError(t, err) + checkAllReleasable(t, c, sb, true) +} + // testClientGatewayContainerPID1Fail is testing clean shutdown and release // of resources when the primary pid1 exits with non-zero exit status func testClientGatewayContainerPID1Fail(t *testing.T, sb integration.Sandbox) { diff --git a/executor/containerdexecutor/executor.go b/executor/containerdexecutor/executor.go index d87a31126..3c66eb636 100644 --- a/executor/containerdexecutor/executor.go +++ b/executor/containerdexecutor/executor.go @@ -225,6 +225,7 @@ func (w *containerdExecutor) Run(ctx context.Context, id string, root executor.M } }() + stdinDone := trackStdinEOF(&process) fixProcessOutput(&process) cioOpts := []cio.Opt{cio.WithStreams(process.Stdin, process.Stdout, process.Stderr)} if meta.Tty { @@ -256,7 +257,7 @@ func (w *containerdExecutor) Run(ctx context.Context, id string, root executor.M } trace.SpanFromContext(ctx).AddEvent("Container created") - err = w.runProcess(ctx, task, process.Resize, process.Signal, process.Meta.ValidExitCodes, func() { + err = w.runProcess(ctx, task, process.Resize, process.Signal, stdinDone, process.Meta.ValidExitCodes, func() { startedOnce.Do(func() { trace.SpanFromContext(ctx).AddEvent("Container started") if started != nil { @@ -336,6 +337,7 @@ func (w *containerdExecutor) Exec(ctx context.Context, id string, process execut spec.Process.Env = process.Meta.Env } + stdinDone := trackStdinEOF(&process) fixProcessOutput(&process) cioOpts := []cio.Opt{cio.WithStreams(process.Stdin, process.Stdout, process.Stderr)} if meta.Tty { @@ -347,10 +349,53 @@ func (w *containerdExecutor) Exec(ctx context.Context, id string, process execut return errors.WithStack(err) } - err = w.runProcess(ctx, taskProcess, process.Resize, process.Signal, process.Meta.ValidExitCodes, nil) + err = w.runProcess(ctx, taskProcess, process.Resize, process.Signal, stdinDone, process.Meta.ValidExitCodes, nil) return err } +type stdinEOFTracker struct { + io.ReadCloser + once sync.Once + done chan struct{} + err error +} + +func trackStdinEOF(process *executor.ProcessInfo) <-chan struct{} { + if process.Stdin == nil { + return nil + } + tracker := &stdinEOFTracker{ + ReadCloser: process.Stdin, + done: make(chan struct{}), + } + process.Stdin = tracker + return tracker.done +} + +func (r *stdinEOFTracker) Read(p []byte) (int, error) { + if r.err != nil { + err := r.err + r.err = nil + r.close() + return 0, err + } + n, err := r.ReadCloser.Read(p) + if err != nil { + if n > 0 { + r.err = err + return n, nil + } + r.close() + } + return n, err +} + +func (r *stdinEOFTracker) close() { + r.once.Do(func() { + close(r.done) + }) +} + func fixProcessOutput(process *executor.ProcessInfo) { // It seems like if containerd has one of stdin, stdout or stderr then the // others need to be present as well otherwise we get this error: @@ -364,7 +409,7 @@ func fixProcessOutput(process *executor.ProcessInfo) { } } -func (w *containerdExecutor) runProcess(ctx context.Context, p ctd.Process, resize <-chan executor.WinSize, signal <-chan syscall.Signal, validExitCodes []int, started func()) error { +func (w *containerdExecutor) runProcess(ctx context.Context, p ctd.Process, resize <-chan executor.WinSize, signal <-chan syscall.Signal, stdinDone <-chan struct{}, validExitCodes []int, started func()) error { // Not using `ctx` here because the context passed only affects the statusCh which we // don't want cancelled when ctx.Done is sent. We want to process statusCh on cancel. statusCh, err := p.Wait(context.Background()) @@ -387,12 +432,25 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p ctd.Process, resi started() } - p.CloseIO(ctx, ctd.WithStdinCloser) - // handle signals (and resize) in separate go loop so it does not // potentially block the container cancel/exit status loop below. eventCtx, eventCancel := context.WithCancelCause(ctx) defer eventCancel(errors.WithStack(context.Canceled)) + if stdinDone == nil { + p.CloseIO(ctx, ctd.WithStdinCloser) + } else { + go func() { + select { + case <-eventCtx.Done(): + case <-stdinDone: + if err := p.CloseIO(eventCtx, ctd.WithStdinCloser); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(context.Cause(eventCtx), context.Canceled) { + bklog.G(eventCtx).Warnf("Failed to close stdin for %s: %s", p.ID(), err) + } + } + } + }() + } go func() { for { select { diff --git a/worker/tests/common.go b/worker/tests/common.go index bdd61f708..11d12497e 100644 --- a/worker/tests/common.go +++ b/worker/tests/common.go @@ -65,6 +65,7 @@ func NewCtx(s string) context.Context { func TestWorkerExec(t *testing.T, w *base.Worker) { ctx := NewCtx("buildkit-test") ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) sm, err := session.NewManager() require.NoError(t, err) @@ -110,13 +111,16 @@ func TestWorkerExec(t *testing.T, w *base.Worker) { execID := identity.NewID() eg := errgroup.Group{} started = make(chan struct{}) + pid1StdinR, pid1StdinW := io.Pipe() + defer pid1StdinW.Close() eg.Go(func() error { _, err := w.WorkerOpt.Executor.Run(ctx, execID, execMount(root), nil, executor.ProcessInfo{ Meta: executor.Meta{ - Args: []string{"sleep", "10"}, + Args: []string{"cat"}, Cwd: "/", Env: []string{"PATH=/bin:/usr/bin:/sbin:/usr/sbin"}, }, + Stdin: pid1StdinR, }, started) return err }) @@ -130,7 +134,7 @@ func TestWorkerExec(t *testing.T, w *base.Worker) { stdout.Reset() stderr.Reset() - // verify pid1 is the sleep command via Exec + // verify pid1 is the cat command via Exec err = w.WorkerOpt.Executor.Exec(ctx, execID, executor.ProcessInfo{ Meta: executor.Meta{ Args: []string{"ps", "-o", "pid,comm"}, @@ -141,8 +145,8 @@ func TestWorkerExec(t *testing.T, w *base.Worker) { t.Logf("Stdout: %s", stdout.String()) t.Logf("Stderr: %s", stderr.String()) require.NoError(t, err) - // verify pid1 is sleep - require.Contains(t, stdout.String(), "1 sleep") + // verify pid1 is cat + require.Contains(t, stdout.String(), "1 cat") require.Empty(t, stderr.String()) // simulate: echo -n "hello" | cat > /tmp/msg @@ -178,11 +182,24 @@ func TestWorkerExec(t *testing.T, w *base.Worker) { require.Empty(t, stderr.String()) // stop pid1 - cancel(errors.WithStack(context.Canceled)) + require.NoError(t, pid1StdinW.Close()) - err = eg.Wait() - // we expect pid1 to get canceled after we test the exec - require.True(t, errors.Is(err, context.Canceled)) + waitCh := make(chan error, 1) + go func() { + waitCh <- eg.Wait() + }() + select { + case err = <-waitCh: + require.NoError(t, err) + case <-time.After(10 * time.Second): + cancel(errors.WithStack(context.Canceled)) + select { + case err = <-waitCh: + require.Failf(t, "timed out waiting for pid1 to exit", "pid1 returned after cancellation: %+v", err) + case <-time.After(5 * time.Second): + require.FailNow(t, "timed out waiting for pid1 to exit after cancellation") + } + } err = snap.Release(ctx) require.NoError(t, err)