// Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // +build linux package net import ( "bytes" "fmt" "io" "io/ioutil" "os" "sync" "syscall" "testing" ) func TestSplice(t *testing.T) { t.Run("simple", testSpliceSimple) t.Run("multipleWrite", testSpliceMultipleWrite) t.Run("big", testSpliceBig) t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader) t.Run("readerAtEOF", testSpliceReaderAtEOF) t.Run("issue25985", testSpliceIssue25985) } func testSpliceSimple(t *testing.T) { srv, err := newSpliceTestServer() if err != nil { t.Fatal(err) } defer srv.Close() copyDone := srv.Copy() msg := []byte("splice test") if _, err := srv.Write(msg); err != nil { t.Fatal(err) } got := make([]byte, len(msg)) if _, err := io.ReadFull(srv, got); err != nil { t.Fatal(err) } if !bytes.Equal(got, msg) { t.Errorf("got %q, wrote %q", got, msg) } srv.CloseWrite() srv.CloseRead() if err := <-copyDone; err != nil { t.Errorf("splice: %v", err) } } func testSpliceMultipleWrite(t *testing.T) { srv, err := newSpliceTestServer() if err != nil { t.Fatal(err) } defer srv.Close() copyDone := srv.Copy() msg1 := []byte("splice test part 1 ") msg2 := []byte(" splice test part 2") if _, err := srv.Write(msg1); err != nil { t.Fatalf("Write: %v", err) } if _, err := srv.Write(msg2); err != nil { t.Fatal(err) } got := make([]byte, len(msg1)+len(msg2)) if _, err := io.ReadFull(srv, got); err != nil { t.Fatal(err) } want := append(msg1, msg2...) if !bytes.Equal(got, want) { t.Errorf("got %q, wrote %q", got, want) } srv.CloseWrite() srv.CloseRead() if err := <-copyDone; err != nil { t.Errorf("splice: %v", err) } } func testSpliceBig(t *testing.T) { // The maximum amount of data that internal/poll.Splice will use in a // splice(2) call is 4 << 20. Use a bigger size here so that we test an // amount that doesn't fit in a single call. size := 5 << 20 srv, err := newSpliceTestServer() if err != nil { t.Fatal(err) } defer srv.Close() big := make([]byte, size) copyDone := srv.Copy() type readResult struct { b []byte err error } readDone := make(chan readResult) go func() { got := make([]byte, len(big)) _, err := io.ReadFull(srv, got) readDone <- readResult{got, err} }() if _, err := srv.Write(big); err != nil { t.Fatal(err) } res := <-readDone if res.err != nil { t.Fatal(res.err) } got := res.b if !bytes.Equal(got, big) { t.Errorf("input and output differ") } srv.CloseWrite() srv.CloseRead() if err := <-copyDone; err != nil { t.Errorf("splice: %v", err) } } func testSpliceHonorsLimitedReader(t *testing.T) { t.Run("stopsAfterN", testSpliceStopsAfterN) t.Run("updatesN", testSpliceUpdatesN) t.Run("readerAtLimit", testSpliceReaderAtLimit) } func testSpliceStopsAfterN(t *testing.T) { clientUp, serverUp, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientUp.Close() defer serverUp.Close() clientDown, serverDown, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientDown.Close() defer serverDown.Close() count := 128 copyDone := make(chan error) lr := &io.LimitedReader{ N: int64(count), R: serverUp, } go func() { _, err := io.Copy(serverDown, lr) serverDown.Close() copyDone <- err }() msg := make([]byte, 2*count) if _, err := clientUp.Write(msg); err != nil { t.Fatal(err) } clientUp.Close() var buf bytes.Buffer if _, err := io.Copy(&buf, clientDown); err != nil { t.Fatal(err) } if buf.Len() != count { t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count) } clientDown.Close() if err := <-copyDone; err != nil { t.Errorf("splice: %v", err) } } func testSpliceUpdatesN(t *testing.T) { clientUp, serverUp, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientUp.Close() defer serverUp.Close() clientDown, serverDown, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientDown.Close() defer serverDown.Close() count := 128 copyDone := make(chan error) lr := &io.LimitedReader{ N: int64(100 + count), R: serverUp, } go func() { _, err := io.Copy(serverDown, lr) copyDone <- err }() msg := make([]byte, count) if _, err := clientUp.Write(msg); err != nil { t.Fatal(err) } clientUp.Close() got := make([]byte, count) if _, err := io.ReadFull(clientDown, got); err != nil { t.Fatal(err) } clientDown.Close() if err := <-copyDone; err != nil { t.Errorf("splice: %v", err) } wantN := int64(100) if lr.N != wantN { t.Errorf("lr.N = %d, want %d", lr.N, wantN) } } func testSpliceReaderAtLimit(t *testing.T) { clientUp, serverUp, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientUp.Close() defer serverUp.Close() clientDown, serverDown, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientDown.Close() defer serverDown.Close() lr := &io.LimitedReader{ N: 0, R: serverUp, } _, err, handled := splice(serverDown.(*TCPConn).fd, lr) if !handled { if serr, ok := err.(*os.SyscallError); ok && serr.Syscall == "pipe2" && serr.Err == syscall.ENOSYS { t.Skip("pipe2 not supported") } t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled) } } func testSpliceReaderAtEOF(t *testing.T) { clientUp, serverUp, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientUp.Close() clientDown, serverDown, err := spliceTestSocketPair("tcp") if err != nil { t.Fatal(err) } defer clientDown.Close() serverUp.Close() // We'd like to call net.splice here and check the handled return // value, but we disable splice on old Linux kernels. // // In that case, poll.Splice and net.splice return a non-nil error // and handled == false. We'd ideally like to see handled == true // because the source reader is at EOF, but if we're running on an old // kernel, and splice is disabled, we won't see EOF from net.splice, // because we won't touch the reader at all. // // Trying to untangle the errors from net.splice and match them // against the errors created by the poll package would be brittle, // so this is a higher level test. // // The following ReadFrom should return immediately, regardless of // whether splice is disabled or not. The other side should then // get a goodbye signal. Test for the goodbye signal. msg := "bye" go func() { serverDown.(*TCPConn).ReadFrom(serverUp) io.WriteString(serverDown, msg) serverDown.Close() }() buf := make([]byte, 3) _, err = io.ReadFull(clientDown, buf) if err != nil { t.Errorf("clientDown: %v", err) } if string(buf) != msg { t.Errorf("clientDown got %q, want %q", buf, msg) } } func testSpliceIssue25985(t *testing.T) { front, err := newLocalListener("tcp") if err != nil { t.Fatal(err) } defer front.Close() back, err := newLocalListener("tcp") if err != nil { t.Fatal(err) } defer back.Close() var wg sync.WaitGroup wg.Add(2) proxy := func() { src, err := front.Accept() if err != nil { return } dst, err := Dial("tcp", back.Addr().String()) if err != nil { return } defer dst.Close() defer src.Close() go func() { io.Copy(src, dst) wg.Done() }() go func() { io.Copy(dst, src) wg.Done() }() } go proxy() toFront, err := Dial("tcp", front.Addr().String()) if err != nil { t.Fatal(err) } io.WriteString(toFront, "foo") toFront.Close() fromProxy, err := back.Accept() if err != nil { t.Fatal(err) } defer fromProxy.Close() _, err = ioutil.ReadAll(fromProxy) if err != nil { t.Fatal(err) } wg.Wait() } func BenchmarkTCPReadFrom(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) var chunkSizes []int for i := uint(10); i <= 20; i++ { chunkSizes = append(chunkSizes, 1<