libgo: Update to weekly.2012-03-13.

From-SVN: r186023
This commit is contained in:
Ian Lance Taylor 2012-03-30 21:27:11 +00:00
parent e0be8a5c20
commit 456fba2651
137 changed files with 4893 additions and 1673 deletions

View File

@ -1,4 +1,4 @@
f4470a54e6db 3cdba7b0650c
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.

View File

@ -813,6 +813,7 @@ go_net_rpc_files = \
go/net/rpc/server.go go/net/rpc/server.go
go_runtime_files = \ go_runtime_files = \
go/runtime/compiler.go \
go/runtime/debug.go \ go/runtime/debug.go \
go/runtime/error.go \ go/runtime/error.go \
go/runtime/extern.go \ go/runtime/extern.go \
@ -843,6 +844,7 @@ go_strconv_files = \
go/strconv/decimal.go \ go/strconv/decimal.go \
go/strconv/extfloat.go \ go/strconv/extfloat.go \
go/strconv/ftoa.go \ go/strconv/ftoa.go \
go/strconv/isprint.go \
go/strconv/itoa.go \ go/strconv/itoa.go \
go/strconv/quote.go go/strconv/quote.go
@ -1000,12 +1002,13 @@ go_crypto_tls_files = \
go/crypto/tls/handshake_server.go \ go/crypto/tls/handshake_server.go \
go/crypto/tls/key_agreement.go \ go/crypto/tls/key_agreement.go \
go/crypto/tls/prf.go \ go/crypto/tls/prf.go \
go/crypto/tls/root_unix.go \
go/crypto/tls/tls.go go/crypto/tls/tls.go
go_crypto_x509_files = \ go_crypto_x509_files = \
go/crypto/x509/cert_pool.go \ go/crypto/x509/cert_pool.go \
go/crypto/x509/pkcs1.go \ go/crypto/x509/pkcs1.go \
go/crypto/x509/pkcs8.go \ go/crypto/x509/pkcs8.go \
go/crypto/x509/root.go \
go/crypto/x509/root_unix.go \
go/crypto/x509/verify.go \ go/crypto/x509/verify.go \
go/crypto/x509/x509.go go/crypto/x509/x509.go
@ -1320,7 +1323,8 @@ go_os_user_files = \
go_path_filepath_files = \ go_path_filepath_files = \
go/path/filepath/match.go \ go/path/filepath/match.go \
go/path/filepath/path.go \ go/path/filepath/path.go \
go/path/filepath/path_unix.go go/path/filepath/path_unix.go \
go/path/filepath/symlink.go
go_regexp_syntax_files = \ go_regexp_syntax_files = \
go/regexp/syntax/compile.go \ go/regexp/syntax/compile.go \

View File

@ -1131,6 +1131,7 @@ go_net_rpc_files = \
go/net/rpc/server.go go/net/rpc/server.go
go_runtime_files = \ go_runtime_files = \
go/runtime/compiler.go \
go/runtime/debug.go \ go/runtime/debug.go \
go/runtime/error.go \ go/runtime/error.go \
go/runtime/extern.go \ go/runtime/extern.go \
@ -1150,6 +1151,7 @@ go_strconv_files = \
go/strconv/decimal.go \ go/strconv/decimal.go \
go/strconv/extfloat.go \ go/strconv/extfloat.go \
go/strconv/ftoa.go \ go/strconv/ftoa.go \
go/strconv/isprint.go \
go/strconv/itoa.go \ go/strconv/itoa.go \
go/strconv/quote.go go/strconv/quote.go
@ -1315,13 +1317,14 @@ go_crypto_tls_files = \
go/crypto/tls/handshake_server.go \ go/crypto/tls/handshake_server.go \
go/crypto/tls/key_agreement.go \ go/crypto/tls/key_agreement.go \
go/crypto/tls/prf.go \ go/crypto/tls/prf.go \
go/crypto/tls/root_unix.go \
go/crypto/tls/tls.go go/crypto/tls/tls.go
go_crypto_x509_files = \ go_crypto_x509_files = \
go/crypto/x509/cert_pool.go \ go/crypto/x509/cert_pool.go \
go/crypto/x509/pkcs1.go \ go/crypto/x509/pkcs1.go \
go/crypto/x509/pkcs8.go \ go/crypto/x509/pkcs8.go \
go/crypto/x509/root.go \
go/crypto/x509/root_unix.go \
go/crypto/x509/verify.go \ go/crypto/x509/verify.go \
go/crypto/x509/x509.go go/crypto/x509/x509.go
@ -1677,7 +1680,8 @@ go_os_user_files = \
go_path_filepath_files = \ go_path_filepath_files = \
go/path/filepath/match.go \ go/path/filepath/match.go \
go/path/filepath/path.go \ go/path/filepath/path.go \
go/path/filepath/path_unix.go go/path/filepath/path_unix.go \
go/path/filepath/symlink.go
go_regexp_syntax_files = \ go_regexp_syntax_files = \
go/regexp/syntax/compile.go \ go/regexp/syntax/compile.go \

View File

@ -18,7 +18,7 @@ import (
) )
var ( var (
ErrHeader = errors.New("invalid tar header") ErrHeader = errors.New("archive/tar: invalid tar header")
) )
// A Reader provides sequential access to the contents of a tar archive. // A Reader provides sequential access to the contents of a tar archive.

View File

@ -5,18 +5,19 @@
package tar package tar
// TODO(dsymonds): // TODO(dsymonds):
// - catch more errors (no first header, write after close, etc.) // - catch more errors (no first header, etc.)
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"strconv" "strconv"
) )
var ( var (
ErrWriteTooLong = errors.New("write too long") ErrWriteTooLong = errors.New("archive/tar: write too long")
ErrFieldTooLong = errors.New("header field too long") ErrFieldTooLong = errors.New("archive/tar: header field too long")
ErrWriteAfterClose = errors.New("write after close") ErrWriteAfterClose = errors.New("archive/tar: write after close")
) )
// A Writer provides sequential writing of a tar archive in POSIX.1 format. // A Writer provides sequential writing of a tar archive in POSIX.1 format.
@ -48,6 +49,11 @@ func NewWriter(w io.Writer) *Writer { return &Writer{w: w} }
// Flush finishes writing the current file (optional). // Flush finishes writing the current file (optional).
func (tw *Writer) Flush() error { func (tw *Writer) Flush() error {
if tw.nb > 0 {
tw.err = fmt.Errorf("archive/tar: missed writing %d bytes", tw.nb)
return tw.err
}
n := tw.nb + tw.pad n := tw.nb + tw.pad
for n > 0 && tw.err == nil { for n > 0 && tw.err == nil {
nr := n nr := n
@ -193,6 +199,9 @@ func (tw *Writer) Close() error {
} }
tw.Flush() tw.Flush()
tw.closed = true tw.closed = true
if tw.err != nil {
return tw.err
}
// trailer: two zero blocks // trailer: two zero blocks
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"strings"
"testing" "testing"
"testing/iotest" "testing/iotest"
"time" "time"
@ -95,7 +96,8 @@ var writerTests = []*writerTest{
Uname: "dsymonds", Uname: "dsymonds",
Gname: "eng", Gname: "eng",
}, },
// no contents // fake contents
contents: strings.Repeat("\x00", 4<<10),
}, },
}, },
}, },
@ -150,7 +152,9 @@ testLoop:
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
tw := NewWriter(iotest.TruncateWriter(buf, 4<<10)) // only catch the first 4 KB tw := NewWriter(iotest.TruncateWriter(buf, 4<<10)) // only catch the first 4 KB
big := false
for j, entry := range test.entries { for j, entry := range test.entries {
big = big || entry.header.Size > 1<<10
if err := tw.WriteHeader(entry.header); err != nil { if err := tw.WriteHeader(entry.header); err != nil {
t.Errorf("test %d, entry %d: Failed writing header: %v", i, j, err) t.Errorf("test %d, entry %d: Failed writing header: %v", i, j, err)
continue testLoop continue testLoop
@ -160,7 +164,8 @@ testLoop:
continue testLoop continue testLoop
} }
} }
if err := tw.Close(); err != nil { // Only interested in Close failures for the small tests.
if err := tw.Close(); err != nil && !big {
t.Errorf("test %d: Failed closing archive: %v", i, err) t.Errorf("test %d: Failed closing archive: %v", i, err)
continue testLoop continue testLoop
} }

View File

@ -124,10 +124,6 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
return return
} }
size := int64(f.CompressedSize) size := int64(f.CompressedSize)
if size == 0 && f.hasDataDescriptor() {
// permit SectionReader to see the rest of the file
size = f.zipsize - (f.headerOffset + bodyOffset)
}
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size) r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
switch f.Method { switch f.Method {
case Store: // (no compression) case Store: // (no compression)
@ -136,10 +132,13 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
rc = flate.NewReader(r) rc = flate.NewReader(r)
default: default:
err = ErrAlgorithm err = ErrAlgorithm
return
} }
if rc != nil { var desr io.Reader
rc = &checksumReader{rc, crc32.NewIEEE(), f, r} if f.hasDataDescriptor() {
desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
} }
rc = &checksumReader{rc, crc32.NewIEEE(), f, desr, nil}
return return
} }
@ -147,23 +146,36 @@ type checksumReader struct {
rc io.ReadCloser rc io.ReadCloser
hash hash.Hash32 hash hash.Hash32
f *File f *File
zipr io.Reader // for reading the data descriptor desr io.Reader // if non-nil, where to read the data descriptor
err error // sticky error
} }
func (r *checksumReader) Read(b []byte) (n int, err error) { func (r *checksumReader) Read(b []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, err = r.rc.Read(b) n, err = r.rc.Read(b)
r.hash.Write(b[:n]) r.hash.Write(b[:n])
if err != io.EOF { if err == nil {
return return
} }
if r.f.hasDataDescriptor() { if err == io.EOF {
if err = readDataDescriptor(r.zipr, r.f); err != nil { if r.desr != nil {
return if err1 := readDataDescriptor(r.desr, r.f); err1 != nil {
err = err1
} else if r.hash.Sum32() != r.f.CRC32 {
err = ErrChecksum
}
} else {
// If there's not a data descriptor, we still compare
// the CRC32 of what we've read against the file header
// or TOC's CRC32, if it seems like it was set.
if r.f.CRC32 != 0 && r.hash.Sum32() != r.f.CRC32 {
err = ErrChecksum
}
} }
} }
if r.hash.Sum32() != r.f.CRC32 { r.err = err
err = ErrChecksum
}
return return
} }
@ -226,10 +238,31 @@ func readDirectoryHeader(f *File, r io.Reader) error {
func readDataDescriptor(r io.Reader, f *File) error { func readDataDescriptor(r io.Reader, f *File) error {
var buf [dataDescriptorLen]byte var buf [dataDescriptorLen]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
// The spec says: "Although not originally assigned a
// signature, the value 0x08074b50 has commonly been adopted
// as a signature value for the data descriptor record.
// Implementers should be aware that ZIP files may be
// encountered with or without this signature marking data
// descriptors and should account for either case when reading
// ZIP files to ensure compatibility."
//
// dataDescriptorLen includes the size of the signature but
// first read just those 4 bytes to see if it exists.
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return err return err
} }
b := readBuf(buf[:]) off := 0
maybeSig := readBuf(buf[:4])
if maybeSig.uint32() != dataDescriptorSignature {
// No data descriptor signature. Keep these four
// bytes.
off += 4
}
if _, err := io.ReadFull(r, buf[off:12]); err != nil {
return err
}
b := readBuf(buf[:12])
f.CRC32 = b.uint32() f.CRC32 = b.uint32()
f.CompressedSize = b.uint32() f.CompressedSize = b.uint32()
f.UncompressedSize = b.uint32() f.UncompressedSize = b.uint32()

View File

@ -10,23 +10,26 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath"
"testing" "testing"
"time" "time"
) )
type ZipTest struct { type ZipTest struct {
Name string Name string
Source func() (r io.ReaderAt, size int64) // if non-nil, used instead of testdata/<Name> file
Comment string Comment string
File []ZipTestFile File []ZipTestFile
Error error // the error that Opening this file should return Error error // the error that Opening this file should return
} }
type ZipTestFile struct { type ZipTestFile struct {
Name string Name string
Content []byte // if blank, will attempt to compare against File Content []byte // if blank, will attempt to compare against File
File string // name of file to compare to (relative to testdata/) ContentErr error
Mtime string // modified time in format "mm-dd-yy hh:mm:ss" File string // name of file to compare to (relative to testdata/)
Mode os.FileMode Mtime string // modified time in format "mm-dd-yy hh:mm:ss"
Mode os.FileMode
} }
// Caution: The Mtime values found for the test files should correspond to // Caution: The Mtime values found for the test files should correspond to
@ -107,6 +110,99 @@ var tests = []ZipTest{
Name: "unix.zip", Name: "unix.zip",
File: crossPlatform, File: crossPlatform,
}, },
{
// created by Go, before we wrote the "optional" data
// descriptor signatures (which are required by OS X)
Name: "go-no-datadesc-sig.zip",
File: []ZipTestFile{
{
Name: "foo.txt",
Content: []byte("foo\n"),
Mtime: "03-08-12 16:59:10",
Mode: 0644,
},
{
Name: "bar.txt",
Content: []byte("bar\n"),
Mtime: "03-08-12 16:59:12",
Mode: 0644,
},
},
},
{
// created by Go, after we wrote the "optional" data
// descriptor signatures (which are required by OS X)
Name: "go-with-datadesc-sig.zip",
File: []ZipTestFile{
{
Name: "foo.txt",
Content: []byte("foo\n"),
Mode: 0666,
},
{
Name: "bar.txt",
Content: []byte("bar\n"),
Mode: 0666,
},
},
},
{
Name: "Bad-CRC32-in-data-descriptor",
Source: returnCorruptCRC32Zip,
File: []ZipTestFile{
{
Name: "foo.txt",
Content: []byte("foo\n"),
Mode: 0666,
ContentErr: ErrChecksum,
},
{
Name: "bar.txt",
Content: []byte("bar\n"),
Mode: 0666,
},
},
},
// Tests that we verify (and accept valid) crc32s on files
// with crc32s in their file header (not in data descriptors)
{
Name: "crc32-not-streamed.zip",
File: []ZipTestFile{
{
Name: "foo.txt",
Content: []byte("foo\n"),
Mtime: "03-08-12 16:59:10",
Mode: 0644,
},
{
Name: "bar.txt",
Content: []byte("bar\n"),
Mtime: "03-08-12 16:59:12",
Mode: 0644,
},
},
},
// Tests that we verify (and reject invalid) crc32s on files
// with crc32s in their file header (not in data descriptors)
{
Name: "crc32-not-streamed.zip",
Source: returnCorruptNotStreamedZip,
File: []ZipTestFile{
{
Name: "foo.txt",
Content: []byte("foo\n"),
Mtime: "03-08-12 16:59:10",
Mode: 0644,
ContentErr: ErrChecksum,
},
{
Name: "bar.txt",
Content: []byte("bar\n"),
Mtime: "03-08-12 16:59:12",
Mode: 0644,
},
},
},
} }
var crossPlatform = []ZipTestFile{ var crossPlatform = []ZipTestFile{
@ -139,7 +235,18 @@ func TestReader(t *testing.T) {
} }
func readTestZip(t *testing.T, zt ZipTest) { func readTestZip(t *testing.T, zt ZipTest) {
z, err := OpenReader("testdata/" + zt.Name) var z *Reader
var err error
if zt.Source != nil {
rat, size := zt.Source()
z, err = NewReader(rat, size)
} else {
var rc *ReadCloser
rc, err = OpenReader(filepath.Join("testdata", zt.Name))
if err == nil {
z = &rc.Reader
}
}
if err != zt.Error { if err != zt.Error {
t.Errorf("error=%v, want %v", err, zt.Error) t.Errorf("error=%v, want %v", err, zt.Error)
return return
@ -149,11 +256,6 @@ func readTestZip(t *testing.T, zt ZipTest) {
if err == ErrFormat { if err == ErrFormat {
return return
} }
defer func() {
if err := z.Close(); err != nil {
t.Errorf("error %q when closing zip file", err)
}
}()
// bail here if no Files expected to be tested // bail here if no Files expected to be tested
// (there may actually be files in the zip, but we don't care) // (there may actually be files in the zip, but we don't care)
@ -170,7 +272,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
// test read of each file // test read of each file
for i, ft := range zt.File { for i, ft := range zt.File {
readTestFile(t, ft, z.File[i]) readTestFile(t, zt, ft, z.File[i])
} }
// test simultaneous reads // test simultaneous reads
@ -179,7 +281,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
for j, ft := range zt.File { for j, ft := range zt.File {
go func(j int, ft ZipTestFile) { go func(j int, ft ZipTestFile) {
readTestFile(t, ft, z.File[j]) readTestFile(t, zt, ft, z.File[j])
done <- true done <- true
}(j, ft) }(j, ft)
n++ n++
@ -188,26 +290,11 @@ func readTestZip(t *testing.T, zt ZipTest) {
for ; n > 0; n-- { for ; n > 0; n-- {
<-done <-done
} }
// test invalid checksum
if !z.File[0].hasDataDescriptor() { // skip test when crc32 in dd
z.File[0].CRC32++ // invalidate
r, err := z.File[0].Open()
if err != nil {
t.Error(err)
return
}
var b bytes.Buffer
_, err = io.Copy(&b, r)
if err != ErrChecksum {
t.Errorf("%s: copy error=%v, want %v", z.File[0].Name, err, ErrChecksum)
}
}
} }
func readTestFile(t *testing.T, ft ZipTestFile, f *File) { func readTestFile(t *testing.T, zt ZipTest, ft ZipTestFile, f *File) {
if f.Name != ft.Name { if f.Name != ft.Name {
t.Errorf("name=%q, want %q", f.Name, ft.Name) t.Errorf("%s: name=%q, want %q", zt.Name, f.Name, ft.Name)
} }
if ft.Mtime != "" { if ft.Mtime != "" {
@ -217,11 +304,11 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
return return
} }
if ft := f.ModTime(); !ft.Equal(mtime) { if ft := f.ModTime(); !ft.Equal(mtime) {
t.Errorf("%s: mtime=%s, want %s", f.Name, ft, mtime) t.Errorf("%s: %s: mtime=%s, want %s", zt.Name, f.Name, ft, mtime)
} }
} }
testFileMode(t, f, ft.Mode) testFileMode(t, zt.Name, f, ft.Mode)
size0 := f.UncompressedSize size0 := f.UncompressedSize
@ -237,8 +324,10 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
} }
_, err = io.Copy(&b, r) _, err = io.Copy(&b, r)
if err != ft.ContentErr {
t.Errorf("%s: copying contents: %v (want %v)", zt.Name, err, ft.ContentErr)
}
if err != nil { if err != nil {
t.Error(err)
return return
} }
r.Close() r.Close()
@ -264,12 +353,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
} }
} }
func testFileMode(t *testing.T, f *File, want os.FileMode) { func testFileMode(t *testing.T, zipName string, f *File, want os.FileMode) {
mode := f.Mode() mode := f.Mode()
if want == 0 { if want == 0 {
t.Errorf("%s mode: got %v, want none", f.Name, mode) t.Errorf("%s: %s mode: got %v, want none", zipName, f.Name, mode)
} else if mode != want { } else if mode != want {
t.Errorf("%s mode: want %v, got %v", f.Name, want, mode) t.Errorf("%s: %s mode: want %v, got %v", zipName, f.Name, want, mode)
} }
} }
@ -294,3 +383,35 @@ func TestInvalidFiles(t *testing.T) {
t.Errorf("sigs: error=%v, want %v", err, ErrFormat) t.Errorf("sigs: error=%v, want %v", err, ErrFormat)
} }
} }
func messWith(fileName string, corrupter func(b []byte)) (r io.ReaderAt, size int64) {
data, err := ioutil.ReadFile(filepath.Join("testdata", fileName))
if err != nil {
panic("Error reading " + fileName + ": " + err.Error())
}
corrupter(data)
return bytes.NewReader(data), int64(len(data))
}
func returnCorruptCRC32Zip() (r io.ReaderAt, size int64) {
return messWith("go-with-datadesc-sig.zip", func(b []byte) {
// Corrupt one of the CRC32s in the data descriptor:
b[0x2d]++
})
}
func returnCorruptNotStreamedZip() (r io.ReaderAt, size int64) {
return messWith("crc32-not-streamed.zip", func(b []byte) {
// Corrupt foo.txt's final crc32 byte, in both
// the file header and TOC. (0x7e -> 0x7f)
b[0x11]++
b[0x9d]++
// TODO(bradfitz): add a new test that only corrupts
// one of these values, and verify that that's also an
// error. Currently, the reader code doesn't verify the
// fileheader and TOC's crc32 match if they're both
// non-zero and only the second line above, the TOC,
// is what matters.
})
}

View File

@ -27,10 +27,11 @@ const (
fileHeaderSignature = 0x04034b50 fileHeaderSignature = 0x04034b50
directoryHeaderSignature = 0x02014b50 directoryHeaderSignature = 0x02014b50
directoryEndSignature = 0x06054b50 directoryEndSignature = 0x06054b50
fileHeaderLen = 30 // + filename + extra dataDescriptorSignature = 0x08074b50 // de-facto standard; required by OS X Finder
directoryHeaderLen = 46 // + filename + extra + comment fileHeaderLen = 30 // + filename + extra
directoryEndLen = 22 // + comment directoryHeaderLen = 46 // + filename + extra + comment
dataDescriptorLen = 12 directoryEndLen = 22 // + comment
dataDescriptorLen = 16 // four uint32: descriptor signature, crc32, compressed size, size
// Constants for the first byte in CreatorVersion // Constants for the first byte in CreatorVersion
creatorFAT = 0 creatorFAT = 0

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -224,6 +224,7 @@ func (w *fileWriter) close() error {
// write data descriptor // write data descriptor
var buf [dataDescriptorLen]byte var buf [dataDescriptorLen]byte
b := writeBuf(buf[:]) b := writeBuf(buf[:])
b.uint32(dataDescriptorSignature) // de-facto standard, required by OS X
b.uint32(fh.CRC32) b.uint32(fh.CRC32)
b.uint32(fh.CompressedSize) b.uint32(fh.CompressedSize)
b.uint32(fh.UncompressedSize) b.uint32(fh.UncompressedSize)

View File

@ -108,7 +108,7 @@ func testReadFile(t *testing.T, f *File, wt *WriteTest) {
if f.Name != wt.Name { if f.Name != wt.Name {
t.Fatalf("File name: got %q, want %q", f.Name, wt.Name) t.Fatalf("File name: got %q, want %q", f.Name, wt.Name)
} }
testFileMode(t, f, wt.Mode) testFileMode(t, wt.Name, f, wt.Mode)
rc, err := f.Open() rc, err := f.Open()
if err != nil { if err != nil {
t.Fatal("opening:", err) t.Fatal("opening:", err)

View File

@ -198,14 +198,6 @@ func (c *Config) time() time.Time {
return t() return t()
} }
func (c *Config) rootCAs() *x509.CertPool {
s := c.RootCAs
if s == nil {
s = defaultRoots()
}
return s
}
func (c *Config) cipherSuites() []uint16 { func (c *Config) cipherSuites() []uint16 {
s := c.CipherSuites s := c.CipherSuites
if s == nil { if s == nil {
@ -311,28 +303,16 @@ func defaultConfig() *Config {
return &emptyConfig return &emptyConfig
} }
var once sync.Once
func defaultRoots() *x509.CertPool {
once.Do(initDefaults)
return varDefaultRoots
}
func defaultCipherSuites() []uint16 {
once.Do(initDefaults)
return varDefaultCipherSuites
}
func initDefaults() {
initDefaultRoots()
initDefaultCipherSuites()
}
var ( var (
varDefaultRoots *x509.CertPool once sync.Once
varDefaultCipherSuites []uint16 varDefaultCipherSuites []uint16
) )
func defaultCipherSuites() []uint16 {
once.Do(initDefaultCipherSuites)
return varDefaultCipherSuites
}
func initDefaultCipherSuites() { func initDefaultCipherSuites() {
varDefaultCipherSuites = make([]uint16, len(cipherSuites)) varDefaultCipherSuites = make([]uint16, len(cipherSuites))
for i, suite := range cipherSuites { for i, suite := range cipherSuites {

View File

@ -102,7 +102,7 @@ func (c *Conn) clientHandshake() error {
if !c.config.InsecureSkipVerify { if !c.config.InsecureSkipVerify {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Roots: c.config.rootCAs(), Roots: c.config.RootCAs,
CurrentTime: c.config.time(), CurrentTime: c.config.time(),
DNSName: c.config.ServerName, DNSName: c.config.ServerName,
Intermediates: x509.NewCertPool(), Intermediates: x509.NewCertPool(),

View File

@ -143,7 +143,7 @@ func testServerScript(t *testing.T, name string, serverScript [][]byte, config *
if peers != nil { if peers != nil {
gotpeers := <-pchan gotpeers := <-pchan
if len(peers) == len(gotpeers) { if len(peers) == len(gotpeers) {
for i, _ := range peers { for i := range peers {
if !peers[i].Equal(gotpeers[i]) { if !peers[i].Equal(gotpeers[i]) {
t.Fatalf("%s: mismatch on peer cert %d", name, i) t.Fatalf("%s: mismatch on peer cert %d", name, i)
} }

View File

@ -5,25 +5,25 @@
package tls package tls
import ( import (
"crypto/x509"
"runtime"
"testing" "testing"
) )
var tlsServers = []string{ var tlsServers = []string{
"google.com:443", "google.com",
"github.com:443", "github.com",
"twitter.com:443", "twitter.com",
} }
func TestOSCertBundles(t *testing.T) { func TestOSCertBundles(t *testing.T) {
defaultRoots()
if testing.Short() { if testing.Short() {
t.Logf("skipping certificate tests in short mode") t.Logf("skipping certificate tests in short mode")
return return
} }
for _, addr := range tlsServers { for _, addr := range tlsServers {
conn, err := Dial("tcp", addr, nil) conn, err := Dial("tcp", addr+":443", &Config{ServerName: addr})
if err != nil { if err != nil {
t.Errorf("unable to verify %v: %v", addr, err) t.Errorf("unable to verify %v: %v", addr, err)
continue continue
@ -34,3 +34,28 @@ func TestOSCertBundles(t *testing.T) {
} }
} }
} }
func TestCertHostnameVerifyWindows(t *testing.T) {
if runtime.GOOS != "windows" {
return
}
if testing.Short() {
t.Logf("skipping certificate tests in short mode")
return
}
for _, addr := range tlsServers {
cfg := &Config{ServerName: "example.com"}
conn, err := Dial("tcp", addr+":443", cfg)
if err == nil {
conn.Close()
t.Errorf("should fail to verify for example.com: %v", addr)
continue
}
_, ok := err.(x509.HostnameError)
if !ok {
t.Errorf("error type mismatch, got: %v", err)
}
}
}

View File

@ -1,47 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/x509"
"syscall"
"unsafe"
)
func loadStore(roots *x509.CertPool, name string) {
store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
if err != nil {
return
}
defer syscall.CertCloseStore(store, 0)
var cert *syscall.CertContext
for {
cert, err = syscall.CertEnumCertificatesInStore(store, cert)
if err != nil {
return
}
buf := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:]
// ParseCertificate requires its own copy of certificate data to keep.
buf2 := make([]byte, cert.Length)
copy(buf2, buf)
if c, err := x509.ParseCertificate(buf2); err == nil {
roots.AddCert(c)
}
}
}
func initDefaultRoots() {
roots := x509.NewCertPool()
// Roots
loadStore(roots, "ROOT")
// Intermediates
loadStore(roots, "CA")
varDefaultRoots = roots
}

View File

@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package tls partially implements the TLS 1.1 protocol, as specified in RFC // Package tls partially implements TLS 1.0, as specified in RFC 2246.
// 4346.
package tls package tls
import ( import (
@ -98,7 +97,9 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
if config == nil { if config == nil {
config = defaultConfig() config = defaultConfig()
} }
if config.ServerName != "" { // If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default. // Make a copy to avoid polluting argument or default.
c := *config c := *config
c.ServerName = hostname c.ServerName = hostname

View File

@ -24,7 +24,7 @@ type pkcs1PrivateKey struct {
Dq *big.Int `asn1:"optional"` Dq *big.Int `asn1:"optional"`
Qinv *big.Int `asn1:"optional"` Qinv *big.Int `asn1:"optional"`
AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional"` AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"`
} }
type pkcs1AdditionalRSAPrime struct { type pkcs1AdditionalRSAPrime struct {

View File

@ -0,0 +1,17 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import "sync"
var (
once sync.Once
systemRoots *CertPool
)
func systemRootsPool() *CertPool {
once.Do(initSystemRoots)
return systemRoots
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package tls package x509
/* /*
#cgo CFLAGS: -mmacosx-version-min=10.6 -D__MAC_OS_X_VERSION_MAX_ALLOWED=1060 #cgo CFLAGS: -mmacosx-version-min=10.6 -D__MAC_OS_X_VERSION_MAX_ALLOWED=1060
@ -59,13 +59,14 @@ int FetchPEMRoots(CFDataRef *pemRoots) {
} }
*/ */
import "C" import "C"
import ( import "unsafe"
"crypto/x509"
"unsafe"
)
func initDefaultRoots() { func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
roots := x509.NewCertPool() return nil, nil
}
func initSystemRoots() {
roots := NewCertPool()
var data C.CFDataRef = nil var data C.CFDataRef = nil
err := C.FetchPEMRoots(&data) err := C.FetchPEMRoots(&data)
@ -75,5 +76,5 @@ func initDefaultRoots() {
roots.AppendCertsFromPEM(buf) roots.AppendCertsFromPEM(buf)
} }
varDefaultRoots = roots systemRoots = roots
} }

View File

@ -4,7 +4,12 @@
// +build plan9 darwin,!cgo // +build plan9 darwin,!cgo
package tls package x509
func initDefaultRoots() { func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
return nil, nil
}
func initSystemRoots() {
systemRoots = NewCertPool()
} }

View File

@ -4,12 +4,9 @@
// +build freebsd linux openbsd netbsd // +build freebsd linux openbsd netbsd
package tls package x509
import ( import "io/ioutil"
"crypto/x509"
"io/ioutil"
)
// Possible certificate files; stop after finding one. // Possible certificate files; stop after finding one.
var certFiles = []string{ var certFiles = []string{
@ -20,8 +17,12 @@ var certFiles = []string{
"/usr/local/share/certs/ca-root-nss.crt", // FreeBSD "/usr/local/share/certs/ca-root-nss.crt", // FreeBSD
} }
func initDefaultRoots() { func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
roots := x509.NewCertPool() return nil, nil
}
func initSystemRoots() {
roots := NewCertPool()
for _, file := range certFiles { for _, file := range certFiles {
data, err := ioutil.ReadFile(file) data, err := ioutil.ReadFile(file)
if err == nil { if err == nil {
@ -29,5 +30,6 @@ func initDefaultRoots() {
break break
} }
} }
varDefaultRoots = roots
systemRoots = roots
} }

View File

@ -0,0 +1,226 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"errors"
"syscall"
"unsafe"
)
// Creates a new *syscall.CertContext representing the leaf certificate in an in-memory
// certificate store containing itself and all of the intermediate certificates specified
// in the opts.Intermediates CertPool.
//
// A pointer to the in-memory store is available in the returned CertContext's Store field.
// The store is automatically freed when the CertContext is freed using
// syscall.CertFreeCertificateContext.
func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertContext, error) {
var storeCtx *syscall.CertContext
leafCtx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &leaf.Raw[0], uint32(len(leaf.Raw)))
if err != nil {
return nil, err
}
defer syscall.CertFreeCertificateContext(leafCtx)
handle, err := syscall.CertOpenStore(syscall.CERT_STORE_PROV_MEMORY, 0, 0, syscall.CERT_STORE_DEFER_CLOSE_UNTIL_LAST_FREE_FLAG, 0)
if err != nil {
return nil, err
}
defer syscall.CertCloseStore(handle, 0)
err = syscall.CertAddCertificateContextToStore(handle, leafCtx, syscall.CERT_STORE_ADD_ALWAYS, &storeCtx)
if err != nil {
return nil, err
}
if opts.Intermediates != nil {
for _, intermediate := range opts.Intermediates.certs {
ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw)))
if err != nil {
return nil, err
}
err = syscall.CertAddCertificateContextToStore(handle, ctx, syscall.CERT_STORE_ADD_ALWAYS, nil)
syscall.CertFreeCertificateContext(ctx)
if err != nil {
return nil, err
}
}
}
return storeCtx, nil
}
// extractSimpleChain extracts the final certificate chain from a CertSimpleChain.
func extractSimpleChain(simpleChain **syscall.CertSimpleChain, count int) (chain []*Certificate, err error) {
if simpleChain == nil || count == 0 {
return nil, errors.New("x509: invalid simple chain")
}
simpleChains := (*[1 << 20]*syscall.CertSimpleChain)(unsafe.Pointer(simpleChain))[:]
lastChain := simpleChains[count-1]
elements := (*[1 << 20]*syscall.CertChainElement)(unsafe.Pointer(lastChain.Elements))[:]
for i := 0; i < int(lastChain.NumElements); i++ {
// Copy the buf, since ParseCertificate does not create its own copy.
cert := elements[i].CertContext
encodedCert := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:]
buf := make([]byte, cert.Length)
copy(buf, encodedCert[:])
parsedCert, err := ParseCertificate(buf)
if err != nil {
return nil, err
}
chain = append(chain, parsedCert)
}
return chain, nil
}
// checkChainTrustStatus checks the trust status of the certificate chain, translating
// any errors it finds into Go errors in the process.
func checkChainTrustStatus(c *Certificate, chainCtx *syscall.CertChainContext) error {
if chainCtx.TrustStatus.ErrorStatus != syscall.CERT_TRUST_NO_ERROR {
status := chainCtx.TrustStatus.ErrorStatus
switch status {
case syscall.CERT_TRUST_IS_NOT_TIME_VALID:
return CertificateInvalidError{c, Expired}
default:
return UnknownAuthorityError{c}
}
}
return nil
}
// checkChainSSLServerPolicy checks that the certificate chain in chainCtx is valid for
// use as a certificate chain for a SSL/TLS server.
func checkChainSSLServerPolicy(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) error {
sslPara := &syscall.SSLExtraCertChainPolicyPara{
AuthType: syscall.AUTHTYPE_SERVER,
ServerName: syscall.StringToUTF16Ptr(opts.DNSName),
}
sslPara.Size = uint32(unsafe.Sizeof(*sslPara))
para := &syscall.CertChainPolicyPara{
ExtraPolicyPara: uintptr(unsafe.Pointer(sslPara)),
}
para.Size = uint32(unsafe.Sizeof(*para))
status := syscall.CertChainPolicyStatus{}
err := syscall.CertVerifyCertificateChainPolicy(syscall.CERT_CHAIN_POLICY_SSL, chainCtx, para, &status)
if err != nil {
return err
}
// TODO(mkrautz): use the lChainIndex and lElementIndex fields
// of the CertChainPolicyStatus to provide proper context, instead
// using c.
if status.Error != 0 {
switch status.Error {
case syscall.CERT_E_EXPIRED:
return CertificateInvalidError{c, Expired}
case syscall.CERT_E_CN_NO_MATCH:
return HostnameError{c, opts.DNSName}
case syscall.CERT_E_UNTRUSTEDROOT:
return UnknownAuthorityError{c}
default:
return UnknownAuthorityError{c}
}
}
return nil
}
// systemVerify is like Verify, except that it uses CryptoAPI calls
// to build certificate chains and verify them.
func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
hasDNSName := opts != nil && len(opts.DNSName) > 0
storeCtx, err := createStoreContext(c, opts)
if err != nil {
return nil, err
}
defer syscall.CertFreeCertificateContext(storeCtx)
para := new(syscall.CertChainPara)
para.Size = uint32(unsafe.Sizeof(*para))
// If there's a DNSName set in opts, assume we're verifying
// a certificate from a TLS server.
if hasDNSName {
oids := []*byte{
&syscall.OID_PKIX_KP_SERVER_AUTH[0],
// Both IE and Chrome allow certificates with
// Server Gated Crypto as well. Some certificates
// in the wild require them.
&syscall.OID_SERVER_GATED_CRYPTO[0],
&syscall.OID_SGC_NETSCAPE[0],
}
para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_OR
para.RequestedUsage.Usage.Length = uint32(len(oids))
para.RequestedUsage.Usage.UsageIdentifiers = &oids[0]
} else {
para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_AND
para.RequestedUsage.Usage.Length = 0
para.RequestedUsage.Usage.UsageIdentifiers = nil
}
var verifyTime *syscall.Filetime
if opts != nil && !opts.CurrentTime.IsZero() {
ft := syscall.NsecToFiletime(opts.CurrentTime.UnixNano())
verifyTime = &ft
}
// CertGetCertificateChain will traverse Windows's root stores
// in an attempt to build a verified certificate chain. Once
// it has found a verified chain, it stops. MSDN docs on
// CERT_CHAIN_CONTEXT:
//
// When a CERT_CHAIN_CONTEXT is built, the first simple chain
// begins with an end certificate and ends with a self-signed
// certificate. If that self-signed certificate is not a root
// or otherwise trusted certificate, an attempt is made to
// build a new chain. CTLs are used to create the new chain
// beginning with the self-signed certificate from the original
// chain as the end certificate of the new chain. This process
// continues building additional simple chains until the first
// self-signed certificate is a trusted certificate or until
// an additional simple chain cannot be built.
//
// The result is that we'll only get a single trusted chain to
// return to our caller.
var chainCtx *syscall.CertChainContext
err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, 0, 0, &chainCtx)
if err != nil {
return nil, err
}
defer syscall.CertFreeCertificateChain(chainCtx)
err = checkChainTrustStatus(c, chainCtx)
if err != nil {
return nil, err
}
if hasDNSName {
err = checkChainSSLServerPolicy(c, chainCtx, opts)
if err != nil {
return nil, err
}
}
chain, err := extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
if err != nil {
return nil, err
}
chains = append(chains, chain)
return chains, nil
}
func initSystemRoots() {
systemRoots = NewCertPool()
}

View File

@ -5,6 +5,7 @@
package x509 package x509
import ( import (
"runtime"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -23,6 +24,9 @@ const (
// certificate has a name constraint which doesn't include the name // certificate has a name constraint which doesn't include the name
// being checked. // being checked.
CANotAuthorizedForThisName CANotAuthorizedForThisName
// TooManyIntermediates results when a path length constraint is
// violated.
TooManyIntermediates
) )
// CertificateInvalidError results when an odd error occurs. Users of this // CertificateInvalidError results when an odd error occurs. Users of this
@ -40,6 +44,8 @@ func (e CertificateInvalidError) Error() string {
return "x509: certificate has expired or is not yet valid" return "x509: certificate has expired or is not yet valid"
case CANotAuthorizedForThisName: case CANotAuthorizedForThisName:
return "x509: a root or intermediate certificate is not authorized to sign in this domain" return "x509: a root or intermediate certificate is not authorized to sign in this domain"
case TooManyIntermediates:
return "x509: too many intermediates for path length constraint"
} }
return "x509: unknown error" return "x509: unknown error"
} }
@ -76,7 +82,7 @@ func (e UnknownAuthorityError) Error() string {
type VerifyOptions struct { type VerifyOptions struct {
DNSName string DNSName string
Intermediates *CertPool Intermediates *CertPool
Roots *CertPool Roots *CertPool // if nil, the system roots are used
CurrentTime time.Time // if zero, the current time is used CurrentTime time.Time // if zero, the current time is used
} }
@ -87,7 +93,7 @@ const (
) )
// isValid performs validity checks on the c. // isValid performs validity checks on the c.
func (c *Certificate) isValid(certType int, opts *VerifyOptions) error { func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *VerifyOptions) error {
now := opts.CurrentTime now := opts.CurrentTime
if now.IsZero() { if now.IsZero() {
now = time.Now() now = time.Now()
@ -130,26 +136,44 @@ func (c *Certificate) isValid(certType int, opts *VerifyOptions) error {
return CertificateInvalidError{c, NotAuthorizedToSign} return CertificateInvalidError{c, NotAuthorizedToSign}
} }
if c.BasicConstraintsValid && c.MaxPathLen >= 0 {
numIntermediates := len(currentChain) - 1
if numIntermediates > c.MaxPathLen {
return CertificateInvalidError{c, TooManyIntermediates}
}
}
return nil return nil
} }
// Verify attempts to verify c by building one or more chains from c to a // Verify attempts to verify c by building one or more chains from c to a
// certificate in opts.roots, using certificates in opts.Intermediates if // certificate in opts.Roots, using certificates in opts.Intermediates if
// needed. If successful, it returns one or more chains where the first // needed. If successful, it returns one or more chains where the first
// element of the chain is c and the last element is from opts.Roots. // element of the chain is c and the last element is from opts.Roots.
// //
// WARNING: this doesn't do any revocation checking. // WARNING: this doesn't do any revocation checking.
func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) { func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
err = c.isValid(leafCertificate, &opts) // Use Windows's own verification and chain building.
if opts.Roots == nil && runtime.GOOS == "windows" {
return c.systemVerify(&opts)
}
if opts.Roots == nil {
opts.Roots = systemRootsPool()
}
err = c.isValid(leafCertificate, nil, &opts)
if err != nil { if err != nil {
return return
} }
if len(opts.DNSName) > 0 { if len(opts.DNSName) > 0 {
err = c.VerifyHostname(opts.DNSName) err = c.VerifyHostname(opts.DNSName)
if err != nil { if err != nil {
return return
} }
} }
return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts) return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts)
} }
@ -163,7 +187,7 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate
func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err error) { func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err error) {
for _, rootNum := range opts.Roots.findVerifiedParents(c) { for _, rootNum := range opts.Roots.findVerifiedParents(c) {
root := opts.Roots.certs[rootNum] root := opts.Roots.certs[rootNum]
err = root.isValid(rootCertificate, opts) err = root.isValid(rootCertificate, currentChain, opts)
if err != nil { if err != nil {
continue continue
} }
@ -178,7 +202,7 @@ nextIntermediate:
continue nextIntermediate continue nextIntermediate
} }
} }
err = intermediate.isValid(intermediateCertificate, opts) err = intermediate.isValid(intermediateCertificate, currentChain, opts)
if err != nil { if err != nil {
continue continue
} }

View File

@ -8,6 +8,7 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem" "encoding/pem"
"errors" "errors"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -19,7 +20,7 @@ type verifyTest struct {
roots []string roots []string
currentTime int64 currentTime int64
dnsName string dnsName string
nilRoots bool systemSkip bool
errorCallback func(*testing.T, int, error) bool errorCallback func(*testing.T, int, error) bool
expectedChains [][]string expectedChains [][]string
@ -57,14 +58,6 @@ var verifyTests = []verifyTest{
errorCallback: expectHostnameError, errorCallback: expectHostnameError,
}, },
{
leaf: googleLeaf,
intermediates: []string{thawteIntermediate},
nilRoots: true, // verifies that we don't crash
currentTime: 1302726541,
dnsName: "www.google.com",
errorCallback: expectAuthorityUnknown,
},
{ {
leaf: googleLeaf, leaf: googleLeaf,
intermediates: []string{thawteIntermediate}, intermediates: []string{thawteIntermediate},
@ -80,6 +73,9 @@ var verifyTests = []verifyTest{
currentTime: 1302726541, currentTime: 1302726541,
dnsName: "www.google.com", dnsName: "www.google.com",
// Skip when using systemVerify, since Windows
// *will* find the missing intermediate cert.
systemSkip: true,
errorCallback: expectAuthorityUnknown, errorCallback: expectAuthorityUnknown,
}, },
{ {
@ -109,6 +105,9 @@ var verifyTests = []verifyTest{
roots: []string{startComRoot}, roots: []string{startComRoot},
currentTime: 1302726541, currentTime: 1302726541,
// Skip when using systemVerify, since Windows
// can only return a single chain to us (for now).
systemSkip: true,
expectedChains: [][]string{ expectedChains: [][]string{
{"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority"}, {"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority"},
{"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority", "StartCom Certification Authority"}, {"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority", "StartCom Certification Authority"},
@ -148,23 +147,26 @@ func certificateFromPEM(pemBytes string) (*Certificate, error) {
return ParseCertificate(block.Bytes) return ParseCertificate(block.Bytes)
} }
func TestVerify(t *testing.T) { func testVerify(t *testing.T, useSystemRoots bool) {
for i, test := range verifyTests { for i, test := range verifyTests {
if useSystemRoots && test.systemSkip {
continue
}
opts := VerifyOptions{ opts := VerifyOptions{
Roots: NewCertPool(),
Intermediates: NewCertPool(), Intermediates: NewCertPool(),
DNSName: test.dnsName, DNSName: test.dnsName,
CurrentTime: time.Unix(test.currentTime, 0), CurrentTime: time.Unix(test.currentTime, 0),
} }
if test.nilRoots {
opts.Roots = nil
}
for j, root := range test.roots { if !useSystemRoots {
ok := opts.Roots.AppendCertsFromPEM([]byte(root)) opts.Roots = NewCertPool()
if !ok { for j, root := range test.roots {
t.Errorf("#%d: failed to parse root #%d", i, j) ok := opts.Roots.AppendCertsFromPEM([]byte(root))
return if !ok {
t.Errorf("#%d: failed to parse root #%d", i, j)
return
}
} }
} }
@ -225,6 +227,19 @@ func TestVerify(t *testing.T) {
} }
} }
func TestGoVerify(t *testing.T) {
testVerify(t, false)
}
func TestSystemVerify(t *testing.T) {
if runtime.GOOS != "windows" {
t.Logf("skipping verify test using system APIs on %q", runtime.GOOS)
return
}
testVerify(t, true)
}
func chainToDebugString(chain []*Certificate) string { func chainToDebugString(chain []*Certificate) string {
var chainStr string var chainStr string
for _, cert := range chain { for _, cert := range chain {

View File

@ -429,7 +429,7 @@ func (h UnhandledCriticalExtension) Error() string {
type basicConstraints struct { type basicConstraints struct {
IsCA bool `asn1:"optional"` IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional"` MaxPathLen int `asn1:"optional,default:-1"`
} }
// RFC 5280 4.2.1.4 // RFC 5280 4.2.1.4

View File

@ -43,6 +43,17 @@ type Driver interface {
// documented. // documented.
var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented") var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// ErrBadConn should be returned by a driver to signal to the sql
// package that a driver.Conn is in a bad state (such as the server
// having earlier closed the connection) and the sql package should
// retry on a new connection.
//
// To prevent duplicate operations, ErrBadConn should NOT be returned
// if there's a possibility that the database server might have
// performed the operation. Even if the server sends back an error,
// you shouldn't return ErrBadConn.
var ErrBadConn = errors.New("driver: bad connection")
// Execer is an optional interface that may be implemented by a Conn. // Execer is an optional interface that may be implemented by a Conn.
// //
// If a Conn does not implement Execer, the db package's DB.Exec will // If a Conn does not implement Execer, the db package's DB.Exec will

View File

@ -82,6 +82,7 @@ type fakeConn struct {
mu sync.Mutex mu sync.Mutex
stmtsMade int stmtsMade int
stmtsClosed int stmtsClosed int
numPrepare int
} }
func (c *fakeConn) incrStat(v *int) { func (c *fakeConn) incrStat(v *int) {
@ -208,10 +209,13 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
func (c *fakeConn) Close() error { func (c *fakeConn) Close() error {
if c.currTx != nil { if c.currTx != nil {
return errors.New("can't close; in a Transaction") return errors.New("can't close fakeConn; in a Transaction")
} }
if c.db == nil { if c.db == nil {
return errors.New("can't close; already closed") return errors.New("can't close fakeConn; already closed")
}
if c.stmtsMade > c.stmtsClosed {
return errors.New("can't close; dangling statement(s)")
} }
c.db = nil c.db = nil
return nil return nil
@ -249,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
// just a limitation for fakedb) // just a limitation for fakedb)
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 3 { if len(parts) != 3 {
stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
@ -259,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
} }
nameVal := strings.Split(colspec, "=") nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 { if len(nameVal) != 2 {
stmt.Close()
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
} }
column, value := nameVal[0], nameVal[1] column, value := nameVal[0], nameVal[1]
_, ok := c.db.columnType(stmt.table, column) _, ok := c.db.columnType(stmt.table, column)
if !ok { if !ok {
stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
} }
if value != "?" { if value != "?" {
stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column) stmt.table, column)
} }
@ -279,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
// parts are table|col=type,col2=type2 // parts are table|col=type,col2=type2
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 { if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") { for n, colspec := range strings.Split(parts[1], ",") {
nameType := strings.Split(colspec, "=") nameType := strings.Split(colspec, "=")
if len(nameType) != 2 { if len(nameType) != 2 {
stmt.Close()
return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
} }
stmt.colName = append(stmt.colName, nameType[0]) stmt.colName = append(stmt.colName, nameType[0])
@ -296,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
// parts are table|col=?,col2=val // parts are table|col=?,col2=val
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 { if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") { for n, colspec := range strings.Split(parts[1], ",") {
nameVal := strings.Split(colspec, "=") nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 { if len(nameVal) != 2 {
stmt.Close()
return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
} }
column, value := nameVal[0], nameVal[1] column, value := nameVal[0], nameVal[1]
ctype, ok := c.db.columnType(stmt.table, column) ctype, ok := c.db.columnType(stmt.table, column)
if !ok { if !ok {
stmt.Close()
return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
} }
stmt.colName = append(stmt.colName, column) stmt.colName = append(stmt.colName, column)
@ -322,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
case "int32": case "int32":
i, err := strconv.Atoi(value) i, err := strconv.Atoi(value)
if err != nil { if err != nil {
stmt.Close()
return nil, errf("invalid conversion to int32 from %q", value) return nil, errf("invalid conversion to int32 from %q", value)
} }
subsetVal = int64(i) // int64 is a subset type, but not int32 subsetVal = int64(i) // int64 is a subset type, but not int32
default: default:
stmt.Close()
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
} }
stmt.colValue = append(stmt.colValue, subsetVal) stmt.colValue = append(stmt.colValue, subsetVal)
@ -339,6 +354,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
} }
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
c.numPrepare++
if c.db == nil { if c.db == nil {
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
} }
@ -360,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
case "INSERT": case "INSERT":
return c.prepareInsert(stmt, parts) return c.prepareInsert(stmt, parts)
default: default:
stmt.Close()
return nil, errf("unsupported command type %q", cmd) return nil, errf("unsupported command type %q", cmd)
} }
return stmt, nil return stmt, nil

View File

@ -175,6 +175,16 @@ var ErrNoRows = errors.New("sql: no rows in result set")
// DB is a database handle. It's safe for concurrent use by multiple // DB is a database handle. It's safe for concurrent use by multiple
// goroutines. // goroutines.
//
// If the underlying database driver has the concept of a connection
// and per-connection session state, the sql package manages creating
// and freeing connections automatically, including maintaining a free
// pool of idle connections. If observing session state is required,
// either do not share a *DB between multiple concurrent goroutines or
// create and observe all state only within a transaction. Once
// DB.Open is called, the returned Tx is bound to a single isolated
// connection. Once Tx.Commit or Tx.Rollback is called, that
// connection is returned to DB's idle connection pool.
type DB struct { type DB struct {
driver driver.Driver driver driver.Driver
dsn string dsn string
@ -241,34 +251,56 @@ func (db *DB) conn() (driver.Conn, error) {
func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
for n, conn := range db.freeConn { for i, conn := range db.freeConn {
if conn == wanted { if conn != wanted {
db.freeConn[n] = db.freeConn[len(db.freeConn)-1] continue
db.freeConn = db.freeConn[:len(db.freeConn)-1]
return wanted, true
} }
db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
db.freeConn = db.freeConn[:len(db.freeConn)-1]
return wanted, true
} }
return nil, false return nil, false
} }
func (db *DB) putConn(c driver.Conn) { // putConnHook is a hook for testing.
db.mu.Lock() var putConnHook func(*DB, driver.Conn)
defer db.mu.Unlock()
if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { // putConn adds a connection to the db's free pool.
db.freeConn = append(db.freeConn, c) // err is optionally the last error that occured on this connection.
func (db *DB) putConn(c driver.Conn, err error) {
if err == driver.ErrBadConn {
// Don't reuse bad connections.
return return
} }
db.closeConn(c) // TODO(bradfitz): release lock before calling this? db.mu.Lock()
} if putConnHook != nil {
putConnHook(db, c)
func (db *DB) closeConn(c driver.Conn) { }
// TODO: check to see if we need this Conn for any prepared statements if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
// that are active. db.freeConn = append(db.freeConn, c)
db.mu.Unlock()
return
}
// TODO: check to see if we need this Conn for any prepared
// statements which are still active?
db.mu.Unlock()
c.Close() c.Close()
} }
// Prepare creates a prepared statement for later execution. // Prepare creates a prepared statement for later execution.
func (db *DB) Prepare(query string) (*Stmt, error) { func (db *DB) Prepare(query string) (*Stmt, error) {
var stmt *Stmt
var err error
for i := 0; i < 10; i++ {
stmt, err = db.prepare(query)
if err != driver.ErrBadConn {
break
}
}
return stmt, err
}
func (db *DB) prepare(query string) (stmt *Stmt, err error) {
// TODO: check if db.driver supports an optional // TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so, // driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound // otherwise we make a prepared statement that's bound
@ -279,12 +311,12 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer db.putConn(ci) defer db.putConn(ci, err)
si, err := ci.Prepare(query) si, err := ci.Prepare(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := &Stmt{ stmt = &Stmt{
db: db, db: db,
query: query, query: query,
css: []connStmt{{ci, si}}, css: []connStmt{{ci, si}},
@ -295,15 +327,22 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
// Exec executes a query without returning any rows. // Exec executes a query without returning any rows.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) { func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
sargs, err := subsetTypeArgs(args) sargs, err := subsetTypeArgs(args)
if err != nil { var res Result
return nil, err for i := 0; i < 10; i++ {
res, err = db.exec(query, sargs)
if err != driver.ErrBadConn {
break
}
} }
return res, err
}
func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
ci, err := db.conn() ci, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer db.putConn(ci) defer db.putConn(ci, err)
if execer, ok := ci.(driver.Execer); ok { if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, sargs) resi, err := execer.Exec(query, sargs)
@ -354,13 +393,25 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
// Begin starts a transaction. The isolation level is dependent on // Begin starts a transaction. The isolation level is dependent on
// the driver. // the driver.
func (db *DB) Begin() (*Tx, error) { func (db *DB) Begin() (*Tx, error) {
var tx *Tx
var err error
for i := 0; i < 10; i++ {
tx, err = db.begin()
if err != driver.ErrBadConn {
break
}
}
return tx, err
}
func (db *DB) begin() (tx *Tx, err error) {
ci, err := db.conn() ci, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
txi, err := ci.Begin() txi, err := ci.Begin()
if err != nil { if err != nil {
db.putConn(ci) db.putConn(ci, err)
return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err) return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err)
} }
return &Tx{ return &Tx{
@ -406,7 +457,7 @@ func (tx *Tx) close() {
panic("double close") // internal error panic("double close") // internal error
} }
tx.done = true tx.done = true
tx.db.putConn(tx.ci) tx.db.putConn(tx.ci, nil)
tx.ci = nil tx.ci = nil
tx.txi = nil tx.txi = nil
} }
@ -561,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
return nil, err return nil, err
} }
rows, err := stmt.Query(args...) rows, err := stmt.Query(args...)
if err == nil { if err != nil {
rows.closeStmt = stmt stmt.Close()
return nil, err
} }
rows.closeStmt = stmt
return rows, err return rows, err
} }
@ -609,7 +662,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer releaseConn() defer releaseConn(nil)
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the // placeholders, so we won't sanity check input here and instead let the
@ -672,7 +725,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
// connStmt returns a free driver connection on which to execute the // connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a // statement, a function to call to release the connection, and a
// statement bound to that connection. // statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) {
if err = s.stickyErr; err != nil { if err = s.stickyErr; err != nil {
return return
} }
@ -691,7 +744,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, e
if err != nil { if err != nil {
return return
} }
releaseConn = func() { s.tx.releaseConn() } releaseConn = func(error) { s.tx.releaseConn() }
return ci, releaseConn, s.txsi, nil return ci, releaseConn, s.txsi, nil
} }
@ -700,7 +753,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, e
for _, v := range s.css { for _, v := range s.css {
// TODO(bradfitz): lazily clean up entries in this // TODO(bradfitz): lazily clean up entries in this
// list with dead conns while enumerating // list with dead conns while enumerating
if _, match = s.db.connIfFree(cs.ci); match { if _, match = s.db.connIfFree(v.ci); match {
cs = v cs = v
break break
} }
@ -710,22 +763,28 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, e
// Make a new conn if all are busy. // Make a new conn if all are busy.
// TODO(bradfitz): or wait for one? make configurable later? // TODO(bradfitz): or wait for one? make configurable later?
if !match { if !match {
ci, err := s.db.conn() for i := 0; ; i++ {
if err != nil { ci, err := s.db.conn()
return nil, nil, nil, err if err != nil {
return nil, nil, nil, err
}
si, err := ci.Prepare(s.query)
if err == driver.ErrBadConn && i < 10 {
continue
}
if err != nil {
return nil, nil, nil, err
}
s.mu.Lock()
cs = connStmt{ci, si}
s.css = append(s.css, cs)
s.mu.Unlock()
break
} }
si, err := ci.Prepare(s.query)
if err != nil {
return nil, nil, nil, err
}
s.mu.Lock()
cs = connStmt{ci, si}
s.css = append(s.css, cs)
s.mu.Unlock()
} }
conn := cs.ci conn := cs.ci
releaseConn = func() { s.db.putConn(conn) } releaseConn = func(err error) { s.db.putConn(conn, err) }
return conn, releaseConn, cs.si, nil return conn, releaseConn, cs.si, nil
} }
@ -749,7 +808,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
} }
rowsi, err := si.Query(sargs) rowsi, err := si.Query(sargs)
if err != nil { if err != nil {
s.db.putConn(ci) releaseConn(err)
return nil, err return nil, err
} }
// Note: ownership of ci passes to the *Rows, to be freed // Note: ownership of ci passes to the *Rows, to be freed
@ -800,7 +859,7 @@ func (s *Stmt) Close() error {
for _, v := range s.css { for _, v := range s.css {
if ci, match := s.db.connIfFree(v.ci); match { if ci, match := s.db.connIfFree(v.ci); match {
v.si.Close() v.si.Close()
s.db.putConn(ci) s.db.putConn(ci, nil)
} else { } else {
// TODO(bradfitz): care that we can't close // TODO(bradfitz): care that we can't close
// this statement because the statement's // this statement because the statement's
@ -827,7 +886,7 @@ func (s *Stmt) Close() error {
type Rows struct { type Rows struct {
db *DB db *DB
ci driver.Conn // owned; must call putconn when closed to release ci driver.Conn // owned; must call putconn when closed to release
releaseConn func() releaseConn func(error)
rowsi driver.Rows rowsi driver.Rows
closed bool closed bool
@ -939,7 +998,7 @@ func (rs *Rows) Close() error {
} }
rs.closed = true rs.closed = true
err := rs.rowsi.Close() err := rs.rowsi.Close()
rs.releaseConn() rs.releaseConn(err)
if rs.closeStmt != nil { if rs.closeStmt != nil {
rs.closeStmt.Close() rs.closeStmt.Close()
} }

View File

@ -5,13 +5,35 @@
package sql package sql
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
) )
func init() {
type dbConn struct {
db *DB
c driver.Conn
}
freedFrom := make(map[dbConn]string)
putConnHook = func(db *DB, c driver.Conn) {
for _, oc := range db.freeConn {
if oc == c {
// print before panic, as panic may get lost due to conflicting panic
// (all goroutines asleep) elsewhere, since we might not unlock
// the mutex in freeConn here.
println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
panic("double free of conn.")
}
}
freedFrom[dbConn{db, c}] = stack()
}
}
const fakeDBName = "foo" const fakeDBName = "foo"
var chrisBirthday = time.Unix(123456789, 0) var chrisBirthday = time.Unix(123456789, 0)
@ -47,9 +69,19 @@ func closeDB(t *testing.T, db *DB) {
} }
} }
// numPrepares assumes that db has exactly 1 idle conn and returns
// its count of calls to Prepare
func numPrepares(t *testing.T, db *DB) int {
if n := len(db.freeConn); n != 1 {
t.Fatalf("free conns = %d; want 1", n)
}
return db.freeConn[0].(*fakeConn).numPrepare
}
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
prepares0 := numPrepares(t, db)
rows, err := db.Query("SELECT|people|age,name|") rows, err := db.Query("SELECT|people|age,name|")
if err != nil { if err != nil {
t.Fatalf("Query: %v", err) t.Fatalf("Query: %v", err)
@ -83,7 +115,10 @@ func TestQuery(t *testing.T) {
// And verify that the final rows.Next() call, which hit EOF, // And verify that the final rows.Next() call, which hit EOF,
// also closed the rows connection. // also closed the rows connection.
if n := len(db.freeConn); n != 1 { if n := len(db.freeConn); n != 1 {
t.Errorf("free conns after query hitting EOF = %d; want 1", n) t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
} }
} }
@ -216,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Prepare: %v", err) t.Fatalf("Prepare: %v", err)
} }
defer stmt.Close()
var age int var age int
for n, tt := range []struct { for n, tt := range []struct {
name string name string
@ -256,6 +292,7 @@ func TestExec(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Stmt, err = %v, %v", stmt, err) t.Errorf("Stmt, err = %v, %v", stmt, err)
} }
defer stmt.Close()
type execTest struct { type execTest struct {
args []interface{} args []interface{}
@ -297,11 +334,14 @@ func TestTxStmt(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err) t.Fatalf("Stmt, err = %v, %v", stmt, err)
} }
defer stmt.Close()
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
t.Fatalf("Begin = %v", err) t.Fatalf("Begin = %v", err)
} }
_, err = tx.Stmt(stmt).Exec("Bobby", 7) txs := tx.Stmt(stmt)
defer txs.Close()
_, err = txs.Exec("Bobby", 7)
if err != nil { if err != nil {
t.Fatalf("Exec = %v", err) t.Fatalf("Exec = %v", err)
} }
@ -330,6 +370,7 @@ func TestTxQuery(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer r.Close()
if !r.Next() { if !r.Next() {
if r.Err() != nil { if r.Err() != nil {
@ -345,6 +386,22 @@ func TestTxQuery(t *testing.T) {
} }
} }
func TestTxQueryInvalid(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Query("SELECT|t1|name|")
if err == nil {
t.Fatal("Error expected")
}
}
// Tests fix for issue 2542, that we release a lock when querying on // Tests fix for issue 2542, that we release a lock when querying on
// a closed connection. // a closed connection.
func TestIssue2542Deadlock(t *testing.T) { func TestIssue2542Deadlock(t *testing.T) {
@ -450,48 +507,48 @@ type nullTestSpec struct {
func TestNullStringParam(t *testing.T) { func TestNullStringParam(t *testing.T) {
spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{ spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
nullTestRow{NullString{"aqua", true}, "", NullString{"aqua", true}}, {NullString{"aqua", true}, "", NullString{"aqua", true}},
nullTestRow{NullString{"brown", false}, "", NullString{"", false}}, {NullString{"brown", false}, "", NullString{"", false}},
nullTestRow{"chartreuse", "", NullString{"chartreuse", true}}, {"chartreuse", "", NullString{"chartreuse", true}},
nullTestRow{NullString{"darkred", true}, "", NullString{"darkred", true}}, {NullString{"darkred", true}, "", NullString{"darkred", true}},
nullTestRow{NullString{"eel", false}, "", NullString{"", false}}, {NullString{"eel", false}, "", NullString{"", false}},
nullTestRow{"foo", NullString{"black", false}, nil}, {"foo", NullString{"black", false}, nil},
}} }}
nullTestRun(t, spec) nullTestRun(t, spec)
} }
func TestNullInt64Param(t *testing.T) { func TestNullInt64Param(t *testing.T) {
spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{ spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{
nullTestRow{NullInt64{31, true}, 1, NullInt64{31, true}}, {NullInt64{31, true}, 1, NullInt64{31, true}},
nullTestRow{NullInt64{-22, false}, 1, NullInt64{0, false}}, {NullInt64{-22, false}, 1, NullInt64{0, false}},
nullTestRow{22, 1, NullInt64{22, true}}, {22, 1, NullInt64{22, true}},
nullTestRow{NullInt64{33, true}, 1, NullInt64{33, true}}, {NullInt64{33, true}, 1, NullInt64{33, true}},
nullTestRow{NullInt64{222, false}, 1, NullInt64{0, false}}, {NullInt64{222, false}, 1, NullInt64{0, false}},
nullTestRow{0, NullInt64{31, false}, nil}, {0, NullInt64{31, false}, nil},
}} }}
nullTestRun(t, spec) nullTestRun(t, spec)
} }
func TestNullFloat64Param(t *testing.T) { func TestNullFloat64Param(t *testing.T) {
spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{ spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{
nullTestRow{NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}}, {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
nullTestRow{NullFloat64{13.1, false}, 1, NullFloat64{0, false}}, {NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
nullTestRow{-22.9, 1, NullFloat64{-22.9, true}}, {-22.9, 1, NullFloat64{-22.9, true}},
nullTestRow{NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}}, {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
nullTestRow{NullFloat64{222, false}, 1, NullFloat64{0, false}}, {NullFloat64{222, false}, 1, NullFloat64{0, false}},
nullTestRow{10, NullFloat64{31.2, false}, nil}, {10, NullFloat64{31.2, false}, nil},
}} }}
nullTestRun(t, spec) nullTestRun(t, spec)
} }
func TestNullBoolParam(t *testing.T) { func TestNullBoolParam(t *testing.T) {
spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{ spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{
nullTestRow{NullBool{false, true}, true, NullBool{false, true}}, {NullBool{false, true}, true, NullBool{false, true}},
nullTestRow{NullBool{true, false}, false, NullBool{false, false}}, {NullBool{true, false}, false, NullBool{false, false}},
nullTestRow{true, true, NullBool{true, true}}, {true, true, NullBool{true, true}},
nullTestRow{NullBool{true, true}, false, NullBool{true, true}}, {NullBool{true, true}, false, NullBool{true, true}},
nullTestRow{NullBool{true, false}, true, NullBool{false, false}}, {NullBool{true, false}, true, NullBool{false, false}},
nullTestRow{true, NullBool{true, false}, nil}, {true, NullBool{true, false}, nil},
}} }}
nullTestRun(t, spec) nullTestRun(t, spec)
} }
@ -510,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
if err != nil { if err != nil {
t.Fatalf("prepare: %v", err) t.Fatalf("prepare: %v", err)
} }
defer stmt.Close()
if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil { if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
t.Errorf("exec insert chris: %v", err) t.Errorf("exec insert chris: %v", err)
} }
@ -549,3 +607,8 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
} }
} }
} }
func stack() string {
buf := make([]byte, 1024)
return string(buf[:runtime.Stack(buf, false)])
}

View File

@ -250,10 +250,14 @@ func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error)
func parseUTCTime(bytes []byte) (ret time.Time, err error) { func parseUTCTime(bytes []byte) (ret time.Time, err error) {
s := string(bytes) s := string(bytes)
ret, err = time.Parse("0601021504Z0700", s) ret, err = time.Parse("0601021504Z0700", s)
if err == nil { if err != nil {
return ret, err = time.Parse("060102150405Z0700", s)
} }
ret, err = time.Parse("060102150405Z0700", s) if err == nil && ret.Year() >= 2050 {
// UTCTime only encodes times prior to 2050. See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1
ret = ret.AddDate(-100, 0, 0)
}
return return
} }

View File

@ -321,7 +321,7 @@ var parseFieldParametersTestData []parseFieldParametersTest = []parseFieldParame
{"default:42", fieldParameters{defaultValue: newInt64(42)}}, {"default:42", fieldParameters{defaultValue: newInt64(42)}},
{"tag:17", fieldParameters{tag: newInt(17)}}, {"tag:17", fieldParameters{tag: newInt(17)}},
{"optional,explicit,default:42,tag:17", fieldParameters{optional: true, explicit: true, defaultValue: newInt64(42), tag: newInt(17)}}, {"optional,explicit,default:42,tag:17", fieldParameters{optional: true, explicit: true, defaultValue: newInt64(42), tag: newInt(17)}},
{"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, false, newInt64(42), newInt(17), 0, false}}, {"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, false, newInt64(42), newInt(17), 0, false, false}},
{"set", fieldParameters{set: true}}, {"set", fieldParameters{set: true}},
} }

View File

@ -75,6 +75,7 @@ type fieldParameters struct {
tag *int // the EXPLICIT or IMPLICIT tag (maybe nil). tag *int // the EXPLICIT or IMPLICIT tag (maybe nil).
stringType int // the string tag to use when marshaling. stringType int // the string tag to use when marshaling.
set bool // true iff this should be encoded as a SET set bool // true iff this should be encoded as a SET
omitEmpty bool // true iff this should be omitted if empty when marshaling.
// Invariants: // Invariants:
// if explicit is set, tag is non-nil. // if explicit is set, tag is non-nil.
@ -116,6 +117,8 @@ func parseFieldParameters(str string) (ret fieldParameters) {
if ret.tag == nil { if ret.tag == nil {
ret.tag = new(int) ret.tag = new(int)
} }
case part == "omitempty":
ret.omitEmpty = true
} }
} }
return return

View File

@ -463,6 +463,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return marshalField(out, v.Elem(), params) return marshalField(out, v.Elem(), params)
} }
if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
return
}
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return return
} }

View File

@ -54,6 +54,10 @@ type optionalRawValueTest struct {
A RawValue `asn1:"optional"` A RawValue `asn1:"optional"`
} }
type omitEmptyTest struct {
A []string `asn1:"omitempty"`
}
type testSET []int type testSET []int
var PST = time.FixedZone("PST", -8*60*60) var PST = time.FixedZone("PST", -8*60*60)
@ -116,6 +120,8 @@ var marshalTests = []marshalTest{
{rawContentsStruct{[]byte{0x30, 3, 1, 2, 3}, 64}, "3003010203"}, {rawContentsStruct{[]byte{0x30, 3, 1, 2, 3}, 64}, "3003010203"},
{RawValue{Tag: 1, Class: 2, IsCompound: false, Bytes: []byte{1, 2, 3}}, "8103010203"}, {RawValue{Tag: 1, Class: 2, IsCompound: false, Bytes: []byte{1, 2, 3}}, "8103010203"},
{testSET([]int{10}), "310302010a"}, {testSET([]int{10}), "310302010a"},
{omitEmptyTest{[]string{}}, "3000"},
{omitEmptyTest{[]string{"1"}}, "30053003130131"},
} }
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {

View File

@ -2,12 +2,17 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package binary implements translation between // Package binary implements translation between numbers and byte sequences
// unsigned integer values and byte sequences // and encoding and decoding of varints.
// and the reading and writing of fixed-size values. //
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic // A fixed-size value is either a fixed-size arithmetic
// type (int8, uint8, int16, float32, complex64, ...) // type (int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values. // or an array or struct containing only fixed-size values.
//
// Varints are a method of encoding integers using one or more bytes;
// numbers with smaller absolute value take a smaller number of bytes.
// For a specification, see http://code.google.com/apis/protocolbuffers/docs/encoding.html.
package binary package binary
import ( import (

View File

@ -92,7 +92,8 @@ var (
// If FieldsPerRecord is positive, Read requires each record to // If FieldsPerRecord is positive, Read requires each record to
// have the given number of fields. If FieldsPerRecord is 0, Read sets it to // have the given number of fields. If FieldsPerRecord is 0, Read sets it to
// the number of fields in the first record, so that future records must // the number of fields in the first record, so that future records must
// have the same field count. // have the same field count. If FieldsPerRecord is negative, no check is
// made and records may have a variable number of fields.
// //
// If LazyQuotes is true, a quote may appear in an unquoted field and a // If LazyQuotes is true, a quote may appear in an unquoted field and a
// non-doubled quote may appear in a quoted field. // non-doubled quote may appear in a quoted field.

View File

@ -707,6 +707,9 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui
if name == "" { if name == "" {
// Copy the representation of the nil interface value to the target. // Copy the representation of the nil interface value to the target.
// This is horribly unsafe and special. // This is horribly unsafe and special.
if indir > 0 {
p = allocate(ityp, p, 1) // All but the last level has been allocated by dec.Indirect
}
*(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.InterfaceData() *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.InterfaceData()
return return
} }

View File

@ -694,8 +694,8 @@ type Bug3 struct {
func TestGobPtrSlices(t *testing.T) { func TestGobPtrSlices(t *testing.T) {
in := []*Bug3{ in := []*Bug3{
&Bug3{1, nil}, {1, nil},
&Bug3{2, nil}, {2, nil},
} }
b := new(bytes.Buffer) b := new(bytes.Buffer)
err := NewEncoder(b).Encode(&in) err := NewEncoder(b).Encode(&in)

View File

@ -573,3 +573,22 @@ func TestGobEncodeIsZero(t *testing.T) {
t.Fatalf("%v != %v", x, y) t.Fatalf("%v != %v", x, y)
} }
} }
func TestGobEncodePtrError(t *testing.T) {
var err error
b := new(bytes.Buffer)
enc := NewEncoder(b)
err = enc.Encode(&err)
if err != nil {
t.Fatal("encode:", err)
}
dec := NewDecoder(b)
err2 := fmt.Errorf("foo")
err = dec.Decode(&err2)
if err != nil {
t.Fatal("decode:", err)
}
if err2 != nil {
t.Fatalf("expected nil, got %v", err2)
}
}

View File

@ -43,7 +43,8 @@ import (
// to keep some browsers from misinterpreting JSON output as HTML. // to keep some browsers from misinterpreting JSON output as HTML.
// //
// Array and slice values encode as JSON arrays, except that // Array and slice values encode as JSON arrays, except that
// []byte encodes as a base64-encoded string. // []byte encodes as a base64-encoded string, and a nil slice
// encodes as the null JSON object.
// //
// Struct values encode as JSON objects. Each exported struct field // Struct values encode as JSON objects. Each exported struct field
// becomes a member of the object unless // becomes a member of the object unless

View File

@ -577,7 +577,7 @@ type decompSet [4]map[string]bool
func makeDecompSet() decompSet { func makeDecompSet() decompSet {
m := decompSet{} m := decompSet{}
for i, _ := range m { for i := range m {
m[i] = make(map[string]bool) m[i] = make(map[string]bool)
} }
return m return m
@ -646,7 +646,7 @@ func printCharInfoTables() int {
fmt.Println("const (") fmt.Println("const (")
for i, m := range decompSet { for i, m := range decompSet {
sa := []string{} sa := []string{}
for s, _ := range m { for s := range m {
sa = append(sa, s) sa = append(sa, s)
} }
sort.Strings(sa) sort.Strings(sa)

View File

@ -1,155 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package main
import (
"fmt"
"os"
"syscall"
"unsafe"
)
// some help functions
func abortf(format string, a ...interface{}) {
fmt.Fprintf(os.Stdout, format, a...)
os.Exit(1)
}
func abortErrNo(funcname string, err error) {
errno, _ := err.(syscall.Errno)
abortf("%s failed: %d %s\n", funcname, uint32(errno), err)
}
// global vars
var (
mh syscall.Handle
bh syscall.Handle
)
// WinProc called by windows to notify us of all windows events we might be interested in.
func WndProc(hwnd syscall.Handle, msg uint32, wparam, lparam uintptr) (rc uintptr) {
switch msg {
case WM_CREATE:
var e error
// CreateWindowEx
bh, e = CreateWindowEx(
0,
syscall.StringToUTF16Ptr("button"),
syscall.StringToUTF16Ptr("Quit"),
WS_CHILD|WS_VISIBLE|BS_DEFPUSHBUTTON,
75, 70, 140, 25,
hwnd, 1, mh, 0)
if e != nil {
abortErrNo("CreateWindowEx", e)
}
fmt.Printf("button handle is %x\n", bh)
rc = DefWindowProc(hwnd, msg, wparam, lparam)
case WM_COMMAND:
switch syscall.Handle(lparam) {
case bh:
e := PostMessage(hwnd, WM_CLOSE, 0, 0)
if e != nil {
abortErrNo("PostMessage", e)
}
default:
rc = DefWindowProc(hwnd, msg, wparam, lparam)
}
case WM_CLOSE:
DestroyWindow(hwnd)
case WM_DESTROY:
PostQuitMessage(0)
default:
rc = DefWindowProc(hwnd, msg, wparam, lparam)
}
//fmt.Printf("WndProc(0x%08x, %d, 0x%08x, 0x%08x) (%d)\n", hwnd, msg, wparam, lparam, rc)
return
}
func rungui() int {
var e error
// GetModuleHandle
mh, e = GetModuleHandle(nil)
if e != nil {
abortErrNo("GetModuleHandle", e)
}
// Get icon we're going to use.
myicon, e := LoadIcon(0, IDI_APPLICATION)
if e != nil {
abortErrNo("LoadIcon", e)
}
// Get cursor we're going to use.
mycursor, e := LoadCursor(0, IDC_ARROW)
if e != nil {
abortErrNo("LoadCursor", e)
}
// Create callback
wproc := syscall.NewCallback(WndProc)
// RegisterClassEx
wcname := syscall.StringToUTF16Ptr("myWindowClass")
var wc Wndclassex
wc.Size = uint32(unsafe.Sizeof(wc))
wc.WndProc = wproc
wc.Instance = mh
wc.Icon = myicon
wc.Cursor = mycursor
wc.Background = COLOR_BTNFACE + 1
wc.MenuName = nil
wc.ClassName = wcname
wc.IconSm = myicon
if _, e := RegisterClassEx(&wc); e != nil {
abortErrNo("RegisterClassEx", e)
}
// CreateWindowEx
wh, e := CreateWindowEx(
WS_EX_CLIENTEDGE,
wcname,
syscall.StringToUTF16Ptr("My window"),
WS_OVERLAPPEDWINDOW,
CW_USEDEFAULT, CW_USEDEFAULT, 300, 200,
0, 0, mh, 0)
if e != nil {
abortErrNo("CreateWindowEx", e)
}
fmt.Printf("main window handle is %x\n", wh)
// ShowWindow
ShowWindow(wh, SW_SHOWDEFAULT)
// UpdateWindow
if e := UpdateWindow(wh); e != nil {
abortErrNo("UpdateWindow", e)
}
// Process all windows messages until WM_QUIT.
var m Msg
for {
r, e := GetMessage(&m, 0, 0, 0)
if e != nil {
abortErrNo("GetMessage", e)
}
if r == 0 {
// WM_QUIT received -> get out
break
}
TranslateMessage(&m)
DispatchMessage(&m)
}
return int(m.Wparam)
}
func main() {
rc := rungui()
os.Exit(rc)
}

View File

@ -1,134 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package main
import (
"syscall"
"unsafe"
)
type Wndclassex struct {
Size uint32
Style uint32
WndProc uintptr
ClsExtra int32
WndExtra int32
Instance syscall.Handle
Icon syscall.Handle
Cursor syscall.Handle
Background syscall.Handle
MenuName *uint16
ClassName *uint16
IconSm syscall.Handle
}
type Point struct {
X uintptr
Y uintptr
}
type Msg struct {
Hwnd syscall.Handle
Message uint32
Wparam uintptr
Lparam uintptr
Time uint32
Pt Point
}
const (
// Window styles
WS_OVERLAPPED = 0
WS_POPUP = 0x80000000
WS_CHILD = 0x40000000
WS_MINIMIZE = 0x20000000
WS_VISIBLE = 0x10000000
WS_DISABLED = 0x8000000
WS_CLIPSIBLINGS = 0x4000000
WS_CLIPCHILDREN = 0x2000000
WS_MAXIMIZE = 0x1000000
WS_CAPTION = WS_BORDER | WS_DLGFRAME
WS_BORDER = 0x800000
WS_DLGFRAME = 0x400000
WS_VSCROLL = 0x200000
WS_HSCROLL = 0x100000
WS_SYSMENU = 0x80000
WS_THICKFRAME = 0x40000
WS_GROUP = 0x20000
WS_TABSTOP = 0x10000
WS_MINIMIZEBOX = 0x20000
WS_MAXIMIZEBOX = 0x10000
WS_TILED = WS_OVERLAPPED
WS_ICONIC = WS_MINIMIZE
WS_SIZEBOX = WS_THICKFRAME
// Common Window Styles
WS_OVERLAPPEDWINDOW = WS_OVERLAPPED | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME | WS_MINIMIZEBOX | WS_MAXIMIZEBOX
WS_TILEDWINDOW = WS_OVERLAPPEDWINDOW
WS_POPUPWINDOW = WS_POPUP | WS_BORDER | WS_SYSMENU
WS_CHILDWINDOW = WS_CHILD
WS_EX_CLIENTEDGE = 0x200
// Some windows messages
WM_CREATE = 1
WM_DESTROY = 2
WM_CLOSE = 16
WM_COMMAND = 273
// Some button control styles
BS_DEFPUSHBUTTON = 1
// Some color constants
COLOR_WINDOW = 5
COLOR_BTNFACE = 15
// Default window position
CW_USEDEFAULT = 0x80000000 - 0x100000000
// Show window default style
SW_SHOWDEFAULT = 10
)
var (
// Some globally known cursors
IDC_ARROW = MakeIntResource(32512)
IDC_IBEAM = MakeIntResource(32513)
IDC_WAIT = MakeIntResource(32514)
IDC_CROSS = MakeIntResource(32515)
// Some globally known icons
IDI_APPLICATION = MakeIntResource(32512)
IDI_HAND = MakeIntResource(32513)
IDI_QUESTION = MakeIntResource(32514)
IDI_EXCLAMATION = MakeIntResource(32515)
IDI_ASTERISK = MakeIntResource(32516)
IDI_WINLOGO = MakeIntResource(32517)
IDI_WARNING = IDI_EXCLAMATION
IDI_ERROR = IDI_HAND
IDI_INFORMATION = IDI_ASTERISK
)
//sys GetModuleHandle(modname *uint16) (handle syscall.Handle, err error) = GetModuleHandleW
//sys RegisterClassEx(wndclass *Wndclassex) (atom uint16, err error) = user32.RegisterClassExW
//sys CreateWindowEx(exstyle uint32, classname *uint16, windowname *uint16, style uint32, x int32, y int32, width int32, height int32, wndparent syscall.Handle, menu syscall.Handle, instance syscall.Handle, param uintptr) (hwnd syscall.Handle, err error) = user32.CreateWindowExW
//sys DefWindowProc(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) = user32.DefWindowProcW
//sys DestroyWindow(hwnd syscall.Handle) (err error) = user32.DestroyWindow
//sys PostQuitMessage(exitcode int32) = user32.PostQuitMessage
//sys ShowWindow(hwnd syscall.Handle, cmdshow int32) (wasvisible bool) = user32.ShowWindow
//sys UpdateWindow(hwnd syscall.Handle) (err error) = user32.UpdateWindow
//sys GetMessage(msg *Msg, hwnd syscall.Handle, MsgFilterMin uint32, MsgFilterMax uint32) (ret int32, err error) [failretval==-1] = user32.GetMessageW
//sys TranslateMessage(msg *Msg) (done bool) = user32.TranslateMessage
//sys DispatchMessage(msg *Msg) (ret int32) = user32.DispatchMessageW
//sys LoadIcon(instance syscall.Handle, iconname *uint16) (icon syscall.Handle, err error) = user32.LoadIconW
//sys LoadCursor(instance syscall.Handle, cursorname *uint16) (cursor syscall.Handle, err error) = user32.LoadCursorW
//sys SetCursor(cursor syscall.Handle) (precursor syscall.Handle, err error) = user32.SetCursor
//sys SendMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) = user32.SendMessageW
//sys PostMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (err error) = user32.PostMessageW
func MakeIntResource(id uint16) *uint16 {
return (*uint16)(unsafe.Pointer(uintptr(id)))
}

View File

@ -1,192 +0,0 @@
// +build windows
// mksyscall_windows.pl winapi.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package main
import "unsafe"
import "syscall"
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
moduser32 = syscall.NewLazyDLL("user32.dll")
procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW")
procRegisterClassExW = moduser32.NewProc("RegisterClassExW")
procCreateWindowExW = moduser32.NewProc("CreateWindowExW")
procDefWindowProcW = moduser32.NewProc("DefWindowProcW")
procDestroyWindow = moduser32.NewProc("DestroyWindow")
procPostQuitMessage = moduser32.NewProc("PostQuitMessage")
procShowWindow = moduser32.NewProc("ShowWindow")
procUpdateWindow = moduser32.NewProc("UpdateWindow")
procGetMessageW = moduser32.NewProc("GetMessageW")
procTranslateMessage = moduser32.NewProc("TranslateMessage")
procDispatchMessageW = moduser32.NewProc("DispatchMessageW")
procLoadIconW = moduser32.NewProc("LoadIconW")
procLoadCursorW = moduser32.NewProc("LoadCursorW")
procSetCursor = moduser32.NewProc("SetCursor")
procSendMessageW = moduser32.NewProc("SendMessageW")
procPostMessageW = moduser32.NewProc("PostMessageW")
)
func GetModuleHandle(modname *uint16) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall(procGetModuleHandleW.Addr(), 1, uintptr(unsafe.Pointer(modname)), 0, 0)
handle = syscall.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func RegisterClassEx(wndclass *Wndclassex) (atom uint16, err error) {
r0, _, e1 := syscall.Syscall(procRegisterClassExW.Addr(), 1, uintptr(unsafe.Pointer(wndclass)), 0, 0)
atom = uint16(r0)
if atom == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func CreateWindowEx(exstyle uint32, classname *uint16, windowname *uint16, style uint32, x int32, y int32, width int32, height int32, wndparent syscall.Handle, menu syscall.Handle, instance syscall.Handle, param uintptr) (hwnd syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall12(procCreateWindowExW.Addr(), 12, uintptr(exstyle), uintptr(unsafe.Pointer(classname)), uintptr(unsafe.Pointer(windowname)), uintptr(style), uintptr(x), uintptr(y), uintptr(width), uintptr(height), uintptr(wndparent), uintptr(menu), uintptr(instance), uintptr(param))
hwnd = syscall.Handle(r0)
if hwnd == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func DefWindowProc(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) {
r0, _, _ := syscall.Syscall6(procDefWindowProcW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
lresult = uintptr(r0)
return
}
func DestroyWindow(hwnd syscall.Handle) (err error) {
r1, _, e1 := syscall.Syscall(procDestroyWindow.Addr(), 1, uintptr(hwnd), 0, 0)
if int(r1) == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func PostQuitMessage(exitcode int32) {
syscall.Syscall(procPostQuitMessage.Addr(), 1, uintptr(exitcode), 0, 0)
return
}
func ShowWindow(hwnd syscall.Handle, cmdshow int32) (wasvisible bool) {
r0, _, _ := syscall.Syscall(procShowWindow.Addr(), 2, uintptr(hwnd), uintptr(cmdshow), 0)
wasvisible = bool(r0 != 0)
return
}
func UpdateWindow(hwnd syscall.Handle) (err error) {
r1, _, e1 := syscall.Syscall(procUpdateWindow.Addr(), 1, uintptr(hwnd), 0, 0)
if int(r1) == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func GetMessage(msg *Msg, hwnd syscall.Handle, MsgFilterMin uint32, MsgFilterMax uint32) (ret int32, err error) {
r0, _, e1 := syscall.Syscall6(procGetMessageW.Addr(), 4, uintptr(unsafe.Pointer(msg)), uintptr(hwnd), uintptr(MsgFilterMin), uintptr(MsgFilterMax), 0, 0)
ret = int32(r0)
if ret == -1 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func TranslateMessage(msg *Msg) (done bool) {
r0, _, _ := syscall.Syscall(procTranslateMessage.Addr(), 1, uintptr(unsafe.Pointer(msg)), 0, 0)
done = bool(r0 != 0)
return
}
func DispatchMessage(msg *Msg) (ret int32) {
r0, _, _ := syscall.Syscall(procDispatchMessageW.Addr(), 1, uintptr(unsafe.Pointer(msg)), 0, 0)
ret = int32(r0)
return
}
func LoadIcon(instance syscall.Handle, iconname *uint16) (icon syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall(procLoadIconW.Addr(), 2, uintptr(instance), uintptr(unsafe.Pointer(iconname)), 0)
icon = syscall.Handle(r0)
if icon == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func LoadCursor(instance syscall.Handle, cursorname *uint16) (cursor syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall(procLoadCursorW.Addr(), 2, uintptr(instance), uintptr(unsafe.Pointer(cursorname)), 0)
cursor = syscall.Handle(r0)
if cursor == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func SetCursor(cursor syscall.Handle) (precursor syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall(procSetCursor.Addr(), 1, uintptr(cursor), 0, 0)
precursor = syscall.Handle(r0)
if precursor == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func SendMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (lresult uintptr) {
r0, _, _ := syscall.Syscall6(procSendMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
lresult = uintptr(r0)
return
}
func PostMessage(hwnd syscall.Handle, msg uint32, wparam uintptr, lparam uintptr) (err error) {
r1, _, e1 := syscall.Syscall6(procPostMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
if int(r1) == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}

View File

@ -41,10 +41,14 @@ type Var interface {
// Int is a 64-bit integer variable that satisfies the Var interface. // Int is a 64-bit integer variable that satisfies the Var interface.
type Int struct { type Int struct {
i int64 i int64
mu sync.Mutex mu sync.RWMutex
} }
func (v *Int) String() string { return strconv.FormatInt(v.i, 10) } func (v *Int) String() string {
v.mu.RLock()
defer v.mu.RUnlock()
return strconv.FormatInt(v.i, 10)
}
func (v *Int) Add(delta int64) { func (v *Int) Add(delta int64) {
v.mu.Lock() v.mu.Lock()
@ -61,10 +65,14 @@ func (v *Int) Set(value int64) {
// Float is a 64-bit float variable that satisfies the Var interface. // Float is a 64-bit float variable that satisfies the Var interface.
type Float struct { type Float struct {
f float64 f float64
mu sync.Mutex mu sync.RWMutex
} }
func (v *Float) String() string { return strconv.FormatFloat(v.f, 'g', -1, 64) } func (v *Float) String() string {
v.mu.RLock()
defer v.mu.RUnlock()
return strconv.FormatFloat(v.f, 'g', -1, 64)
}
// Add adds delta to v. // Add adds delta to v.
func (v *Float) Add(delta float64) { func (v *Float) Add(delta float64) {
@ -95,17 +103,17 @@ type KeyValue struct {
func (v *Map) String() string { func (v *Map) String() string {
v.mu.RLock() v.mu.RLock()
defer v.mu.RUnlock() defer v.mu.RUnlock()
b := new(bytes.Buffer) var b bytes.Buffer
fmt.Fprintf(b, "{") fmt.Fprintf(&b, "{")
first := true first := true
for key, val := range v.m { for key, val := range v.m {
if !first { if !first {
fmt.Fprintf(b, ", ") fmt.Fprintf(&b, ", ")
} }
fmt.Fprintf(b, "\"%s\": %v", key, val) fmt.Fprintf(&b, "\"%s\": %v", key, val)
first = false first = false
} }
fmt.Fprintf(b, "}") fmt.Fprintf(&b, "}")
return b.String() return b.String()
} }
@ -180,12 +188,21 @@ func (v *Map) Do(f func(KeyValue)) {
// String is a string variable, and satisfies the Var interface. // String is a string variable, and satisfies the Var interface.
type String struct { type String struct {
s string s string
mu sync.RWMutex
} }
func (v *String) String() string { return strconv.Quote(v.s) } func (v *String) String() string {
v.mu.RLock()
defer v.mu.RUnlock()
return strconv.Quote(v.s)
}
func (v *String) Set(value string) { v.s = value } func (v *String) Set(value string) {
v.mu.Lock()
defer v.mu.Unlock()
v.s = value
}
// Func implements Var by calling the function // Func implements Var by calling the function
// and formatting the returned value using JSON. // and formatting the returned value using JSON.

View File

@ -7,7 +7,8 @@
to C's printf and scanf. The format 'verbs' are derived from C's but to C's printf and scanf. The format 'verbs' are derived from C's but
are simpler. are simpler.
Printing:
Printing
The verbs: The verbs:
@ -127,7 +128,8 @@
by a single character (the verb) and end with a parenthesized by a single character (the verb) and end with a parenthesized
description. description.
Scanning:
Scanning
An analogous set of functions scans formatted text to yield An analogous set of functions scans formatted text to yield
values. Scan, Scanf and Scanln read from os.Stdin; Fscan, values. Scan, Scanf and Scanln read from os.Stdin; Fscan,

View File

@ -0,0 +1,7 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fmt
var IsSpace = isSpace

View File

@ -13,6 +13,7 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"unicode"
) )
type ( type (
@ -830,3 +831,13 @@ func TestBadVerbRecursion(t *testing.T) {
t.Error("fail with value") t.Error("fail with value")
} }
} }
func TestIsSpace(t *testing.T) {
// This tests the internal isSpace function.
// IsSpace = isSpace is defined in export_test.go.
for i := rune(0); i <= unicode.MaxRune; i++ {
if IsSpace(i) != unicode.IsSpace(i) {
t.Errorf("isSpace(%U) = %v, want %v", i, IsSpace(i), unicode.IsSpace(i))
}
}
}

View File

@ -5,9 +5,7 @@
package fmt package fmt
import ( import (
"bytes"
"strconv" "strconv"
"unicode"
"unicode/utf8" "unicode/utf8"
) )
@ -36,10 +34,10 @@ func init() {
} }
// A fmt is the raw formatter used by Printf etc. // A fmt is the raw formatter used by Printf etc.
// It prints into a bytes.Buffer that must be set up externally. // It prints into a buffer that must be set up separately.
type fmt struct { type fmt struct {
intbuf [nByte]byte intbuf [nByte]byte
buf *bytes.Buffer buf *buffer
// width, precision // width, precision
wid int wid int
prec int prec int
@ -69,7 +67,7 @@ func (f *fmt) clearflags() {
f.zero = false f.zero = false
} }
func (f *fmt) init(buf *bytes.Buffer) { func (f *fmt) init(buf *buffer) {
f.buf = buf f.buf = buf
f.clearflags() f.clearflags()
} }
@ -247,7 +245,7 @@ func (f *fmt) integer(a int64, base uint64, signedness bool, digits string) {
} }
// If we want a quoted char for %#U, move the data up to make room. // If we want a quoted char for %#U, move the data up to make room.
if f.unicode && f.uniQuote && a >= 0 && a <= unicode.MaxRune && unicode.IsPrint(rune(a)) { if f.unicode && f.uniQuote && a >= 0 && a <= utf8.MaxRune && strconv.IsPrint(rune(a)) {
runeWidth := utf8.RuneLen(rune(a)) runeWidth := utf8.RuneLen(rune(a))
width := 1 + 1 + runeWidth + 1 // space, quote, rune, quote width := 1 + 1 + runeWidth + 1 // space, quote, rune, quote
copy(buf[i-width:], buf[i:]) // guaranteed to have enough room. copy(buf[i-width:], buf[i:]) // guaranteed to have enough room.
@ -290,16 +288,15 @@ func (f *fmt) fmt_s(s string) {
// fmt_sx formats a string as a hexadecimal encoding of its bytes. // fmt_sx formats a string as a hexadecimal encoding of its bytes.
func (f *fmt) fmt_sx(s, digits string) { func (f *fmt) fmt_sx(s, digits string) {
// TODO: Avoid buffer by pre-padding. // TODO: Avoid buffer by pre-padding.
var b bytes.Buffer var b []byte
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
if i > 0 && f.space { if i > 0 && f.space {
b.WriteByte(' ') b = append(b, ' ')
} }
v := s[i] v := s[i]
b.WriteByte(digits[v>>4]) b = append(b, digits[v>>4], digits[v&0xF])
b.WriteByte(digits[v&0xF])
} }
f.pad(b.Bytes()) f.pad(b)
} }
// fmt_q formats a string as a double-quoted, escaped Go string constant. // fmt_q formats a string as a double-quoted, escaped Go string constant.

View File

@ -5,13 +5,11 @@
package fmt package fmt
import ( import (
"bytes"
"errors" "errors"
"io" "io"
"os" "os"
"reflect" "reflect"
"sync" "sync"
"unicode"
"unicode/utf8" "unicode/utf8"
) )
@ -71,11 +69,45 @@ type GoStringer interface {
GoString() string GoString() string
} }
// Use simple []byte instead of bytes.Buffer to avoid large dependency.
type buffer []byte
func (b *buffer) Write(p []byte) (n int, err error) {
*b = append(*b, p...)
return len(p), nil
}
func (b *buffer) WriteString(s string) (n int, err error) {
*b = append(*b, s...)
return len(s), nil
}
func (b *buffer) WriteByte(c byte) error {
*b = append(*b, c)
return nil
}
func (bp *buffer) WriteRune(r rune) error {
if r < utf8.RuneSelf {
*bp = append(*bp, byte(r))
return nil
}
b := *bp
n := len(b)
for n+utf8.UTFMax > cap(b) {
b = append(b, 0)
}
w := utf8.EncodeRune(b[n:n+utf8.UTFMax], r)
*bp = b[:n+w]
return nil
}
type pp struct { type pp struct {
n int n int
panicking bool panicking bool
erroring bool // printing an error condition erroring bool // printing an error condition
buf bytes.Buffer buf buffer
// field holds the current item, as an interface{}. // field holds the current item, as an interface{}.
field interface{} field interface{}
// value holds the current item, as a reflect.Value, and will be // value holds the current item, as a reflect.Value, and will be
@ -133,10 +165,10 @@ func newPrinter() *pp {
// Save used pp structs in ppFree; avoids an allocation per invocation. // Save used pp structs in ppFree; avoids an allocation per invocation.
func (p *pp) free() { func (p *pp) free() {
// Don't hold on to pp structs with large buffers. // Don't hold on to pp structs with large buffers.
if cap(p.buf.Bytes()) > 1024 { if cap(p.buf) > 1024 {
return return
} }
p.buf.Reset() p.buf = p.buf[:0]
p.field = nil p.field = nil
p.value = reflect.Value{} p.value = reflect.Value{}
ppFree.put(p) ppFree.put(p)
@ -179,7 +211,7 @@ func (p *pp) Write(b []byte) (ret int, err error) {
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
p := newPrinter() p := newPrinter()
p.doPrintf(format, a) p.doPrintf(format, a)
n64, err := p.buf.WriteTo(w) n64, err := w.Write(p.buf)
p.free() p.free()
return int(n64), err return int(n64), err
} }
@ -194,7 +226,7 @@ func Printf(format string, a ...interface{}) (n int, err error) {
func Sprintf(format string, a ...interface{}) string { func Sprintf(format string, a ...interface{}) string {
p := newPrinter() p := newPrinter()
p.doPrintf(format, a) p.doPrintf(format, a)
s := p.buf.String() s := string(p.buf)
p.free() p.free()
return s return s
} }
@ -213,7 +245,7 @@ func Errorf(format string, a ...interface{}) error {
func Fprint(w io.Writer, a ...interface{}) (n int, err error) { func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
p := newPrinter() p := newPrinter()
p.doPrint(a, false, false) p.doPrint(a, false, false)
n64, err := p.buf.WriteTo(w) n64, err := w.Write(p.buf)
p.free() p.free()
return int(n64), err return int(n64), err
} }
@ -230,7 +262,7 @@ func Print(a ...interface{}) (n int, err error) {
func Sprint(a ...interface{}) string { func Sprint(a ...interface{}) string {
p := newPrinter() p := newPrinter()
p.doPrint(a, false, false) p.doPrint(a, false, false)
s := p.buf.String() s := string(p.buf)
p.free() p.free()
return s return s
} }
@ -245,7 +277,7 @@ func Sprint(a ...interface{}) string {
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) { func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
p := newPrinter() p := newPrinter()
p.doPrint(a, true, true) p.doPrint(a, true, true)
n64, err := p.buf.WriteTo(w) n64, err := w.Write(p.buf)
p.free() p.free()
return int(n64), err return int(n64), err
} }
@ -262,7 +294,7 @@ func Println(a ...interface{}) (n int, err error) {
func Sprintln(a ...interface{}) string { func Sprintln(a ...interface{}) string {
p := newPrinter() p := newPrinter()
p.doPrint(a, true, true) p.doPrint(a, true, true)
s := p.buf.String() s := string(p.buf)
p.free() p.free()
return s return s
} }
@ -352,7 +384,7 @@ func (p *pp) fmtInt64(v int64, verb rune) {
case 'o': case 'o':
p.fmt.integer(v, 8, signed, ldigits) p.fmt.integer(v, 8, signed, ldigits)
case 'q': case 'q':
if 0 <= v && v <= unicode.MaxRune { if 0 <= v && v <= utf8.MaxRune {
p.fmt.fmt_qc(v) p.fmt.fmt_qc(v)
} else { } else {
p.badVerb(verb) p.badVerb(verb)
@ -416,7 +448,7 @@ func (p *pp) fmtUint64(v uint64, verb rune, goSyntax bool) {
case 'o': case 'o':
p.fmt.integer(int64(v), 8, unsigned, ldigits) p.fmt.integer(int64(v), 8, unsigned, ldigits)
case 'q': case 'q':
if 0 <= v && v <= unicode.MaxRune { if 0 <= v && v <= utf8.MaxRune {
p.fmt.fmt_qc(int64(v)) p.fmt.fmt_qc(int64(v))
} else { } else {
p.badVerb(verb) p.badVerb(verb)

View File

@ -5,15 +5,12 @@
package fmt package fmt
import ( import (
"bytes"
"errors" "errors"
"io" "io"
"math" "math"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"unicode"
"unicode/utf8" "unicode/utf8"
) )
@ -87,25 +84,36 @@ func Scanf(format string, a ...interface{}) (n int, err error) {
return Fscanf(os.Stdin, format, a...) return Fscanf(os.Stdin, format, a...)
} }
type stringReader string
func (r *stringReader) Read(b []byte) (n int, err error) {
n = copy(b, *r)
*r = (*r)[n:]
if n == 0 {
err = io.EOF
}
return
}
// Sscan scans the argument string, storing successive space-separated // Sscan scans the argument string, storing successive space-separated
// values into successive arguments. Newlines count as space. It // values into successive arguments. Newlines count as space. It
// returns the number of items successfully scanned. If that is less // returns the number of items successfully scanned. If that is less
// than the number of arguments, err will report why. // than the number of arguments, err will report why.
func Sscan(str string, a ...interface{}) (n int, err error) { func Sscan(str string, a ...interface{}) (n int, err error) {
return Fscan(strings.NewReader(str), a...) return Fscan((*stringReader)(&str), a...)
} }
// Sscanln is similar to Sscan, but stops scanning at a newline and // Sscanln is similar to Sscan, but stops scanning at a newline and
// after the final item there must be a newline or EOF. // after the final item there must be a newline or EOF.
func Sscanln(str string, a ...interface{}) (n int, err error) { func Sscanln(str string, a ...interface{}) (n int, err error) {
return Fscanln(strings.NewReader(str), a...) return Fscanln((*stringReader)(&str), a...)
} }
// Sscanf scans the argument string, storing successive space-separated // Sscanf scans the argument string, storing successive space-separated
// values into successive arguments as determined by the format. It // values into successive arguments as determined by the format. It
// returns the number of items successfully parsed. // returns the number of items successfully parsed.
func Sscanf(str string, format string, a ...interface{}) (n int, err error) { func Sscanf(str string, format string, a ...interface{}) (n int, err error) {
return Fscanf(strings.NewReader(str), format, a...) return Fscanf((*stringReader)(&str), format, a...)
} }
// Fscan scans text read from r, storing successive space-separated // Fscan scans text read from r, storing successive space-separated
@ -149,7 +157,7 @@ const eof = -1
// ss is the internal implementation of ScanState. // ss is the internal implementation of ScanState.
type ss struct { type ss struct {
rr io.RuneReader // where to read input rr io.RuneReader // where to read input
buf bytes.Buffer // token accumulator buf buffer // token accumulator
peekRune rune // one-rune lookahead peekRune rune // one-rune lookahead
prevRune rune // last rune returned by ReadRune prevRune rune // last rune returned by ReadRune
count int // runes consumed so far. count int // runes consumed so far.
@ -262,14 +270,46 @@ func (s *ss) Token(skipSpace bool, f func(rune) bool) (tok []byte, err error) {
if f == nil { if f == nil {
f = notSpace f = notSpace
} }
s.buf.Reset() s.buf = s.buf[:0]
tok = s.token(skipSpace, f) tok = s.token(skipSpace, f)
return return
} }
// space is a copy of the unicode.White_Space ranges,
// to avoid depending on package unicode.
var space = [][2]uint16{
{0x0009, 0x000d},
{0x0020, 0x0020},
{0x0085, 0x0085},
{0x00a0, 0x00a0},
{0x1680, 0x1680},
{0x180e, 0x180e},
{0x2000, 0x200a},
{0x2028, 0x2029},
{0x202f, 0x202f},
{0x205f, 0x205f},
{0x3000, 0x3000},
}
func isSpace(r rune) bool {
if r >= 1<<16 {
return false
}
rx := uint16(r)
for _, rng := range space {
if rx < rng[0] {
return false
}
if rx <= rng[1] {
return true
}
}
return false
}
// notSpace is the default scanning function used in Token. // notSpace is the default scanning function used in Token.
func notSpace(r rune) bool { func notSpace(r rune) bool {
return !unicode.IsSpace(r) return !isSpace(r)
} }
// skipSpace provides Scan() methods the ability to skip space and newline characters // skipSpace provides Scan() methods the ability to skip space and newline characters
@ -378,10 +418,10 @@ func (s *ss) free(old ssave) {
return return
} }
// Don't hold on to ss structs with large buffers. // Don't hold on to ss structs with large buffers.
if cap(s.buf.Bytes()) > 1024 { if cap(s.buf) > 1024 {
return return
} }
s.buf.Reset() s.buf = s.buf[:0]
s.rr = nil s.rr = nil
ssFree.put(s) ssFree.put(s)
} }
@ -403,7 +443,7 @@ func (s *ss) skipSpace(stopAtNewline bool) {
s.errorString("unexpected newline") s.errorString("unexpected newline")
return return
} }
if !unicode.IsSpace(r) { if !isSpace(r) {
s.UnreadRune() s.UnreadRune()
break break
} }
@ -429,7 +469,7 @@ func (s *ss) token(skipSpace bool, f func(rune) bool) []byte {
} }
s.buf.WriteRune(r) s.buf.WriteRune(r)
} }
return s.buf.Bytes() return s.buf
} }
// typeError indicates that the type of the operand did not match the format // typeError indicates that the type of the operand did not match the format
@ -440,6 +480,15 @@ func (s *ss) typeError(field interface{}, expected string) {
var complexError = errors.New("syntax error scanning complex number") var complexError = errors.New("syntax error scanning complex number")
var boolError = errors.New("syntax error scanning boolean") var boolError = errors.New("syntax error scanning boolean")
func indexRune(s string, r rune) int {
for i, c := range s {
if c == r {
return i
}
}
return -1
}
// consume reads the next rune in the input and reports whether it is in the ok string. // consume reads the next rune in the input and reports whether it is in the ok string.
// If accept is true, it puts the character into the input token. // If accept is true, it puts the character into the input token.
func (s *ss) consume(ok string, accept bool) bool { func (s *ss) consume(ok string, accept bool) bool {
@ -447,7 +496,7 @@ func (s *ss) consume(ok string, accept bool) bool {
if r == eof { if r == eof {
return false return false
} }
if strings.IndexRune(ok, r) >= 0 { if indexRune(ok, r) >= 0 {
if accept { if accept {
s.buf.WriteRune(r) s.buf.WriteRune(r)
} }
@ -465,7 +514,7 @@ func (s *ss) peek(ok string) bool {
if r != eof { if r != eof {
s.UnreadRune() s.UnreadRune()
} }
return strings.IndexRune(ok, r) >= 0 return indexRune(ok, r) >= 0
} }
func (s *ss) notEOF() { func (s *ss) notEOF() {
@ -560,7 +609,7 @@ func (s *ss) scanNumber(digits string, haveDigits bool) string {
} }
for s.accept(digits) { for s.accept(digits) {
} }
return s.buf.String() return string(s.buf)
} }
// scanRune returns the next rune value in the input. // scanRune returns the next rune value in the input.
@ -660,16 +709,16 @@ func (s *ss) scanUint(verb rune, bitSize int) uint64 {
// if the width is specified. It's not rigorous about syntax because it doesn't check that // if the width is specified. It's not rigorous about syntax because it doesn't check that
// we have at least some digits, but Atof will do that. // we have at least some digits, but Atof will do that.
func (s *ss) floatToken() string { func (s *ss) floatToken() string {
s.buf.Reset() s.buf = s.buf[:0]
// NaN? // NaN?
if s.accept("nN") && s.accept("aA") && s.accept("nN") { if s.accept("nN") && s.accept("aA") && s.accept("nN") {
return s.buf.String() return string(s.buf)
} }
// leading sign? // leading sign?
s.accept(sign) s.accept(sign)
// Inf? // Inf?
if s.accept("iI") && s.accept("nN") && s.accept("fF") { if s.accept("iI") && s.accept("nN") && s.accept("fF") {
return s.buf.String() return string(s.buf)
} }
// digits? // digits?
for s.accept(decimalDigits) { for s.accept(decimalDigits) {
@ -688,7 +737,7 @@ func (s *ss) floatToken() string {
for s.accept(decimalDigits) { for s.accept(decimalDigits) {
} }
} }
return s.buf.String() return string(s.buf)
} }
// complexTokens returns the real and imaginary parts of the complex number starting here. // complexTokens returns the real and imaginary parts of the complex number starting here.
@ -698,13 +747,13 @@ func (s *ss) complexTokens() (real, imag string) {
// TODO: accept N and Ni independently? // TODO: accept N and Ni independently?
parens := s.accept("(") parens := s.accept("(")
real = s.floatToken() real = s.floatToken()
s.buf.Reset() s.buf = s.buf[:0]
// Must now have a sign. // Must now have a sign.
if !s.accept("+-") { if !s.accept("+-") {
s.error(complexError) s.error(complexError)
} }
// Sign is now in buffer // Sign is now in buffer
imagSign := s.buf.String() imagSign := string(s.buf)
imag = s.floatToken() imag = s.floatToken()
if !s.accept("i") { if !s.accept("i") {
s.error(complexError) s.error(complexError)
@ -717,7 +766,7 @@ func (s *ss) complexTokens() (real, imag string) {
// convertFloat converts the string to a float64value. // convertFloat converts the string to a float64value.
func (s *ss) convertFloat(str string, n int) float64 { func (s *ss) convertFloat(str string, n int) float64 {
if p := strings.Index(str, "p"); p >= 0 { if p := indexRune(str, 'p'); p >= 0 {
// Atof doesn't handle power-of-2 exponents, // Atof doesn't handle power-of-2 exponents,
// but they're easy to evaluate. // but they're easy to evaluate.
f, err := strconv.ParseFloat(str[:p], n) f, err := strconv.ParseFloat(str[:p], n)
@ -794,7 +843,7 @@ func (s *ss) quotedString() string {
} }
s.buf.WriteRune(r) s.buf.WriteRune(r)
} }
return s.buf.String() return string(s.buf)
case '"': case '"':
// Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes. // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
s.buf.WriteRune(quote) s.buf.WriteRune(quote)
@ -811,7 +860,7 @@ func (s *ss) quotedString() string {
break break
} }
} }
result, err := strconv.Unquote(s.buf.String()) result, err := strconv.Unquote(string(s.buf))
if err != nil { if err != nil {
s.error(err) s.error(err)
} }
@ -844,7 +893,7 @@ func (s *ss) hexByte() (b byte, ok bool) {
if rune1 == eof { if rune1 == eof {
return return
} }
if unicode.IsSpace(rune1) { if isSpace(rune1) {
s.UnreadRune() s.UnreadRune()
return return
} }
@ -862,11 +911,11 @@ func (s *ss) hexString() string {
} }
s.buf.WriteByte(b) s.buf.WriteByte(b)
} }
if s.buf.Len() == 0 { if len(s.buf) == 0 {
s.errorString("Scan: no hex data for %x string") s.errorString("Scan: no hex data for %x string")
return "" return ""
} }
return s.buf.String() return string(s.buf)
} }
const floatVerbs = "beEfFgGv" const floatVerbs = "beEfFgGv"
@ -875,7 +924,7 @@ const hugeWid = 1 << 30
// scanOne scans a single value, deriving the scanner from the type of the argument. // scanOne scans a single value, deriving the scanner from the type of the argument.
func (s *ss) scanOne(verb rune, field interface{}) { func (s *ss) scanOne(verb rune, field interface{}) {
s.buf.Reset() s.buf = s.buf[:0]
var err error var err error
// If the parameter has its own Scan method, use that. // If the parameter has its own Scan method, use that.
if v, ok := field.(Scanner); ok { if v, ok := field.(Scanner); ok {
@ -1004,7 +1053,7 @@ func (s *ss) doScan(a []interface{}) (numProcessed int, err error) {
if r == '\n' || r == eof { if r == '\n' || r == eof {
break break
} }
if !unicode.IsSpace(r) { if !isSpace(r) {
s.errorString("Scan: expected newline") s.errorString("Scan: expected newline")
break break
} }
@ -1032,7 +1081,7 @@ func (s *ss) advance(format string) (i int) {
i += w // skip the first % i += w // skip the first %
} }
sawSpace := false sawSpace := false
for unicode.IsSpace(fmtc) && i < len(format) { for isSpace(fmtc) && i < len(format) {
sawSpace = true sawSpace = true
i += w i += w
fmtc, w = utf8.DecodeRuneInString(format[i:]) fmtc, w = utf8.DecodeRuneInString(format[i:])
@ -1044,7 +1093,7 @@ func (s *ss) advance(format string) (i int) {
if inputc == eof { if inputc == eof {
return return
} }
if !unicode.IsSpace(inputc) { if !isSpace(inputc) {
// Space in format but not in input: error // Space in format but not in input: error
s.errorString("expected space in input to match format") s.errorString("expected space in input to match format")
} }

View File

@ -34,7 +34,7 @@ type Context struct {
CgoEnabled bool // whether cgo can be used CgoEnabled bool // whether cgo can be used
BuildTags []string // additional tags to recognize in +build lines BuildTags []string // additional tags to recognize in +build lines
UseAllFiles bool // use files regardless of +build lines, file names UseAllFiles bool // use files regardless of +build lines, file names
Gccgo bool // assume use of gccgo when computing object paths Compiler string // compiler to assume when computing target paths
// By default, Import uses the operating system's file system calls // By default, Import uses the operating system's file system calls
// to read directories and files. To read from other sources, // to read directories and files. To read from other sources,
@ -210,6 +210,7 @@ func (ctxt *Context) SrcDirs() []string {
// if set, or else the compiled code's GOARCH, GOOS, and GOROOT. // if set, or else the compiled code's GOARCH, GOOS, and GOROOT.
var Default Context = defaultContext() var Default Context = defaultContext()
// This list is also known to ../../../cmd/dist/build.c.
var cgoEnabled = map[string]bool{ var cgoEnabled = map[string]bool{
"darwin/386": true, "darwin/386": true,
"darwin/amd64": true, "darwin/amd64": true,
@ -228,6 +229,7 @@ func defaultContext() Context {
c.GOOS = envOr("GOOS", runtime.GOOS) c.GOOS = envOr("GOOS", runtime.GOOS)
c.GOROOT = runtime.GOROOT() c.GOROOT = runtime.GOROOT()
c.GOPATH = envOr("GOPATH", "") c.GOPATH = envOr("GOPATH", "")
c.Compiler = runtime.Compiler
switch os.Getenv("CGO_ENABLED") { switch os.Getenv("CGO_ENABLED") {
case "1": case "1":
@ -277,11 +279,12 @@ type Package struct {
PkgObj string // installed .a file PkgObj string // installed .a file
// Source files // Source files
GoFiles []string // .go source files (excluding CgoFiles, TestGoFiles, XTestGoFiles) GoFiles []string // .go source files (excluding CgoFiles, TestGoFiles, XTestGoFiles)
CgoFiles []string // .go source files that import "C" CgoFiles []string // .go source files that import "C"
CFiles []string // .c source files CFiles []string // .c source files
HFiles []string // .h source files HFiles []string // .h source files
SFiles []string // .s source files SFiles []string // .s source files
SysoFiles []string // .syso system object files to add to archive
// Cgo directives // Cgo directives
CgoPkgConfig []string // Cgo pkg-config directives CgoPkgConfig []string // Cgo pkg-config directives
@ -314,6 +317,16 @@ func (ctxt *Context) ImportDir(dir string, mode ImportMode) (*Package, error) {
return ctxt.Import(".", dir, mode) return ctxt.Import(".", dir, mode)
} }
// NoGoError is the error used by Import to describe a directory
// containing no Go source files.
type NoGoError struct {
Dir string
}
func (e *NoGoError) Error() string {
return "no Go source files in " + e.Dir
}
// Import returns details about the Go package named by the import path, // Import returns details about the Go package named by the import path,
// interpreting local import paths relative to the src directory. If the path // interpreting local import paths relative to the src directory. If the path
// is a local import path naming a package that can be imported using a // is a local import path naming a package that can be imported using a
@ -336,11 +349,16 @@ func (ctxt *Context) Import(path string, src string, mode ImportMode) (*Package,
} }
var pkga string var pkga string
if ctxt.Gccgo { var pkgerr error
switch ctxt.Compiler {
case "gccgo":
dir, elem := pathpkg.Split(p.ImportPath) dir, elem := pathpkg.Split(p.ImportPath)
pkga = "pkg/gccgo/" + dir + "lib" + elem + ".a" pkga = "pkg/gccgo/" + dir + "lib" + elem + ".a"
} else { case "gc":
pkga = "pkg/" + ctxt.GOOS + "_" + ctxt.GOARCH + "/" + p.ImportPath + ".a" pkga = "pkg/" + ctxt.GOOS + "_" + ctxt.GOARCH + "/" + p.ImportPath + ".a"
default:
// Save error for end of function.
pkgerr = fmt.Errorf("import %q: unknown compiler %q", path, ctxt.Compiler)
} }
binaryOnly := false binaryOnly := false
@ -396,7 +414,7 @@ func (ctxt *Context) Import(path string, src string, mode ImportMode) (*Package,
if ctxt.GOROOT != "" { if ctxt.GOROOT != "" {
dir := ctxt.joinPath(ctxt.GOROOT, "src", "pkg", path) dir := ctxt.joinPath(ctxt.GOROOT, "src", "pkg", path)
isDir := ctxt.isDir(dir) isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga)) binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
if isDir || binaryOnly { if isDir || binaryOnly {
p.Dir = dir p.Dir = dir
p.Goroot = true p.Goroot = true
@ -407,7 +425,7 @@ func (ctxt *Context) Import(path string, src string, mode ImportMode) (*Package,
for _, root := range ctxt.gopath() { for _, root := range ctxt.gopath() {
dir := ctxt.joinPath(root, "src", path) dir := ctxt.joinPath(root, "src", path)
isDir := ctxt.isDir(dir) isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && ctxt.isFile(ctxt.joinPath(root, pkga)) binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(root, pkga))
if isDir || binaryOnly { if isDir || binaryOnly {
p.Dir = dir p.Dir = dir
p.Root = root p.Root = root
@ -426,14 +444,16 @@ Found:
} }
p.PkgRoot = ctxt.joinPath(p.Root, "pkg") p.PkgRoot = ctxt.joinPath(p.Root, "pkg")
p.BinDir = ctxt.joinPath(p.Root, "bin") p.BinDir = ctxt.joinPath(p.Root, "bin")
p.PkgObj = ctxt.joinPath(p.Root, pkga) if pkga != "" {
p.PkgObj = ctxt.joinPath(p.Root, pkga)
}
} }
if mode&FindOnly != 0 { if mode&FindOnly != 0 {
return p, nil return p, pkgerr
} }
if binaryOnly && (mode&AllowBinary) != 0 { if binaryOnly && (mode&AllowBinary) != 0 {
return p, nil return p, pkgerr
} }
dirs, err := ctxt.readDir(p.Dir) dirs, err := ctxt.readDir(p.Dir)
@ -467,7 +487,13 @@ Found:
ext := name[i:] ext := name[i:]
switch ext { switch ext {
case ".go", ".c", ".s", ".h", ".S": case ".go", ".c", ".s", ".h", ".S":
// tentatively okay // tentatively okay - read to make sure
case ".syso":
// binary objects to add to package archive
// Likely of the form foo_windows.syso, but
// the name was vetted above with goodOSArchFile.
p.SysoFiles = append(p.SysoFiles, name)
continue
default: default:
// skip // skip
continue continue
@ -586,7 +612,7 @@ Found:
} }
} }
if p.Name == "" { if p.Name == "" {
return p, fmt.Errorf("no Go source files in %s", p.Dir) return p, &NoGoError{p.Dir}
} }
p.Imports, p.ImportPos = cleanImports(imported) p.Imports, p.ImportPos = cleanImports(imported)
@ -601,7 +627,7 @@ Found:
sort.Strings(p.SFiles) sort.Strings(p.SFiles)
} }
return p, nil return p, pkgerr
} }
func cleanImports(m map[string][]token.Position) ([]string, map[string][]token.Position) { func cleanImports(m map[string][]token.Position) ([]string, map[string][]token.Position) {

View File

@ -0,0 +1,424 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file exercises the import parser but also checks that
// some low-level packages do not have new dependencies added.
package build_test
import (
"go/build"
"sort"
"testing"
)
// pkgDeps defines the expected dependencies between packages in
// the Go source tree. It is a statement of policy.
// Changes should not be made to this map without prior discussion.
//
// The map contains two kinds of entries:
// 1) Lower-case keys are standard import paths and list the
// allowed imports in that package.
// 2) Upper-case keys define aliases for package sets, which can then
// be used as dependencies by other rules.
//
// DO NOT CHANGE THIS DATA TO FIX BUILDS.
//
var pkgDeps = map[string][]string{
// L0 is the lowest level, core, nearly unavoidable packages.
"errors": {},
"io": {"errors", "sync"},
"runtime": {"unsafe"},
"sync": {"sync/atomic"},
"sync/atomic": {"unsafe"},
"unsafe": {},
"L0": {
"errors",
"io",
"runtime",
"sync",
"sync/atomic",
"unsafe",
},
// L1 adds simple functions and strings processing,
// but not Unicode tables.
"math": {"unsafe"},
"math/cmplx": {"math"},
"math/rand": {"L0", "math"},
"sort": {"math"},
"strconv": {"L0", "unicode/utf8", "math"},
"unicode/utf16": {},
"unicode/utf8": {},
"L1": {
"L0",
"math",
"math/cmplx",
"math/rand",
"sort",
"strconv",
"unicode/utf16",
"unicode/utf8",
},
// L2 adds Unicode and strings processing.
"bufio": {"L0", "unicode/utf8", "bytes"},
"bytes": {"L0", "unicode", "unicode/utf8"},
"path": {"L0", "unicode/utf8", "strings"},
"strings": {"L0", "unicode", "unicode/utf8"},
"unicode": {},
"L2": {
"L1",
"bufio",
"bytes",
"path",
"strings",
"unicode",
},
// L3 adds reflection and some basic utility packages
// and interface definitions, but nothing that makes
// system calls.
"crypto": {"L2", "hash"}, // interfaces
"crypto/cipher": {"L2"}, // interfaces
"encoding/base32": {"L2"},
"encoding/base64": {"L2"},
"encoding/binary": {"L2", "reflect"},
"hash": {"L2"}, // interfaces
"hash/adler32": {"L2", "hash"},
"hash/crc32": {"L2", "hash"},
"hash/crc64": {"L2", "hash"},
"hash/fnv": {"L2", "hash"},
"image": {"L2", "image/color"}, // interfaces
"image/color": {"L2"}, // interfaces
"reflect": {"L2"},
"L3": {
"L2",
"crypto",
"crypto/cipher",
"encoding/base32",
"encoding/base64",
"encoding/binary",
"hash",
"hash/adler32",
"hash/crc32",
"hash/crc64",
"hash/fnv",
"image",
"image/color",
"reflect",
},
// End of linear dependency definitions.
// Operating system access.
"syscall": {"L0", "unicode/utf16"},
"time": {"L0", "syscall"},
"os": {"L1", "os", "syscall", "time"},
"path/filepath": {"L2", "os", "syscall"},
"io/ioutil": {"L2", "os", "path/filepath", "time"},
"os/exec": {"L2", "os", "syscall"},
"os/signal": {"L2", "os", "syscall"},
// OS enables basic operating system functionality,
// but not direct use of package syscall, nor os/signal.
"OS": {
"io/ioutil",
"os",
"os/exec",
"path/filepath",
"time",
},
// Formatted I/O: few dependencies (L1) but we must add reflect.
"fmt": {"L1", "os", "reflect"},
"log": {"L1", "os", "fmt", "time"},
// Packages used by testing must be low-level (L2+fmt).
"regexp": {"L2", "regexp/syntax"},
"regexp/syntax": {"L2"},
"runtime/debug": {"L2", "fmt", "io/ioutil", "os"},
"runtime/pprof": {"L2", "fmt", "text/tabwriter"},
"text/tabwriter": {"L2"},
"testing": {"L2", "flag", "fmt", "os", "runtime/pprof", "time"},
"testing/iotest": {"L2", "log"},
"testing/quick": {"L2", "flag", "fmt", "reflect"},
// L4 is defined as L3+fmt+log+time, because in general once
// you're using L3 packages, use of fmt, log, or time is not a big deal.
"L4": {
"L3",
"fmt",
"log",
"time",
},
// Go parser.
"go/ast": {"L4", "OS", "go/scanner", "go/token"},
"go/doc": {"L4", "go/ast", "go/token", "regexp", "text/template"},
"go/parser": {"L4", "OS", "go/ast", "go/scanner", "go/token"},
"go/printer": {"L4", "OS", "go/ast", "go/scanner", "go/token", "text/tabwriter"},
"go/scanner": {"L4", "OS", "go/token"},
"go/token": {"L4"},
"GOPARSER": {
"go/ast",
"go/doc",
"go/parser",
"go/printer",
"go/scanner",
"go/token",
},
// One of a kind.
"archive/tar": {"L4", "OS"},
"archive/zip": {"L4", "OS", "compress/flate"},
"compress/bzip2": {"L4"},
"compress/flate": {"L4"},
"compress/gzip": {"L4", "compress/flate"},
"compress/lzw": {"L4"},
"compress/zlib": {"L4", "compress/flate"},
"database/sql": {"L4", "database/sql/driver"},
"database/sql/driver": {"L4", "time"},
"debug/dwarf": {"L4"},
"debug/elf": {"L4", "OS", "debug/dwarf"},
"debug/gosym": {"L4"},
"debug/macho": {"L4", "OS", "debug/dwarf"},
"debug/pe": {"L4", "OS", "debug/dwarf"},
"encoding/ascii85": {"L4"},
"encoding/asn1": {"L4", "math/big"},
"encoding/csv": {"L4"},
"encoding/gob": {"L4", "OS"},
"encoding/hex": {"L4"},
"encoding/json": {"L4"},
"encoding/pem": {"L4"},
"encoding/xml": {"L4"},
"flag": {"L4", "OS"},
"go/build": {"L4", "OS", "GOPARSER"},
"html": {"L4"},
"image/draw": {"L4"},
"image/gif": {"L4", "compress/lzw"},
"image/jpeg": {"L4"},
"image/png": {"L4", "compress/zlib"},
"index/suffixarray": {"L4", "regexp"},
"math/big": {"L4"},
"mime": {"L4", "OS", "syscall"},
"net/url": {"L4"},
"text/scanner": {"L4", "OS"},
"text/template/parse": {"L4"},
"html/template": {
"L4", "OS", "encoding/json", "html", "text/template",
"text/template/parse",
},
"text/template": {
"L4", "OS", "net/url", "text/template/parse",
},
// Cgo.
"runtime/cgo": {"L0", "C"},
"CGO": {"C", "runtime/cgo"},
// Fake entry to satisfy the pseudo-import "C"
// that shows up in programs that use cgo.
"C": {},
"os/user": {"L4", "CGO", "syscall"},
// Basic networking.
// Because net must be used by any package that wants to
// do networking portably, it must have a small dependency set: just L1+basic os.
"net": {"L1", "CGO", "os", "syscall", "time"},
// NET enables use of basic network-related packages.
"NET": {
"net",
"mime",
"net/textproto",
"net/url",
},
// Uses of networking.
"log/syslog": {"L4", "OS", "net"},
"net/mail": {"L4", "NET", "OS"},
"net/textproto": {"L4", "OS", "net"},
// Core crypto.
"crypto/aes": {"L3"},
"crypto/des": {"L3"},
"crypto/hmac": {"L3"},
"crypto/md5": {"L3"},
"crypto/rc4": {"L3"},
"crypto/sha1": {"L3"},
"crypto/sha256": {"L3"},
"crypto/sha512": {"L3"},
"crypto/subtle": {"L3"},
"CRYPTO": {
"crypto/aes",
"crypto/des",
"crypto/hmac",
"crypto/md5",
"crypto/rc4",
"crypto/sha1",
"crypto/sha256",
"crypto/sha512",
"crypto/subtle",
},
// Random byte, number generation.
// This would be part of core crypto except that it imports
// math/big, which imports fmt.
"crypto/rand": {"L4", "CRYPTO", "OS", "math/big", "syscall"},
// Mathematical crypto: dependencies on fmt (L4) and math/big.
// We could avoid some of the fmt, but math/big imports fmt anyway.
"crypto/dsa": {"L4", "CRYPTO", "math/big"},
"crypto/ecdsa": {"L4", "CRYPTO", "crypto/elliptic", "math/big"},
"crypto/elliptic": {"L4", "CRYPTO", "math/big"},
"crypto/rsa": {"L4", "CRYPTO", "crypto/rand", "math/big"},
"CRYPTO-MATH": {
"CRYPTO",
"crypto/dsa",
"crypto/ecdsa",
"crypto/elliptic",
"crypto/rand",
"crypto/rsa",
"encoding/asn1",
"math/big",
},
// SSL/TLS.
"crypto/tls": {
"L4", "CRYPTO-MATH", "CGO", "OS",
"crypto/x509", "encoding/pem", "net", "syscall",
},
"crypto/x509": {"L4", "CRYPTO-MATH", "OS", "CGO", "crypto/x509/pkix", "encoding/pem", "syscall"},
"crypto/x509/pkix": {"L4", "CRYPTO-MATH"},
// Simple net+crypto-aware packages.
"mime/multipart": {"L4", "OS", "mime", "crypto/rand", "net/textproto"},
"net/smtp": {"L4", "CRYPTO", "NET", "crypto/tls"},
// HTTP, kingpin of dependencies.
"net/http": {
"L4", "NET", "OS",
"compress/gzip", "crypto/tls", "mime/multipart", "runtime/debug",
},
// HTTP-using packages.
"expvar": {"L4", "OS", "encoding/json", "net/http"},
"net/http/cgi": {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"},
"net/http/fcgi": {"L4", "NET", "OS", "net/http", "net/http/cgi"},
"net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http"},
"net/http/httputil": {"L4", "NET", "OS", "net/http"},
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof"},
"net/rpc": {"L4", "NET", "encoding/gob", "net/http", "text/template"},
"net/rpc/jsonrpc": {"L4", "NET", "encoding/json", "net/rpc"},
}
// isMacro reports whether p is a package dependency macro
// (uppercase name).
func isMacro(p string) bool {
return 'A' <= p[0] && p[0] <= 'Z'
}
func allowed(pkg string) map[string]bool {
m := map[string]bool{}
var allow func(string)
allow = func(p string) {
if m[p] {
return
}
m[p] = true // set even for macros, to avoid loop on cycle
// Upper-case names are macro-expanded.
if isMacro(p) {
for _, pp := range pkgDeps[p] {
allow(pp)
}
}
}
for _, pp := range pkgDeps[pkg] {
allow(pp)
}
return m
}
var bools = []bool{false, true}
var geese = []string{"darwin", "freebsd", "linux", "netbsd", "openbsd", "plan9", "windows"}
var goarches = []string{"386", "amd64", "arm"}
type osPkg struct {
goos, pkg string
}
// allowedErrors are the operating systems and packages known to contain errors
// (currently just "no Go source files")
var allowedErrors = map[osPkg]bool{
osPkg{"windows", "log/syslog"}: true,
osPkg{"plan9", "log/syslog"}: true,
}
func TestDependencies(t *testing.T) {
var all []string
for k := range pkgDeps {
all = append(all, k)
}
sort.Strings(all)
ctxt := build.Default
test := func(mustImport bool) {
for _, pkg := range all {
if isMacro(pkg) {
continue
}
p, err := ctxt.Import(pkg, "", 0)
if err != nil {
if allowedErrors[osPkg{ctxt.GOOS, pkg}] {
continue
}
// Some of the combinations we try might not
// be reasonable (like arm,plan9,cgo), so ignore
// errors for the auto-generated combinations.
if !mustImport {
continue
}
t.Errorf("%s/%s/cgo=%v %v", ctxt.GOOS, ctxt.GOARCH, ctxt.CgoEnabled, err)
continue
}
ok := allowed(pkg)
var bad []string
for _, imp := range p.Imports {
if !ok[imp] {
bad = append(bad, imp)
}
}
if bad != nil {
t.Errorf("%s/%s/cgo=%v unexpected dependency: %s imports %v", ctxt.GOOS, ctxt.GOARCH, ctxt.CgoEnabled, pkg, bad)
}
}
}
test(true)
if testing.Short() {
t.Logf("skipping other systems")
return
}
for _, ctxt.GOOS = range geese {
for _, ctxt.GOARCH = range goarches {
for _, ctxt.CgoEnabled = range bools {
test(false)
}
}
}
}

View File

@ -0,0 +1,166 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements a parser test harness. The files in the testdata
// directory are parsed and the errors reported are compared against the
// error messages expected in the test files. The test files must end in
// .src rather than .go so that they are not disturbed by gofmt runs.
//
// Expected errors are indicated in the test files by putting a comment
// of the form /* ERROR "rx" */ immediately following an offending token.
// The harness will verify that an error matching the regular expression
// rx is reported at that source position.
//
// For instance, the following test file indicates that a "not declared"
// error should be reported for the undeclared variable x:
//
// package p
// func f() {
// _ = x /* ERROR "not declared" */ + 1
// }
package parser
import (
"go/scanner"
"go/token"
"io/ioutil"
"path/filepath"
"regexp"
"strings"
"testing"
)
const testdata = "testdata"
// getFile assumes that each filename occurs at most once
func getFile(filename string) (file *token.File) {
fset.Iterate(func(f *token.File) bool {
if f.Name() == filename {
if file != nil {
panic(filename + " used multiple times")
}
file = f
}
return true
})
return file
}
func getPos(filename string, offset int) token.Pos {
if f := getFile(filename); f != nil {
return f.Pos(offset)
}
return token.NoPos
}
// ERROR comments must be of the form /* ERROR "rx" */ and rx is
// a regular expression that matches the expected error message.
//
var errRx = regexp.MustCompile(`^/\* *ERROR *"([^"]*)" *\*/$`)
// expectedErrors collects the regular expressions of ERROR comments found
// in files and returns them as a map of error positions to error messages.
//
func expectedErrors(t *testing.T, filename string, src []byte) map[token.Pos]string {
errors := make(map[token.Pos]string)
var s scanner.Scanner
// file was parsed already - do not add it again to the file
// set otherwise the position information returned here will
// not match the position information collected by the parser
s.Init(getFile(filename), src, nil, scanner.ScanComments)
var prev token.Pos // position of last non-comment, non-semicolon token
for {
pos, tok, lit := s.Scan()
switch tok {
case token.EOF:
return errors
case token.COMMENT:
s := errRx.FindStringSubmatch(lit)
if len(s) == 2 {
errors[prev] = string(s[1])
}
default:
prev = pos
}
}
panic("unreachable")
}
// compareErrors compares the map of expected error messages with the list
// of found errors and reports discrepancies.
//
func compareErrors(t *testing.T, expected map[token.Pos]string, found scanner.ErrorList) {
for _, error := range found {
// error.Pos is a token.Position, but we want
// a token.Pos so we can do a map lookup
pos := getPos(error.Pos.Filename, error.Pos.Offset)
if msg, found := expected[pos]; found {
// we expect a message at pos; check if it matches
rx, err := regexp.Compile(msg)
if err != nil {
t.Errorf("%s: %v", error.Pos, err)
continue
}
if match := rx.MatchString(error.Msg); !match {
t.Errorf("%s: %q does not match %q", error.Pos, error.Msg, msg)
continue
}
// we have a match - eliminate this error
delete(expected, pos)
} else {
// To keep in mind when analyzing failed test output:
// If the same error position occurs multiple times in errors,
// this message will be triggered (because the first error at
// the position removes this position from the expected errors).
t.Errorf("%s: unexpected error: %s", error.Pos, error.Msg)
}
}
// there should be no expected errors left
if len(expected) > 0 {
t.Errorf("%d errors not reported:", len(expected))
for pos, msg := range expected {
t.Errorf("%s: %s\n", fset.Position(pos), msg)
}
}
}
func checkErrors(t *testing.T, filename string, input interface{}) {
src, err := readSource(filename, input)
if err != nil {
t.Error(err)
return
}
_, err = ParseFile(fset, filename, src, DeclarationErrors)
found, ok := err.(scanner.ErrorList)
if err != nil && !ok {
t.Error(err)
return
}
// we are expecting the following errors
// (collect these after parsing a file so that it is found in the file set)
expected := expectedErrors(t, filename, src)
// verify errors returned by the parser
compareErrors(t, expected, found)
}
func TestErrors(t *testing.T) {
list, err := ioutil.ReadDir(testdata)
if err != nil {
t.Fatal(err)
}
for _, fi := range list {
name := fi.Name()
if !fi.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".src") {
checkErrors(t, filepath.Join(testdata, name), nil)
}
}
}

View File

@ -40,6 +40,13 @@ type parser struct {
tok token.Token // one token look-ahead tok token.Token // one token look-ahead
lit string // token literal lit string // token literal
// Error recovery
// (used to limit the number of calls to syncXXX functions
// w/o making scanning progress - avoids potential endless
// loops across multiple parser functions during error recovery)
syncPos token.Pos // last synchronization position
syncCnt int // number of calls to syncXXX without progress
// Non-syntactic parser control // Non-syntactic parser control
exprLev int // < 0: in control clause, >= 0: in expression exprLev int // < 0: in control clause, >= 0: in expression
@ -362,26 +369,106 @@ func (p *parser) expect(tok token.Token) token.Pos {
// expectClosing is like expect but provides a better error message // expectClosing is like expect but provides a better error message
// for the common case of a missing comma before a newline. // for the common case of a missing comma before a newline.
// //
func (p *parser) expectClosing(tok token.Token, construct string) token.Pos { func (p *parser) expectClosing(tok token.Token, context string) token.Pos {
if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" { if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" {
p.error(p.pos, "missing ',' before newline in "+construct) p.error(p.pos, "missing ',' before newline in "+context)
p.next() p.next()
} }
return p.expect(tok) return p.expect(tok)
} }
func (p *parser) expectSemi() { func (p *parser) expectSemi() {
// semicolon is optional before a closing ')' or '}'
if p.tok != token.RPAREN && p.tok != token.RBRACE { if p.tok != token.RPAREN && p.tok != token.RBRACE {
p.expect(token.SEMICOLON) if p.tok == token.SEMICOLON {
p.next()
} else {
p.errorExpected(p.pos, "';'")
syncStmt(p)
}
} }
} }
func (p *parser) atComma(context string) bool {
if p.tok == token.COMMA {
return true
}
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.error(p.pos, "missing ',' before newline in "+context)
return true // "insert" the comma and continue
}
return false
}
func assert(cond bool, msg string) { func assert(cond bool, msg string) {
if !cond { if !cond {
panic("go/parser internal error: " + msg) panic("go/parser internal error: " + msg)
} }
} }
// syncStmt advances to the next statement.
// Used for synchronization after an error.
//
func syncStmt(p *parser) {
for {
switch p.tok {
case token.BREAK, token.CONST, token.CONTINUE, token.DEFER,
token.FALLTHROUGH, token.FOR, token.GO, token.GOTO,
token.IF, token.RETURN, token.SELECT, token.SWITCH,
token.TYPE, token.VAR:
// Return only if parser made some progress since last
// sync or if it has not reached 10 sync calls without
// progress. Otherwise consume at least one token to
// avoid an endless parser loop (it is possible that
// both parseOperand and parseStmt call syncStmt and
// correctly do not advance, thus the need for the
// invocation limit p.syncCnt).
if p.pos == p.syncPos && p.syncCnt < 10 {
p.syncCnt++
return
}
if p.pos > p.syncPos {
p.syncPos = p.pos
p.syncCnt = 0
return
}
// Reaching here indicates a parser bug, likely an
// incorrect token list in this function, but it only
// leads to skipping of possibly correct code if a
// previous error is present, and thus is preferred
// over a non-terminating parse.
case token.EOF:
return
}
p.next()
}
}
// syncDecl advances to the next declaration.
// Used for synchronization after an error.
//
func syncDecl(p *parser) {
for {
switch p.tok {
case token.CONST, token.TYPE, token.VAR:
// see comments in syncStmt
if p.pos == p.syncPos && p.syncCnt < 10 {
p.syncCnt++
return
}
if p.pos > p.syncPos {
p.syncPos = p.pos
p.syncCnt = 0
return
}
case token.EOF:
return
}
p.next()
}
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Identifiers // Identifiers
@ -522,9 +609,11 @@ func (p *parser) makeIdentList(list []ast.Expr) []*ast.Ident {
for i, x := range list { for i, x := range list {
ident, isIdent := x.(*ast.Ident) ident, isIdent := x.(*ast.Ident)
if !isIdent { if !isIdent {
pos := x.Pos() if _, isBad := x.(*ast.BadExpr); !isBad {
p.errorExpected(pos, "identifier") // only report error if it's a new one
ident = &ast.Ident{NamePos: pos, Name: "_"} p.errorExpected(x.Pos(), "identifier")
}
ident = &ast.Ident{NamePos: x.Pos(), Name: "_"}
} }
idents[i] = ident idents[i] = ident
} }
@ -688,7 +777,7 @@ func (p *parser) parseParameterList(scope *ast.Scope, ellipsisOk bool) (params [
// Go spec: The scope of an identifier denoting a function // Go spec: The scope of an identifier denoting a function
// parameter or result variable is the function body. // parameter or result variable is the function body.
p.declare(field, nil, scope, ast.Var, idents...) p.declare(field, nil, scope, ast.Var, idents...)
if p.tok != token.COMMA { if !p.atComma("parameter list") {
break break
} }
p.next() p.next()
@ -991,19 +1080,19 @@ func (p *parser) parseOperand(lhs bool) ast.Expr {
case token.FUNC: case token.FUNC:
return p.parseFuncTypeOrLit() return p.parseFuncTypeOrLit()
default:
if typ := p.tryIdentOrType(true); typ != nil {
// could be type for composite literal or conversion
_, isIdent := typ.(*ast.Ident)
assert(!isIdent, "type cannot be identifier")
return typ
}
} }
if typ := p.tryIdentOrType(true); typ != nil {
// could be type for composite literal or conversion
_, isIdent := typ.(*ast.Ident)
assert(!isIdent, "type cannot be identifier")
return typ
}
// we have an error
pos := p.pos pos := p.pos
p.errorExpected(pos, "operand") p.errorExpected(pos, "operand")
p.next() // make progress syncStmt(p)
return &ast.BadExpr{From: pos, To: p.pos} return &ast.BadExpr{From: pos, To: p.pos}
} }
@ -1078,7 +1167,7 @@ func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
ellipsis = p.pos ellipsis = p.pos
p.next() p.next()
} }
if p.tok != token.COMMA { if !p.atComma("argument list") {
break break
} }
p.next() p.next()
@ -1118,7 +1207,7 @@ func (p *parser) parseElementList() (list []ast.Expr) {
for p.tok != token.RBRACE && p.tok != token.EOF { for p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseElement(true)) list = append(list, p.parseElement(true))
if p.tok != token.COMMA { if !p.atComma("composite literal") {
break break
} }
p.next() p.next()
@ -1262,8 +1351,8 @@ L:
x = p.parseTypeAssertion(p.checkExpr(x)) x = p.parseTypeAssertion(p.checkExpr(x))
default: default:
pos := p.pos pos := p.pos
p.next() // make progress
p.errorExpected(pos, "selector or type assertion") p.errorExpected(pos, "selector or type assertion")
p.next() // make progress
x = &ast.BadExpr{From: pos, To: p.pos} x = &ast.BadExpr{From: pos, To: p.pos}
} }
case token.LBRACK: case token.LBRACK:
@ -1471,7 +1560,10 @@ func (p *parser) parseCallExpr() *ast.CallExpr {
if call, isCall := x.(*ast.CallExpr); isCall { if call, isCall := x.(*ast.CallExpr); isCall {
return call return call
} }
p.errorExpected(x.Pos(), "function/method call") if _, isBad := x.(*ast.BadExpr); !isBad {
// only report error if it's a new one
p.errorExpected(x.Pos(), "function/method call")
}
return nil return nil
} }
@ -1862,7 +1954,7 @@ func (p *parser) parseStmt() (s ast.Stmt) {
switch p.tok { switch p.tok {
case token.CONST, token.TYPE, token.VAR: case token.CONST, token.TYPE, token.VAR:
s = &ast.DeclStmt{Decl: p.parseDecl()} s = &ast.DeclStmt{Decl: p.parseDecl(syncStmt)}
case case
// tokens that may start an expression // tokens that may start an expression
token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands
@ -1904,7 +1996,7 @@ func (p *parser) parseStmt() (s ast.Stmt) {
// no statement found // no statement found
pos := p.pos pos := p.pos
p.errorExpected(pos, "statement") p.errorExpected(pos, "statement")
p.next() // make progress syncStmt(p)
s = &ast.BadStmt{From: pos, To: p.pos} s = &ast.BadStmt{From: pos, To: p.pos}
} }
@ -2095,8 +2187,13 @@ func (p *parser) parseReceiver(scope *ast.Scope) *ast.FieldList {
recv := par.List[0] recv := par.List[0]
base := deref(recv.Type) base := deref(recv.Type)
if _, isIdent := base.(*ast.Ident); !isIdent { if _, isIdent := base.(*ast.Ident); !isIdent {
p.errorExpected(base.Pos(), "(unqualified) identifier") if _, isBad := base.(*ast.BadExpr); !isBad {
par.List = []*ast.Field{{Type: &ast.BadExpr{From: recv.Pos(), To: recv.End()}}} // only report error if it's a new one
p.errorExpected(base.Pos(), "(unqualified) identifier")
}
par.List = []*ast.Field{
{Type: &ast.BadExpr{From: recv.Pos(), To: recv.End()}},
}
} }
return par return par
@ -2152,7 +2249,7 @@ func (p *parser) parseFuncDecl() *ast.FuncDecl {
return decl return decl
} }
func (p *parser) parseDecl() ast.Decl { func (p *parser) parseDecl(sync func(*parser)) ast.Decl {
if p.trace { if p.trace {
defer un(trace(p, "Declaration")) defer un(trace(p, "Declaration"))
} }
@ -2174,9 +2271,8 @@ func (p *parser) parseDecl() ast.Decl {
default: default:
pos := p.pos pos := p.pos
p.errorExpected(pos, "declaration") p.errorExpected(pos, "declaration")
p.next() // make progress sync(p)
decl := &ast.BadDecl{From: pos, To: p.pos} return &ast.BadDecl{From: pos, To: p.pos}
return decl
} }
return p.parseGenDecl(p.tok, f) return p.parseGenDecl(p.tok, f)
@ -2215,7 +2311,7 @@ func (p *parser) parseFile() *ast.File {
if p.mode&ImportsOnly == 0 { if p.mode&ImportsOnly == 0 {
// rest of package body // rest of package body
for p.tok != token.EOF { for p.tok != token.EOF {
decls = append(decls, p.parseDecl()) decls = append(decls, p.parseDecl(syncDecl))
} }
} }
} }

View File

@ -14,87 +14,14 @@ import (
var fset = token.NewFileSet() var fset = token.NewFileSet()
var illegalInputs = []interface{}{
nil,
3.14,
[]byte(nil),
"foo!",
`package p; func f() { if /* should have condition */ {} };`,
`package p; func f() { if ; /* should have condition */ {} };`,
`package p; func f() { if f(); /* should have condition */ {} };`,
`package p; const c; /* should have constant value */`,
`package p; func f() { if _ = range x; true {} };`,
`package p; func f() { switch _ = range x; true {} };`,
`package p; func f() { for _ = range x ; ; {} };`,
`package p; func f() { for ; ; _ = range x {} };`,
`package p; func f() { for ; _ = range x ; {} };`,
`package p; func f() { switch t = t.(type) {} };`,
`package p; func f() { switch t, t = t.(type) {} };`,
`package p; func f() { switch t = t.(type), t {} };`,
`package p; var a = [1]int; /* illegal expression */`,
`package p; var a = [...]int; /* illegal expression */`,
`package p; var a = struct{} /* illegal expression */`,
`package p; var a = func(); /* illegal expression */`,
`package p; var a = interface{} /* illegal expression */`,
`package p; var a = []int /* illegal expression */`,
`package p; var a = map[int]int /* illegal expression */`,
`package p; var a = chan int; /* illegal expression */`,
`package p; var a = []int{[]int}; /* illegal expression */`,
`package p; var a = ([]int); /* illegal expression */`,
`package p; var a = a[[]int:[]int]; /* illegal expression */`,
`package p; var a = <- chan int; /* illegal expression */`,
`package p; func f() { select { case _ <- chan int: } };`,
}
func TestParseIllegalInputs(t *testing.T) {
for _, src := range illegalInputs {
_, err := ParseFile(fset, "", src, 0)
if err == nil {
t.Errorf("ParseFile(%v) should have failed", src)
}
}
}
var validPrograms = []string{
"package p\n",
`package p;`,
`package p; import "fmt"; func f() { fmt.Println("Hello, World!") };`,
`package p; func f() { if f(T{}) {} };`,
`package p; func f() { _ = (<-chan int)(x) };`,
`package p; func f() { _ = (<-chan <-chan int)(x) };`,
`package p; func f(func() func() func());`,
`package p; func f(...T);`,
`package p; func f(float, ...int);`,
`package p; func f(x int, a ...int) { f(0, a...); f(1, a...,) };`,
`package p; func f(int,) {};`,
`package p; func f(...int,) {};`,
`package p; func f(x ...int,) {};`,
`package p; type T []int; var a []bool; func f() { if a[T{42}[0]] {} };`,
`package p; type T []int; func g(int) bool { return true }; func f() { if g(T{42}[0]) {} };`,
`package p; type T []int; func f() { for _ = range []int{T{42}[0]} {} };`,
`package p; var a = T{{1, 2}, {3, 4}}`,
`package p; func f() { select { case <- c: case c <- d: case c <- <- d: case <-c <- d: } };`,
`package p; func f() { select { case x := (<-c): } };`,
`package p; func f() { if ; true {} };`,
`package p; func f() { switch ; {} };`,
`package p; func f() { for _ = range "foo" + "bar" {} };`,
}
func TestParseValidPrograms(t *testing.T) {
for _, src := range validPrograms {
_, err := ParseFile(fset, "", src, SpuriousErrors)
if err != nil {
t.Errorf("ParseFile(%q): %v", src, err)
}
}
}
var validFiles = []string{ var validFiles = []string{
"parser.go", "parser.go",
"parser_test.go", "parser_test.go",
"error_test.go",
"short_test.go",
} }
func TestParse3(t *testing.T) { func TestParse(t *testing.T) {
for _, filename := range validFiles { for _, filename := range validFiles {
_, err := ParseFile(fset, filename, nil, DeclarationErrors) _, err := ParseFile(fset, filename, nil, DeclarationErrors)
if err != nil { if err != nil {
@ -116,7 +43,7 @@ func nameFilter(filename string) bool {
func dirFilter(f os.FileInfo) bool { return nameFilter(f.Name()) } func dirFilter(f os.FileInfo) bool { return nameFilter(f.Name()) }
func TestParse4(t *testing.T) { func TestParseDir(t *testing.T) {
path := "." path := "."
pkgs, err := ParseDir(fset, path, dirFilter, 0) pkgs, err := ParseDir(fset, path, dirFilter, 0)
if err != nil { if err != nil {
@ -158,7 +85,7 @@ func TestParseExpr(t *testing.T) {
} }
// it must not crash // it must not crash
for _, src := range validPrograms { for _, src := range valids {
ParseExpr(src) ParseExpr(src)
} }
} }

View File

@ -0,0 +1,75 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains test cases for short valid and invalid programs.
package parser
import "testing"
var valids = []string{
"package p\n",
`package p;`,
`package p; import "fmt"; func f() { fmt.Println("Hello, World!") };`,
`package p; func f() { if f(T{}) {} };`,
`package p; func f() { _ = (<-chan int)(x) };`,
`package p; func f() { _ = (<-chan <-chan int)(x) };`,
`package p; func f(func() func() func());`,
`package p; func f(...T);`,
`package p; func f(float, ...int);`,
`package p; func f(x int, a ...int) { f(0, a...); f(1, a...,) };`,
`package p; func f(int,) {};`,
`package p; func f(...int,) {};`,
`package p; func f(x ...int,) {};`,
`package p; type T []int; var a []bool; func f() { if a[T{42}[0]] {} };`,
`package p; type T []int; func g(int) bool { return true }; func f() { if g(T{42}[0]) {} };`,
`package p; type T []int; func f() { for _ = range []int{T{42}[0]} {} };`,
`package p; var a = T{{1, 2}, {3, 4}}`,
`package p; func f() { select { case <- c: case c <- d: case c <- <- d: case <-c <- d: } };`,
`package p; func f() { select { case x := (<-c): } };`,
`package p; func f() { if ; true {} };`,
`package p; func f() { switch ; {} };`,
`package p; func f() { for _ = range "foo" + "bar" {} };`,
}
func TestValid(t *testing.T) {
for _, src := range valids {
checkErrors(t, src, src)
}
}
var invalids = []string{
`foo /* ERROR "expected 'package'" */ !`,
`package p; func f() { if { /* ERROR "expected operand" */ } };`,
`package p; func f() { if ; { /* ERROR "expected operand" */ } };`,
`package p; func f() { if f(); { /* ERROR "expected operand" */ } };`,
`package p; const c; /* ERROR "expected '='" */`,
`package p; func f() { if _ /* ERROR "expected condition" */ = range x; true {} };`,
`package p; func f() { switch _ /* ERROR "expected condition" */ = range x; true {} };`,
`package p; func f() { for _ = range x ; /* ERROR "expected '{'" */ ; {} };`,
`package p; func f() { for ; ; _ = range /* ERROR "expected operand" */ x {} };`,
`package p; func f() { for ; _ /* ERROR "expected condition" */ = range x ; {} };`,
`package p; func f() { switch t /* ERROR "expected condition" */ = t.(type) {} };`,
`package p; func f() { switch t /* ERROR "expected condition" */ , t = t.(type) {} };`,
`package p; func f() { switch t /* ERROR "expected condition" */ = t.(type), t {} };`,
`package p; var a = [ /* ERROR "expected expression" */ 1]int;`,
`package p; var a = [ /* ERROR "expected expression" */ ...]int;`,
`package p; var a = struct /* ERROR "expected expression" */ {}`,
`package p; var a = func /* ERROR "expected expression" */ ();`,
`package p; var a = interface /* ERROR "expected expression" */ {}`,
`package p; var a = [ /* ERROR "expected expression" */ ]int`,
`package p; var a = map /* ERROR "expected expression" */ [int]int`,
`package p; var a = chan /* ERROR "expected expression" */ int;`,
`package p; var a = []int{[ /* ERROR "expected expression" */ ]int};`,
`package p; var a = ( /* ERROR "expected expression" */ []int);`,
`package p; var a = a[[ /* ERROR "expected expression" */ ]int:[]int];`,
`package p; var a = <- /* ERROR "expected expression" */ chan int;`,
`package p; func f() { select { case _ <- chan /* ERROR "expected expression" */ int: } };`,
}
func TestInvalid(t *testing.T) {
for _, src := range invalids {
checkErrors(t, src, src)
}
}

19
libgo/go/go/parser/testdata/commas.src vendored Normal file
View File

@ -0,0 +1,19 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Test case for error messages/parser synchronization
// after missing commas.
package p
var _ = []int{
0 /* ERROR "missing ','" */
}
var _ = []int{
0,
1,
2,
3 /* ERROR "missing ','" */
}

View File

@ -0,0 +1,46 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Test case for issue 3106: Better synchronization of
// parser after certain syntax errors.
package main
func f() {
var m Mutex
c := MakeCond(&m)
percent := 0
const step = 10
for i := 0; i < 5; i++ {
go func() {
for {
// Emulates some useful work.
time.Sleep(1e8)
m.Lock()
defer
if /* ERROR "expected operand, found 'if'" */ percent == 100 {
m.Unlock()
break
}
percent++
if percent % step == 0 {
//c.Signal()
}
m.Unlock()
}
}()
}
for {
m.Lock()
if percent == 0 || percent % step != 0 {
c.Wait()
}
fmt.Print(",")
if percent == 100 {
m.Unlock()
break
}
m.Unlock()
}
}

View File

@ -15,7 +15,7 @@ import (
"unicode/utf8" "unicode/utf8"
) )
// Other formatting issues: // Formatting issues:
// - better comment formatting for /*-style comments at the end of a line (e.g. a declaration) // - better comment formatting for /*-style comments at the end of a line (e.g. a declaration)
// when the comment spans multiple lines; if such a comment is just two lines, formatting is // when the comment spans multiple lines; if such a comment is just two lines, formatting is
// not idempotent // not idempotent
@ -964,6 +964,41 @@ func (p *printer) controlClause(isForStmt bool, init ast.Stmt, expr ast.Expr, po
} }
} }
// indentList reports whether an expression list would look better if it
// were indented wholesale (starting with the very first element, rather
// than starting at the first line break).
//
func (p *printer) indentList(list []ast.Expr) bool {
// Heuristic: indentList returns true if there are more than one multi-
// line element in the list, or if there is any element that is not
// starting on the same line as the previous one ends.
if len(list) >= 2 {
var b = p.lineFor(list[0].Pos())
var e = p.lineFor(list[len(list)-1].End())
if 0 < b && b < e {
// list spans multiple lines
n := 0 // multi-line element count
line := b
for _, x := range list {
xb := p.lineFor(x.Pos())
xe := p.lineFor(x.End())
if line < xb {
// x is not starting on the same
// line as the previous one ended
return true
}
if xb < xe {
// x is a multi-line element
n++
}
line = xe
}
return n > 1
}
}
return false
}
func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) { func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
p.print(stmt.Pos()) p.print(stmt.Pos())
@ -1030,7 +1065,18 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
p.print(token.RETURN) p.print(token.RETURN)
if s.Results != nil { if s.Results != nil {
p.print(blank) p.print(blank)
p.exprList(s.Pos(), s.Results, 1, 0, token.NoPos) // Use indentList heuristic to make corner cases look
// better (issue 1207). A more systematic approach would
// always indent, but this would cause significant
// reformatting of the code base and not necessarily
// lead to more nicely formatted code in general.
if p.indentList(s.Results) {
p.print(indent)
p.exprList(s.Pos(), s.Results, 1, noIndent, token.NoPos)
p.print(unindent)
} else {
p.exprList(s.Pos(), s.Results, 1, 0, token.NoPos)
}
} }
case *ast.BranchStmt: case *ast.BranchStmt:
@ -1200,9 +1246,9 @@ func keepTypeColumn(specs []ast.Spec) []bool {
return m return m
} }
func (p *printer) valueSpec(s *ast.ValueSpec, keepType, doIndent bool) { func (p *printer) valueSpec(s *ast.ValueSpec, keepType bool) {
p.setComment(s.Doc) p.setComment(s.Doc)
p.identList(s.Names, doIndent) // always present p.identList(s.Names, false) // always present
extraTabs := 3 extraTabs := 3
if s.Type != nil || keepType { if s.Type != nil || keepType {
p.print(vtab) p.print(vtab)
@ -1290,7 +1336,7 @@ func (p *printer) genDecl(d *ast.GenDecl) {
if i > 0 { if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, newSection) p.linebreak(p.lineFor(s.Pos()), 1, ignore, newSection)
} }
p.valueSpec(s.(*ast.ValueSpec), keepType[i], false) p.valueSpec(s.(*ast.ValueSpec), keepType[i])
newSection = p.isMultiLine(s) newSection = p.isMultiLine(s)
} }
} else { } else {

View File

@ -55,12 +55,24 @@ func _f() {
return T{ return T{
1, 1,
2, 2,
}, }, nil
return T{
1,
2,
},
T{
x: 3,
y: 4,
}, nil
return T{
1,
2,
},
nil nil
return T{ return T{
1, 1,
2, 2,
}, },
T{ T{
x: 3, x: 3,
y: 4, y: 4,
@ -70,10 +82,10 @@ func _f() {
z z
return func() {} return func() {}
return func() { return func() {
_ = 0 _ = 0
}, T{ }, T{
1, 2, 1, 2,
} }
return func() { return func() {
_ = 0 _ = 0
} }
@ -84,6 +96,37 @@ func _f() {
} }
} }
// Formatting of multi-line returns: test cases from issue 1207.
func F() (*T, os.Error) {
return &T{
X: 1,
Y: 2,
},
nil
}
func G() (*T, *T, os.Error) {
return &T{
X: 1,
Y: 2,
},
&T{
X: 3,
Y: 4,
},
nil
}
func _() interface{} {
return &fileStat{
name: basename(file.name),
size: mkSize(d.FileSizeHigh, d.FileSizeLow),
modTime: mkModTime(d.LastWriteTime),
mode: mkMode(d.FileAttributes),
sys: mkSysFromFI(&d),
}, nil
}
// Formatting of if-statement headers. // Formatting of if-statement headers.
func _() { func _() {
if true { if true {

View File

@ -52,6 +52,18 @@ func _f() {
3}, 3},
3, 3,
} }
return T{
1,
2,
}, nil
return T{
1,
2,
},
T{
x: 3,
y: 4,
}, nil
return T{ return T{
1, 1,
2, 2,
@ -84,6 +96,37 @@ func _f() {
} }
} }
// Formatting of multi-line returns: test cases from issue 1207.
func F() (*T, os.Error) {
return &T{
X: 1,
Y: 2,
},
nil
}
func G() (*T, *T, os.Error) {
return &T{
X: 1,
Y: 2,
},
&T{
X: 3,
Y: 4,
},
nil
}
func _() interface{} {
return &fileStat{
name: basename(file.name),
size: mkSize(d.FileSizeHigh, d.FileSizeLow),
modTime: mkModTime(d.LastWriteTime),
mode: mkMode(d.FileAttributes),
sys: mkSysFromFI(&d),
}, nil
}
// Formatting of if-statement headers. // Formatting of if-statement headers.
func _() { func _() {
if true {} if true {}

View File

@ -109,7 +109,7 @@ const (
func (s *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode Mode) { func (s *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode Mode) {
// Explicitly initialize all fields since a scanner may be reused. // Explicitly initialize all fields since a scanner may be reused.
if file.Size() != len(src) { if file.Size() != len(src) {
panic("file size does not match src len") panic(fmt.Sprintf("file size (%d) does not match src len (%d)", file.Size(), len(src)))
} }
s.file = file s.file = file
s.dir, _ = filepath.Split(file.Name()) s.dir, _ = filepath.Split(file.Name())

View File

@ -29,7 +29,7 @@ can be safely embedded in an HTML document. The escaping is contextual, so
actions can appear within JavaScript, CSS, and URI contexts. actions can appear within JavaScript, CSS, and URI contexts.
The security model used by this package assumes that template authors are The security model used by this package assumes that template authors are
trusted, while text/template Execute's data parameter is not. More details are trusted, while Execute's data parameter is not. More details are
provided below. provided below.
Example Example

View File

@ -173,6 +173,13 @@ type ReaderAt interface {
// at offset off. It returns the number of bytes written from p (0 <= n <= len(p)) // at offset off. It returns the number of bytes written from p (0 <= n <= len(p))
// and any error encountered that caused the write to stop early. // and any error encountered that caused the write to stop early.
// WriteAt must return a non-nil error if it returns n < len(p). // WriteAt must return a non-nil error if it returns n < len(p).
//
// If WriteAt is writing to a destination with a seek offset,
// WriteAt should not affect nor be affected by the underlying
// seek offset.
//
// Clients of WriteAt can execute parallel WriteAt calls on the same
// destination if the ranges do not overlap.
type WriterAt interface { type WriterAt interface {
WriteAt(p []byte, off int64) (n int, err error) WriteAt(p []byte, off int64) (n int, err error)
} }

View File

@ -13,8 +13,6 @@
package log package log
import ( import (
"bytes"
_ "debug/elf"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -29,7 +27,7 @@ const (
// order they appear (the order listed here) or the format they present (as // order they appear (the order listed here) or the format they present (as
// described in the comments). A colon appears after these items: // described in the comments). A colon appears after these items:
// 2009/0123 01:23:23.123123 /a/b/c/d.go:23: message // 2009/0123 01:23:23.123123 /a/b/c/d.go:23: message
Ldate = 1 << iota // the date: 2009/0123 Ldate = 1 << iota // the date: 2009/01/23
Ltime // the time: 01:23:23 Ltime // the time: 01:23:23
Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime.
Llongfile // full file name and line number: /a/b/c/d.go:23 Llongfile // full file name and line number: /a/b/c/d.go:23
@ -42,11 +40,11 @@ const (
// the Writer's Write method. A Logger can be used simultaneously from // the Writer's Write method. A Logger can be used simultaneously from
// multiple goroutines; it guarantees to serialize access to the Writer. // multiple goroutines; it guarantees to serialize access to the Writer.
type Logger struct { type Logger struct {
mu sync.Mutex // ensures atomic writes; protects the following fields mu sync.Mutex // ensures atomic writes; protects the following fields
prefix string // prefix to write at beginning of each line prefix string // prefix to write at beginning of each line
flag int // properties flag int // properties
out io.Writer // destination for output out io.Writer // destination for output
buf bytes.Buffer // for accumulating text to write buf []byte // for accumulating text to write
} }
// New creates a new Logger. The out variable sets the // New creates a new Logger. The out variable sets the
@ -61,10 +59,10 @@ var std = New(os.Stderr, "", LstdFlags)
// Cheap integer to fixed-width decimal ASCII. Give a negative width to avoid zero-padding. // Cheap integer to fixed-width decimal ASCII. Give a negative width to avoid zero-padding.
// Knows the buffer has capacity. // Knows the buffer has capacity.
func itoa(buf *bytes.Buffer, i int, wid int) { func itoa(buf *[]byte, i int, wid int) {
var u uint = uint(i) var u uint = uint(i)
if u == 0 && wid <= 1 { if u == 0 && wid <= 1 {
buf.WriteByte('0') *buf = append(*buf, '0')
return return
} }
@ -76,38 +74,33 @@ func itoa(buf *bytes.Buffer, i int, wid int) {
wid-- wid--
b[bp] = byte(u%10) + '0' b[bp] = byte(u%10) + '0'
} }
*buf = append(*buf, b[bp:]...)
// avoid slicing b to avoid an allocation.
for bp < len(b) {
buf.WriteByte(b[bp])
bp++
}
} }
func (l *Logger) formatHeader(buf *bytes.Buffer, t time.Time, file string, line int) { func (l *Logger) formatHeader(buf *[]byte, t time.Time, file string, line int) {
buf.WriteString(l.prefix) *buf = append(*buf, l.prefix...)
if l.flag&(Ldate|Ltime|Lmicroseconds) != 0 { if l.flag&(Ldate|Ltime|Lmicroseconds) != 0 {
if l.flag&Ldate != 0 { if l.flag&Ldate != 0 {
year, month, day := t.Date() year, month, day := t.Date()
itoa(buf, year, 4) itoa(buf, year, 4)
buf.WriteByte('/') *buf = append(*buf, '/')
itoa(buf, int(month), 2) itoa(buf, int(month), 2)
buf.WriteByte('/') *buf = append(*buf, '/')
itoa(buf, day, 2) itoa(buf, day, 2)
buf.WriteByte(' ') *buf = append(*buf, ' ')
} }
if l.flag&(Ltime|Lmicroseconds) != 0 { if l.flag&(Ltime|Lmicroseconds) != 0 {
hour, min, sec := t.Clock() hour, min, sec := t.Clock()
itoa(buf, hour, 2) itoa(buf, hour, 2)
buf.WriteByte(':') *buf = append(*buf, ':')
itoa(buf, min, 2) itoa(buf, min, 2)
buf.WriteByte(':') *buf = append(*buf, ':')
itoa(buf, sec, 2) itoa(buf, sec, 2)
if l.flag&Lmicroseconds != 0 { if l.flag&Lmicroseconds != 0 {
buf.WriteByte('.') *buf = append(*buf, '.')
itoa(buf, t.Nanosecond()/1e3, 6) itoa(buf, t.Nanosecond()/1e3, 6)
} }
buf.WriteByte(' ') *buf = append(*buf, ' ')
} }
} }
if l.flag&(Lshortfile|Llongfile) != 0 { if l.flag&(Lshortfile|Llongfile) != 0 {
@ -121,10 +114,10 @@ func (l *Logger) formatHeader(buf *bytes.Buffer, t time.Time, file string, line
} }
file = short file = short
} }
buf.WriteString(file) *buf = append(*buf, file...)
buf.WriteByte(':') *buf = append(*buf, ':')
itoa(buf, line, -1) itoa(buf, line, -1)
buf.WriteString(": ") *buf = append(*buf, ": "...)
} }
} }
@ -151,13 +144,13 @@ func (l *Logger) Output(calldepth int, s string) error {
} }
l.mu.Lock() l.mu.Lock()
} }
l.buf.Reset() l.buf = l.buf[:0]
l.formatHeader(&l.buf, now, file, line) l.formatHeader(&l.buf, now, file, line)
l.buf.WriteString(s) l.buf = append(l.buf, s...)
if len(s) > 0 && s[len(s)-1] != '\n' { if len(s) > 0 && s[len(s)-1] != '\n' {
l.buf.WriteByte('\n') l.buf = append(l.buf, '\n')
} }
_, err := l.out.Write(l.buf.Bytes()) _, err := l.out.Write(l.buf)
return err return err
} }

View File

@ -6,6 +6,7 @@ package net
import ( import (
"flag" "flag"
"fmt"
"regexp" "regexp"
"runtime" "runtime"
"testing" "testing"
@ -32,7 +33,7 @@ func TestDialTimeout(t *testing.T) {
numConns := listenerBacklog + 10 numConns := listenerBacklog + 10
// TODO(bradfitz): It's hard to test this in a portable // TODO(bradfitz): It's hard to test this in a portable
// way. This is unforunate, but works for now. // way. This is unfortunate, but works for now.
switch runtime.GOOS { switch runtime.GOOS {
case "linux": case "linux":
// The kernel will start accepting TCP connections before userspace // The kernel will start accepting TCP connections before userspace
@ -44,13 +45,25 @@ func TestDialTimeout(t *testing.T) {
errc <- err errc <- err
}() }()
} }
case "darwin": case "darwin", "windows":
// At least OS X 10.7 seems to accept any number of // At least OS X 10.7 seems to accept any number of
// connections, ignoring listen's backlog, so resort // connections, ignoring listen's backlog, so resort
// to connecting to a hopefully-dead 127/8 address. // to connecting to a hopefully-dead 127/8 address.
// Same for windows. // Same for windows.
//
// Use an IANA reserved port (49151) instead of 80, because
// on our 386 builder, this Dial succeeds, connecting
// to an IIS web server somewhere. The data center
// or VM or firewall must be stealing the TCP connection.
//
// IANA Service Name and Transport Protocol Port Number Registry
// <http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xml>
go func() { go func() {
_, err := DialTimeout("tcp", "127.0.71.111:80", 200*time.Millisecond) c, err := DialTimeout("tcp", "127.0.71.111:49151", 200*time.Millisecond)
if err == nil {
err = fmt.Errorf("unexpected: connected to %s!", c.RemoteAddr())
c.Close()
}
errc <- err errc <- err
}() }()
default: default:

View File

@ -5,8 +5,6 @@
package net package net
import ( import (
"bytes"
"fmt"
"math/rand" "math/rand"
"sort" "sort"
) )
@ -45,20 +43,22 @@ func reverseaddr(addr string) (arpa string, err error) {
return "", &DNSError{Err: "unrecognized address", Name: addr} return "", &DNSError{Err: "unrecognized address", Name: addr}
} }
if ip.To4() != nil { if ip.To4() != nil {
return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa.", ip[15], ip[14], ip[13], ip[12]), nil return itoa(int(ip[15])) + "." + itoa(int(ip[14])) + "." + itoa(int(ip[13])) + "." +
itoa(int(ip[12])) + ".in-addr.arpa.", nil
} }
// Must be IPv6 // Must be IPv6
var buf bytes.Buffer buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
// Add it, in reverse, to the buffer // Add it, in reverse, to the buffer
for i := len(ip) - 1; i >= 0; i-- { for i := len(ip) - 1; i >= 0; i-- {
s := fmt.Sprintf("%02x", ip[i]) v := ip[i]
buf.WriteByte(s[1]) buf = append(buf, hexDigit[v&0xF])
buf.WriteByte('.') buf = append(buf, '.')
buf.WriteByte(s[0]) buf = append(buf, hexDigit[v>>4])
buf.WriteByte('.') buf = append(buf, '.')
} }
// Append "ip6.arpa." and return (buf already has the final .) // Append "ip6.arpa." and return (buf already has the final .)
return buf.String() + "ip6.arpa.", nil buf = append(buf, "ip6.arpa."...)
return string(buf), nil
} }
// Find answer for name in dns message. // Find answer for name in dns message.

View File

@ -7,11 +7,10 @@
// This is intended to support name resolution during Dial. // This is intended to support name resolution during Dial.
// It doesn't have to be blazing fast. // It doesn't have to be blazing fast.
// //
// Rather than write the usual handful of routines to pack and // Each message structure has a Walk method that is used by
// unpack every message that can appear on the wire, we use // a generic pack/unpack routine. Thus, if in the future we need
// reflection to write a generic pack/unpack for structs and then // to define new message structs, no new pack/unpack/printing code
// use it. Thus, if in the future we need to define new message // needs to be written.
// structs, no new pack/unpack/printing code needs to be written.
// //
// The first half of this file defines the DNS message formats. // The first half of this file defines the DNS message formats.
// The second half implements the conversion to and from wire format. // The second half implements the conversion to and from wire format.
@ -23,12 +22,6 @@
package net package net
import (
"fmt"
"os"
"reflect"
)
// Packet formats // Packet formats
// Wire constants. // Wire constants.
@ -75,6 +68,20 @@ const (
dnsRcodeRefused = 5 dnsRcodeRefused = 5
) )
// A dnsStruct describes how to iterate over its fields to emulate
// reflective marshalling.
type dnsStruct interface {
// Walk iterates over fields of a structure and calls f
// with a reference to that field, the name of the field
// and a tag ("", "domain", "ipv4", "ipv6") specifying
// particular encodings. Possible concrete types
// for v are *uint16, *uint32, *string, or []byte, and
// *int, *bool in the case of dnsMsgHdr.
// Whenever f returns false, Walk must stop and return
// false, and otherwise return true.
Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool)
}
// The wire format for the DNS packet header. // The wire format for the DNS packet header.
type dnsHeader struct { type dnsHeader struct {
Id uint16 Id uint16
@ -82,6 +89,15 @@ type dnsHeader struct {
Qdcount, Ancount, Nscount, Arcount uint16 Qdcount, Ancount, Nscount, Arcount uint16
} }
func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&h.Id, "Id", "") &&
f(&h.Bits, "Bits", "") &&
f(&h.Qdcount, "Qdcount", "") &&
f(&h.Ancount, "Ancount", "") &&
f(&h.Nscount, "Nscount", "") &&
f(&h.Arcount, "Arcount", "")
}
const ( const (
// dnsHeader.Bits // dnsHeader.Bits
_QR = 1 << 15 // query/response (response=1) _QR = 1 << 15 // query/response (response=1)
@ -98,6 +114,12 @@ type dnsQuestion struct {
Qclass uint16 Qclass uint16
} }
func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&q.Name, "Name", "domain") &&
f(&q.Qtype, "Qtype", "") &&
f(&q.Qclass, "Qclass", "")
}
// DNS responses (resource records). // DNS responses (resource records).
// There are many types of messages, // There are many types of messages,
// but they all share the same header. // but they all share the same header.
@ -113,7 +135,16 @@ func (h *dnsRR_Header) Header() *dnsRR_Header {
return h return h
} }
func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&h.Name, "Name", "domain") &&
f(&h.Rrtype, "Rrtype", "") &&
f(&h.Class, "Class", "") &&
f(&h.Ttl, "Ttl", "") &&
f(&h.Rdlength, "Rdlength", "")
}
type dnsRR interface { type dnsRR interface {
dnsStruct
Header() *dnsRR_Header Header() *dnsRR_Header
} }
@ -128,6 +159,10 @@ func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain")
}
type dnsRR_HINFO struct { type dnsRR_HINFO struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Cpu string Cpu string
@ -138,6 +173,10 @@ func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_HINFO) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Cpu, "Cpu", "") && f(&rr.Os, "Os", "")
}
type dnsRR_MB struct { type dnsRR_MB struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Mb string `net:"domain-name"` Mb string `net:"domain-name"`
@ -147,6 +186,10 @@ func (rr *dnsRR_MB) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_MB) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Mb, "Mb", "domain")
}
type dnsRR_MG struct { type dnsRR_MG struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Mg string `net:"domain-name"` Mg string `net:"domain-name"`
@ -156,6 +199,10 @@ func (rr *dnsRR_MG) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_MG) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Mg, "Mg", "domain")
}
type dnsRR_MINFO struct { type dnsRR_MINFO struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Rmail string `net:"domain-name"` Rmail string `net:"domain-name"`
@ -166,6 +213,10 @@ func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_MINFO) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Rmail, "Rmail", "domain") && f(&rr.Email, "Email", "domain")
}
type dnsRR_MR struct { type dnsRR_MR struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Mr string `net:"domain-name"` Mr string `net:"domain-name"`
@ -175,6 +226,10 @@ func (rr *dnsRR_MR) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_MR) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Mr, "Mr", "domain")
}
type dnsRR_MX struct { type dnsRR_MX struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Pref uint16 Pref uint16
@ -185,6 +240,10 @@ func (rr *dnsRR_MX) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain")
}
type dnsRR_NS struct { type dnsRR_NS struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Ns string `net:"domain-name"` Ns string `net:"domain-name"`
@ -194,6 +253,10 @@ func (rr *dnsRR_NS) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain")
}
type dnsRR_PTR struct { type dnsRR_PTR struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Ptr string `net:"domain-name"` Ptr string `net:"domain-name"`
@ -203,6 +266,10 @@ func (rr *dnsRR_PTR) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain")
}
type dnsRR_SOA struct { type dnsRR_SOA struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Ns string `net:"domain-name"` Ns string `net:"domain-name"`
@ -218,6 +285,17 @@ func (rr *dnsRR_SOA) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) &&
f(&rr.Ns, "Ns", "domain") &&
f(&rr.Mbox, "Mbox", "domain") &&
f(&rr.Serial, "Serial", "") &&
f(&rr.Refresh, "Refresh", "") &&
f(&rr.Retry, "Retry", "") &&
f(&rr.Expire, "Expire", "") &&
f(&rr.Minttl, "Minttl", "")
}
type dnsRR_TXT struct { type dnsRR_TXT struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Txt string // not domain name Txt string // not domain name
@ -227,6 +305,10 @@ func (rr *dnsRR_TXT) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.Txt, "Txt", "")
}
type dnsRR_SRV struct { type dnsRR_SRV struct {
Hdr dnsRR_Header Hdr dnsRR_Header
Priority uint16 Priority uint16
@ -239,6 +321,14 @@ func (rr *dnsRR_SRV) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) &&
f(&rr.Priority, "Priority", "") &&
f(&rr.Weight, "Weight", "") &&
f(&rr.Port, "Port", "") &&
f(&rr.Target, "Target", "domain")
}
type dnsRR_A struct { type dnsRR_A struct {
Hdr dnsRR_Header Hdr dnsRR_Header
A uint32 `net:"ipv4"` A uint32 `net:"ipv4"`
@ -248,6 +338,10 @@ func (rr *dnsRR_A) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4")
}
type dnsRR_AAAA struct { type dnsRR_AAAA struct {
Hdr dnsRR_Header Hdr dnsRR_Header
AAAA [16]byte `net:"ipv6"` AAAA [16]byte `net:"ipv6"`
@ -257,6 +351,10 @@ func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
return &rr.Hdr return &rr.Hdr
} }
func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool {
return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6")
}
// Packing and unpacking. // Packing and unpacking.
// //
// All the packers and unpackers take a (msg []byte, off int) // All the packers and unpackers take a (msg []byte, off int)
@ -386,134 +484,107 @@ Loop:
return s, off1, true return s, off1, true
} }
// TODO(rsc): Move into generic library? // packStruct packs a structure into msg at specified offset off, and
// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, // returns off1 such that msg[off:off1] is the encoded data.
// [n]byte, and other (often anonymous) structs. func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { ok = any.Walk(func(field interface{}, name, tag string) bool {
for i := 0; i < val.NumField(); i++ { switch fv := field.(type) {
f := val.Type().Field(i)
switch fv := val.Field(i); fv.Kind() {
default: default:
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) println("net: dns: unknown packing type")
return len(msg), false return false
case reflect.Struct: case *uint16:
off, ok = packStructValue(fv, msg, off) i := *fv
case reflect.Uint16:
if off+2 > len(msg) { if off+2 > len(msg) {
return len(msg), false return false
} }
i := fv.Uint()
msg[off] = byte(i >> 8) msg[off] = byte(i >> 8)
msg[off+1] = byte(i) msg[off+1] = byte(i)
off += 2 off += 2
case reflect.Uint32: case *uint32:
if off+4 > len(msg) { i := *fv
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 24) msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16) msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8) msg[off+2] = byte(i >> 8)
msg[off+3] = byte(i) msg[off+3] = byte(i)
off += 4 off += 4
case reflect.Array: case []byte:
if fv.Type().Elem().Kind() != reflect.Uint8 { n := len(fv)
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
return len(msg), false
}
n := fv.Len()
if off+n > len(msg) { if off+n > len(msg) {
return len(msg), false return false
} }
reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv) copy(msg[off:off+n], fv)
off += n off += n
case reflect.String: case *string:
// There are multiple string encodings. s := *fv
// The tag distinguishes ordinary strings from domain names. switch tag {
s := fv.String()
switch f.Tag {
default: default:
fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) println("net: dns: unknown string tag", tag)
return len(msg), false return false
case `net:"domain-name"`: case "domain":
off, ok = packDomainName(s, msg, off) off, ok = packDomainName(s, msg, off)
if !ok { if !ok {
return len(msg), false return false
} }
case "": case "":
// Counted string: 1 byte length. // Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > len(msg) { if len(s) > 255 || off+1+len(s) > len(msg) {
return len(msg), false return false
} }
msg[off] = byte(len(s)) msg[off] = byte(len(s))
off++ off++
off += copy(msg[off:], s) off += copy(msg[off:], s)
} }
} }
return true
})
if !ok {
return len(msg), false
} }
return off, true return off, true
} }
func structValue(any interface{}) reflect.Value { // unpackStruct decodes msg[off:] into the given structure, and
return reflect.ValueOf(any).Elem() // returns off1 such that msg[off:off1] is the encoded data.
} func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
ok = any.Walk(func(field interface{}, name, tag string) bool {
func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { switch fv := field.(type) {
off, ok = packStructValue(structValue(any), msg, off)
return off, ok
}
// TODO(rsc): Move into generic library?
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
for i := 0; i < val.NumField(); i++ {
f := val.Type().Field(i)
switch fv := val.Field(i); fv.Kind() {
default: default:
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) println("net: dns: unknown packing type")
return len(msg), false return false
case reflect.Struct: case *uint16:
off, ok = unpackStructValue(fv, msg, off)
case reflect.Uint16:
if off+2 > len(msg) { if off+2 > len(msg) {
return len(msg), false return false
} }
i := uint16(msg[off])<<8 | uint16(msg[off+1]) *fv = uint16(msg[off])<<8 | uint16(msg[off+1])
fv.SetUint(uint64(i))
off += 2 off += 2
case reflect.Uint32: case *uint32:
if off+4 > len(msg) { if off+4 > len(msg) {
return len(msg), false return false
} }
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) *fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 |
fv.SetUint(uint64(i)) uint32(msg[off+2])<<8 | uint32(msg[off+3])
off += 4 off += 4
case reflect.Array: case []byte:
if fv.Type().Elem().Kind() != reflect.Uint8 { n := len(fv)
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
return len(msg), false
}
n := fv.Len()
if off+n > len(msg) { if off+n > len(msg) {
return len(msg), false return false
} }
reflect.Copy(fv, reflect.ValueOf(msg[off:off+n])) copy(fv, msg[off:off+n])
off += n off += n
case reflect.String: case *string:
var s string var s string
switch f.Tag { switch tag {
default: default:
fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) println("net: dns: unknown string tag", tag)
return len(msg), false return false
case `net:"domain-name"`: case "domain":
s, off, ok = unpackDomainName(msg, off) s, off, ok = unpackDomainName(msg, off)
if !ok { if !ok {
return len(msg), false return false
} }
case "": case "":
if off >= len(msg) || off+1+int(msg[off]) > len(msg) { if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
return len(msg), false return false
} }
n := int(msg[off]) n := int(msg[off])
off++ off++
@ -524,51 +595,77 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
off += n off += n
s = string(b) s = string(b)
} }
fv.SetString(s) *fv = s
} }
return true
})
if !ok {
return len(msg), false
} }
return off, true return off, true
} }
func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { // Generic struct printer. Prints fields with tag "ipv4" or "ipv6"
off, ok = unpackStructValue(structValue(any), msg, off) // as IP addresses.
return off, ok func printStruct(any dnsStruct) string {
}
// Generic struct printer.
// Doesn't care about the string tag `net:"domain-name"`,
// but does look for an `net:"ipv4"` tag on uint32 variables
// and the `net:"ipv6"` tag on array variables,
// printing them as IP addresses.
func printStructValue(val reflect.Value) string {
s := "{" s := "{"
for i := 0; i < val.NumField(); i++ { i := 0
if i > 0 { any.Walk(func(val interface{}, name, tag string) bool {
i++
if i > 1 {
s += ", " s += ", "
} }
f := val.Type().Field(i) s += name + "="
if !f.Anonymous { switch tag {
s += f.Name + "=" case "ipv4":
} i := val.(uint32)
fval := val.Field(i)
if fv := fval; fv.Kind() == reflect.Struct {
s += printStructValue(fv)
} else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == `net:"ipv4"` {
i := fv.Uint()
s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
} else if fv := fval; fv.Kind() == reflect.Array && f.Tag == `net:"ipv6"` { case "ipv6":
i := fv.Interface().([]byte) i := val.([]byte)
s += IP(i).String() s += IP(i).String()
} else { default:
s += fmt.Sprint(fval.Interface()) var i int64
switch v := val.(type) {
default:
// can't really happen.
s += "<unknown type>"
return true
case *string:
s += *v
return true
case []byte:
s += string(v)
return true
case *bool:
if *v {
s += "true"
} else {
s += "false"
}
return true
case *int:
i = int64(*v)
case *uint:
i = int64(*v)
case *uint8:
i = int64(*v)
case *uint16:
i = int64(*v)
case *uint32:
i = int64(*v)
case *uint64:
i = int64(*v)
case *uintptr:
i = int64(*v)
}
s += itoa(int(i))
} }
} return true
})
s += "}" s += "}"
return s return s
} }
func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
// Resource record packer. // Resource record packer.
func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) { func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
var off1 int var off1 int
@ -627,6 +724,17 @@ type dnsMsgHdr struct {
rcode int rcode int
} }
func (h *dnsMsgHdr) Walk(f func(v interface{}, name, tag string) bool) bool {
return f(&h.id, "id", "") &&
f(&h.response, "response", "") &&
f(&h.opcode, "opcode", "") &&
f(&h.authoritative, "authoritative", "") &&
f(&h.truncated, "truncated", "") &&
f(&h.recursion_desired, "recursion_desired", "") &&
f(&h.recursion_available, "recursion_available", "") &&
f(&h.rcode, "rcode", "")
}
type dnsMsg struct { type dnsMsg struct {
dnsMsgHdr dnsMsgHdr
question []dnsQuestion question []dnsQuestion

View File

@ -6,6 +6,7 @@ package net
import ( import (
"encoding/hex" "encoding/hex"
"reflect"
"testing" "testing"
) )
@ -19,6 +20,7 @@ func TestDNSParseSRVReply(t *testing.T) {
if !ok { if !ok {
t.Fatalf("unpacking packet failed") t.Fatalf("unpacking packet failed")
} }
msg.String() // exercise this code path
if g, e := len(msg.answer), 5; g != e { if g, e := len(msg.answer), 5; g != e {
t.Errorf("len(msg.answer) = %d; want %d", g, e) t.Errorf("len(msg.answer) = %d; want %d", g, e)
} }
@ -38,6 +40,16 @@ func TestDNSParseSRVReply(t *testing.T) {
t.Errorf("len(addrs) = %d; want %d", g, e) t.Errorf("len(addrs) = %d; want %d", g, e)
t.Logf("addrs = %#v", addrs) t.Logf("addrs = %#v", addrs)
} }
// repack and unpack.
data2, ok := msg.Pack()
msg2 := new(dnsMsg)
msg2.Unpack(data2)
switch {
case !ok:
t.Errorf("failed to repack message")
case !reflect.DeepEqual(msg, msg2):
t.Errorf("repacked message differs from original")
}
} }
func TestDNSParseCorruptSRVReply(t *testing.T) { func TestDNSParseCorruptSRVReply(t *testing.T) {
@ -50,6 +62,7 @@ func TestDNSParseCorruptSRVReply(t *testing.T) {
if !ok { if !ok {
t.Fatalf("unpacking packet failed") t.Fatalf("unpacking packet failed")
} }
msg.String() // exercise this code path
if g, e := len(msg.answer), 5; g != e { if g, e := len(msg.answer), 5; g != e {
t.Errorf("len(msg.answer) = %d; want %d", g, e) t.Errorf("len(msg.answer) = %d; want %d", g, e)
} }

View File

@ -84,7 +84,8 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
events, already := p.events[fd] events, already := p.events[fd]
if !already { if !already {
print("Epoll unexpected fd=", fd, "\n") // The fd returned by the kernel may have been
// cancelled already; return silently.
return return
} }

View File

@ -27,7 +27,8 @@ type connFile interface {
} }
func testFileListener(t *testing.T, net, laddr string) { func testFileListener(t *testing.T, net, laddr string) {
if net == "tcp" { switch net {
case "tcp", "tcp4", "tcp6":
laddr += ":0" // any available port laddr += ":0" // any available port
} }
l, err := Listen(net, laddr) l, err := Listen(net, laddr)
@ -55,20 +56,52 @@ func testFileListener(t *testing.T, net, laddr string) {
} }
} }
var fileListenerTests = []struct {
net string
laddr string
ipv6 bool // test with underlying AF_INET6 socket
linux bool // test with abstract unix domain socket, a Linux-ism
}{
{net: "tcp", laddr: ""},
{net: "tcp", laddr: "0.0.0.0"},
{net: "tcp", laddr: "[::ffff:0.0.0.0]"},
{net: "tcp", laddr: "[::]", ipv6: true},
{net: "tcp", laddr: "127.0.0.1"},
{net: "tcp", laddr: "[::ffff:127.0.0.1]"},
{net: "tcp", laddr: "[::1]", ipv6: true},
{net: "tcp4", laddr: ""},
{net: "tcp4", laddr: "0.0.0.0"},
{net: "tcp4", laddr: "[::ffff:0.0.0.0]"},
{net: "tcp4", laddr: "127.0.0.1"},
{net: "tcp4", laddr: "[::ffff:127.0.0.1]"},
{net: "tcp6", laddr: "", ipv6: true},
{net: "tcp6", laddr: "[::]", ipv6: true},
{net: "tcp6", laddr: "[::1]", ipv6: true},
{net: "unix", laddr: "@gotest/net", linux: true},
{net: "unixpacket", laddr: "@gotest/net", linux: true},
}
func TestFileListener(t *testing.T) { func TestFileListener(t *testing.T) {
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { switch runtime.GOOS {
case "plan9", "windows":
t.Logf("skipping test on %q", runtime.GOOS)
return return
} }
testFileListener(t, "tcp", "127.0.0.1")
testFileListener(t, "tcp", "127.0.0.1") for _, tt := range fileListenerTests {
if supportsIPv6 && supportsIPv4map { if skipServerTest(tt.net, "unix", tt.laddr, tt.ipv6, false, tt.linux) {
testFileListener(t, "tcp", "[::ffff:127.0.0.1]") continue
testFileListener(t, "tcp", "127.0.0.1") }
testFileListener(t, "tcp", "[::ffff:127.0.0.1]") if skipServerTest(tt.net, "unixpacket", tt.laddr, tt.ipv6, false, tt.linux) {
} continue
if runtime.GOOS == "linux" { }
testFileListener(t, "unix", "@gotest/net") testFileListener(t, tt.net, tt.laddr)
testFileListener(t, "unixpacket", "@gotest/net")
} }
} }
@ -98,9 +131,13 @@ func testFilePacketConn(t *testing.T, pcf packetConnFile, listen bool) {
} }
func testFilePacketConnListen(t *testing.T, net, laddr string) { func testFilePacketConnListen(t *testing.T, net, laddr string) {
switch net {
case "udp", "udp4", "udp6":
laddr += ":0" // any available port
}
l, err := ListenPacket(net, laddr) l, err := ListenPacket(net, laddr)
if err != nil { if err != nil {
t.Fatalf("Listen failed: %v", err) t.Fatalf("ListenPacket failed: %v", err)
} }
testFilePacketConn(t, l.(packetConnFile), true) testFilePacketConn(t, l.(packetConnFile), true)
if err := l.Close(); err != nil { if err := l.Close(); err != nil {
@ -109,6 +146,10 @@ func testFilePacketConnListen(t *testing.T, net, laddr string) {
} }
func testFilePacketConnDial(t *testing.T, net, raddr string) { func testFilePacketConnDial(t *testing.T, net, raddr string) {
switch net {
case "udp", "udp4", "udp6":
raddr += ":12345"
}
c, err := Dial(net, raddr) c, err := Dial(net, raddr)
if err != nil { if err != nil {
t.Fatalf("Dial failed: %v", err) t.Fatalf("Dial failed: %v", err)
@ -119,19 +160,42 @@ func testFilePacketConnDial(t *testing.T, net, raddr string) {
} }
} }
var filePacketConnTests = []struct {
net string
addr string
ipv6 bool // test with underlying AF_INET6 socket
linux bool // test with abstract unix domain socket, a Linux-ism
}{
{net: "udp", addr: "127.0.0.1"},
{net: "udp", addr: "[::ffff:127.0.0.1]"},
{net: "udp", addr: "[::1]", ipv6: true},
{net: "udp4", addr: "127.0.0.1"},
{net: "udp4", addr: "[::ffff:127.0.0.1]"},
{net: "udp6", addr: "[::1]", ipv6: true},
{net: "unixgram", addr: "@gotest3/net", linux: true},
}
func TestFilePacketConn(t *testing.T) { func TestFilePacketConn(t *testing.T) {
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { switch runtime.GOOS {
case "plan9", "windows":
t.Logf("skipping test on %q", runtime.GOOS)
return return
} }
testFilePacketConnListen(t, "udp", "127.0.0.1:0")
testFilePacketConnDial(t, "udp", "127.0.0.1:12345") for _, tt := range filePacketConnTests {
if supportsIPv6 { if skipServerTest(tt.net, "unixgram", tt.addr, tt.ipv6, false, tt.linux) {
testFilePacketConnListen(t, "udp", "[::1]:0") continue
} }
if supportsIPv6 && supportsIPv4map { testFilePacketConnListen(t, tt.net, tt.addr)
testFilePacketConnDial(t, "udp", "[::ffff:127.0.0.1]:12345") switch tt.addr {
} case "", "0.0.0.0", "[::ffff:0.0.0.0]", "[::]":
if runtime.GOOS == "linux" { default:
testFilePacketConnListen(t, "unixgram", "@gotest1/net") if tt.net != "unixgram" {
testFilePacketConnDial(t, tt.net, tt.addr)
}
}
} }
} }

View File

@ -238,9 +238,9 @@ func TestRedirects(t *testing.T) {
} }
var expectedCookies = []*Cookie{ var expectedCookies = []*Cookie{
&Cookie{Name: "ChocolateChip", Value: "tasty"}, {Name: "ChocolateChip", Value: "tasty"},
&Cookie{Name: "First", Value: "Hit"}, {Name: "First", Value: "Hit"},
&Cookie{Name: "Second", Value: "Hit"}, {Name: "Second", Value: "Hit"},
} }
var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) { var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {

View File

@ -455,11 +455,13 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
// First line: GET /index.html HTTP/1.0 // First line: GET /index.html HTTP/1.0
var s string var s string
if s, err = tp.ReadLine(); err != nil { if s, err = tp.ReadLine(); err != nil {
return nil, err
}
defer func() {
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
return nil, err }()
}
var f []string var f []string
if f = strings.SplitN(s, " ", 3); len(f) < 3 { if f = strings.SplitN(s, " ", 3); len(f) < 3 {

View File

@ -5,6 +5,7 @@
package http_test package http_test
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
@ -177,6 +178,24 @@ func TestRequestMultipartCallOrder(t *testing.T) {
} }
} }
var readRequestErrorTests = []struct {
in string
err error
}{
{"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", nil},
{"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF},
{"", io.EOF},
}
func TestReadRequestErrors(t *testing.T) {
for i, tt := range readRequestErrorTests {
_, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in)))
if err != tt.err {
t.Errorf("%d. got error = %v; want %v", i, err, tt.err)
}
}
}
func testMissingFile(t *testing.T, req *Request) { func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing") f, fh, err := req.FormFile("missing")
if f != nil { if f != nil {

View File

@ -601,7 +601,7 @@ func (c *conn) serve() {
// while they're still writing their // while they're still writing their
// request. Undefined behavior. // request. Undefined behavior.
msg = "413 Request Entity Too Large" msg = "413 Request Entity Too Large"
} else if err == io.ErrUnexpectedEOF { } else if err == io.EOF {
break // Don't reply break // Don't reply
} else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
break // Don't reply break // Don't reply

View File

@ -196,7 +196,7 @@ func (t *Transport) CloseIdleConnections() {
pconn.close() pconn.close()
} }
} }
t.idleConn = nil t.idleConn = make(map[string][]*persistConn)
} }
// //

View File

@ -698,6 +698,32 @@ func TestTransportPersistConnLeak(t *testing.T) {
} }
} }
// This used to crash; http://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) {
tr := &Transport{}
c := &Client{Transport: tr}
unblockCh := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockCh
tr.CloseIdleConnections()
}))
defer ts.Close()
didreq := make(chan bool)
go func() {
res, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
} else {
res.Body.Close() // returns idle conn
}
didreq <- true
}()
unblockCh <- true
<-didreq
}
type fooProto struct{} type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) { func (fooProto) RoundTrip(req *Request) (*Response, error) {

View File

@ -78,7 +78,7 @@ func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
return interfaceMulticastAddrTable(ifi.Index) return interfaceMulticastAddrTable(ifi.Index)
} }
// Interfaces returns a list of the systems's network interfaces. // Interfaces returns a list of the system's network interfaces.
func Interfaces() ([]Interface, error) { func Interfaces() ([]Interface, error) {
return interfaceTable(0) return interfaceTable(0)
} }

View File

@ -7,7 +7,6 @@
package net package net
import ( import (
"fmt"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
@ -194,7 +193,9 @@ func parseProcNetIGMP(path string, ifi *Interface) []Addr {
name = f[1] name = f[1]
case len(f[0]) == 8: case len(f[0]) == 8:
if ifi == nil || name == ifi.Name { if ifi == nil || name == ifi.Name {
fmt.Sscanf(f[0], "%08x", &b) for i := 0; i+1 < len(f[0]); i += 2 {
b[i/2], _ = xtoi2(f[0][i:i+2], 0)
}
ifma := IPAddr{IP: IPv4(b[3], b[2], b[1], b[0])} ifma := IPAddr{IP: IPv4(b[3], b[2], b[1], b[0])}
ifmat = append(ifmat, ifma.toAddr()) ifmat = append(ifmat, ifma.toAddr())
} }
@ -218,10 +219,11 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr {
continue continue
} }
if ifi == nil || f[1] == ifi.Name { if ifi == nil || f[1] == ifi.Name {
fmt.Sscanf(f[2], "%32x", &b) for i := 0; i+1 < len(f[2]); i += 2 {
b[i/2], _ = xtoi2(f[2][i:i+2], 0)
}
ifma := IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}} ifma := IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}}
ifmat = append(ifmat, ifma.toAddr()) ifmat = append(ifmat, ifma.toAddr())
} }
} }
return ifmat return ifmat

View File

@ -34,6 +34,13 @@ func (a *IPAddr) family() int {
return syscall.AF_INET6 return syscall.AF_INET6
} }
func (a *IPAddr) isWildcard() bool {
if a == nil || a.IP == nil {
return true
}
return a.IP.IsUnspecified()
}
func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) { func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, 0) return ipToSockaddr(family, a.IP, 0)
} }

View File

@ -38,6 +38,7 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
continue continue
} }
defer closesocket(s) defer closesocket(s)
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
sa, err := probes[i].la.toAddr().sockaddr(syscall.AF_INET6) sa, err := probes[i].la.toAddr().sockaddr(syscall.AF_INET6)
if err != nil { if err != nil {
continue continue
@ -55,58 +56,75 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
// favoriteAddrFamily returns the appropriate address family to // favoriteAddrFamily returns the appropriate address family to
// the given net, laddr, raddr and mode. At first it figures // the given net, laddr, raddr and mode. At first it figures
// address family out from the net. If mode indicates "listen" // address family out from the net. If mode indicates "listen"
// and laddr.(type).IP is nil, it assumes that the user wants to // and laddr is a wildcard, it assumes that the user wants to
// make a passive connection with wildcard address family, both // make a passive connection with a wildcard address family, both
// INET and INET6, and wildcard address. Otherwise guess: if the // AF_INET and AF_INET6, and a wildcard address like following:
// addresses are IPv4 then returns INET, or else returns INET6. //
func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) int { // 1. A wild-wild listen, "tcp" + ""
// If the platform supports both IPv6 and IPv6 IPv4-mapping
// capabilities, we assume that the user want to listen on
// both IPv4 and IPv6 wildcard address over an AF_INET6
// socket with IPV6_V6ONLY=0. Otherwise we prefer an IPv4
// wildcard address listen over an AF_INET socket.
//
// 2. A wild-ipv4wild listen, "tcp" + "0.0.0.0"
// Same as 1.
//
// 3. A wild-ipv6wild listen, "tcp" + "[::]"
// Almost same as 1 but we prefer an IPv6 wildcard address
// listen over an AF_INET6 socket with IPV6_V6ONLY=0 when
// the platform supports IPv6 capability but not IPv6 IPv4-
// mapping capability.
//
// 4. A ipv4-ipv4wild listen, "tcp4" + "" or "0.0.0.0"
// We use an IPv4 (AF_INET) wildcard address listen.
//
// 5. A ipv6-ipv6wild listen, "tcp6" + "" or "[::]"
// We use an IPv6 (AF_INET6, IPV6_V6ONLY=1) wildcard address
// listen.
//
// Otherwise guess: if the addresses are IPv4 then returns AF_INET,
// or else returns AF_INET6. It also returns a boolean value what
// designates IPV6_V6ONLY option.
//
// Note that OpenBSD allows neither "net.inet6.ip6.v6only=1" change
// nor IPPROTO_IPV6 level IPV6_V6ONLY socket option setting.
func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) {
switch net[len(net)-1] { switch net[len(net)-1] {
case '4': case '4':
return syscall.AF_INET return syscall.AF_INET, false
case '6': case '6':
return syscall.AF_INET6 return syscall.AF_INET6, true
} }
if mode == "listen" { if mode == "listen" && laddr.isWildcard() {
// Note that OpenBSD allows neither "net.inet6.ip6.v6only" if supportsIPv4map {
// change nor IPPROTO_IPV6 level IPV6_V6ONLY socket option return syscall.AF_INET6, false
// setting.
switch a := laddr.(type) {
case *TCPAddr:
if a.IP == nil && supportsIPv6 && supportsIPv4map {
return syscall.AF_INET6
}
case *UDPAddr:
if a.IP == nil && supportsIPv6 && supportsIPv4map {
return syscall.AF_INET6
}
case *IPAddr:
if a.IP == nil && supportsIPv6 && supportsIPv4map {
return syscall.AF_INET6
}
} }
return laddr.family(), false
} }
if (laddr == nil || laddr.family() == syscall.AF_INET) && if (laddr == nil || laddr.family() == syscall.AF_INET) &&
(raddr == nil || raddr.family() == syscall.AF_INET) { (raddr == nil || raddr.family() == syscall.AF_INET) {
return syscall.AF_INET return syscall.AF_INET, false
} }
return syscall.AF_INET6 return syscall.AF_INET6, false
} }
// Internet sockets (TCP, UDP) // Internet sockets (TCP, UDP, IP)
// A sockaddr represents a TCP or UDP network address that can // A sockaddr represents a TCP, UDP or IP network address that can
// be converted into a syscall.Sockaddr. // be converted into a syscall.Sockaddr.
type sockaddr interface { type sockaddr interface {
Addr Addr
sockaddr(family int) (syscall.Sockaddr, error)
family() int family() int
isWildcard() bool
sockaddr(family int) (syscall.Sockaddr, error)
} }
func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
var la, ra syscall.Sockaddr var la, ra syscall.Sockaddr
family := favoriteAddrFamily(net, laddr, raddr, mode) family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
if laddr != nil { if laddr != nil {
if la, err = laddr.sockaddr(family); err != nil { if la, err = laddr.sockaddr(family); err != nil {
goto Error goto Error
@ -117,7 +135,7 @@ func internetSocket(net string, laddr, raddr sockaddr, sotype, proto int, mode s
goto Error goto Error
} }
} }
fd, err = socket(net, family, sotype, proto, la, ra, toAddr) fd, err = socket(net, family, sotype, proto, ipv6only, la, ra, toAddr)
if err != nil { if err != nil {
goto Error goto Error
} }
@ -152,7 +170,7 @@ func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, error) {
} }
// IPv4 callers use 0.0.0.0 to mean "announce on any available address". // IPv4 callers use 0.0.0.0 to mean "announce on any available address".
// In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0", // In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0",
// which it refuses to do. Rewrite to the IPv6 all zeros. // which it refuses to do. Rewrite to the IPv6 unspecified address.
if ip.Equal(IPv4zero) { if ip.Equal(IPv4zero) {
ip = IPv6zero ip = IPv6zero
} }

View File

@ -6,24 +6,26 @@
package net package net
import ( import "errors"
"bytes"
"errors" const hexDigit = "0123456789abcdef"
"fmt"
)
// A HardwareAddr represents a physical hardware address. // A HardwareAddr represents a physical hardware address.
type HardwareAddr []byte type HardwareAddr []byte
func (a HardwareAddr) String() string { func (a HardwareAddr) String() string {
var buf bytes.Buffer if len(a) == 0 {
return ""
}
buf := make([]byte, 0, len(a)*3-1)
for i, b := range a { for i, b := range a {
if i > 0 { if i > 0 {
buf.WriteByte(':') buf = append(buf, ':')
} }
fmt.Fprintf(&buf, "%02x", b) buf = append(buf, hexDigit[b>>4])
buf = append(buf, hexDigit[b&0xF])
} }
return buf.String() return string(buf)
} }
// ParseMAC parses s as an IEEE 802 MAC-48, EUI-48, or EUI-64 using one of the // ParseMAC parses s as an IEEE 802 MAC-48, EUI-48, or EUI-64 using one of the

View File

@ -43,12 +43,24 @@ func match(err error, s string) bool {
return err != nil && strings.Contains(err.Error(), s) return err != nil && strings.Contains(err.Error(), s)
} }
func TestParseMAC(t *testing.T) { func TestMACParseString(t *testing.T) {
for _, tt := range mactests { for i, tt := range mactests {
out, err := ParseMAC(tt.in) out, err := ParseMAC(tt.in)
if !reflect.DeepEqual(out, tt.out) || !match(err, tt.err) { if !reflect.DeepEqual(out, tt.out) || !match(err, tt.err) {
t.Errorf("ParseMAC(%q) = %v, %v, want %v, %v", tt.in, out, err, tt.out, t.Errorf("ParseMAC(%q) = %v, %v, want %v, %v", tt.in, out, err, tt.out,
tt.err) tt.err)
} }
if tt.err == "" {
// Verify that serialization works too, and that it round-trips.
s := out.String()
out2, err := ParseMAC(s)
if err != nil {
t.Errorf("%d. ParseMAC(%q) = %v", i, s, err)
continue
}
if !reflect.DeepEqual(out2, out) {
t.Errorf("%d. ParseMAC(%q) = %v, want %v", i, s, out2, out)
}
}
} }
} }

View File

@ -394,8 +394,7 @@ func (p *addrParser) consumeAtom(dot bool) (atom string, err error) {
i := 1 i := 1
for ; i < p.len() && isAtext((*p)[i], dot); i++ { for ; i < p.len() && isAtext((*p)[i], dot); i++ {
} }
// TODO(dsymonds): Remove the []byte() conversion here when 6g doesn't need it. atom, *p = string((*p)[:i]), (*p)[i:]
atom, *p = string([]byte((*p)[:i])), (*p)[i:]
return atom, nil return atom, nil
} }

View File

@ -47,9 +47,11 @@ var multicastListenerTests = []struct {
func TestMulticastListener(t *testing.T) { func TestMulticastListener(t *testing.T) {
switch runtime.GOOS { switch runtime.GOOS {
case "netbsd", "openbsd", "plan9", "windows": case "netbsd", "openbsd", "plan9", "windows":
t.Logf("skipping test on %q", runtime.GOOS)
return return
case "linux": case "linux":
if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" {
t.Logf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH)
return return
} }
} }
@ -86,7 +88,13 @@ func TestMulticastListener(t *testing.T) {
func TestSimpleMulticastListener(t *testing.T) { func TestSimpleMulticastListener(t *testing.T) {
switch runtime.GOOS { switch runtime.GOOS {
case "plan9": case "plan9":
t.Logf("skipping test on %q", runtime.GOOS)
return return
case "windows":
if testing.Short() || !*testExternal {
t.Logf("skipping test on windows to avoid firewall")
return
}
} }
for _, tt := range multicastListenerTests { for _, tt := range multicastListenerTests {

View File

@ -54,6 +54,8 @@ type Addr interface {
} }
// Conn is a generic stream-oriented network connection. // Conn is a generic stream-oriented network connection.
//
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn interface { type Conn interface {
// Read reads data from the connection. // Read reads data from the connection.
// Read can be made to time out and return a Error with Timeout() == true // Read can be made to time out and return a Error with Timeout() == true
@ -66,6 +68,7 @@ type Conn interface {
Write(b []byte) (n int, err error) Write(b []byte) (n int, err error)
// Close closes the connection. // Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors.
Close() error Close() error
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
@ -89,11 +92,11 @@ type Conn interface {
// A zero value for t means I/O operations will not time out. // A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error SetDeadline(t time.Time) error
// SetReadDeadline sets the deadline for Read calls. // SetReadDeadline sets the deadline for future Read calls.
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for Write calls. // SetWriteDeadline sets the deadline for future Write calls.
// Even if write times out, it may return n > 0, indicating that // Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written. // some of the data was successfully written.
// A zero value for t means Write will not time out. // A zero value for t means Write will not time out.
@ -108,6 +111,8 @@ type Error interface {
} }
// PacketConn is a generic packet-oriented network connection. // PacketConn is a generic packet-oriented network connection.
//
// Multiple goroutines may invoke methods on a PacketConn simultaneously.
type PacketConn interface { type PacketConn interface {
// ReadFrom reads a packet from the connection, // ReadFrom reads a packet from the connection,
// copying the payload into b. It returns the number of // copying the payload into b. It returns the number of
@ -126,6 +131,7 @@ type PacketConn interface {
WriteTo(b []byte, addr Addr) (n int, err error) WriteTo(b []byte, addr Addr) (n int, err error)
// Close closes the connection. // Close closes the connection.
// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
Close() error Close() error
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
@ -135,13 +141,13 @@ type PacketConn interface {
// with the connection. // with the connection.
SetDeadline(t time.Time) error SetDeadline(t time.Time) error
// SetReadDeadline sets the deadline for all Read calls to return. // SetReadDeadline sets the deadline for future Read calls.
// If the deadline is reached, Read will fail with a timeout // If the deadline is reached, Read will fail with a timeout
// (see type Error) instead of blocking. // (see type Error) instead of blocking.
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for all Write calls to return. // SetWriteDeadline sets the deadline for future Write calls.
// If the deadline is reached, Write will fail with a timeout // If the deadline is reached, Write will fail with a timeout
// (see type Error) instead of blocking. // (see type Error) instead of blocking.
// A zero value for t means Write will not time out. // A zero value for t means Write will not time out.
@ -151,11 +157,14 @@ type PacketConn interface {
} }
// A Listener is a generic network listener for stream-oriented protocols. // A Listener is a generic network listener for stream-oriented protocols.
//
// Multiple goroutines may invoke methods on a Listener simultaneously.
type Listener interface { type Listener interface {
// Accept waits for and returns the next connection to the listener. // Accept waits for and returns the next connection to the listener.
Accept() (c Conn, err error) Accept() (c Conn, err error)
// Close closes the listener. // Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error Close() error
// Addr returns the listener's network address. // Addr returns the listener's network address.

View File

@ -13,6 +13,7 @@ import (
func TestShutdown(t *testing.T) { func TestShutdown(t *testing.T) {
if runtime.GOOS == "plan9" { if runtime.GOOS == "plan9" {
t.Logf("skipping test on %q", runtime.GOOS)
return return
} }
l, err := Listen("tcp", "127.0.0.1:0") l, err := Listen("tcp", "127.0.0.1:0")

View File

@ -13,7 +13,9 @@ import (
func TestReadLine(t *testing.T) { func TestReadLine(t *testing.T) {
// /etc/services file does not exist on windows and Plan 9. // /etc/services file does not exist on windows and Plan 9.
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { switch runtime.GOOS {
case "plan9", "windows":
t.Logf("skipping test on %q", runtime.GOOS)
return return
} }
filename := "/etc/services" // a nice big file filename := "/etc/services" // a nice big file

View File

@ -36,7 +36,8 @@ type Call struct {
// Client represents an RPC Client. // Client represents an RPC Client.
// There may be multiple outstanding Calls associated // There may be multiple outstanding Calls associated
// with a single Client. // with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct { type Client struct {
mutex sync.Mutex // protects pending, seq, request mutex sync.Mutex // protects pending, seq, request
sending sync.Mutex sending sync.Mutex

View File

@ -9,234 +9,461 @@ import (
"io" "io"
"os" "os"
"runtime" "runtime"
"strings"
"testing" "testing"
"time" "time"
) )
func skipServerTest(net, unixsotype, addr string, ipv6, ipv4map, linuxonly bool) bool {
switch runtime.GOOS {
case "linux":
case "plan9", "windows":
// "unix" sockets are not supported on Windows and Plan 9.
if net == unixsotype {
return true
}
default:
if net == unixsotype && linuxonly {
return true
}
}
switch addr {
case "", "0.0.0.0", "[::ffff:0.0.0.0]", "[::]":
if testing.Short() || !*testExternal {
return true
}
}
if ipv6 && !supportsIPv6 {
return true
}
if ipv4map && !supportsIPv4map {
return true
}
return false
}
var streamConnServerTests = []struct {
snet string // server side
saddr string
cnet string // client side
caddr string
ipv6 bool // test with underlying AF_INET6 socket
ipv4map bool // test with IPv6 IPv4-mapping functionality
empty bool // test with empty data
linux bool // test with abstract unix domain socket, a Linux-ism
}{
{snet: "tcp", saddr: "", cnet: "tcp", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "0.0.0.0", cnet: "tcp", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "[::ffff:0.0.0.0]", cnet: "tcp", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "[::]", cnet: "tcp", caddr: "[::1]", ipv6: true},
{snet: "tcp", saddr: "", cnet: "tcp", caddr: "[::1]", ipv4map: true},
{snet: "tcp", saddr: "0.0.0.0", cnet: "tcp", caddr: "[::1]", ipv4map: true},
{snet: "tcp", saddr: "[::ffff:0.0.0.0]", cnet: "tcp", caddr: "[::1]", ipv4map: true},
{snet: "tcp", saddr: "[::]", cnet: "tcp", caddr: "127.0.0.1", ipv4map: true},
{snet: "tcp", saddr: "", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "0.0.0.0", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "[::ffff:0.0.0.0]", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "[::]", cnet: "tcp6", caddr: "[::1]", ipv6: true},
{snet: "tcp", saddr: "", cnet: "tcp6", caddr: "[::1]", ipv4map: true},
{snet: "tcp", saddr: "0.0.0.0", cnet: "tcp6", caddr: "[::1]", ipv4map: true},
{snet: "tcp", saddr: "[::ffff:0.0.0.0]", cnet: "tcp6", caddr: "[::1]", ipv4map: true},
{snet: "tcp", saddr: "[::]", cnet: "tcp4", caddr: "127.0.0.1", ipv4map: true},
{snet: "tcp", saddr: "127.0.0.1", cnet: "tcp", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "[::ffff:127.0.0.1]", cnet: "tcp", caddr: "127.0.0.1"},
{snet: "tcp", saddr: "[::1]", cnet: "tcp", caddr: "[::1]", ipv6: true},
{snet: "tcp4", saddr: "", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp4", saddr: "0.0.0.0", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp4", saddr: "[::ffff:0.0.0.0]", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp4", saddr: "127.0.0.1", cnet: "tcp4", caddr: "127.0.0.1"},
{snet: "tcp6", saddr: "", cnet: "tcp6", caddr: "[::1]", ipv6: true},
{snet: "tcp6", saddr: "[::]", cnet: "tcp6", caddr: "[::1]", ipv6: true},
{snet: "tcp6", saddr: "[::1]", cnet: "tcp6", caddr: "[::1]", ipv6: true},
{snet: "unix", saddr: "/tmp/gotest1.net", cnet: "unix", caddr: "/tmp/gotest1.net.local"},
{snet: "unix", saddr: "@gotest2/net", cnet: "unix", caddr: "@gotest2/net.local", linux: true},
}
func TestStreamConnServer(t *testing.T) {
for _, tt := range streamConnServerTests {
if skipServerTest(tt.snet, "unix", tt.saddr, tt.ipv6, tt.ipv4map, tt.linux) {
continue
}
listening := make(chan string)
done := make(chan int)
switch tt.snet {
case "tcp", "tcp4", "tcp6":
tt.saddr += ":0"
case "unix":
os.Remove(tt.saddr)
os.Remove(tt.caddr)
}
go runStreamConnServer(t, tt.snet, tt.saddr, listening, done)
taddr := <-listening // wait for server to start
switch tt.cnet {
case "tcp", "tcp4", "tcp6":
_, port, err := SplitHostPort(taddr)
if err != nil {
t.Errorf("SplitHostPort(%q) failed: %v", taddr, err)
return
}
taddr = tt.caddr + ":" + port
}
runStreamConnClient(t, tt.cnet, taddr, tt.empty)
<-done // make sure server stopped
switch tt.snet {
case "unix":
os.Remove(tt.saddr)
os.Remove(tt.caddr)
}
}
}
var seqpacketConnServerTests = []struct {
net string
saddr string // server address
caddr string // client address
empty bool // test with empty data
}{
{net: "unixpacket", saddr: "/tmp/gotest3.net", caddr: "/tmp/gotest3.net.local"},
{net: "unixpacket", saddr: "@gotest4/net", caddr: "@gotest4/net.local"},
}
func TestSeqpacketConnServer(t *testing.T) {
if runtime.GOOS != "linux" {
t.Logf("skipping test on %q", runtime.GOOS)
return
}
for _, tt := range seqpacketConnServerTests {
listening := make(chan string)
done := make(chan int)
switch tt.net {
case "unixpacket":
os.Remove(tt.saddr)
os.Remove(tt.caddr)
}
go runStreamConnServer(t, tt.net, tt.saddr, listening, done)
taddr := <-listening // wait for server to start
runStreamConnClient(t, tt.net, taddr, tt.empty)
<-done // make sure server stopped
switch tt.net {
case "unixpacket":
os.Remove(tt.saddr)
os.Remove(tt.caddr)
}
}
}
func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- string, done chan<- int) {
l, err := Listen(net, laddr)
if err != nil {
t.Errorf("Listen(%q, %q) failed: %v", net, laddr, err)
listening <- "<nil>"
done <- 1
return
}
defer l.Close()
listening <- l.Addr().String()
echo := func(rw io.ReadWriter, done chan<- int) {
buf := make([]byte, 1024)
for {
n, err := rw.Read(buf[0:])
if err != nil || n == 0 || string(buf[:n]) == "END" {
break
}
rw.Write(buf[0:n])
}
done <- 1
}
run:
for {
c, err := l.Accept()
if err != nil {
continue run
}
echodone := make(chan int)
go echo(c, echodone)
<-echodone // make sure echo stopped
c.Close()
break run
}
done <- 1
}
func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) {
c, err := Dial(net, taddr)
if err != nil {
t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err)
return
}
defer c.Close()
c.SetReadDeadline(time.Now().Add(1 * time.Second))
var wb []byte
if !isEmpty {
wb = []byte("StreamConnClient by Dial\n")
}
if n, err := c.Write(wb); err != nil || n != len(wb) {
t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb))
return
}
rb := make([]byte, 1024)
if n, err := c.Read(rb[0:]); err != nil || n != len(wb) {
t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb))
return
}
// Send explicit ending for unixpacket.
// Older Linux kernels do not stop reads on close.
switch net {
case "unixpacket":
c.Write([]byte("END"))
}
}
// Do not test empty datagrams by default. // Do not test empty datagrams by default.
// It causes unexplained timeouts on some systems, // It causes unexplained timeouts on some systems,
// including Snow Leopard. I think that the kernel // including Snow Leopard. I think that the kernel
// doesn't quite expect them. // doesn't quite expect them.
var testUDP = flag.Bool("udp", false, "whether to test UDP datagrams") var testDatagram = flag.Bool("datagram", false, "whether to test udp and unixgram")
func runEcho(fd io.ReadWriter, done chan<- int) { var datagramPacketConnServerTests = []struct {
var buf [1024]byte snet string // server side
saddr string
cnet string // client side
caddr string
ipv6 bool // test with underlying AF_INET6 socket
ipv4map bool // test with IPv6 IPv4-mapping functionality
dial bool // test with Dial or DialUnix
empty bool // test with empty data
linux bool // test with abstract unix domain socket, a Linux-ism
}{
{snet: "udp", saddr: "", cnet: "udp", caddr: "127.0.0.1"},
{snet: "udp", saddr: "0.0.0.0", cnet: "udp", caddr: "127.0.0.1"},
{snet: "udp", saddr: "[::ffff:0.0.0.0]", cnet: "udp", caddr: "127.0.0.1"},
{snet: "udp", saddr: "[::]", cnet: "udp", caddr: "[::1]", ipv6: true},
for { {snet: "udp", saddr: "", cnet: "udp", caddr: "[::1]", ipv4map: true},
n, err := fd.Read(buf[0:]) {snet: "udp", saddr: "0.0.0.0", cnet: "udp", caddr: "[::1]", ipv4map: true},
if err != nil || n == 0 || string(buf[:n]) == "END" { {snet: "udp", saddr: "[::ffff:0.0.0.0]", cnet: "udp", caddr: "[::1]", ipv4map: true},
break {snet: "udp", saddr: "[::]", cnet: "udp", caddr: "127.0.0.1", ipv4map: true},
}
fd.Write(buf[0:n]) {snet: "udp", saddr: "", cnet: "udp4", caddr: "127.0.0.1"},
} {snet: "udp", saddr: "0.0.0.0", cnet: "udp4", caddr: "127.0.0.1"},
done <- 1 {snet: "udp", saddr: "[::ffff:0.0.0.0]", cnet: "udp4", caddr: "127.0.0.1"},
{snet: "udp", saddr: "[::]", cnet: "udp6", caddr: "[::1]", ipv6: true},
{snet: "udp", saddr: "", cnet: "udp6", caddr: "[::1]", ipv4map: true},
{snet: "udp", saddr: "0.0.0.0", cnet: "udp6", caddr: "[::1]", ipv4map: true},
{snet: "udp", saddr: "[::ffff:0.0.0.0]", cnet: "udp6", caddr: "[::1]", ipv4map: true},
{snet: "udp", saddr: "[::]", cnet: "udp4", caddr: "127.0.0.1", ipv4map: true},
{snet: "udp", saddr: "127.0.0.1", cnet: "udp", caddr: "127.0.0.1"},
{snet: "udp", saddr: "[::ffff:127.0.0.1]", cnet: "udp", caddr: "127.0.0.1"},
{snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true},
{snet: "udp4", saddr: "", cnet: "udp4", caddr: "127.0.0.1"},
{snet: "udp4", saddr: "0.0.0.0", cnet: "udp4", caddr: "127.0.0.1"},
{snet: "udp4", saddr: "[::ffff:0.0.0.0]", cnet: "udp4", caddr: "127.0.0.1"},
{snet: "udp4", saddr: "127.0.0.1", cnet: "udp4", caddr: "127.0.0.1"},
{snet: "udp6", saddr: "", cnet: "udp6", caddr: "[::1]", ipv6: true},
{snet: "udp6", saddr: "[::]", cnet: "udp6", caddr: "[::1]", ipv6: true},
{snet: "udp6", saddr: "[::1]", cnet: "udp6", caddr: "[::1]", ipv6: true},
{snet: "udp", saddr: "127.0.0.1", cnet: "udp", caddr: "127.0.0.1", dial: true},
{snet: "udp", saddr: "127.0.0.1", cnet: "udp", caddr: "127.0.0.1", empty: true},
{snet: "udp", saddr: "127.0.0.1", cnet: "udp", caddr: "127.0.0.1", dial: true, empty: true},
{snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, dial: true},
{snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, empty: true},
{snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, dial: true, empty: true},
{snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local"},
{snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", dial: true},
{snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", empty: true},
{snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", dial: true, empty: true},
{snet: "unixgram", saddr: "@gotest6/net", cnet: "unixgram", caddr: "@gotest6/net.local", linux: true},
} }
func runServe(t *testing.T, network, addr string, listening chan<- string, done chan<- int) { func TestDatagramPacketConnServer(t *testing.T) {
l, err := Listen(network, addr) if !*testDatagram {
if err != nil {
t.Fatalf("net.Listen(%q, %q) = _, %v", network, addr, err)
}
listening <- l.Addr().String()
for {
fd, err := l.Accept()
if err != nil {
break
}
echodone := make(chan int)
go runEcho(fd, echodone)
<-echodone // make sure Echo stops
l.Close()
}
done <- 1
}
func connect(t *testing.T, network, addr string, isEmpty bool) {
var fd Conn
var err error
if network == "unixgram" {
fd, err = DialUnix(network, &UnixAddr{addr + ".local", network}, &UnixAddr{addr, network})
} else {
fd, err = Dial(network, addr)
}
if err != nil {
t.Fatalf("net.Dial(%q, %q) = _, %v", network, addr, err)
}
fd.SetReadDeadline(time.Now().Add(1 * time.Second))
var b []byte
if !isEmpty {
b = []byte("hello, world\n")
}
var b1 [100]byte
n, err1 := fd.Write(b)
if n != len(b) {
t.Fatalf("fd.Write(%q) = %d, %v", b, n, err1)
}
n, err1 = fd.Read(b1[0:])
if n != len(b) || err1 != nil {
t.Fatalf("fd.Read() = %d, %v (want %d, nil)", n, err1, len(b))
}
// Send explicit ending for unixpacket.
// Older Linux kernels do not stop reads on close.
if network == "unixpacket" {
fd.Write([]byte("END"))
}
fd.Close()
}
func doTest(t *testing.T, network, listenaddr, dialaddr string) {
t.Logf("Test %q %q %q", network, listenaddr, dialaddr)
switch listenaddr {
case "", "0.0.0.0", "[::]", "[::ffff:0.0.0.0]":
if testing.Short() || !*testExternal {
t.Logf("skip wildcard listen during short test")
return
}
}
listening := make(chan string)
done := make(chan int)
if network == "tcp" || network == "tcp4" || network == "tcp6" {
listenaddr += ":0" // any available port
}
go runServe(t, network, listenaddr, listening, done)
addr := <-listening // wait for server to start
if network == "tcp" || network == "tcp4" || network == "tcp6" {
dialaddr += addr[strings.LastIndex(addr, ":"):]
}
connect(t, network, dialaddr, false)
<-done // make sure server stopped
}
func TestTCPServer(t *testing.T) {
doTest(t, "tcp", "", "127.0.0.1")
doTest(t, "tcp", "0.0.0.0", "127.0.0.1")
doTest(t, "tcp", "127.0.0.1", "127.0.0.1")
doTest(t, "tcp4", "", "127.0.0.1")
doTest(t, "tcp4", "0.0.0.0", "127.0.0.1")
doTest(t, "tcp4", "127.0.0.1", "127.0.0.1")
if supportsIPv6 {
doTest(t, "tcp", "[::]", "[::1]")
doTest(t, "tcp", "[::1]", "[::1]")
doTest(t, "tcp6", "", "[::1]")
doTest(t, "tcp6", "[::]", "[::1]")
doTest(t, "tcp6", "[::1]", "[::1]")
}
if supportsIPv6 && supportsIPv4map {
doTest(t, "tcp", "[::ffff:0.0.0.0]", "127.0.0.1")
doTest(t, "tcp", "[::]", "127.0.0.1")
doTest(t, "tcp4", "[::ffff:0.0.0.0]", "127.0.0.1")
doTest(t, "tcp6", "", "127.0.0.1")
doTest(t, "tcp6", "[::ffff:0.0.0.0]", "127.0.0.1")
doTest(t, "tcp6", "[::]", "127.0.0.1")
doTest(t, "tcp", "127.0.0.1", "[::ffff:127.0.0.1]")
doTest(t, "tcp", "[::ffff:127.0.0.1]", "127.0.0.1")
doTest(t, "tcp4", "127.0.0.1", "[::ffff:127.0.0.1]")
doTest(t, "tcp4", "[::ffff:127.0.0.1]", "127.0.0.1")
doTest(t, "tcp6", "127.0.0.1", "[::ffff:127.0.0.1]")
doTest(t, "tcp6", "[::ffff:127.0.0.1]", "127.0.0.1")
}
}
func TestUnixServer(t *testing.T) {
// "unix" sockets are not supported on windows and Plan 9.
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
return return
} }
os.Remove("/tmp/gotest.net")
doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net") for _, tt := range datagramPacketConnServerTests {
os.Remove("/tmp/gotest.net") if skipServerTest(tt.snet, "unixgram", tt.saddr, tt.ipv6, tt.ipv4map, tt.linux) {
if runtime.GOOS == "linux" { continue
doTest(t, "unixpacket", "/tmp/gotest.net", "/tmp/gotest.net") }
os.Remove("/tmp/gotest.net")
// Test abstract unix domain socket, a Linux-ism listening := make(chan string)
doTest(t, "unix", "@gotest/net", "@gotest/net") done := make(chan int)
doTest(t, "unixpacket", "@gotest/net", "@gotest/net") switch tt.snet {
case "udp", "udp4", "udp6":
tt.saddr += ":0"
case "unixgram":
os.Remove(tt.saddr)
os.Remove(tt.caddr)
}
go runDatagramPacketConnServer(t, tt.snet, tt.saddr, listening, done)
taddr := <-listening // wait for server to start
switch tt.cnet {
case "udp", "udp4", "udp6":
_, port, err := SplitHostPort(taddr)
if err != nil {
t.Errorf("SplitHostPort(%q) failed: %v", taddr, err)
return
}
taddr = tt.caddr + ":" + port
tt.caddr += ":0"
}
if tt.dial {
runDatagramConnClient(t, tt.cnet, tt.caddr, taddr, tt.empty)
} else {
runDatagramPacketConnClient(t, tt.cnet, tt.caddr, taddr, tt.empty)
}
<-done // tell server to stop
<-done // make sure server stopped
switch tt.snet {
case "unixgram":
os.Remove(tt.saddr)
os.Remove(tt.caddr)
}
} }
} }
func runPacket(t *testing.T, network, addr string, listening chan<- string, done chan<- int) { func runDatagramPacketConnServer(t *testing.T, net, laddr string, listening chan<- string, done chan<- int) {
c, err := ListenPacket(network, addr) c, err := ListenPacket(net, laddr)
if err != nil { if err != nil {
t.Fatalf("net.ListenPacket(%q, %q) = _, %v", network, addr, err) t.Errorf("ListenPacket(%q, %q) failed: %v", net, laddr, err)
listening <- "<nil>"
done <- 1
return
} }
defer c.Close()
listening <- c.LocalAddr().String() listening <- c.LocalAddr().String()
var buf [1000]byte
Run: buf := make([]byte, 1024)
run:
for { for {
c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
n, addr, err := c.ReadFrom(buf[0:]) n, ra, err := c.ReadFrom(buf[0:])
if e, ok := err.(Error); ok && e.Timeout() { if nerr, ok := err.(Error); ok && nerr.Timeout() {
select { select {
case done <- 1: case done <- 1:
break Run break run
default: default:
continue Run continue run
} }
} }
if err != nil { if err != nil {
break break run
} }
if _, err = c.WriteTo(buf[0:n], addr); err != nil { if _, err = c.WriteTo(buf[0:n], ra); err != nil {
t.Fatalf("WriteTo %v: %v", addr, err) t.Errorf("WriteTo(%v) failed: %v", ra, err)
break run
} }
} }
c.Close()
done <- 1 done <- 1
} }
func doTestPacket(t *testing.T, network, listenaddr, dialaddr string, isEmpty bool) { func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) {
t.Logf("TestPacket %q %q %q", network, listenaddr, dialaddr) var c Conn
listening := make(chan string) var err error
done := make(chan int) switch net {
if network == "udp" { case "udp", "udp4", "udp6":
listenaddr += ":0" // any available port c, err = Dial(net, taddr)
if err != nil {
t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err)
return
}
case "unixgram":
c, err = DialUnix(net, &UnixAddr{laddr, net}, &UnixAddr{taddr, net})
if err != nil {
t.Errorf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err)
return
}
} }
go runPacket(t, network, listenaddr, listening, done) defer c.Close()
addr := <-listening // wait for server to start c.SetReadDeadline(time.Now().Add(1 * time.Second))
if network == "udp" {
dialaddr += addr[strings.LastIndex(addr, ":"):]
}
connect(t, network, dialaddr, isEmpty)
<-done // tell server to stop
<-done // wait for stop
}
func TestUDPServer(t *testing.T) { var wb []byte
if !*testUDP { if !isEmpty {
wb = []byte("DatagramConnClient by Dial\n")
}
if n, err := c.Write(wb[0:]); err != nil || n != len(wb) {
t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb))
return return
} }
for _, isEmpty := range []bool{false, true} {
doTestPacket(t, "udp", "0.0.0.0", "127.0.0.1", isEmpty) rb := make([]byte, 1024)
doTestPacket(t, "udp", "", "127.0.0.1", isEmpty) if n, err := c.Read(rb[0:]); err != nil || n != len(wb) {
if supportsIPv6 && supportsIPv4map { t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb))
doTestPacket(t, "udp", "[::]", "[::ffff:127.0.0.1]", isEmpty) return
doTestPacket(t, "udp", "[::]", "127.0.0.1", isEmpty)
doTestPacket(t, "udp", "0.0.0.0", "[::ffff:127.0.0.1]", isEmpty)
}
} }
} }
func TestUnixDatagramServer(t *testing.T) { func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) {
// "unix" sockets are not supported on windows and Plan 9. var ra Addr
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { var err error
return switch net {
} case "udp", "udp4", "udp6":
for _, isEmpty := range []bool{false} { ra, err = ResolveUDPAddr(net, taddr)
os.Remove("/tmp/gotest1.net") if err != nil {
os.Remove("/tmp/gotest1.net.local") t.Errorf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err)
doTestPacket(t, "unixgram", "/tmp/gotest1.net", "/tmp/gotest1.net", isEmpty) return
os.Remove("/tmp/gotest1.net") }
os.Remove("/tmp/gotest1.net.local") case "unixgram":
if runtime.GOOS == "linux" { ra, err = ResolveUnixAddr(net, taddr)
// Test abstract unix domain socket, a Linux-ism if err != nil {
doTestPacket(t, "unixgram", "@gotest1/net", "@gotest1/net", isEmpty) t.Errorf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err)
return
} }
} }
c, err := ListenPacket(net, laddr)
if err != nil {
t.Errorf("ListenPacket(%q, %q) faild: %v", net, laddr, err)
return
}
defer c.Close()
c.SetReadDeadline(time.Now().Add(1 * time.Second))
var wb []byte
if !isEmpty {
wb = []byte("DatagramPacketConnClient by ListenPacket\n")
}
if n, err := c.WriteTo(wb[0:], ra); err != nil || n != len(wb) {
t.Errorf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb))
return
}
rb := make([]byte, 1024)
if n, _, err := c.ReadFrom(rb[0:]); err != nil || n != len(wb) {
t.Errorf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb))
return
}
} }

View File

@ -16,7 +16,7 @@ import (
var listenerBacklog = maxListenerBacklog() var listenerBacklog = maxListenerBacklog()
// Generic socket creation. // Generic socket creation.
func socket(net string, f, t, p int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
// See ../syscall/exec.go for description of ForkLock. // See ../syscall/exec.go for description of ForkLock.
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
s, err := syscall.Socket(f, t, p) s, err := syscall.Socket(f, t, p)
@ -27,7 +27,7 @@ func socket(net string, f, t, p int, la, ra syscall.Sockaddr, toAddr func(syscal
syscall.CloseOnExec(s) syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock() syscall.ForkLock.RUnlock()
err = setDefaultSockopts(s, f, t) err = setDefaultSockopts(s, f, t, ipv6only)
if err != nil { if err != nil {
closesocket(s) closesocket(s)
return nil, err return nil, err

View File

@ -9,7 +9,6 @@
package net package net
import ( import (
"bytes"
"os" "os"
"syscall" "syscall"
"time" "time"
@ -98,7 +97,7 @@ func setIPv4MreqToInterface(mreq *syscall.IPMreq, ifi *Interface) error {
} }
} }
done: done:
if bytes.Equal(mreq.Multiaddr[:], IPv4zero.To4()) { if bytesEqual(mreq.Multiaddr[:], IPv4zero.To4()) {
return errNoSuchMulticastInterface return errNoSuchMulticastInterface
} }
return nil return nil

View File

@ -13,12 +13,17 @@ import (
"syscall" "syscall"
) )
func setDefaultSockopts(s, f, t int) error { func setDefaultSockopts(s, f, t int, ipv6only bool) error {
switch f { switch f {
case syscall.AF_INET6: case syscall.AF_INET6:
// Allow both IP versions even if the OS default is otherwise. if ipv6only {
// Note that some operating systems never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1)
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) } else {
// Allow both IP versions even if the OS default
// is otherwise. Note that some operating systems
// never admit this option.
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
}
} }
// Allow broadcast. // Allow broadcast.
err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)

View File

@ -11,12 +11,17 @@ import (
"syscall" "syscall"
) )
func setDefaultSockopts(s, f, t int) error { func setDefaultSockopts(s, f, t int, ipv6only bool) error {
switch f { switch f {
case syscall.AF_INET6: case syscall.AF_INET6:
// Allow both IP versions even if the OS default is otherwise. if ipv6only {
// Note that some operating systems never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1)
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) } else {
// Allow both IP versions even if the OS default
// is otherwise. Note that some operating systems
// never admit this option.
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
}
} }
// Allow broadcast. // Allow broadcast.
err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)

View File

@ -11,12 +11,17 @@ import (
"syscall" "syscall"
) )
func setDefaultSockopts(s syscall.Handle, f, t int) error { func setDefaultSockopts(s syscall.Handle, f, t int, ipv6only bool) error {
switch f { switch f {
case syscall.AF_INET6: case syscall.AF_INET6:
// Allow both IP versions even if the OS default is otherwise. if ipv6only {
// Note that some operating systems never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1)
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) } else {
// Allow both IP versions even if the OS default
// is otherwise. Note that some operating systems
// never admit this option.
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
}
} }
// Allow broadcast. // Allow broadcast.
syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)

Some files were not shown because too many files have changed in this diff Show More