libgo: Update to current version of master library.

From-SVN: r193688
This commit is contained in:
Ian Lance Taylor 2012-11-21 07:03:38 +00:00
parent a51fb17f48
commit fabcaa8df3
321 changed files with 62096 additions and 19248 deletions

View File

@ -1,4 +1,4 @@
291d9f1baf75 a070de932857
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.

View File

@ -773,7 +773,6 @@ go_net_files = \
go/net/lookup_unix.go \ go/net/lookup_unix.go \
go/net/mac.go \ go/net/mac.go \
go/net/net.go \ go/net/net.go \
go/net/net_posix.go \
go/net/parse.go \ go/net/parse.go \
go/net/pipe.go \ go/net/pipe.go \
go/net/port.go \ go/net/port.go \
@ -1117,6 +1116,7 @@ go_crypto_x509_files = \
go/crypto/x509/pkcs8.go \ go/crypto/x509/pkcs8.go \
go/crypto/x509/root.go \ go/crypto/x509/root.go \
go/crypto/x509/root_unix.go \ go/crypto/x509/root_unix.go \
go/crypto/x509/sec1.go \
go/crypto/x509/verify.go \ go/crypto/x509/verify.go \
go/crypto/x509/x509.go go/crypto/x509/x509.go
@ -1245,10 +1245,17 @@ go_exp_terminal_files = \
go/exp/terminal/terminal.go \ go/exp/terminal/terminal.go \
go/exp/terminal/util.go go/exp/terminal/util.go
go_exp_types_files = \ go_exp_types_files = \
go/exp/types/builtins.go \
go/exp/types/check.go \ go/exp/types/check.go \
go/exp/types/const.go \ go/exp/types/const.go \
go/exp/types/conversions.go \
go/exp/types/errors.go \
go/exp/types/exportdata.go \ go/exp/types/exportdata.go \
go/exp/types/expr.go \
go/exp/types/gcimporter.go \ go/exp/types/gcimporter.go \
go/exp/types/operand.go \
go/exp/types/predicates.go \
go/exp/types/stmt.go \
go/exp/types/types.go \ go/exp/types/types.go \
go/exp/types/universe.go go/exp/types/universe.go
go_exp_utf8string_files = \ go_exp_utf8string_files = \
@ -1329,6 +1336,7 @@ go_image_jpeg_files = \
go/image/jpeg/huffman.go \ go/image/jpeg/huffman.go \
go/image/jpeg/idct.go \ go/image/jpeg/idct.go \
go/image/jpeg/reader.go \ go/image/jpeg/reader.go \
go/image/jpeg/scan.go \
go/image/jpeg/writer.go go/image/jpeg/writer.go
go_image_png_files = \ go_image_png_files = \

View File

@ -1027,7 +1027,6 @@ go_net_files = \
go/net/lookup_unix.go \ go/net/lookup_unix.go \
go/net/mac.go \ go/net/mac.go \
go/net/net.go \ go/net/net.go \
go/net/net_posix.go \
go/net/parse.go \ go/net/parse.go \
go/net/pipe.go \ go/net/pipe.go \
go/net/port.go \ go/net/port.go \
@ -1312,6 +1311,7 @@ go_crypto_x509_files = \
go/crypto/x509/pkcs8.go \ go/crypto/x509/pkcs8.go \
go/crypto/x509/root.go \ go/crypto/x509/root.go \
go/crypto/x509/root_unix.go \ go/crypto/x509/root_unix.go \
go/crypto/x509/sec1.go \
go/crypto/x509/verify.go \ go/crypto/x509/verify.go \
go/crypto/x509/x509.go go/crypto/x509/x509.go
@ -1463,10 +1463,17 @@ go_exp_terminal_files = \
go/exp/terminal/util.go go/exp/terminal/util.go
go_exp_types_files = \ go_exp_types_files = \
go/exp/types/builtins.go \
go/exp/types/check.go \ go/exp/types/check.go \
go/exp/types/const.go \ go/exp/types/const.go \
go/exp/types/conversions.go \
go/exp/types/errors.go \
go/exp/types/exportdata.go \ go/exp/types/exportdata.go \
go/exp/types/expr.go \
go/exp/types/gcimporter.go \ go/exp/types/gcimporter.go \
go/exp/types/operand.go \
go/exp/types/predicates.go \
go/exp/types/stmt.go \
go/exp/types/types.go \ go/exp/types/types.go \
go/exp/types/universe.go go/exp/types/universe.go
@ -1557,6 +1564,7 @@ go_image_jpeg_files = \
go/image/jpeg/huffman.go \ go/image/jpeg/huffman.go \
go/image/jpeg/idct.go \ go/image/jpeg/idct.go \
go/image/jpeg/reader.go \ go/image/jpeg/reader.go \
go/image/jpeg/scan.go \
go/image/jpeg/writer.go go/image/jpeg/writer.go
go_image_png_files = \ go_image_png_files = \

View File

@ -72,6 +72,18 @@ func cString(b []byte) string {
} }
func (tr *Reader) octal(b []byte) int64 { func (tr *Reader) octal(b []byte) int64 {
// Check for binary format first.
if len(b) > 0 && b[0]&0x80 != 0 {
var x int64
for i, c := range b {
if i == 0 {
c &= 0x7f // ignore signal bit in first byte
}
x = x<<8 | int64(c)
}
return x
}
// Removing leading spaces. // Removing leading spaces.
for len(b) > 0 && b[0] == ' ' { for len(b) > 0 && b[0] == ' ' {
b = b[1:] b = b[1:]

View File

@ -5,7 +5,10 @@
package tar package tar
import ( import (
"bytes"
"io/ioutil"
"os" "os"
"reflect"
"testing" "testing"
"time" "time"
) )
@ -54,3 +57,44 @@ func (symlink) Mode() os.FileMode { return os.ModeSymlink }
func (symlink) ModTime() time.Time { return time.Time{} } func (symlink) ModTime() time.Time { return time.Time{} }
func (symlink) IsDir() bool { return false } func (symlink) IsDir() bool { return false }
func (symlink) Sys() interface{} { return nil } func (symlink) Sys() interface{} { return nil }
func TestRoundTrip(t *testing.T) {
data := []byte("some file contents")
var b bytes.Buffer
tw := NewWriter(&b)
hdr := &Header{
Name: "file.txt",
Uid: 1 << 21, // too big for 8 octal digits
Size: int64(len(data)),
ModTime: time.Now(),
}
// tar only supports second precision.
hdr.ModTime = hdr.ModTime.Add(-time.Duration(hdr.ModTime.Nanosecond()) * time.Nanosecond)
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("tw.WriteHeader: %v", err)
}
if _, err := tw.Write(data); err != nil {
t.Fatalf("tw.Write: %v", err)
}
if err := tw.Close(); err != nil {
t.Fatalf("tw.Close: %v", err)
}
// Read it back.
tr := NewReader(&b)
rHdr, err := tr.Next()
if err != nil {
t.Fatalf("tr.Next: %v", err)
}
if !reflect.DeepEqual(rHdr, hdr) {
t.Errorf("Header mismatch.\n got %+v\nwant %+v", rHdr, hdr)
}
rData, err := ioutil.ReadAll(tr)
if err != nil {
t.Fatalf("Read: %v", err)
}
if !bytes.Equal(rData, data) {
t.Errorf("Data mismatch.\n got %q\nwant %q", rData, data)
}
}

View File

@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
"time"
) )
var ( var (
@ -110,6 +111,12 @@ func (tw *Writer) numeric(b []byte, x int64) {
b[0] |= 0x80 // highest bit indicates binary format b[0] |= 0x80 // highest bit indicates binary format
} }
var (
minTime = time.Unix(0, 0)
// There is room for 11 octal digits (33 bits) of mtime.
maxTime = minTime.Add((1<<33 - 1) * time.Second)
)
// WriteHeader writes hdr and prepares to accept the file's contents. // WriteHeader writes hdr and prepares to accept the file's contents.
// WriteHeader calls Flush if it is not the first header. // WriteHeader calls Flush if it is not the first header.
// Calling after a Close will return ErrWriteAfterClose. // Calling after a Close will return ErrWriteAfterClose.
@ -133,11 +140,17 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
// TODO(dsymonds): handle names longer than 100 chars // TODO(dsymonds): handle names longer than 100 chars
copy(s.next(100), []byte(hdr.Name)) copy(s.next(100), []byte(hdr.Name))
// Handle out of range ModTime carefully.
var modTime int64
if !hdr.ModTime.Before(minTime) && !hdr.ModTime.After(maxTime) {
modTime = hdr.ModTime.Unix()
}
tw.octal(s.next(8), hdr.Mode) // 100:108 tw.octal(s.next(8), hdr.Mode) // 100:108
tw.numeric(s.next(8), int64(hdr.Uid)) // 108:116 tw.numeric(s.next(8), int64(hdr.Uid)) // 108:116
tw.numeric(s.next(8), int64(hdr.Gid)) // 116:124 tw.numeric(s.next(8), int64(hdr.Gid)) // 116:124
tw.numeric(s.next(12), hdr.Size) // 124:136 tw.numeric(s.next(12), hdr.Size) // 124:136
tw.numeric(s.next(12), hdr.ModTime.Unix()) // 136:148 tw.numeric(s.next(12), modTime) // 136:148
s.next(8) // chksum (148:156) s.next(8) // chksum (148:156)
s.next(1)[0] = hdr.Typeflag // 156:157 s.next(1)[0] = hdr.Typeflag // 156:157
tw.cString(s.next(100), hdr.Linkname) // linkname (157:257) tw.cString(s.next(100), hdr.Linkname) // linkname (157:257)

View File

@ -238,9 +238,12 @@ func readDirectoryHeader(f *File, r io.Reader) error {
if len(f.Extra) > 0 { if len(f.Extra) > 0 {
b := readBuf(f.Extra) b := readBuf(f.Extra)
for len(b) > 0 { for len(b) >= 4 { // need at least tag and size
tag := b.uint16() tag := b.uint16()
size := b.uint16() size := b.uint16()
if int(size) > len(b) {
return ErrFormat
}
if tag == zip64ExtraId { if tag == zip64ExtraId {
// update directory values from the zip64 extra block // update directory values from the zip64 extra block
eb := readBuf(b) eb := readBuf(b)
@ -256,6 +259,10 @@ func readDirectoryHeader(f *File, r io.Reader) error {
} }
b = b[size:] b = b[size:]
} }
// Should have consumed the whole header.
if len(b) != 0 {
return ErrFormat
}
} }
return nil return nil
} }

View File

@ -173,3 +173,85 @@ func TestZip64(t *testing.T) {
t.Errorf("UncompressedSize64 %d, want %d", got, want) t.Errorf("UncompressedSize64 %d, want %d", got, want)
} }
} }
func testInvalidHeader(h *FileHeader, t *testing.T) {
var buf bytes.Buffer
z := NewWriter(&buf)
f, err := z.CreateHeader(h)
if err != nil {
t.Fatalf("error creating header: %v", err)
}
if _, err := f.Write([]byte("hi")); err != nil {
t.Fatalf("error writing content: %v", err)
}
if err := z.Close(); err != nil {
t.Fatalf("error closing zip writer: %v", err)
}
b := buf.Bytes()
if _, err = NewReader(bytes.NewReader(b), int64(len(b))); err != ErrFormat {
t.Fatalf("got %v, expected ErrFormat", err)
}
}
func testValidHeader(h *FileHeader, t *testing.T) {
var buf bytes.Buffer
z := NewWriter(&buf)
f, err := z.CreateHeader(h)
if err != nil {
t.Fatalf("error creating header: %v", err)
}
if _, err := f.Write([]byte("hi")); err != nil {
t.Fatalf("error writing content: %v", err)
}
if err := z.Close(); err != nil {
t.Fatalf("error closing zip writer: %v", err)
}
b := buf.Bytes()
if _, err = NewReader(bytes.NewReader(b), int64(len(b))); err != nil {
t.Fatalf("got %v, expected nil", err)
}
}
// Issue 4302.
func TestHeaderInvalidTagAndSize(t *testing.T) {
const timeFormat = "20060102T150405.000.txt"
ts := time.Now()
filename := ts.Format(timeFormat)
h := FileHeader{
Name: filename,
Method: Deflate,
Extra: []byte(ts.Format(time.RFC3339Nano)), // missing tag and len
}
h.SetModTime(ts)
testInvalidHeader(&h, t)
}
func TestHeaderTooShort(t *testing.T) {
h := FileHeader{
Name: "foo.txt",
Method: Deflate,
Extra: []byte{zip64ExtraId}, // missing size
}
testInvalidHeader(&h, t)
}
// Issue 4393. It is valid to have an extra data header
// which contains no body.
func TestZeroLengthHeader(t *testing.T) {
h := FileHeader{
Name: "extadata.txt",
Method: Deflate,
Extra: []byte{
85, 84, 5, 0, 3, 154, 144, 195, 77, // tag 21589 size 5
85, 120, 0, 0, // tag 30805 size 0
},
}
testValidHeader(&h, t)
}

View File

@ -567,6 +567,36 @@ func (b *Writer) WriteString(s string) (int, error) {
return nn, nil return nn, nil
} }
// ReadFrom implements io.ReaderFrom.
func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
if b.Buffered() == 0 {
if w, ok := b.wr.(io.ReaderFrom); ok {
return w.ReadFrom(r)
}
}
var m int
for {
m, err = r.Read(b.buf[b.n:])
if m == 0 {
break
}
b.n += m
n += int64(m)
if b.Available() == 0 {
if err1 := b.Flush(); err1 != nil {
return n, err1
}
}
if err != nil {
break
}
}
if err == io.EOF {
err = nil
}
return n, err
}
// buffered input and output // buffered input and output
// ReadWriter stores pointers to a Reader and a Writer. // ReadWriter stores pointers to a Reader and a Writer.

View File

@ -763,8 +763,8 @@ func testReadLineNewlines(t *testing.T, input string, expect []readLineResult) {
} }
} }
func TestReaderWriteTo(t *testing.T) { func createTestInput(n int) []byte {
input := make([]byte, 8192) input := make([]byte, n)
for i := range input { for i := range input {
// 101 and 251 are arbitrary prime numbers. // 101 and 251 are arbitrary prime numbers.
// The idea is to create an input sequence // The idea is to create an input sequence
@ -774,7 +774,12 @@ func TestReaderWriteTo(t *testing.T) {
input[i] ^= byte(i / 101) input[i] ^= byte(i / 101)
} }
} }
r := NewReader(bytes.NewBuffer(input)) return input
}
func TestReaderWriteTo(t *testing.T) {
input := createTestInput(8192)
r := NewReader(onlyReader{bytes.NewBuffer(input)})
w := new(bytes.Buffer) w := new(bytes.Buffer)
if n, err := r.WriteTo(w); err != nil || n != int64(len(input)) { if n, err := r.WriteTo(w); err != nil || n != int64(len(input)) {
t.Fatalf("r.WriteTo(w) = %d, %v, want %d, nil", n, err, len(input)) t.Fatalf("r.WriteTo(w) = %d, %v, want %d, nil", n, err, len(input))
@ -817,12 +822,129 @@ func TestReaderWriteToErrors(t *testing.T) {
} }
} }
func TestWriterReadFrom(t *testing.T) {
ws := []func(io.Writer) io.Writer{
func(w io.Writer) io.Writer { return onlyWriter{w} },
func(w io.Writer) io.Writer { return w },
}
rs := []func(io.Reader) io.Reader{
iotest.DataErrReader,
func(r io.Reader) io.Reader { return r },
}
for ri, rfunc := range rs {
for wi, wfunc := range ws {
input := createTestInput(8192)
b := new(bytes.Buffer)
w := NewWriter(wfunc(b))
r := rfunc(bytes.NewBuffer(input))
if n, err := w.ReadFrom(r); err != nil || n != int64(len(input)) {
t.Errorf("ws[%d],rs[%d]: w.ReadFrom(r) = %d, %v, want %d, nil", wi, ri, n, err, len(input))
continue
}
if got, want := b.String(), string(input); got != want {
t.Errorf("ws[%d], rs[%d]:\ngot %q\nwant %q\n", wi, ri, got, want)
}
}
}
}
type errorReaderFromTest struct {
rn, wn int
rerr, werr error
expected error
}
func (r errorReaderFromTest) Read(p []byte) (int, error) {
return len(p) * r.rn, r.rerr
}
func (w errorReaderFromTest) Write(p []byte) (int, error) {
return len(p) * w.wn, w.werr
}
var errorReaderFromTests = []errorReaderFromTest{
{0, 1, io.EOF, nil, nil},
{1, 1, io.EOF, nil, nil},
{0, 1, io.ErrClosedPipe, nil, io.ErrClosedPipe},
{0, 0, io.ErrClosedPipe, io.ErrShortWrite, io.ErrClosedPipe},
{1, 0, nil, io.ErrShortWrite, io.ErrShortWrite},
}
func TestWriterReadFromErrors(t *testing.T) {
for i, rw := range errorReaderFromTests {
w := NewWriter(rw)
if _, err := w.ReadFrom(rw); err != rw.expected {
t.Errorf("w.ReadFrom(errorReaderFromTests[%d]) = _, %v, want _,%v", i, err, rw.expected)
}
}
}
// TestWriterReadFromCounts tests that using io.Copy to copy into a
// bufio.Writer does not prematurely flush the buffer. For example, when
// buffering writes to a network socket, excessive network writes should be
// avoided.
func TestWriterReadFromCounts(t *testing.T) {
var w0 writeCountingDiscard
b0 := NewWriterSize(&w0, 1234)
b0.WriteString(strings.Repeat("x", 1000))
if w0 != 0 {
t.Fatalf("write 1000 'x's: got %d writes, want 0", w0)
}
b0.WriteString(strings.Repeat("x", 200))
if w0 != 0 {
t.Fatalf("write 1200 'x's: got %d writes, want 0", w0)
}
io.Copy(b0, onlyReader{strings.NewReader(strings.Repeat("x", 30))})
if w0 != 0 {
t.Fatalf("write 1230 'x's: got %d writes, want 0", w0)
}
io.Copy(b0, onlyReader{strings.NewReader(strings.Repeat("x", 9))})
if w0 != 1 {
t.Fatalf("write 1239 'x's: got %d writes, want 1", w0)
}
var w1 writeCountingDiscard
b1 := NewWriterSize(&w1, 1234)
b1.WriteString(strings.Repeat("x", 1200))
b1.Flush()
if w1 != 1 {
t.Fatalf("flush 1200 'x's: got %d writes, want 1", w1)
}
b1.WriteString(strings.Repeat("x", 89))
if w1 != 1 {
t.Fatalf("write 1200 + 89 'x's: got %d writes, want 1", w1)
}
io.Copy(b1, onlyReader{strings.NewReader(strings.Repeat("x", 700))})
if w1 != 1 {
t.Fatalf("write 1200 + 789 'x's: got %d writes, want 1", w1)
}
io.Copy(b1, onlyReader{strings.NewReader(strings.Repeat("x", 600))})
if w1 != 2 {
t.Fatalf("write 1200 + 1389 'x's: got %d writes, want 2", w1)
}
b1.Flush()
if w1 != 3 {
t.Fatalf("flush 1200 + 1389 'x's: got %d writes, want 3", w1)
}
}
// A writeCountingDiscard is like ioutil.Discard and counts the number of times
// Write is called on it.
type writeCountingDiscard int
func (w *writeCountingDiscard) Write(p []byte) (int, error) {
*w++
return len(p), nil
}
// An onlyReader only implements io.Reader, no matter what other methods the underlying implementation may have. // An onlyReader only implements io.Reader, no matter what other methods the underlying implementation may have.
type onlyReader struct { type onlyReader struct {
r io.Reader r io.Reader
} }
func (r *onlyReader) Read(b []byte) (int, error) { func (r onlyReader) Read(b []byte) (int, error) {
return r.r.Read(b) return r.r.Read(b)
} }
@ -831,7 +953,7 @@ type onlyWriter struct {
w io.Writer w io.Writer
} }
func (w *onlyWriter) Write(b []byte) (int, error) { func (w onlyWriter) Write(b []byte) (int, error) {
return w.w.Write(b) return w.w.Write(b)
} }
@ -840,7 +962,7 @@ func BenchmarkReaderCopyOptimal(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
b.StopTimer() b.StopTimer()
src := NewReader(bytes.NewBuffer(make([]byte, 8192))) src := NewReader(bytes.NewBuffer(make([]byte, 8192)))
dst := &onlyWriter{new(bytes.Buffer)} dst := onlyWriter{new(bytes.Buffer)}
b.StartTimer() b.StartTimer()
io.Copy(dst, src) io.Copy(dst, src)
} }
@ -850,8 +972,8 @@ func BenchmarkReaderCopyUnoptimal(b *testing.B) {
// Unoptimal case is where the underlying reader doesn't implement io.WriterTo // Unoptimal case is where the underlying reader doesn't implement io.WriterTo
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
b.StopTimer() b.StopTimer()
src := NewReader(&onlyReader{bytes.NewBuffer(make([]byte, 8192))}) src := NewReader(onlyReader{bytes.NewBuffer(make([]byte, 8192))})
dst := &onlyWriter{new(bytes.Buffer)} dst := onlyWriter{new(bytes.Buffer)}
b.StartTimer() b.StartTimer()
io.Copy(dst, src) io.Copy(dst, src)
} }
@ -860,8 +982,39 @@ func BenchmarkReaderCopyUnoptimal(b *testing.B) {
func BenchmarkReaderCopyNoWriteTo(b *testing.B) { func BenchmarkReaderCopyNoWriteTo(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
b.StopTimer() b.StopTimer()
src := &onlyReader{NewReader(bytes.NewBuffer(make([]byte, 8192)))} src := onlyReader{NewReader(bytes.NewBuffer(make([]byte, 8192)))}
dst := &onlyWriter{new(bytes.Buffer)} dst := onlyWriter{new(bytes.Buffer)}
b.StartTimer()
io.Copy(dst, src)
}
}
func BenchmarkWriterCopyOptimal(b *testing.B) {
// Optimal case is where the underlying writer implements io.ReaderFrom
for i := 0; i < b.N; i++ {
b.StopTimer()
src := onlyReader{bytes.NewBuffer(make([]byte, 8192))}
dst := NewWriter(new(bytes.Buffer))
b.StartTimer()
io.Copy(dst, src)
}
}
func BenchmarkWriterCopyUnoptimal(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
src := onlyReader{bytes.NewBuffer(make([]byte, 8192))}
dst := NewWriter(onlyWriter{new(bytes.Buffer)})
b.StartTimer()
io.Copy(dst, src)
}
}
func BenchmarkWriterCopyNoReadFrom(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
src := onlyReader{bytes.NewBuffer(make([]byte, 8192))}
dst := onlyWriter{NewWriter(new(bytes.Buffer))}
b.StartTimer() b.StartTimer()
io.Copy(dst, src) io.Copy(dst, src)
} }

View File

@ -251,10 +251,10 @@ func TestReadFrom(t *testing.T) {
func TestWriteTo(t *testing.T) { func TestWriteTo(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
s := fillBytes(t, "TestReadFrom (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) s := fillBytes(t, "TestWriteTo (1)", &buf, "", 5, testBytes[0:len(testBytes)/i])
var b Buffer var b Buffer
buf.WriteTo(&b) buf.WriteTo(&b)
empty(t, "TestReadFrom (2)", &b, s, make([]byte, len(data))) empty(t, "TestWriteTo (2)", &b, s, make([]byte, len(data)))
} }
} }

View File

@ -10,7 +10,7 @@ import (
"unicode/utf8" "unicode/utf8"
) )
// A Reader implements the io.Reader, io.ReaderAt, io.Seeker, // A Reader implements the io.Reader, io.ReaderAt, io.WriterTo, io.Seeker,
// io.ByteScanner, and io.RuneScanner interfaces by reading from // io.ByteScanner, and io.RuneScanner interfaces by reading from
// a byte slice. // a byte slice.
// Unlike a Buffer, a Reader is read-only and supports seeking. // Unlike a Buffer, a Reader is read-only and supports seeking.
@ -121,5 +121,24 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) {
return abs, nil return abs, nil
} }
// WriteTo implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
r.prevRune = -1
if r.i >= len(r.s) {
return 0, io.EOF
}
b := r.s[r.i:]
m, err := w.Write(b)
if m > len(b) {
panic("bytes.Reader.WriteTo: invalid Write count")
}
r.i += m
n = int64(m)
if m != len(b) && err == nil {
err = io.ErrShortWrite
}
return
}
// NewReader returns a new Reader reading from b. // NewReader returns a new Reader reading from b.
func NewReader(b []byte) *Reader { return &Reader{b, 0, -1} } func NewReader(b []byte) *Reader { return &Reader{b, 0, -1} }

View File

@ -86,3 +86,24 @@ func TestReaderAt(t *testing.T) {
} }
} }
} }
func TestReaderWriteTo(t *testing.T) {
for i := 3; i < 30; i += 3 {
s := data[:len(data)/i]
r := NewReader(testBytes[:len(testBytes)/i])
var b Buffer
n, err := r.WriteTo(&b)
if expect := int64(len(s)); n != expect {
t.Errorf("got %v; want %v", n, expect)
}
if err != nil {
t.Errorf("got error = %v; want nil", err)
}
if b.String() != s {
t.Errorf("got string %q; want %q", b.String(), s)
}
if r.Len() != 0 {
t.Errorf("reader contains %v bytes; want 0", r.Len())
}
}
}

View File

@ -14,21 +14,16 @@ import (
// because the error handling was verbose. Instead, any error is kept and can // because the error handling was verbose. Instead, any error is kept and can
// be checked afterwards. // be checked afterwards.
type bitReader struct { type bitReader struct {
r byteReader r io.ByteReader
n uint64 n uint64
bits uint bits uint
err error err error
} }
// bitReader needs to read bytes from an io.Reader. We attempt to convert the // newBitReader returns a new bitReader reading from r. If r is not
// given io.Reader to this interface and, if it doesn't already fit, we wrap in // already an io.ByteReader, it will be converted via a bufio.Reader.
// a bufio.Reader.
type byteReader interface {
ReadByte() (byte, error)
}
func newBitReader(r io.Reader) bitReader { func newBitReader(r io.Reader) bitReader {
byter, ok := r.(byteReader) byter, ok := r.(io.ByteReader)
if !ok { if !ok {
byter = bufio.NewReader(r) byter = bufio.NewReader(r)
} }

View File

@ -208,8 +208,8 @@ type decompressor struct {
h1, h2 huffmanDecoder h1, h2 huffmanDecoder
// Length arrays used to define Huffman codes. // Length arrays used to define Huffman codes.
bits [maxLit + maxDist]int bits *[maxLit + maxDist]int
codebits [numCodes]int codebits *[numCodes]int
// Output history, buffer. // Output history, buffer.
hist *[maxHist]byte hist *[maxHist]byte
@ -692,6 +692,8 @@ func makeReader(r io.Reader) Reader {
// finished reading. // finished reading.
func NewReader(r io.Reader) io.ReadCloser { func NewReader(r io.Reader) io.ReadCloser {
var f decompressor var f decompressor
f.bits = new([maxLit + maxDist]int)
f.codebits = new([numCodes]int)
f.r = makeReader(r) f.r = makeReader(r)
f.hist = new([maxHist]byte) f.hist = new([maxHist]byte)
f.step = (*decompressor).nextBlock f.step = (*decompressor).nextBlock
@ -707,6 +709,8 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var f decompressor var f decompressor
f.r = makeReader(r) f.r = makeReader(r)
f.hist = new([maxHist]byte) f.hist = new([maxHist]byte)
f.bits = new([maxLit + maxDist]int)
f.codebits = new([numCodes]int)
f.step = (*decompressor).nextBlock f.step = (*decompressor).nextBlock
f.setDict(dict) f.setDict(dict)
return &f return &f

View File

@ -176,7 +176,7 @@ func (l *List) MoveToBack(e *Element) {
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// PuchBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same.
func (l *List) PushBackList(other *List) { func (l *List) PushBackList(other *List) {
l.lazyInit() l.lazyInit()

View File

@ -33,6 +33,9 @@ type cbcEncrypter cbc
// mode, using the given Block. The length of iv must be the same as the // mode, using the given Block. The length of iv must be the same as the
// Block's block size. // Block's block size.
func NewCBCEncrypter(b Block, iv []byte) BlockMode { func NewCBCEncrypter(b Block, iv []byte) BlockMode {
if len(iv) != b.BlockSize() {
panic("cipher.NewCBCEncrypter: IV length must equal block size")
}
return (*cbcEncrypter)(newCBC(b, iv)) return (*cbcEncrypter)(newCBC(b, iv))
} }
@ -58,6 +61,9 @@ type cbcDecrypter cbc
// mode, using the given Block. The length of iv must be the same as the // mode, using the given Block. The length of iv must be the same as the
// Block's block size and must match the iv used to encrypt the data. // Block's block size and must match the iv used to encrypt the data.
func NewCBCDecrypter(b Block, iv []byte) BlockMode { func NewCBCDecrypter(b Block, iv []byte) BlockMode {
if len(iv) != b.BlockSize() {
panic("cipher.NewCBCDecrypter: IV length must equal block size")
}
return (*cbcDecrypter)(newCBC(b, iv)) return (*cbcDecrypter)(newCBC(b, iv))
} }

View File

@ -17,6 +17,9 @@ type cfb struct {
// using the given Block. The iv must be the same length as the Block's block // using the given Block. The iv must be the same length as the Block's block
// size. // size.
func NewCFBEncrypter(block Block, iv []byte) Stream { func NewCFBEncrypter(block Block, iv []byte) Stream {
if len(iv) != block.BlockSize() {
panic("cipher.NewCBFEncrypter: IV length must equal block size")
}
return newCFB(block, iv, false) return newCFB(block, iv, false)
} }
@ -24,6 +27,9 @@ func NewCFBEncrypter(block Block, iv []byte) Stream {
// using the given Block. The iv must be the same length as the Block's block // using the given Block. The iv must be the same length as the Block's block
// size. // size.
func NewCFBDecrypter(block Block, iv []byte) Stream { func NewCFBDecrypter(block Block, iv []byte) Stream {
if len(iv) != block.BlockSize() {
panic("cipher.NewCBFEncrypter: IV length must equal block size")
}
return newCFB(block, iv, true) return newCFB(block, iv, true)
} }

View File

@ -23,7 +23,7 @@ type ctr struct {
// counter mode. The length of iv must be the same as the Block's block size. // counter mode. The length of iv must be the same as the Block's block size.
func NewCTR(block Block, iv []byte) Stream { func NewCTR(block Block, iv []byte) Stream {
if len(iv) != block.BlockSize() { if len(iv) != block.BlockSize() {
panic("cipher.NewCTR: iv length must equal block size") panic("cipher.NewCTR: IV length must equal block size")
} }
return &ctr{ return &ctr{

View File

@ -0,0 +1,283 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cipher_test
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"os"
)
func ExampleNewCBCDecrypter() {
key := []byte("example key 1234")
ciphertext, _ := hex.DecodeString("f363f3ccdcb12bb883abf484ba77d9cd7d32b5baecb3d4b1b3e0e4beffdb3ded")
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// The IV needs to be unique, but not secure. Therefore it's common to
// include it at the beginning of the ciphertext.
if len(ciphertext) < aes.BlockSize {
panic("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
// CBC mode always works in whole blocks.
if len(ciphertext)%aes.BlockSize != 0 {
panic("ciphertext is not a multiple of the block size")
}
mode := cipher.NewCBCDecrypter(block, iv)
// CryptBlocks can work in-place if the two arguments are the same.
mode.CryptBlocks(ciphertext, ciphertext)
// If the original plaintext lengths are not a multiple of the block
// size, padding would have to be added when encrypting, which would be
// removed at this point. For an example, see
// https://tools.ietf.org/html/rfc5246#section-6.2.3.2. However, it's
// critical to note that ciphertexts must be authenticated (i.e. by
// using crypto/hmac) before being decrypted in order to avoid creating
// a padding oracle.
fmt.Printf("%s\n", ciphertext)
// Output: exampleplaintext
}
func ExampleNewCBCEncrypter() {
key := []byte("example key 1234")
plaintext := []byte("exampleplaintext")
// CBC mode works on blocks so plaintexts may need to be padded to the
// next whole block. For an example of such padding, see
// https://tools.ietf.org/html/rfc5246#section-6.2.3.2. Here we'll
// assume that the plaintext is already of the correct length.
if len(plaintext)%aes.BlockSize != 0 {
panic("plaintext is not a multiple of the block size")
}
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// The IV needs to be unique, but not secure. Therefore it's common to
// include it at the beginning of the ciphertext.
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
panic(err)
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext[aes.BlockSize:], plaintext)
// It's important to remember that ciphertexts must be authenticated
// (i.e. by using crypto/hmac) as well as being encrypted in order to
// be secure.
fmt.Printf("%x\n", ciphertext)
}
func ExampleNewCFBDecrypter() {
key := []byte("example key 1234")
ciphertext, _ := hex.DecodeString("22277966616d9bc47177bd02603d08c9a67d5380d0fe8cf3b44438dff7b9")
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// The IV needs to be unique, but not secure. Therefore it's common to
// include it at the beginning of the ciphertext.
if len(ciphertext) < aes.BlockSize {
panic("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
// XORKeyStream can work in-place if the two arguments are the same.
stream.XORKeyStream(ciphertext, ciphertext)
fmt.Printf("%s", ciphertext)
// Output: some plaintext
}
func ExampleNewCFBEncrypter() {
key := []byte("example key 1234")
plaintext := []byte("some plaintext")
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// The IV needs to be unique, but not secure. Therefore it's common to
// include it at the beginning of the ciphertext.
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
panic(err)
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
// It's important to remember that ciphertexts must be authenticated
// (i.e. by using crypto/hmac) as well as being encrypted in order to
// be secure.
}
func ExampleNewCTR() {
key := []byte("example key 1234")
plaintext := []byte("some plaintext")
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// The IV needs to be unique, but not secure. Therefore it's common to
// include it at the beginning of the ciphertext.
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
panic(err)
}
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
// It's important to remember that ciphertexts must be authenticated
// (i.e. by using crypto/hmac) as well as being encrypted in order to
// be secure.
// CTR mode is the same for both encryption and decryption, so we can
// also decrypt that ciphertext with NewCTR.
plaintext2 := make([]byte, len(plaintext))
stream = cipher.NewCTR(block, iv)
stream.XORKeyStream(plaintext2, ciphertext[aes.BlockSize:])
fmt.Printf("%s\n", plaintext2)
// Output: some plaintext
}
func ExampleNewOFB() {
key := []byte("example key 1234")
plaintext := []byte("some plaintext")
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// The IV needs to be unique, but not secure. Therefore it's common to
// include it at the beginning of the ciphertext.
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
panic(err)
}
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
// It's important to remember that ciphertexts must be authenticated
// (i.e. by using crypto/hmac) as well as being encrypted in order to
// be secure.
// OFB mode is the same for both encryption and decryption, so we can
// also decrypt that ciphertext with NewOFB.
plaintext2 := make([]byte, len(plaintext))
stream = cipher.NewOFB(block, iv)
stream.XORKeyStream(plaintext2, ciphertext[aes.BlockSize:])
fmt.Printf("%s\n", plaintext2)
// Output: some plaintext
}
func ExampleStreamReader() {
key := []byte("example key 1234")
inFile, err := os.Open("encrypted-file")
if err != nil {
panic(err)
}
defer inFile.Close()
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// If the key is unique for each ciphertext, then it's ok to use a zero
// IV.
var iv [aes.BlockSize]byte
stream := cipher.NewOFB(block, iv[:])
outFile, err := os.OpenFile("decrypted-file", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
panic(err)
}
defer outFile.Close()
reader := &cipher.StreamReader{stream, inFile}
// Copy the input file to the output file, decrypting as we go.
if _, err := io.Copy(outFile, reader); err != nil {
panic(err)
}
// Note that this example is simplistic in that it omits any
// authentication of the encrypted data. It you were actually to use
// StreamReader in this manner, an attacker could flip arbitary bits in
// the output.
}
func ExampleStreamWriter() {
key := []byte("example key 1234")
inFile, err := os.Open("plaintext-file")
if err != nil {
panic(err)
}
defer inFile.Close()
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// If the key is unique for each ciphertext, then it's ok to use a zero
// IV.
var iv [aes.BlockSize]byte
stream := cipher.NewOFB(block, iv[:])
outFile, err := os.OpenFile("encrypted-file", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
panic(err)
}
defer outFile.Close()
writer := &cipher.StreamWriter{stream, outFile, nil}
// Copy the input file to the output file, encrypting as we go.
if _, err := io.Copy(writer, inFile); err != nil {
panic(err)
}
// Note that this example is simplistic in that it omits any
// authentication of the encrypted data. It you were actually to use
// StreamReader in this manner, an attacker could flip arbitary bits in
// the decrypted result.
}

View File

@ -2,13 +2,27 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as /*
// defined in U.S. Federal Information Processing Standards Publication 198. Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as
// An HMAC is a cryptographic hash that uses a key to sign a message. defined in U.S. Federal Information Processing Standards Publication 198.
// The receiver verifies the hash by recomputing it using the same key. An HMAC is a cryptographic hash that uses a key to sign a message.
The receiver verifies the hash by recomputing it using the same key.
Receivers should be careful to use Equal to compare MACs in order to avoid
timing side-channels:
// CheckMAC returns true if messageMAC is a valid HMAC tag for message.
func CheckMAC(message, messageMAC, key []byte) bool {
mac := hmac.New(sha256.New, key)
mac.Write(message)
expectedMAC := mac.Sum(nil)
return hmac.Equal(messageMAC, expectedMAC)
}
*/
package hmac package hmac
import ( import (
"crypto/subtle"
"hash" "hash"
) )
@ -57,7 +71,7 @@ func (h *hmac) BlockSize() int { return h.blocksize }
func (h *hmac) Reset() { func (h *hmac) Reset() {
h.inner.Reset() h.inner.Reset()
h.tmpPad(0x36) h.tmpPad(0x36)
h.inner.Write(h.tmp[0:h.blocksize]) h.inner.Write(h.tmp[:h.blocksize])
} }
// New returns a new HMAC hash using the given hash.Hash type and key. // New returns a new HMAC hash using the given hash.Hash type and key.
@ -78,3 +92,11 @@ func New(h func() hash.Hash, key []byte) hash.Hash {
hm.Reset() hm.Reset()
return hm return hm
} }
// Equal compares two MACs for equality without leaking timing information.
func Equal(mac1, mac2 []byte) bool {
// We don't have to be constant time if the lengths of the MACs are
// different as that suggests that a completely different hash function
// was used.
return len(mac1) == len(mac2) && subtle.ConstantTimeCompare(mac1, mac2) == 1
}

View File

@ -491,3 +491,22 @@ func TestHMAC(t *testing.T) {
} }
} }
} }
func TestEqual(t *testing.T) {
a := []byte("test")
b := []byte("test1")
c := []byte("test2")
if !Equal(b, b) {
t.Error("Equal failed with equal arguments")
}
if Equal(a, b) {
t.Error("Equal accepted a prefix of the second argument")
}
if Equal(b, a) {
t.Error("Equal accepted a prefix of the first argument")
}
if Equal(b, c) {
t.Error("Equal accepted unequal slices")
}
}

View File

@ -203,6 +203,8 @@ func block(dig *digest, p []byte) {
// less code and run 1.3x faster if we take advantage of that. // less code and run 1.3x faster if we take advantage of that.
// My apologies. // My apologies.
X = (*[16]uint32)(unsafe.Pointer(&p[0])) X = (*[16]uint32)(unsafe.Pointer(&p[0]))
} else if uintptr(unsafe.Pointer(&p[0]))&(unsafe.Alignof(uint32(0))-1) == 0 {
X = (*[16]uint32)(unsafe.Pointer(&p[0]))
} else { } else {
X = &xbuf X = &xbuf
j := 0 j := 0

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"testing" "testing"
"unsafe"
) )
type md5Test struct { type md5Test struct {
@ -54,13 +55,19 @@ func TestGolden(t *testing.T) {
for i := 0; i < len(golden); i++ { for i := 0; i < len(golden); i++ {
g := golden[i] g := golden[i]
c := md5.New() c := md5.New()
for j := 0; j < 3; j++ { buf := make([]byte, len(g.in)+4)
for j := 0; j < 3+4; j++ {
if j < 2 { if j < 2 {
io.WriteString(c, g.in) io.WriteString(c, g.in)
} else { } else if j == 2 {
io.WriteString(c, g.in[0:len(g.in)/2]) io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum(nil) c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:]) io.WriteString(c, g.in[len(g.in)/2:])
} else if j > 2 {
// test unaligned write
buf = buf[1:]
copy(buf, g.in)
c.Write(buf[:len(g.in)])
} }
s := fmt.Sprintf("%x", c.Sum(nil)) s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out { if s != g.out {
@ -80,26 +87,45 @@ func ExampleNew() {
} }
var bench = md5.New() var bench = md5.New()
var buf = makeBuf() var buf = make([]byte, 8192+1)
var sum = make([]byte, bench.Size())
func makeBuf() []byte { func benchmarkSize(b *testing.B, size int, unaligned bool) {
b := make([]byte, 8<<10) b.SetBytes(int64(size))
for i := range b { buf := buf
b[i] = byte(i) if unaligned {
if uintptr(unsafe.Pointer(&buf[0]))&(unsafe.Alignof(uint32(0))-1) == 0 {
buf = buf[1:]
} }
return b }
b.ResetTimer()
for i := 0; i < b.N; i++ {
bench.Reset()
bench.Write(buf[:size])
bench.Sum(sum[:0])
}
}
func BenchmarkHash8Bytes(b *testing.B) {
benchmarkSize(b, 8, false)
} }
func BenchmarkHash1K(b *testing.B) { func BenchmarkHash1K(b *testing.B) {
b.SetBytes(1024) benchmarkSize(b, 1024, false)
for i := 0; i < b.N; i++ {
bench.Write(buf[:1024])
}
} }
func BenchmarkHash8K(b *testing.B) { func BenchmarkHash8K(b *testing.B) {
b.SetBytes(int64(len(buf))) benchmarkSize(b, 8192, false)
for i := 0; i < b.N; i++ {
bench.Write(buf)
} }
func BenchmarkHash8BytesUnaligned(b *testing.B) {
benchmarkSize(b, 8, true)
}
func BenchmarkHash1KUnaligned(b *testing.B) {
benchmarkSize(b, 1024, true)
}
func BenchmarkHash8KUnaligned(b *testing.B) {
benchmarkSize(b, 8192, true)
} }

View File

@ -22,6 +22,8 @@ func block(dig *digest, p []byte) {
// less code and run 1.3x faster if we take advantage of that. // less code and run 1.3x faster if we take advantage of that.
// My apologies. // My apologies.
X = (*[16]uint32)(unsafe.Pointer(&p[0])) X = (*[16]uint32)(unsafe.Pointer(&p[0]))
} else if uintptr(unsafe.Pointer(&p[0]))&(unsafe.Alignof(uint32(0))-1) == 0 {
X = (*[16]uint32)(unsafe.Pointer(&p[0]))
} else { } else {
X = &xbuf X = &xbuf
j := 0 j := 0

View File

@ -116,7 +116,7 @@ func BenchmarkRSA2048Decrypt(b *testing.B) {
} }
priv.Precompute() priv.Precompute()
c := fromBase10("1000") c := fromBase10("8472002792838218989464636159316973636630013835787202418124758118372358261975764365740026024610403138425986214991379012696600761514742817632790916315594342398720903716529235119816755589383377471752116975374952783629225022962092351886861518911824745188989071172097120352727368980275252089141512321893536744324822590480751098257559766328893767334861211872318961900897793874075248286439689249972315699410830094164386544311554704755110361048571142336148077772023880664786019636334369759624917224888206329520528064315309519262325023881707530002540634660750469137117568199824615333883758410040459705787022909848740188613313")
b.StartTimer() b.StartTimer()
@ -141,7 +141,7 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) {
} }
priv.Precompute() priv.Precompute()
c := fromBase10("1000") c := fromBase10("8472002792838218989464636159316973636630013835787202418124758118372358261975764365740026024610403138425986214991379012696600761514742817632790916315594342398720903716529235119816755589383377471752116975374952783629225022962092351886861518911824745188989071172097120352727368980275252089141512321893536744324822590480751098257559766328893767334861211872318961900897793874075248286439689249972315699410830094164386544311554704755110361048571142336148077772023880664786019636334369759624917224888206329520528064315309519262325023881707530002540634660750469137117568199824615333883758410040459705787022909848740188613313")
b.StartTimer() b.StartTimer()

View File

@ -81,26 +81,26 @@ func ExampleNew() {
} }
var bench = sha1.New() var bench = sha1.New()
var buf = makeBuf() var buf = make([]byte, 8192)
func makeBuf() []byte { func benchmarkSize(b *testing.B, size int) {
b := make([]byte, 8<<10) b.SetBytes(int64(size))
for i := range b { sum := make([]byte, bench.Size())
b[i] = byte(i) for i := 0; i < b.N; i++ {
bench.Reset()
bench.Write(buf[:size])
bench.Sum(sum[:0])
} }
return b }
func BenchmarkHash8Bytes(b *testing.B) {
benchmarkSize(b, 8)
} }
func BenchmarkHash1K(b *testing.B) { func BenchmarkHash1K(b *testing.B) {
b.SetBytes(1024) benchmarkSize(b, 1024)
for i := 0; i < b.N; i++ {
bench.Write(buf[:1024])
}
} }
func BenchmarkHash8K(b *testing.B) { func BenchmarkHash8K(b *testing.B) {
b.SetBytes(int64(len(buf))) benchmarkSize(b, 8192)
for i := 0; i < b.N; i++ {
bench.Write(buf)
}
} }

View File

@ -16,7 +16,7 @@ const (
) )
func block(dig *digest, p []byte) { func block(dig *digest, p []byte) {
var w [80]uint32 var w [16]uint32
h0, h1, h2, h3, h4 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4] h0, h1, h2, h3, h4 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4]
for len(p) >= chunk { for len(p) >= chunk {
@ -26,42 +26,56 @@ func block(dig *digest, p []byte) {
j := i * 4 j := i * 4
w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3]) w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3])
} }
for i := 16; i < 80; i++ {
tmp := w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]
w[i] = tmp<<1 | tmp>>(32-1)
}
a, b, c, d, e := h0, h1, h2, h3, h4 a, b, c, d, e := h0, h1, h2, h3, h4
// Each of the four 20-iteration rounds // Each of the four 20-iteration rounds
// differs only in the computation of f and // differs only in the computation of f and
// the choice of K (_K0, _K1, etc). // the choice of K (_K0, _K1, etc).
for i := 0; i < 20; i++ { i := 0
for ; i < 16; i++ {
f := b&c | (^b)&d f := b&c | (^b)&d
a5 := a<<5 | a>>(32-5) a5 := a<<5 | a>>(32-5)
b30 := b<<30 | b>>(32-30) b30 := b<<30 | b>>(32-30)
t := a5 + f + e + w[i] + _K0 t := a5 + f + e + w[i&0xf] + _K0
a, b, c, d, e = t, a, b30, c, d a, b, c, d, e = t, a, b30, c, d
} }
for i := 20; i < 40; i++ { for ; i < 20; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = tmp<<1 | tmp>>(32-1)
f := b&c | (^b)&d
a5 := a<<5 | a>>(32-5)
b30 := b<<30 | b>>(32-30)
t := a5 + f + e + w[i&0xf] + _K0
a, b, c, d, e = t, a, b30, c, d
}
for ; i < 40; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = tmp<<1 | tmp>>(32-1)
f := b ^ c ^ d f := b ^ c ^ d
a5 := a<<5 | a>>(32-5) a5 := a<<5 | a>>(32-5)
b30 := b<<30 | b>>(32-30) b30 := b<<30 | b>>(32-30)
t := a5 + f + e + w[i] + _K1 t := a5 + f + e + w[i&0xf] + _K1
a, b, c, d, e = t, a, b30, c, d a, b, c, d, e = t, a, b30, c, d
} }
for i := 40; i < 60; i++ { for ; i < 60; i++ {
f := b&c | b&d | c&d tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = tmp<<1 | tmp>>(32-1)
f := ((b | c) & d) | (b & c)
a5 := a<<5 | a>>(32-5) a5 := a<<5 | a>>(32-5)
b30 := b<<30 | b>>(32-30) b30 := b<<30 | b>>(32-30)
t := a5 + f + e + w[i] + _K2 t := a5 + f + e + w[i&0xf] + _K2
a, b, c, d, e = t, a, b30, c, d a, b, c, d, e = t, a, b30, c, d
} }
for i := 60; i < 80; i++ { for ; i < 80; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = tmp<<1 | tmp>>(32-1)
f := b ^ c ^ d f := b ^ c ^ d
a5 := a<<5 | a>>(32-5) a5 := a<<5 | a>>(32-5)
b30 := b<<30 | b>>(32-30) b30 := b<<30 | b>>(32-30)
t := a5 + f + e + w[i] + _K3 t := a5 + f + e + w[i&0xf] + _K3
a, b, c, d, e = t, a, b30, c, d a, b, c, d, e = t, a, b30, c, d
} }

View File

@ -125,26 +125,26 @@ func TestGolden(t *testing.T) {
} }
var bench = New() var bench = New()
var buf = makeBuf() var buf = make([]byte, 8192)
func makeBuf() []byte { func benchmarkSize(b *testing.B, size int) {
b := make([]byte, 8<<10) b.SetBytes(int64(size))
for i := range b { sum := make([]byte, bench.Size())
b[i] = byte(i) for i := 0; i < b.N; i++ {
bench.Reset()
bench.Write(buf[:size])
bench.Sum(sum[:0])
} }
return b }
func BenchmarkHash8Bytes(b *testing.B) {
benchmarkSize(b, 8)
} }
func BenchmarkHash1K(b *testing.B) { func BenchmarkHash1K(b *testing.B) {
b.SetBytes(1024) benchmarkSize(b, 1024)
for i := 0; i < b.N; i++ {
bench.Write(buf[:1024])
}
} }
func BenchmarkHash8K(b *testing.B) { func BenchmarkHash8K(b *testing.B) {
b.SetBytes(int64(len(buf))) benchmarkSize(b, 8192)
for i := 0; i < b.N; i++ {
bench.Write(buf)
}
} }

View File

@ -125,26 +125,26 @@ func TestGolden(t *testing.T) {
} }
var bench = New() var bench = New()
var buf = makeBuf() var buf = make([]byte, 8192)
func makeBuf() []byte { func benchmarkSize(b *testing.B, size int) {
b := make([]byte, 8<<10) b.SetBytes(int64(size))
for i := range b { sum := make([]byte, bench.Size())
b[i] = byte(i) for i := 0; i < b.N; i++ {
bench.Reset()
bench.Write(buf[:size])
bench.Sum(sum[:0])
} }
return b }
func BenchmarkHash8Bytes(b *testing.B) {
benchmarkSize(b, 8)
} }
func BenchmarkHash1K(b *testing.B) { func BenchmarkHash1K(b *testing.B) {
b.SetBytes(1024) benchmarkSize(b, 1024)
for i := 0; i < b.N; i++ {
bench.Write(buf[:1024])
}
} }
func BenchmarkHash8K(b *testing.B) { func BenchmarkHash8K(b *testing.B) {
b.SetBytes(int64(len(buf))) benchmarkSize(b, 8192)
for i := 0; i < b.N; i++ {
bench.Write(buf)
}
} }

View File

@ -604,9 +604,11 @@ Again:
// sendAlert sends a TLS alert message. // sendAlert sends a TLS alert message.
// c.out.Mutex <= L. // c.out.Mutex <= L.
func (c *Conn) sendAlertLocked(err alert) error { func (c *Conn) sendAlertLocked(err alert) error {
c.tmp[0] = alertLevelError switch err {
if err == alertNoRenegotiation { case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning c.tmp[0] = alertLevelWarning
default:
c.tmp[0] = alertLevelError
} }
c.tmp[1] = byte(err) c.tmp[1] = byte(err)
c.writeRecord(recordTypeAlert, c.tmp[0:2]) c.writeRecord(recordTypeAlert, c.tmp[0:2])

View File

@ -246,15 +246,15 @@ var ecdheAESClientScript = [][]byte{
}, },
{ {
0x16, 0x03, 0x01, 0x00, 0x54, 0x02, 0x00, 0x00, 0x16, 0x03, 0x01, 0x00, 0x54, 0x02, 0x00, 0x00,
0x50, 0x03, 0x01, 0x4f, 0x7f, 0x24, 0x25, 0x10, 0x50, 0x03, 0x01, 0x50, 0x77, 0x31, 0xf7, 0x5b,
0xa8, 0x9d, 0xb1, 0x33, 0xd6, 0x53, 0x81, 0xce, 0xdb, 0x3d, 0x7a, 0x62, 0x76, 0x70, 0x95, 0x33,
0xb0, 0x69, 0xed, 0x1b, 0x9c, 0x5e, 0x40, 0x3a, 0x73, 0x71, 0x13, 0xfe, 0xa3, 0xb1, 0xd8, 0xb3,
0x4d, 0x06, 0xbc, 0xc7, 0x84, 0x51, 0x5a, 0x30, 0x4d, 0x0d, 0xdc, 0xfe, 0x58, 0x6e, 0x6a, 0x3a,
0x40, 0x50, 0x48, 0x20, 0xcd, 0x91, 0x80, 0x08, 0xf9, 0xde, 0xdc, 0x20, 0x8e, 0xfa, 0x3d, 0x60,
0xff, 0x82, 0x38, 0xc6, 0x03, 0x2d, 0x45, 0x4c, 0xd0, 0xda, 0xa4, 0x0e, 0x36, 0xf0, 0xde, 0xb6,
0x91, 0xbb, 0xcc, 0x27, 0x3d, 0x58, 0xff, 0x0d, 0x81, 0xb4, 0x80, 0x5e, 0xf9, 0xd2, 0x4c, 0xec,
0x26, 0x34, 0x7b, 0x48, 0x7a, 0xce, 0x25, 0x20, 0xd1, 0x9c, 0x2a, 0x81, 0xc3, 0x36, 0x0b, 0x0f,
0x90, 0x0f, 0x35, 0x9f, 0xc0, 0x13, 0x00, 0x00, 0x4a, 0x3d, 0xdf, 0x75, 0xc0, 0x13, 0x00, 0x00,
0x08, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x08, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01,
0x02, 0x16, 0x03, 0x01, 0x02, 0x39, 0x0b, 0x00, 0x02, 0x16, 0x03, 0x01, 0x02, 0x39, 0x0b, 0x00,
0x02, 0x35, 0x00, 0x02, 0x32, 0x00, 0x02, 0x2f, 0x02, 0x35, 0x00, 0x02, 0x32, 0x00, 0x02, 0x2f,
@ -329,23 +329,23 @@ var ecdheAESClientScript = [][]byte{
0xbb, 0x77, 0xba, 0xe4, 0x12, 0xbb, 0xf4, 0xc8, 0xbb, 0x77, 0xba, 0xe4, 0x12, 0xbb, 0xf4, 0xc8,
0x5e, 0x9c, 0x81, 0xa8, 0x97, 0x60, 0x4c, 0x16, 0x5e, 0x9c, 0x81, 0xa8, 0x97, 0x60, 0x4c, 0x16,
0x03, 0x01, 0x00, 0x8b, 0x0c, 0x00, 0x00, 0x87, 0x03, 0x01, 0x00, 0x8b, 0x0c, 0x00, 0x00, 0x87,
0x03, 0x00, 0x17, 0x41, 0x04, 0x0b, 0xe5, 0x39, 0x03, 0x00, 0x17, 0x41, 0x04, 0xec, 0x06, 0x1f,
0xde, 0x17, 0x7a, 0xaf, 0x96, 0xd5, 0x16, 0x01, 0xa0, 0x5e, 0x29, 0x49, 0x71, 0x8b, 0x04, 0x9f,
0xa8, 0x06, 0x80, 0x98, 0x75, 0x52, 0x56, 0x92, 0x47, 0x87, 0xb1, 0xcb, 0xae, 0x57, 0x8f, 0xd7,
0x15, 0xf9, 0x8d, 0xc0, 0x98, 0x62, 0xed, 0x54, 0xf6, 0xf8, 0x59, 0x74, 0x64, 0x5d, 0x3a, 0x08,
0xb7, 0xef, 0x03, 0x11, 0x34, 0x82, 0x65, 0xd1, 0xaf, 0x20, 0xc6, 0xd9, 0xfc, 0x5e, 0x36, 0x8b,
0xde, 0x25, 0x15, 0x4c, 0xf3, 0xdf, 0x4d, 0xbd, 0x62, 0x0e, 0xdb, 0xee, 0xd8, 0xcd, 0xef, 0x25,
0x6c, 0xed, 0x3d, 0xd6, 0x04, 0xcc, 0xd1, 0xf7, 0x8a, 0x38, 0x88, 0x2d, 0x5c, 0x71, 0x50, 0x22,
0x6d, 0x32, 0xb1, 0x1c, 0x59, 0xca, 0xfb, 0xbc, 0xda, 0x3f, 0x94, 0x06, 0xc9, 0x68, 0x5b, 0x78,
0x61, 0xeb, 0x4b, 0xe6, 0x00, 0x00, 0x40, 0x3e, 0x3d, 0x95, 0xca, 0x54, 0x44, 0x00, 0x40, 0x36,
0xe6, 0x23, 0x54, 0x61, 0x3f, 0x63, 0x16, 0xeb, 0xcf, 0x10, 0x81, 0xb4, 0x32, 0x45, 0x3c, 0xa5,
0x5c, 0xc3, 0xba, 0x8a, 0x19, 0x13, 0x60, 0x9f, 0x2d, 0x3e, 0xb0, 0xf8, 0xf4, 0x51, 0xf5, 0x28,
0x23, 0xbf, 0x36, 0x1a, 0x32, 0x7a, 0xae, 0x34, 0x09, 0x85, 0x71, 0xa6, 0x79, 0x71, 0x4b, 0x4e,
0x7f, 0x2f, 0x89, 0x85, 0xe1, 0x0e, 0x93, 0xd7, 0xda, 0x32, 0x5a, 0xc7, 0xb3, 0x57, 0xfd, 0xe8,
0xf0, 0xab, 0xa1, 0x0d, 0x54, 0x95, 0x79, 0x0b, 0x12, 0xab, 0xd8, 0x29, 0xfb, 0x8b, 0x43, 0x8f,
0xb4, 0xf1, 0x1c, 0x1d, 0x0f, 0x8c, 0x16, 0xec, 0x7e, 0x27, 0x63, 0x91, 0x84, 0x9c, 0x51, 0x0c,
0x82, 0x60, 0xee, 0xa3, 0x71, 0x2f, 0xaf, 0x3e, 0x26, 0x7e, 0x36, 0x3b, 0x37, 0x8d, 0x8f, 0x9e,
0xf1, 0xbd, 0xb5, 0x1b, 0x7f, 0xe0, 0xd2, 0x16, 0xe2, 0x82, 0x62, 0xbb, 0xe5, 0xdf, 0xfc, 0x16,
0x03, 0x01, 0x00, 0x04, 0x0e, 0x00, 0x00, 0x00, 0x03, 0x01, 0x00, 0x04, 0x0e, 0x00, 0x00, 0x00,
}, },
{ {
@ -359,34 +359,34 @@ var ecdheAESClientScript = [][]byte{
0xe2, 0x32, 0x42, 0xe9, 0x58, 0xb6, 0xd7, 0x49, 0xe2, 0x32, 0x42, 0xe9, 0x58, 0xb6, 0xd7, 0x49,
0xa6, 0xb5, 0x68, 0x1a, 0x41, 0x03, 0x56, 0x6b, 0xa6, 0xb5, 0x68, 0x1a, 0x41, 0x03, 0x56, 0x6b,
0xdc, 0x5a, 0x89, 0x14, 0x03, 0x01, 0x00, 0x01, 0xdc, 0x5a, 0x89, 0x14, 0x03, 0x01, 0x00, 0x01,
0x01, 0x16, 0x03, 0x01, 0x00, 0x30, 0x09, 0xac, 0x01, 0x16, 0x03, 0x01, 0x00, 0x30, 0x9a, 0xaa,
0xbe, 0x94, 0x75, 0x4d, 0x73, 0x45, 0xbd, 0xa8, 0xca, 0x5b, 0x57, 0xae, 0x34, 0x92, 0x80, 0x45,
0x0c, 0xe3, 0x5f, 0x72, 0x0b, 0x40, 0x4f, 0xd0, 0x7f, 0xe6, 0xf9, 0x09, 0x19, 0xd0, 0xf0, 0x1e,
0xd2, 0xcb, 0x16, 0x50, 0xfe, 0xdd, 0x1a, 0x33, 0x4b, 0xc3, 0xda, 0x71, 0xce, 0x34, 0x33, 0x56,
0x5c, 0x18, 0x37, 0x98, 0x42, 0xfc, 0x25, 0x42, 0x9f, 0x20, 0x9f, 0xf9, 0xa8, 0x62, 0x6c, 0x38,
0x33, 0xce, 0x60, 0xcf, 0x8e, 0x95, 0x6e, 0x48, 0x1b, 0x41, 0xf5, 0x54, 0xf2, 0x79, 0x42, 0x6c,
0xed, 0x00, 0x35, 0x50, 0x26, 0x7f, 0xb5, 0x0e, 0xe7, 0xe1, 0xbc, 0x54,
}, },
{ {
0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03, 0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
0x01, 0x00, 0x30, 0xf6, 0x6a, 0xdb, 0x83, 0xd4, 0x01, 0x00, 0x30, 0x62, 0x82, 0x41, 0x75, 0x2b,
0x3c, 0x77, 0x52, 0xad, 0xc0, 0x0f, 0x3a, 0x2c, 0xee, 0x0f, 0xdc, 0x6c, 0x48, 0x5a, 0x63, 0xd6,
0x42, 0xb9, 0x60, 0x4b, 0xb2, 0xf6, 0x84, 0xfd, 0xcb, 0x0a, 0xfd, 0x0a, 0x0e, 0xde, 0x8b, 0x41,
0x4e, 0x96, 0xfc, 0x15, 0xe7, 0x94, 0x25, 0xb0, 0x19, 0x0c, 0x13, 0x6b, 0x12, 0xd1, 0xc2, 0x53,
0x59, 0x9f, 0xdd, 0xb6, 0x58, 0x03, 0x13, 0x8d, 0xeb, 0x1e, 0xf3, 0x7a, 0xbf, 0x23, 0xc5, 0xa6,
0xeb, 0xb0, 0xad, 0x30, 0x31, 0x58, 0x6c, 0xa0, 0x81, 0xa1, 0xdb, 0xab, 0x2f, 0x2c, 0xbc, 0x35,
0x8f, 0x57, 0x50, 0x96, 0x72, 0x83,
}, },
{ {
0x17, 0x03, 0x01, 0x00, 0x20, 0xab, 0x64, 0x3d, 0x17, 0x03, 0x01, 0x00, 0x20, 0xaf, 0x5d, 0x35,
0x79, 0x69, 0x3e, 0xba, 0xc4, 0x24, 0x7b, 0xe5, 0x57, 0x10, 0x60, 0xb3, 0x25, 0x7c, 0x26, 0x0f,
0xe5, 0x23, 0x66, 0x6f, 0x32, 0xdf, 0x50, 0x7c, 0xf3, 0x5e, 0xb3, 0x0d, 0xad, 0x14, 0x53, 0xcc,
0x06, 0x2a, 0x02, 0x82, 0x79, 0x40, 0xdb, 0xb1, 0x0c, 0x08, 0xd9, 0xa2, 0x67, 0xab, 0xf4, 0x03,
0x04, 0xc0, 0x2b, 0xdc, 0x3a, 0x15, 0x03, 0x01, 0x17, 0x20, 0xf1, 0x7e, 0xca, 0x15, 0x03, 0x01,
0x00, 0x20, 0xf8, 0xad, 0xca, 0xd7, 0x96, 0xf0, 0x00, 0x20, 0x30, 0xd0, 0xc1, 0xfb, 0x5f, 0xa6,
0xd6, 0xa3, 0x62, 0xe1, 0x03, 0x44, 0xdb, 0xd0, 0x1b, 0xb4, 0x48, 0xc2, 0x0b, 0x98, 0xa8, 0x88,
0xc9, 0x63, 0x3e, 0x1b, 0x70, 0x41, 0x57, 0x0c, 0x7a, 0xba, 0xdf, 0x36, 0x06, 0xd8, 0xcc, 0xe9,
0xd8, 0x8e, 0x71, 0x49, 0x68, 0xe3, 0x04, 0x53, 0x34, 0xdd, 0x64, 0xc8, 0x73, 0xc5, 0xa2, 0x34,
0x5a, 0xbe, 0x64, 0xb7,
}, },
} }

File diff suppressed because it is too large Load Diff

View File

@ -51,7 +51,7 @@ func TestKeysFromPreMasterSecret(t *testing.T) {
masterSecret := masterFromPreMasterSecret(test.version, in, clientRandom, serverRandom) masterSecret := masterFromPreMasterSecret(test.version, in, clientRandom, serverRandom)
if s := hex.EncodeToString(masterSecret); s != test.masterSecret { if s := hex.EncodeToString(masterSecret); s != test.masterSecret {
t.Errorf("#%d: bad master secret %s, want %s", s, test.masterSecret) t.Errorf("#%d: bad master secret %s, want %s", i, s, test.masterSecret)
continue continue
} }

View File

@ -6,6 +6,8 @@
package tls package tls
import ( import (
"crypto"
"crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
@ -153,30 +155,16 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error)
err = errors.New("crypto/tls: failed to parse key PEM data") err = errors.New("crypto/tls: failed to parse key PEM data")
return return
} }
if keyDERBlock.Type != "CERTIFICATE" { if strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break break
} }
} }
// OpenSSL 0.9.8 generates PKCS#1 private keys by default, while cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
// OpenSSL 1.0.0 generates PKCS#8 keys. We try both. if err != nil {
var key *rsa.PrivateKey
if key, err = x509.ParsePKCS1PrivateKey(keyDERBlock.Bytes); err != nil {
var privKey interface{}
if privKey, err = x509.ParsePKCS8PrivateKey(keyDERBlock.Bytes); err != nil {
err = errors.New("crypto/tls: failed to parse key: " + err.Error())
return return
} }
var ok bool
if key, ok = privKey.(*rsa.PrivateKey); !ok {
err = errors.New("crypto/tls: found non-RSA private key in PKCS#8 wrapping")
return
}
}
cert.PrivateKey = key
// We don't need to parse the public key for TLS, but we so do anyway // We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key. // to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
@ -184,10 +172,54 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error)
return return
} }
if x509Cert.PublicKeyAlgorithm != x509.RSA || x509Cert.PublicKey.(*rsa.PublicKey).N.Cmp(key.PublicKey.N) != 0 { switch pub := x509Cert.PublicKey.(type) {
case *rsa.PublicKey:
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
if !ok {
err = errors.New("crypto/tls: private key type does not match public key type")
return
}
if pub.N.Cmp(priv.N) != 0 {
err = errors.New("crypto/tls: private key does not match public key") err = errors.New("crypto/tls: private key does not match public key")
return return
} }
case *ecdsa.PublicKey:
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
if !ok {
err = errors.New("crypto/tls: private key type does not match public key type")
return
}
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
err = errors.New("crypto/tls: private key does not match public key")
return
}
default:
err = errors.New("crypto/tls: unknown public key algorithm")
return
}
return return
} }
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey:
return key, nil
default:
return nil, errors.New("crypto/tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, nil
}
return nil, errors.New("crypto/tls: failed to parse private key")
}

View File

@ -8,7 +8,7 @@ import (
"testing" "testing"
) )
var certPEM = `-----BEGIN CERTIFICATE----- var rsaCertPEM = `-----BEGIN CERTIFICATE-----
MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
@ -22,7 +22,7 @@ r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
-----END CERTIFICATE----- -----END CERTIFICATE-----
` `
var keyPEM = `-----BEGIN RSA PRIVATE KEY----- var rsaKeyPEM = `-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
@ -33,15 +33,61 @@ D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
-----END RSA PRIVATE KEY----- -----END RSA PRIVATE KEY-----
` `
func TestX509KeyPair(t *testing.T) { var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
_, err := X509KeyPair([]byte(keyPEM+certPEM), []byte(keyPEM+certPEM)) MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
if err != nil { EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
t.Errorf("Failed to load key followed by cert: %s", err) eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
+jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
-----END CERTIFICATE-----
`
var ecdsaKeyPEM = `-----BEGIN EC PARAMETERS-----
BgUrgQQAIw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
-----END EC PRIVATE KEY-----
`
var keyPairTests = []struct {
algo string
cert *string
key *string
}{
{"ECDSA", &ecdsaCertPEM, &ecdsaKeyPEM},
{"RSA", &rsaCertPEM, &rsaKeyPEM},
} }
_, err = X509KeyPair([]byte(certPEM+keyPEM), []byte(certPEM+keyPEM)) func TestX509KeyPair(t *testing.T) {
if err != nil { var pem []byte
t.Errorf("Failed to load cert followed by key: %s", err) for _, test := range keyPairTests {
println(err.Error()) pem = []byte(*test.cert + *test.key)
if _, err := X509KeyPair(pem, pem); err != nil {
t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
}
pem = []byte(*test.key + *test.cert)
if _, err := X509KeyPair(pem, pem); err != nil {
t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
}
}
}
func TestX509MixedKeyPair(t *testing.T) {
if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
t.Error("Load of RSA certificate succeeded with ECDSA private key")
}
if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
t.Error("Load of ECDSA certificate succeeded with RSA private key")
} }
} }

View File

@ -16,13 +16,64 @@ import (
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"errors" "errors"
"io"
"strings" "strings"
) )
// rfc1423Algos represents how to create a block cipher for a decryption mode. type PEMCipher int
// Possible values for the EncryptPEMBlock encryption algorithm.
const (
_ PEMCipher = iota
PEMCipherDES
PEMCipher3DES
PEMCipherAES128
PEMCipherAES192
PEMCipherAES256
)
// rfc1423Algo holds a method for enciphering a PEM block.
type rfc1423Algo struct { type rfc1423Algo struct {
cipherFunc func([]byte) (cipher.Block, error) cipher PEMCipher
name string
cipherFunc func(key []byte) (cipher.Block, error)
keySize int keySize int
blockSize int
}
// rfc1423Algos holds a slice of the possible ways to encrypt a PEM
// block. The ivSize numbers were taken from the OpenSSL source.
var rfc1423Algos = []rfc1423Algo{{
cipher: PEMCipherDES,
name: "DES-CBC",
cipherFunc: des.NewCipher,
keySize: 8,
blockSize: des.BlockSize,
}, {
cipher: PEMCipher3DES,
name: "DES-EDE3-CBC",
cipherFunc: des.NewTripleDESCipher,
keySize: 24,
blockSize: des.BlockSize,
}, {
cipher: PEMCipherAES128,
name: "AES-128-CBC",
cipherFunc: aes.NewCipher,
keySize: 16,
blockSize: aes.BlockSize,
}, {
cipher: PEMCipherAES192,
name: "AES-192-CBC",
cipherFunc: aes.NewCipher,
keySize: 24,
blockSize: aes.BlockSize,
}, {
cipher: PEMCipherAES256,
name: "AES-256-CBC",
cipherFunc: aes.NewCipher,
keySize: 32,
blockSize: aes.BlockSize,
},
} }
// deriveKey uses a key derivation function to stretch the password into a key // deriveKey uses a key derivation function to stretch the password into a key
@ -41,20 +92,9 @@ func (c rfc1423Algo) deriveKey(password, salt []byte) []byte {
digest = hash.Sum(digest[:0]) digest = hash.Sum(digest[:0])
copy(out[i:], digest) copy(out[i:], digest)
} }
return out return out
} }
// rfc1423Algos is a mapping of encryption algorithm to an rfc1423Algo that can
// create block ciphers for that mode.
var rfc1423Algos = map[string]rfc1423Algo{
"DES-CBC": {des.NewCipher, 8},
"DES-EDE3-CBC": {des.NewTripleDESCipher, 24},
"AES-128-CBC": {aes.NewCipher, 16},
"AES-192-CBC": {aes.NewCipher, 24},
"AES-256-CBC": {aes.NewCipher, 32},
}
// IsEncryptedPEMBlock returns if the PEM block is password encrypted. // IsEncryptedPEMBlock returns if the PEM block is password encrypted.
func IsEncryptedPEMBlock(b *pem.Block) bool { func IsEncryptedPEMBlock(b *pem.Block) bool {
_, ok := b.Headers["DEK-Info"] _, ok := b.Headers["DEK-Info"]
@ -81,17 +121,16 @@ func DecryptPEMBlock(b *pem.Block, password []byte) ([]byte, error) {
} }
mode, hexIV := dek[:idx], dek[idx+1:] mode, hexIV := dek[:idx], dek[idx+1:]
ciph := cipherByName(mode)
if ciph == nil {
return nil, errors.New("x509: unknown encryption mode")
}
iv, err := hex.DecodeString(hexIV) iv, err := hex.DecodeString(hexIV)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(iv) < 8 { if len(iv) != ciph.blockSize {
return nil, errors.New("x509: not enough bytes in IV") return nil, errors.New("x509: incorrect IV size")
}
ciph, ok := rfc1423Algos[mode]
if !ok {
return nil, errors.New("x509: unknown encryption mode")
} }
// Based on the OpenSSL implementation. The salt is the first 8 bytes // Based on the OpenSSL implementation. The salt is the first 8 bytes
@ -107,27 +146,88 @@ func DecryptPEMBlock(b *pem.Block, password []byte) ([]byte, error) {
dec.CryptBlocks(data, b.Bytes) dec.CryptBlocks(data, b.Bytes)
// Blocks are padded using a scheme where the last n bytes of padding are all // Blocks are padded using a scheme where the last n bytes of padding are all
// equal to n. It can pad from 1 to 8 bytes inclusive. See RFC 1423. // equal to n. It can pad from 1 to blocksize bytes inclusive. See RFC 1423.
// For example: // For example:
// [x y z 2 2] // [x y z 2 2]
// [x y 7 7 7 7 7 7 7] // [x y 7 7 7 7 7 7 7]
// If we detect a bad padding, we assume it is an invalid password. // If we detect a bad padding, we assume it is an invalid password.
dlen := len(data) dlen := len(data)
if dlen == 0 { if dlen == 0 || dlen%ciph.blockSize != 0 {
return nil, errors.New("x509: invalid padding") return nil, errors.New("x509: invalid padding")
} }
last := data[dlen-1] last := int(data[dlen-1])
if dlen < int(last) { if dlen < last {
return nil, IncorrectPasswordError return nil, IncorrectPasswordError
} }
if last == 0 || last > 8 { if last == 0 || last > ciph.blockSize {
return nil, IncorrectPasswordError return nil, IncorrectPasswordError
} }
for _, val := range data[dlen-int(last):] { for _, val := range data[dlen-last:] {
if val != last { if int(val) != last {
return nil, IncorrectPasswordError return nil, IncorrectPasswordError
} }
} }
return data[:dlen-last], nil
}
return data[:dlen-int(last)], nil // EncryptPEMBlock returns a PEM block of the specified type holding the
// given DER-encoded data encrypted with the specified algorithm and
// password.
func EncryptPEMBlock(rand io.Reader, blockType string, data, password []byte, alg PEMCipher) (*pem.Block, error) {
ciph := cipherByKey(alg)
if ciph == nil {
return nil, errors.New("x509: unknown encryption mode")
}
iv := make([]byte, ciph.blockSize)
if _, err := io.ReadFull(rand, iv); err != nil {
return nil, errors.New("x509: cannot generate IV: " + err.Error())
}
// The salt is the first 8 bytes of the initialization vector,
// matching the key derivation in DecryptPEMBlock.
key := ciph.deriveKey(password, iv[:8])
block, err := ciph.cipherFunc(key)
if err != nil {
return nil, err
}
enc := cipher.NewCBCEncrypter(block, iv)
pad := ciph.blockSize - len(data)%ciph.blockSize
encrypted := make([]byte, len(data), len(data)+pad)
// We could save this copy by encrypting all the whole blocks in
// the data separately, but it doesn't seem worth the additional
// code.
copy(encrypted, data)
// See RFC 1423, section 1.1
for i := 0; i < pad; i++ {
encrypted = append(encrypted, byte(pad))
}
enc.CryptBlocks(encrypted, encrypted)
return &pem.Block{
Type: blockType,
Headers: map[string]string{
"Proc-Type": "4,ENCRYPTED",
"DEK-Info": ciph.name + "," + hex.EncodeToString(iv),
},
Bytes: encrypted,
}, nil
}
func cipherByName(name string) *rfc1423Algo {
for i := range rfc1423Algos {
alg := &rfc1423Algos[i]
if alg.name == name {
return alg
}
}
return nil
}
func cipherByKey(key PEMCipher) *rfc1423Algo {
for i := range rfc1423Algos {
alg := &rfc1423Algos[i]
if alg.cipher == key {
return alg
}
}
return nil
} }

View File

@ -5,34 +5,79 @@
package x509 package x509
import ( import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/pem" "encoding/pem"
"testing" "testing"
) )
func TestDecrypt(t *testing.T) { func TestDecrypt(t *testing.T) {
for _, data := range testData { for i, data := range testData {
t.Logf("test %d. %s", i, data.kind)
block, rest := pem.Decode(data.pemData) block, rest := pem.Decode(data.pemData)
if len(rest) > 0 { if len(rest) > 0 {
t.Error(data.kind, "extra data") t.Error("extra data")
} }
der, err := DecryptPEMBlock(block, data.password) der, err := DecryptPEMBlock(block, data.password)
if err != nil { if err != nil {
t.Error(data.kind, err) t.Error("decrypt failed: ", err)
continue continue
} }
if _, err := ParsePKCS1PrivateKey(der); err != nil { if _, err := ParsePKCS1PrivateKey(der); err != nil {
t.Error(data.kind, "Invalid private key") t.Error("invalid private key: ", err)
}
plainDER, err := base64.StdEncoding.DecodeString(data.plainDER)
if err != nil {
t.Fatal("cannot decode test DER data: ", err)
}
if !bytes.Equal(der, plainDER) {
t.Error("data mismatch")
}
}
}
func TestEncrypt(t *testing.T) {
for i, data := range testData {
t.Logf("test %d. %s", i, data.kind)
plainDER, err := base64.StdEncoding.DecodeString(data.plainDER)
if err != nil {
t.Fatal("cannot decode test DER data: ", err)
}
password := []byte("kremvax1")
block, err := EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", plainDER, password, data.kind)
if err != nil {
t.Error("encrypt: ", err)
continue
}
if !IsEncryptedPEMBlock(block) {
t.Error("PEM block does not appear to be encrypted")
}
if block.Type != "RSA PRIVATE KEY" {
t.Errorf("unexpected block type; got %q want %q", block.Type, "RSA PRIVATE KEY")
}
if block.Headers["Proc-Type"] != "4,ENCRYPTED" {
t.Errorf("block does not have correct Proc-Type header")
}
der, err := DecryptPEMBlock(block, password)
if err != nil {
t.Error("decrypt: ", err)
continue
}
if !bytes.Equal(der, plainDER) {
t.Errorf("data mismatch")
} }
} }
} }
var testData = []struct { var testData = []struct {
kind string kind PEMCipher
password []byte password []byte
pemData []byte pemData []byte
plainDER string
}{ }{
{ {
kind: "DES-CBC", kind: PEMCipherDES,
password: []byte("asdf"), password: []byte("asdf"),
pemData: []byte(` pemData: []byte(`
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
@ -47,9 +92,17 @@ XOH9VfTjb52q/I8Suozq9coVQwg4tXfIoYUdT//O+mB7zJb9HI9Ps77b9TxDE6Gm
4C9brwZ3zg2vqXcwwV6QRZMtyll9rOpxkbw6NPlpfBqkc3xS51bbxivbO/Nve4KD 4C9brwZ3zg2vqXcwwV6QRZMtyll9rOpxkbw6NPlpfBqkc3xS51bbxivbO/Nve4KD
r12ymjFNF4stXCfJnNqKoZ50BHmEEUDu5Wb0fpVn82XrGw7CYc4iug== r12ymjFNF4stXCfJnNqKoZ50BHmEEUDu5Wb0fpVn82XrGw7CYc4iug==
-----END RSA PRIVATE KEY-----`), -----END RSA PRIVATE KEY-----`),
plainDER: `
MIIBPAIBAAJBAPASZe+tCPU6p80AjHhDkVsLYa51D35e/YGa8QcZyooeZM8EHozo
KD0fNiKI+53bHdy07N+81VQ8/ejPcRoXPlsCAwEAAQJBAMTxIuSq27VpR+zZ7WJf
c6fvv1OBvpMZ0/d1pxL/KnOAgq2rD5hDtk9b0LGhTPgQAmrrMTKuSeGoIuYE+gKQ
QvkCIQD+GC1m+/do+QRurr0uo46Kx1LzLeSCrjBk34wiOp2+dwIhAPHfTLRXS2fv
7rljm0bYa4+eDZpz+E8RcXEgzhhvcQQ9AiAI5eHZJGOyml3MXnQjiPi55WcDOw0w
glcRgT6QCEtz2wIhANSyqaFtosIkHKqrDUGfz/bb5tqMYTAnBruVPaf/WEOBAiEA
9xORWeRG1tRpso4+dYy4KdDkuLPIO01KY6neYGm3BCM=`,
}, },
{ {
kind: "DES-EDE3-CBC", kind: PEMCipher3DES,
password: []byte("asdf"), password: []byte("asdf"),
pemData: []byte(` pemData: []byte(`
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
@ -64,9 +117,17 @@ ldw5w7WC7d13x2LsRkwo8ZrDKgIV+Y9GNvhuCCkTzNP0V3gNeJpd201HZHR+9n3w
3z0VjR/MGqsfcy1ziEWMNOO53At3zlG6zP05aHMnMcZoVXadEK6L1gz++inSSDCq 3z0VjR/MGqsfcy1ziEWMNOO53At3zlG6zP05aHMnMcZoVXadEK6L1gz++inSSDCq
gI0UJP4e3JVB7AkgYymYAwiYALAkoEIuanxoc50njJk= gI0UJP4e3JVB7AkgYymYAwiYALAkoEIuanxoc50njJk=
-----END RSA PRIVATE KEY-----`), -----END RSA PRIVATE KEY-----`),
plainDER: `
MIIBOwIBAAJBANOCXKdoNS/iP/MAbl9cf1/SF3P+Ns7ZeNL27CfmDh0O6Zduaax5
NBiumd2PmjkaCu7lQ5JOibHfWn+xJsc3kw0CAwEAAQJANX/W8d1Q/sCqzkuAn4xl
B5a7qfJWaLHndu1QRLNTRJPn0Ee7OKJ4H0QKOhQM6vpjRrz+P2u9thn6wUxoPsef
QQIhAP/jCkfejFcy4v15beqKzwz08/tslVjF+Yq41eJGejmxAiEA05pMoqfkyjcx
fyvGhpoOyoCp71vSGUfR2I9CR65oKh0CIC1Msjs66LlfJtQctRq6bCEtFCxEcsP+
eEjYo/Sk6WphAiEAxpgWPMJeU/shFT28gS+tmhjPZLpEoT1qkVlC14u0b3ECIQDX
tZZZxCtPAm7shftEib0VU77Lk8MsXJcx2C4voRsjEw==`,
}, },
{ {
kind: "AES-128-CBC", kind: PEMCipherAES128,
password: []byte("asdf"), password: []byte("asdf"),
pemData: []byte(` pemData: []byte(`
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
@ -81,9 +142,17 @@ GZbBpf1jDH/pr0iGonuAdl2PCCZUiy+8eLsD2tyviHUkFLOB+ykYoJ5t8ngZ/B6D
080LzLHPCrXKdlr/f50yhNWq08ZxMWQFkui+FDHPDUaEELKAXV8/5PDxw80Rtybo 080LzLHPCrXKdlr/f50yhNWq08ZxMWQFkui+FDHPDUaEELKAXV8/5PDxw80Rtybo
AVYoCVIbZXZCuCO81op8UcOgEpTtyU5Lgh3Mw5scQL0= AVYoCVIbZXZCuCO81op8UcOgEpTtyU5Lgh3Mw5scQL0=
-----END RSA PRIVATE KEY-----`), -----END RSA PRIVATE KEY-----`),
plainDER: `
MIIBOgIBAAJBAMBlj5FxYtqbcy8wY89d/S7n0+r5MzD9F63BA/Lpl78vQKtdJ5dT
cDGh/rBt1ufRrNp0WihcmZi7Mpl/3jHjiWECAwEAAQJABNOHYnKhtDIqFYj1OAJ3
k3GlU0OlERmIOoeY/cL2V4lgwllPBEs7r134AY4wMmZSBUj8UR/O4SNO668ElKPE
cQIhAOuqY7/115x5KCdGDMWi+jNaMxIvI4ETGwV40ykGzqlzAiEA0P9oEC3m9tHB
kbpjSTxaNkrXxDgdEOZz8X0uOUUwHNsCIAwzcSCiGLyYJTULUmP1ESERfW1mlV78
XzzESaJpIM/zAiBQkSTcl9VhcJreQqvjn5BnPZLP4ZHS4gPwJAGdsj5J4QIhAOVR
B3WlRNTXR2WsJ5JdByezg9xzdXzULqmga0OE339a`,
}, },
{ {
kind: "AES-192-CBC", kind: PEMCipherAES192,
password: []byte("asdf"), password: []byte("asdf"),
pemData: []byte(` pemData: []byte(`
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
@ -98,9 +167,17 @@ ReUtTw8exmKsY4gsSjhkg5uiw7/ZB1Ihto0qnfQJgjGc680qGkT1d6JfvOfeYAk6
xn5RqS/h8rYAYm64KnepfC9vIujo4NqpaREDmaLdX5MJPQ+SlytITQvgUsUq3q/t xn5RqS/h8rYAYm64KnepfC9vIujo4NqpaREDmaLdX5MJPQ+SlytITQvgUsUq3q/t
Ss85xjQEZH3hzwjQqdJvmA4hYP6SUjxYpBM+02xZ1Xw= Ss85xjQEZH3hzwjQqdJvmA4hYP6SUjxYpBM+02xZ1Xw=
-----END RSA PRIVATE KEY-----`), -----END RSA PRIVATE KEY-----`),
plainDER: `
MIIBOwIBAAJBAMGcRrZiNNmtF20zyS6MQ7pdGx17aFDl+lTl+qnLuJRUCMUG05xs
OmxmL/O1Qlf+bnqR8Bgg65SfKg21SYuLhiMCAwEAAQJBAL94uuHyO4wux2VC+qpj
IzPykjdU7XRcDHbbvksf4xokSeUFjjD3PB0Qa83M94y89ZfdILIqS9x5EgSB4/lX
qNkCIQD6cCIqLfzq/lYbZbQgAAjpBXeQVYsbvVtJrPrXJAlVVQIhAMXpDKMeFPMn
J0g2rbx1gngx0qOa5r5iMU5w/noN4W2XAiBjf+WzCG5yFvazD+dOx3TC0A8+4x3P
uZ3pWbaXf5PNuQIgAcdXarvhelH2w2piY1g3BPeFqhzBSCK/yLGxR82KIh8CIQDD
+qGKsd09NhQ/G27y/DARzOYtml1NvdmCQAgsDIIOLA==`,
}, },
{ {
kind: "AES-256-CBC", kind: PEMCipherAES256,
password: []byte("asdf"), password: []byte("asdf"),
pemData: []byte(` pemData: []byte(`
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
@ -115,5 +192,32 @@ Pz3RZScwIuubzTGJ1x8EzdffYOsdCa9Mtgpp3L136+23dOd6L/qK2EG2fzrJSHs/
sv5Z/KwlX+3MDEpPQpUwGPlGGdLnjI3UZ+cjgqBcoMiNc6HfgbBgYJSU6aDSHuCk sv5Z/KwlX+3MDEpPQpUwGPlGGdLnjI3UZ+cjgqBcoMiNc6HfgbBgYJSU6aDSHuCk
clCwByxWkBNgJ2GrkwNrF26v+bGJJJNR4SKouY1jQf0= clCwByxWkBNgJ2GrkwNrF26v+bGJJJNR4SKouY1jQf0=
-----END RSA PRIVATE KEY-----`), -----END RSA PRIVATE KEY-----`),
plainDER: `
MIIBOgIBAAJBAKy3GFkstoCHIEeUU/qO8207m8WSrjksR+p9B4tf1w5k+2O1V/GY
AQ5WFCApItcOkQe/I0yZZJk/PmCqMzSxrc8CAwEAAQJAOCAz0F7AW9oNelVQSP8F
Sfzx7O1yom+qWyAQQJF/gFR11gpf9xpVnnyu1WxIRnDUh1LZwUsjwlDYb7MB74id
oQIhANPcOiLwOPT4sIUpRM5HG6BF1BI7L77VpyGVk8xNP7X/AiEA0LMHZtk4I+lJ
nClgYp4Yh2JZ1Znbu7IoQMCEJCjwKDECIGd8Dzm5tViTkUW6Hs3Tlf73nNs65duF
aRnSglss8I3pAiEAonEnKruawgD8RavDFR+fUgmQiPz4FnGGeVgfwpGG1JECIBYq
PXHYtPqxQIbD2pScR5qum7iGUh11lEUPkmt+2uqS`,
},
{
// generated with:
// openssl genrsa -aes128 -passout pass:asdf -out server.orig.key 128
kind: PEMCipherAES128,
password: []byte("asdf"),
pemData: []byte(`
-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,74611ABC2571AF11B1BF9B69E62C89E7
6ei/MlytjE0FFgZOGQ+jrwomKfpl8kdefeE0NSt/DMRrw8OacHAzBNi3pPEa0eX3
eND9l7C9meCirWovjj9QWVHrXyugFuDIqgdhQ8iHTgCfF3lrmcttVrbIfMDw+smD
hTP8O1mS/MHl92NE0nhv0w==
-----END RSA PRIVATE KEY-----`),
plainDER: `
MGMCAQACEQC6ssxmYuauuHGOCDAI54RdAgMBAAECEQCWIn6Yv2O+kBcDF7STctKB
AgkA8SEfu/2i3g0CCQDGNlXbBHX7kQIIK3Ww5o0cYbECCQDCimPb0dYGsQIIeQ7A
jryIst8=`,
}, },
} }

View File

@ -12,7 +12,8 @@ import (
) )
// pkcs8 reflects an ASN.1, PKCS#8 PrivateKey. See // pkcs8 reflects an ASN.1, PKCS#8 PrivateKey. See
// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn. // ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn
// and RFC5208.
type pkcs8 struct { type pkcs8 struct {
Version int Version int
Algo pkix.AlgorithmIdentifier Algo pkix.AlgorithmIdentifier
@ -21,7 +22,7 @@ type pkcs8 struct {
} }
// ParsePKCS8PrivateKey parses an unencrypted, PKCS#8 private key. See // ParsePKCS8PrivateKey parses an unencrypted, PKCS#8 private key. See
// http://www.rsa.com/rsalabs/node.asp?id=2130 // http://www.rsa.com/rsalabs/node.asp?id=2130 and RFC5208.
func ParsePKCS8PrivateKey(der []byte) (key interface{}, err error) { func ParsePKCS8PrivateKey(der []byte) (key interface{}, err error) {
var privKey pkcs8 var privKey pkcs8
if _, err := asn1.Unmarshal(der, &privKey); err != nil { if _, err := asn1.Unmarshal(der, &privKey); err != nil {
@ -34,6 +35,19 @@ func ParsePKCS8PrivateKey(der []byte) (key interface{}, err error) {
return nil, errors.New("crypto/x509: failed to parse RSA private key embedded in PKCS#8: " + err.Error()) return nil, errors.New("crypto/x509: failed to parse RSA private key embedded in PKCS#8: " + err.Error())
} }
return key, nil return key, nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyECDSA):
bytes := privKey.Algo.Parameters.FullBytes
namedCurveOID := new(asn1.ObjectIdentifier)
if _, err := asn1.Unmarshal(bytes, namedCurveOID); err != nil {
namedCurveOID = nil
}
key, err = parseECPrivateKey(namedCurveOID, privKey.PrivateKey)
if err != nil {
return nil, errors.New("crypto/x509: failed to parse EC private key embedded in PKCS#8: " + err.Error())
}
return key, nil
default: default:
return nil, fmt.Errorf("crypto/x509: PKCS#8 wrapping contained private key with unknown algorithm: %v", privKey.Algo.Algorithm) return nil, fmt.Errorf("crypto/x509: PKCS#8 wrapping contained private key with unknown algorithm: %v", privKey.Algo.Algorithm)
} }

View File

@ -9,12 +9,20 @@ import (
"testing" "testing"
) )
var pkcs8PrivateKeyHex = `30820278020100300d06092a864886f70d0101010500048202623082025e02010002818100cfb1b5bf9685ffa97b4f99df4ff122b70e59ac9b992f3bc2b3dde17d53c1a34928719b02e8fd17839499bfbd515bd6ef99c7a1c47a239718fe36bfd824c0d96060084b5f67f0273443007a24dfaf5634f7772c9346e10eb294c2306671a5a5e719ae24b4de467291bc571014b0e02dec04534d66a9bb171d644b66b091780e8d020301000102818100b595778383c4afdbab95d2bfed12b3f93bb0a73a7ad952f44d7185fd9ec6c34de8f03a48770f2009c8580bcd275e9632714e9a5e3f32f29dc55474b2329ff0ebc08b3ffcb35bc96e6516b483df80a4a59cceb71918cbabf91564e64a39d7e35dce21cb3031824fdbc845dba6458852ec16af5dddf51a8397a8797ae0337b1439024100ea0eb1b914158c70db39031dd8904d6f18f408c85fbbc592d7d20dee7986969efbda081fdf8bc40e1b1336d6b638110c836bfdc3f314560d2e49cd4fbde1e20b024100e32a4e793b574c9c4a94c8803db5152141e72d03de64e54ef2c8ed104988ca780cd11397bc359630d01b97ebd87067c5451ba777cf045ca23f5912f1031308c702406dfcdbbd5a57c9f85abc4edf9e9e29153507b07ce0a7ef6f52e60dcfebe1b8341babd8b789a837485da6c8d55b29bbb142ace3c24a1f5b54b454d01b51e2ad03024100bd6a2b60dee01e1b3bfcef6a2f09ed027c273cdbbaf6ba55a80f6dcc64e4509ee560f84b4f3e076bd03b11e42fe71a3fdd2dffe7e0902c8584f8cad877cdc945024100aa512fa4ada69881f1d8bb8ad6614f192b83200aef5edf4811313d5ef30a86cbd0a90f7b025c71ea06ec6b34db6306c86b1040670fd8654ad7291d066d06d031` var pkcs8RSAPrivateKeyHex = `30820278020100300d06092a864886f70d0101010500048202623082025e02010002818100cfb1b5bf9685ffa97b4f99df4ff122b70e59ac9b992f3bc2b3dde17d53c1a34928719b02e8fd17839499bfbd515bd6ef99c7a1c47a239718fe36bfd824c0d96060084b5f67f0273443007a24dfaf5634f7772c9346e10eb294c2306671a5a5e719ae24b4de467291bc571014b0e02dec04534d66a9bb171d644b66b091780e8d020301000102818100b595778383c4afdbab95d2bfed12b3f93bb0a73a7ad952f44d7185fd9ec6c34de8f03a48770f2009c8580bcd275e9632714e9a5e3f32f29dc55474b2329ff0ebc08b3ffcb35bc96e6516b483df80a4a59cceb71918cbabf91564e64a39d7e35dce21cb3031824fdbc845dba6458852ec16af5dddf51a8397a8797ae0337b1439024100ea0eb1b914158c70db39031dd8904d6f18f408c85fbbc592d7d20dee7986969efbda081fdf8bc40e1b1336d6b638110c836bfdc3f314560d2e49cd4fbde1e20b024100e32a4e793b574c9c4a94c8803db5152141e72d03de64e54ef2c8ed104988ca780cd11397bc359630d01b97ebd87067c5451ba777cf045ca23f5912f1031308c702406dfcdbbd5a57c9f85abc4edf9e9e29153507b07ce0a7ef6f52e60dcfebe1b8341babd8b789a837485da6c8d55b29bbb142ace3c24a1f5b54b454d01b51e2ad03024100bd6a2b60dee01e1b3bfcef6a2f09ed027c273cdbbaf6ba55a80f6dcc64e4509ee560f84b4f3e076bd03b11e42fe71a3fdd2dffe7e0902c8584f8cad877cdc945024100aa512fa4ada69881f1d8bb8ad6614f192b83200aef5edf4811313d5ef30a86cbd0a90f7b025c71ea06ec6b34db6306c86b1040670fd8654ad7291d066d06d031`
// Generated using:
// openssl ecparam -genkey -name secp521r1 | openssl pkcs8 -topk8 -nocrypt
var pkcs8ECPrivateKeyHex = `3081ed020100301006072a8648ce3d020106052b810400230481d53081d20201010441850d81618c5da1aec74c2eed608ba816038506975e6427237c2def150c96a3b13efbfa1f89f1be15cdf4d0ac26422e680e65a0ddd4ad3541ad76165fbf54d6e34ba18189038186000400da97bcedba1eb6d30aeb93c9f9a1454598fa47278df27d6f60ea73eb672d8dc528a9b67885b5b5dcef93c9824f7449ab512ee6a27e76142f56b94b474cfd697e810046c8ca70419365245c1d7d44d0db82c334073835d002232714548abbae6e5700f5ef315ee08b929d8581383dcf2d1c98c2f8a9fccbf79c9579f7b2fd8a90115ac2`
func TestPKCS8(t *testing.T) { func TestPKCS8(t *testing.T) {
derBytes, _ := hex.DecodeString(pkcs8PrivateKeyHex) derBytes, _ := hex.DecodeString(pkcs8RSAPrivateKeyHex)
_, err := ParsePKCS8PrivateKey(derBytes) if _, err := ParsePKCS8PrivateKey(derBytes); err != nil {
if err != nil { t.Errorf("failed to decode PKCS8 with RSA private key: %s", err)
t.Errorf("failed to decode PKCS8 key: %s", err) }
derBytes, _ = hex.DecodeString(pkcs8ECPrivateKeyHex)
if _, err := ParsePKCS8PrivateKey(derBytes); err != nil {
t.Errorf("failed to decode PKCS8 with EC private key: %s", err)
} }
} }

View File

@ -0,0 +1,69 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/asn1"
"errors"
"fmt"
"math/big"
)
const ecPrivKeyVersion = 1
// ecPrivateKey reflects an ASN.1 Elliptic Curve Private Key Structure.
// References:
// RFC5915
// SEC1 - http://www.secg.org/download/aid-780/sec1-v2.pdf
// Per RFC5915 the NamedCurveOID is marked as ASN.1 OPTIONAL, however in
// most cases it is not.
type ecPrivateKey struct {
Version int
PrivateKey []byte
NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"`
}
// ParseECPrivateKey parses an ASN.1 Elliptic Curve Private Key Structure.
func ParseECPrivateKey(der []byte) (key *ecdsa.PrivateKey, err error) {
return parseECPrivateKey(nil, der)
}
// parseECPrivateKey parses an ASN.1 Elliptic Curve Private Key Structure.
// The OID for the named curve may be provided from another source (such as
// the PKCS8 container) - if it is provided then use this instead of the OID
// that may exist in the EC private key structure.
func parseECPrivateKey(namedCurveOID *asn1.ObjectIdentifier, der []byte) (key *ecdsa.PrivateKey, err error) {
var privKey ecPrivateKey
if _, err := asn1.Unmarshal(der, &privKey); err != nil {
return nil, errors.New("crypto/x509: failed to parse EC private key: " + err.Error())
}
if privKey.Version != ecPrivKeyVersion {
return nil, fmt.Errorf("crypto/x509: unknown EC private key version %d", privKey.Version)
}
var curve elliptic.Curve
if namedCurveOID != nil {
curve = namedCurveFromOID(*namedCurveOID)
} else {
curve = namedCurveFromOID(privKey.NamedCurveOID)
}
if curve == nil {
return nil, errors.New("crypto/x509: unknown elliptic curve")
}
k := new(big.Int).SetBytes(privKey.PrivateKey)
if k.Cmp(curve.Params().N) >= 0 {
return nil, errors.New("crypto/x509: invalid elliptic curve private key value")
}
priv := new(ecdsa.PrivateKey)
priv.Curve = curve
priv.D = k
priv.X, priv.Y = curve.ScalarBaseMult(privKey.PrivateKey)
return priv, nil
}

View File

@ -0,0 +1,22 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"encoding/hex"
"testing"
)
// Generated using:
// openssl ecparam -genkey -name secp384r1 -outform PEM
var ecPrivateKeyHex = `3081a40201010430bdb9839c08ee793d1157886a7a758a3c8b2a17a4df48f17ace57c72c56b4723cf21dcda21d4e1ad57ff034f19fcfd98ea00706052b81040022a16403620004feea808b5ee2429cfcce13c32160e1c960990bd050bb0fdf7222f3decd0a55008e32a6aa3c9062051c4cba92a7a3b178b24567412d43cdd2f882fa5addddd726fe3e208d2c26d733a773a597abb749714df7256ead5105fa6e7b3650de236b50`
func TestParseECPrivateKey(t *testing.T) {
derBytes, _ := hex.DecodeString(ecPrivateKeyHex)
_, err := ParseECPrivateKey(derBytes)
if err != nil {
t.Errorf("failed to decode EC private key: %s", err)
}
}

View File

@ -1224,7 +1224,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub interf
SerialNumber: template.SerialNumber, SerialNumber: template.SerialNumber,
SignatureAlgorithm: signatureAlgorithm, SignatureAlgorithm: signatureAlgorithm,
Issuer: asn1.RawValue{FullBytes: asn1Issuer}, Issuer: asn1.RawValue{FullBytes: asn1Issuer},
Validity: validity{template.NotBefore, template.NotAfter}, Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()},
Subject: asn1.RawValue{FullBytes: asn1Subject}, Subject: asn1.RawValue{FullBytes: asn1Subject},
PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey}, PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey},
Extensions: extensions, Extensions: extensions,
@ -1314,8 +1314,8 @@ func (c *Certificate) CreateCRL(rand io.Reader, priv interface{}, revokedCerts [
Algorithm: oidSignatureSHA1WithRSA, Algorithm: oidSignatureSHA1WithRSA,
}, },
Issuer: c.Subject.ToRDNSequence(), Issuer: c.Subject.ToRDNSequence(),
ThisUpdate: now, ThisUpdate: now.UTC(),
NextUpdate: expiry, NextUpdate: expiry.UTC(),
RevokedCertificates: revokedCerts, RevokedCertificates: revokedCerts,
} }

View File

@ -417,10 +417,6 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, error) {
return nil, nil, errors.New("cannot load string table section") return nil, nil, errors.New("cannot load string table section")
} }
// The first entry is all zeros.
var skip [Sym32Size]byte
symtab.Read(skip[0:])
symbols := make([]Symbol, symtab.Len()/Sym32Size) symbols := make([]Symbol, symtab.Len()/Sym32Size)
i := 0 i := 0
@ -460,10 +456,6 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, error) {
return nil, nil, errors.New("cannot load string table section") return nil, nil, errors.New("cannot load string table section")
} }
// The first entry is all zeros.
var skip [Sym64Size]byte
symtab.Read(skip[0:])
symbols := make([]Symbol, symtab.Len()/Sym64Size) symbols := make([]Symbol, symtab.Len()/Sym64Size)
i := 0 i := 0
@ -708,8 +700,8 @@ func (f *File) gnuVersionInit(str []byte) {
// gnuVersion adds Library and Version information to sym, // gnuVersion adds Library and Version information to sym,
// which came from offset i of the symbol table. // which came from offset i of the symbol table.
func (f *File) gnuVersion(i int, sym *ImportedSymbol) { func (f *File) gnuVersion(i int, sym *ImportedSymbol) {
// Each entry is two bytes; skip undef entry at beginning. // Each entry is two bytes.
i = (i + 1) * 2 i = i * 2
if i >= len(f.gnuVersym) { if i >= len(f.gnuVersym) {
return return
} }

View File

@ -175,23 +175,41 @@ func TestOpen(t *testing.T) {
} }
} }
type relocationTestEntry struct {
entryNumber int
entry *dwarf.Entry
}
type relocationTest struct { type relocationTest struct {
file string file string
firstEntry *dwarf.Entry entries []relocationTestEntry
} }
var relocationTests = []relocationTest{ var relocationTests = []relocationTest{
{ {
"testdata/go-relocation-test-gcc441-x86-64.obj", "testdata/go-relocation-test-gcc441-x86-64.obj",
&dwarf.Entry{Offset: 0xb, Tag: dwarf.TagCompileUnit, Children: true, Field: []dwarf.Field{{Attr: dwarf.AttrProducer, Val: "GNU C 4.4.1"}, {Attr: dwarf.AttrLanguage, Val: int64(1)}, {Attr: dwarf.AttrName, Val: "go-relocation-test.c"}, {Attr: dwarf.AttrCompDir, Val: "/tmp"}, {Attr: dwarf.AttrLowpc, Val: uint64(0x0)}, {Attr: dwarf.AttrHighpc, Val: uint64(0x6)}, {Attr: dwarf.AttrStmtList, Val: int64(0)}}}, []relocationTestEntry{
{0, &dwarf.Entry{Offset: 0xb, Tag: dwarf.TagCompileUnit, Children: true, Field: []dwarf.Field{{Attr: dwarf.AttrProducer, Val: "GNU C 4.4.1"}, {Attr: dwarf.AttrLanguage, Val: int64(1)}, {Attr: dwarf.AttrName, Val: "go-relocation-test.c"}, {Attr: dwarf.AttrCompDir, Val: "/tmp"}, {Attr: dwarf.AttrLowpc, Val: uint64(0x0)}, {Attr: dwarf.AttrHighpc, Val: uint64(0x6)}, {Attr: dwarf.AttrStmtList, Val: int64(0)}}}},
},
}, },
{ {
"testdata/go-relocation-test-gcc441-x86.obj", "testdata/go-relocation-test-gcc441-x86.obj",
&dwarf.Entry{Offset: 0xb, Tag: dwarf.TagCompileUnit, Children: true, Field: []dwarf.Field{{Attr: dwarf.AttrProducer, Val: "GNU C 4.4.1"}, {Attr: dwarf.AttrLanguage, Val: int64(1)}, {Attr: dwarf.AttrName, Val: "t.c"}, {Attr: dwarf.AttrCompDir, Val: "/tmp"}, {Attr: dwarf.AttrLowpc, Val: uint64(0x0)}, {Attr: dwarf.AttrHighpc, Val: uint64(0x5)}, {Attr: dwarf.AttrStmtList, Val: int64(0)}}}, []relocationTestEntry{
{0, &dwarf.Entry{Offset: 0xb, Tag: dwarf.TagCompileUnit, Children: true, Field: []dwarf.Field{{Attr: dwarf.AttrProducer, Val: "GNU C 4.4.1"}, {Attr: dwarf.AttrLanguage, Val: int64(1)}, {Attr: dwarf.AttrName, Val: "t.c"}, {Attr: dwarf.AttrCompDir, Val: "/tmp"}, {Attr: dwarf.AttrLowpc, Val: uint64(0x0)}, {Attr: dwarf.AttrHighpc, Val: uint64(0x5)}, {Attr: dwarf.AttrStmtList, Val: int64(0)}}}},
},
}, },
{ {
"testdata/go-relocation-test-gcc424-x86-64.obj", "testdata/go-relocation-test-gcc424-x86-64.obj",
&dwarf.Entry{Offset: 0xb, Tag: dwarf.TagCompileUnit, Children: true, Field: []dwarf.Field{{Attr: dwarf.AttrProducer, Val: "GNU C 4.2.4 (Ubuntu 4.2.4-1ubuntu4)"}, {Attr: dwarf.AttrLanguage, Val: int64(1)}, {Attr: dwarf.AttrName, Val: "go-relocation-test-gcc424.c"}, {Attr: dwarf.AttrCompDir, Val: "/tmp"}, {Attr: dwarf.AttrLowpc, Val: uint64(0x0)}, {Attr: dwarf.AttrHighpc, Val: uint64(0x6)}, {Attr: dwarf.AttrStmtList, Val: int64(0)}}}, []relocationTestEntry{
{0, &dwarf.Entry{Offset: 0xb, Tag: dwarf.TagCompileUnit, Children: true, Field: []dwarf.Field{{Attr: dwarf.AttrProducer, Val: "GNU C 4.2.4 (Ubuntu 4.2.4-1ubuntu4)"}, {Attr: dwarf.AttrLanguage, Val: int64(1)}, {Attr: dwarf.AttrName, Val: "go-relocation-test-gcc424.c"}, {Attr: dwarf.AttrCompDir, Val: "/tmp"}, {Attr: dwarf.AttrLowpc, Val: uint64(0x0)}, {Attr: dwarf.AttrHighpc, Val: uint64(0x6)}, {Attr: dwarf.AttrStmtList, Val: int64(0)}}}},
},
},
{
"testdata/gcc-amd64-openbsd-debug-with-rela.obj",
[]relocationTestEntry{
{203, &dwarf.Entry{Offset: 0xc62, Tag: dwarf.TagMember, Children: false, Field: []dwarf.Field{{Attr: dwarf.AttrName, Val: "it_interval"}, {Attr: dwarf.AttrDeclFile, Val: int64(7)}, {Attr: dwarf.AttrDeclLine, Val: int64(236)}, {Attr: dwarf.AttrType, Val: dwarf.Offset(0xb7f)}, {Attr: dwarf.AttrDataMemberLoc, Val: []byte{0x23, 0x0}}}}},
{204, &dwarf.Entry{Offset: 0xc70, Tag: dwarf.TagMember, Children: false, Field: []dwarf.Field{{Attr: dwarf.AttrName, Val: "it_value"}, {Attr: dwarf.AttrDeclFile, Val: int64(7)}, {Attr: dwarf.AttrDeclLine, Val: int64(237)}, {Attr: dwarf.AttrType, Val: dwarf.Offset(0xb7f)}, {Attr: dwarf.AttrDataMemberLoc, Val: []byte{0x23, 0x10}}}}},
},
}, },
} }
@ -207,23 +225,27 @@ func TestDWARFRelocations(t *testing.T) {
t.Error(err) t.Error(err)
continue continue
} }
for _, testEntry := range test.entries {
reader := dwarf.Reader() reader := dwarf.Reader()
// Checking only the first entry is sufficient since it has for j := 0; j < testEntry.entryNumber; j++ {
// many different strings. If the relocation had failed, all entry, err := reader.Next()
// the string offsets would be zero and all the strings would if entry == nil || err != nil {
// end up being the same. t.Errorf("Failed to skip to entry %d: %v", testEntry.entryNumber, err)
firstEntry, err := reader.Next() continue
}
}
entry, err := reader.Next()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
continue continue
} }
if !reflect.DeepEqual(testEntry.entry, entry) {
if !reflect.DeepEqual(test.firstEntry, firstEntry) { t.Errorf("#%d/%d: mismatch: got:%#v want:%#v", i, testEntry.entryNumber, entry, testEntry.entry)
t.Errorf("#%d: mismatch: got:%#v want:%#v", i, firstEntry, test.firstEntry)
continue continue
} }
} }
} }
}
func TestNoSectionOverlaps(t *testing.T) { func TestNoSectionOverlaps(t *testing.T) {
// Ensure 6l outputs sections without overlaps. // Ensure 6l outputs sections without overlaps.

Binary file not shown.

View File

@ -296,8 +296,7 @@ func marshalTwoDigits(out *forkableWriter, v int) (err error) {
} }
func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
utc := t.UTC() year, month, day := t.Date()
year, month, day := utc.Date()
switch { switch {
case 1950 <= year && year < 2000: case 1950 <= year && year < 2000:
@ -321,7 +320,7 @@ func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
return return
} }
hour, min, sec := utc.Clock() hour, min, sec := t.Clock()
err = marshalTwoDigits(out, hour) err = marshalTwoDigits(out, hour)
if err != nil { if err != nil {

View File

@ -82,7 +82,7 @@ var marshalTests = []marshalTest{
{explicitTagTest{64}, "3005a503020140"}, {explicitTagTest{64}, "3005a503020140"},
{time.Unix(0, 0).UTC(), "170d3730303130313030303030305a"}, {time.Unix(0, 0).UTC(), "170d3730303130313030303030305a"},
{time.Unix(1258325776, 0).UTC(), "170d3039313131353232353631365a"}, {time.Unix(1258325776, 0).UTC(), "170d3039313131353232353631365a"},
{time.Unix(1258325776, 0).In(PST), "17113039313131353232353631362d30383030"}, {time.Unix(1258325776, 0).In(PST), "17113039313131353134353631362d30383030"},
{BitString{[]byte{0x80}, 1}, "03020780"}, {BitString{[]byte{0x80}, 1}, "03020780"},
{BitString{[]byte{0x81, 0xf0}, 12}, "03030481f0"}, {BitString{[]byte{0x81, 0xf0}, 12}, "03030481f0"},
{ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"}, {ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"},

View File

@ -125,6 +125,9 @@ func (bigEndian) GoString() string { return "binary.BigEndian" }
// of fixed-size values. // of fixed-size values.
// Bytes read from r are decoded using the specified byte order // Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data. // and written to successive fields of the data.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
func Read(r io.Reader, order ByteOrder, data interface{}) error { func Read(r io.Reader, order ByteOrder, data interface{}) error {
// Fast path for basic types. // Fast path for basic types.
if n := intDestSize(data); n != 0 { if n := intDestSize(data); n != 0 {
@ -154,7 +157,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error {
return nil return nil
} }
// Fallback to reflect-based. // Fallback to reflect-based decoding.
var v reflect.Value var v reflect.Value
switch d := reflect.ValueOf(data); d.Kind() { switch d := reflect.ValueOf(data); d.Kind() {
case reflect.Ptr: case reflect.Ptr:
@ -181,6 +184,8 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error {
// values, or a pointer to such data. // values, or a pointer to such data.
// Bytes written to w are encoded using the specified byte order // Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data. // and read from successive fields of the data.
// When writing structs, zero values are are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data interface{}) error { func Write(w io.Writer, order ByteOrder, data interface{}) error {
// Fast path for basic types. // Fast path for basic types.
var b [8]byte var b [8]byte
@ -239,6 +244,8 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error {
_, err := w.Write(bs) _, err := w.Write(bs)
return err return err
} }
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data)) v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v) size := dataSize(v)
if size < 0 { if size < 0 {
@ -300,15 +307,13 @@ func sizeof(t reflect.Type) int {
return -1 return -1
} }
type decoder struct { type coder struct {
order ByteOrder order ByteOrder
buf []byte buf []byte
} }
type encoder struct { type decoder coder
order ByteOrder type encoder coder
buf []byte
}
func (d *decoder) uint8() uint8 { func (d *decoder) uint8() uint8 {
x := d.buf[0] x := d.buf[0]
@ -379,9 +384,19 @@ func (d *decoder) value(v reflect.Value) {
} }
case reflect.Struct: case reflect.Struct:
t := v.Type()
l := v.NumField() l := v.NumField()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
d.value(v.Field(i)) // Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
} }
case reflect.Slice: case reflect.Slice:
@ -435,9 +450,15 @@ func (e *encoder) value(v reflect.Value) {
} }
case reflect.Struct: case reflect.Struct:
t := v.Type()
l := v.NumField() l := v.NumField()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
e.value(v.Field(i)) // see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
} }
case reflect.Slice: case reflect.Slice:
@ -492,6 +513,18 @@ func (e *encoder) value(v reflect.Value) {
} }
} }
func (d *decoder) skip(v reflect.Value) {
d.buf = d.buf[dataSize(v):]
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
for i := range e.buf[0:n] {
e.buf[i] = 0
}
e.buf = e.buf[n:]
}
// intDestSize returns the size of the integer that ptrType points to, // intDestSize returns the size of the integer that ptrType points to,
// or 0 if the type is not supported. // or 0 if the type is not supported.
func intDestSize(ptrType interface{}) int { func intDestSize(ptrType interface{}) int {

View File

@ -121,18 +121,14 @@ func testWrite(t *testing.T, order ByteOrder, b []byte, s1 interface{}) {
checkResult(t, "Write", order, err, buf.Bytes(), b) checkResult(t, "Write", order, err, buf.Bytes(), b)
} }
func TestBigEndianRead(t *testing.T) { testRead(t, BigEndian, big, s) }
func TestLittleEndianRead(t *testing.T) { testRead(t, LittleEndian, little, s) } func TestLittleEndianRead(t *testing.T) { testRead(t, LittleEndian, little, s) }
func TestBigEndianWrite(t *testing.T) { testWrite(t, BigEndian, big, s) }
func TestLittleEndianWrite(t *testing.T) { testWrite(t, LittleEndian, little, s) } func TestLittleEndianWrite(t *testing.T) { testWrite(t, LittleEndian, little, s) }
func TestBigEndianPtrWrite(t *testing.T) { testWrite(t, BigEndian, big, &s) }
func TestLittleEndianPtrWrite(t *testing.T) { testWrite(t, LittleEndian, little, &s) } func TestLittleEndianPtrWrite(t *testing.T) { testWrite(t, LittleEndian, little, &s) }
func TestBigEndianRead(t *testing.T) { testRead(t, BigEndian, big, s) }
func TestBigEndianWrite(t *testing.T) { testWrite(t, BigEndian, big, s) }
func TestBigEndianPtrWrite(t *testing.T) { testWrite(t, BigEndian, big, &s) }
func TestReadSlice(t *testing.T) { func TestReadSlice(t *testing.T) {
slice := make([]int32, 2) slice := make([]int32, 2)
err := Read(bytes.NewBuffer(src), BigEndian, slice) err := Read(bytes.NewBuffer(src), BigEndian, slice)
@ -148,20 +144,75 @@ func TestWriteSlice(t *testing.T) {
func TestWriteT(t *testing.T) { func TestWriteT(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
ts := T{} ts := T{}
err := Write(buf, BigEndian, ts) if err := Write(buf, BigEndian, ts); err == nil {
if err == nil { t.Errorf("WriteT: have err == nil, want non-nil")
t.Errorf("WriteT: have nil, want non-nil")
} }
tv := reflect.Indirect(reflect.ValueOf(ts)) tv := reflect.Indirect(reflect.ValueOf(ts))
for i, n := 0, tv.NumField(); i < n; i++ { for i, n := 0, tv.NumField(); i < n; i++ {
err = Write(buf, BigEndian, tv.Field(i).Interface()) if err := Write(buf, BigEndian, tv.Field(i).Interface()); err == nil {
if err == nil { t.Errorf("WriteT.%v: have err == nil, want non-nil", tv.Field(i).Type())
t.Errorf("WriteT.%v: have nil, want non-nil", tv.Field(i).Type())
} }
} }
} }
type BlankFields struct {
A uint32
_ int32
B float64
_ [4]int16
C byte
_ [7]byte
_ struct {
f [8]float32
}
}
type BlankFieldsProbe struct {
A uint32
P0 int32
B float64
P1 [4]int16
C byte
P2 [7]byte
P3 struct {
F [8]float32
}
}
func TestBlankFields(t *testing.T) {
buf := new(bytes.Buffer)
b1 := BlankFields{A: 1234567890, B: 2.718281828, C: 42}
if err := Write(buf, LittleEndian, &b1); err != nil {
t.Error(err)
}
// zero values must have been written for blank fields
var p BlankFieldsProbe
if err := Read(buf, LittleEndian, &p); err != nil {
t.Error(err)
}
// quick test: only check first value of slices
if p.P0 != 0 || p.P1[0] != 0 || p.P2[0] != 0 || p.P3.F[0] != 0 {
t.Errorf("non-zero values for originally blank fields: %#v", p)
}
// write p and see if we can probe only some fields
if err := Write(buf, LittleEndian, &p); err != nil {
t.Error(err)
}
// read should ignore blank fields in b2
var b2 BlankFields
if err := Read(buf, LittleEndian, &b2); err != nil {
t.Error(err)
}
if b1.A != b2.A || b1.B != b2.B || b1.C != b2.C {
t.Errorf("%#v != %#v", b1, b2)
}
}
type byteSliceReader struct { type byteSliceReader struct {
remain []byte remain []byte
} }

View File

@ -123,7 +123,7 @@ func ReadUvarint(r io.ByteReader) (uint64, error) {
panic("unreachable") panic("unreachable")
} }
// ReadVarint reads an encoded unsigned integer from r and returns it as an int64. // ReadVarint reads an encoded signed integer from r and returns it as an int64.
func ReadVarint(r io.ByteReader) (int64, error) { func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1) x := int64(ux >> 1)

View File

@ -67,8 +67,8 @@ func Unmarshal(data []byte, v interface{}) error {
// Unmarshaler is the interface implemented by objects // Unmarshaler is the interface implemented by objects
// that can unmarshal a JSON description of themselves. // that can unmarshal a JSON description of themselves.
// The input can be assumed to be a valid JSON object // The input can be assumed to be a valid encoding of
// encoding. UnmarshalJSON must copy the JSON data // a JSON value. UnmarshalJSON must copy the JSON data
// if it wishes to retain the data after returning. // if it wishes to retain the data after returning.
type Unmarshaler interface { type Unmarshaler interface {
UnmarshalJSON([]byte) error UnmarshalJSON([]byte) error
@ -617,12 +617,10 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
switch c := item[0]; c { switch c := item[0]; c {
case 'n': // null case 'n': // null
switch v.Kind() { switch v.Kind() {
default:
d.saveError(&UnmarshalTypeError{"null", v.Type()})
case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
v.Set(reflect.Zero(v.Type())) v.Set(reflect.Zero(v.Type()))
// otherwise, ignore null for primitives/string
} }
case 't', 'f': // true, false case 't', 'f': // true, false
value := c == 't' value := c == 't'
switch v.Kind() { switch v.Kind() {

View File

@ -953,3 +953,50 @@ func TestInterfaceSet(t *testing.T) {
} }
} }
} }
// JSON null values should be ignored for primitives and string values instead of resulting in an error.
// Issue 2540
func TestUnmarshalNulls(t *testing.T) {
jsonData := []byte(`{
"Bool" : null,
"Int" : null,
"Int8" : null,
"Int16" : null,
"Int32" : null,
"Int64" : null,
"Uint" : null,
"Uint8" : null,
"Uint16" : null,
"Uint32" : null,
"Uint64" : null,
"Float32" : null,
"Float64" : null,
"String" : null}`)
nulls := All{
Bool: true,
Int: 2,
Int8: 3,
Int16: 4,
Int32: 5,
Int64: 6,
Uint: 7,
Uint8: 8,
Uint16: 9,
Uint32: 10,
Uint64: 11,
Float32: 12.1,
Float64: 13.1,
String: "14"}
err := Unmarshal(jsonData, &nulls)
if err != nil {
t.Errorf("Unmarshal of null values failed: %v", err)
}
if !nulls.Bool || nulls.Int != 2 || nulls.Int8 != 3 || nulls.Int16 != 4 || nulls.Int32 != 5 || nulls.Int64 != 6 ||
nulls.Uint != 7 || nulls.Uint8 != 8 || nulls.Uint16 != 9 || nulls.Uint32 != 10 || nulls.Uint64 != 11 ||
nulls.Float32 != 12.1 || nulls.Float64 != 13.1 || nulls.String != "14" {
t.Errorf("Unmarshal of null values affected primitives")
}
}

View File

@ -11,6 +11,7 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"io" "io"
"sort"
) )
// A Block represents a PEM encoded structure. // A Block represents a PEM encoded structure.
@ -209,26 +210,46 @@ func (l *lineBreaker) Close() (err error) {
return return
} }
func Encode(out io.Writer, b *Block) (err error) { func writeHeader(out io.Writer, k, v string) error {
_, err = out.Write(pemStart[1:]) _, err := out.Write([]byte(k + ": " + v + "\n"))
if err != nil { return err
return
} }
_, err = out.Write([]byte(b.Type + "-----\n"))
if err != nil { func Encode(out io.Writer, b *Block) error {
return if _, err := out.Write(pemStart[1:]); err != nil {
return err
}
if _, err := out.Write([]byte(b.Type + "-----\n")); err != nil {
return err
} }
if len(b.Headers) > 0 { if len(b.Headers) > 0 {
for k, v := range b.Headers { const procType = "Proc-Type"
_, err = out.Write([]byte(k + ": " + v + "\n")) h := make([]string, 0, len(b.Headers))
if err != nil { hasProcType := false
return for k := range b.Headers {
if k == procType {
hasProcType = true
continue
}
h = append(h, k)
}
// The Proc-Type header must be written first.
// See RFC 1421, section 4.6.1.1
if hasProcType {
if err := writeHeader(out, procType, b.Headers[procType]); err != nil {
return err
} }
} }
_, err = out.Write([]byte{'\n'}) // For consistency of output, write other headers sorted by key.
if err != nil { sort.Strings(h)
return for _, k := range h {
if err := writeHeader(out, k, b.Headers[k]); err != nil {
return err
}
}
if _, err := out.Write([]byte{'\n'}); err != nil {
return err
} }
} }
@ -236,19 +257,17 @@ func Encode(out io.Writer, b *Block) (err error) {
breaker.out = out breaker.out = out
b64 := base64.NewEncoder(base64.StdEncoding, &breaker) b64 := base64.NewEncoder(base64.StdEncoding, &breaker)
_, err = b64.Write(b.Bytes) if _, err := b64.Write(b.Bytes); err != nil {
if err != nil { return err
return
} }
b64.Close() b64.Close()
breaker.Close() breaker.Close()
_, err = out.Write(pemEnd[1:]) if _, err := out.Write(pemEnd[1:]); err != nil {
if err != nil { return err
return
} }
_, err = out.Write([]byte(b.Type + "-----\n")) _, err := out.Write([]byte(b.Type + "-----\n"))
return return err
} }
func EncodeToMemory(b *Block) []byte { func EncodeToMemory(b *Block) []byte {

View File

@ -43,7 +43,7 @@ func TestDecode(t *testing.T) {
if !reflect.DeepEqual(result, privateKey) { if !reflect.DeepEqual(result, privateKey) {
t.Errorf("#1 got:%#v want:%#v", result, privateKey) t.Errorf("#1 got:%#v want:%#v", result, privateKey)
} }
result, _ = Decode([]byte(pemPrivateKey)) result, _ = Decode([]byte(pemPrivateKey2))
if !reflect.DeepEqual(result, privateKey2) { if !reflect.DeepEqual(result, privateKey2) {
t.Errorf("#2 got:%#v want:%#v", result, privateKey2) t.Errorf("#2 got:%#v want:%#v", result, privateKey2)
} }
@ -51,8 +51,8 @@ func TestDecode(t *testing.T) {
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
r := EncodeToMemory(privateKey2) r := EncodeToMemory(privateKey2)
if string(r) != pemPrivateKey { if string(r) != pemPrivateKey2 {
t.Errorf("got:%s want:%s", r, pemPrivateKey) t.Errorf("got:%s want:%s", r, pemPrivateKey2)
} }
} }
@ -341,50 +341,64 @@ var privateKey = &Block{Type: "RSA PRIVATE KEY",
}, },
} }
var privateKey2 = &Block{Type: "RSA PRIVATE KEY", var privateKey2 = &Block{
Headers: map[string]string{}, Type: "RSA PRIVATE KEY",
Bytes: []uint8{0x30, 0x82, 0x1, 0x3a, 0x2, 0x1, 0x0, 0x2, Headers: map[string]string{
0x41, 0x0, 0xb2, 0x99, 0xf, 0x49, 0xc4, 0x7d, 0xfa, 0x8c, "Proc-Type": "4,ENCRYPTED",
0xd4, 0x0, 0xae, 0x6a, 0x4d, 0x1b, 0x8a, 0x3b, 0x6a, 0x13, "DEK-Info": "AES-128-CBC,BFCD243FEDBB40A4AA6DDAA1335473A4",
0x64, 0x2b, 0x23, 0xf2, 0x8b, 0x0, 0x3b, 0xfb, 0x97, 0x79, "Content-Domain": "RFC822",
0xa, 0xde, 0x9a, 0x4c, 0xc8, 0x2b, 0x8b, 0x2a, 0x81, 0x74, },
0x7d, 0xde, 0xc0, 0x8b, 0x62, 0x96, 0xe5, 0x3a, 0x8, 0xc3, Bytes: []uint8{
0x31, 0x68, 0x7e, 0xf2, 0x5c, 0x4b, 0xf4, 0x93, 0x6b, 0xa1, 0xa8, 0x35, 0xcc, 0x2b, 0xb9, 0xcb, 0x21, 0xab, 0xc0,
0xc0, 0xe6, 0x4, 0x1e, 0x9d, 0x15, 0x2, 0x3, 0x1, 0x0, 0x1, 0x9d, 0x76, 0x61, 0x0, 0xf4, 0x81, 0xad, 0x69, 0xd2,
0x2, 0x41, 0x0, 0x8a, 0xbd, 0x6a, 0x69, 0xf4, 0xd1, 0xa4, 0xc0, 0x42, 0x41, 0x3b, 0xe4, 0x3c, 0xaf, 0x59, 0x5e,
0xb4, 0x87, 0xf0, 0xab, 0x8d, 0x7a, 0xae, 0xfd, 0x38, 0x60, 0x6d, 0x2a, 0x3c, 0x9c, 0xa1, 0xa4, 0x5e, 0x68, 0x37,
0x94, 0x5, 0xc9, 0x99, 0x98, 0x4e, 0x30, 0xf5, 0x67, 0xe1, 0xc4, 0x8c, 0x70, 0x1c, 0xa9, 0x18, 0xe6, 0xc2, 0x2b,
0xe8, 0xae, 0xef, 0xf4, 0x4e, 0x8b, 0x18, 0xbd, 0xb1, 0xec, 0x8a, 0x91, 0xdc, 0x2d, 0x1f, 0x8, 0x23, 0x39, 0xf1,
0x78, 0xdf, 0xa3, 0x1a, 0x55, 0xe3, 0x2a, 0x48, 0xd7, 0xfb, 0x4b, 0x8b, 0x1b, 0x2f, 0x46, 0xb, 0xb2, 0x26, 0xba,
0x13, 0x1f, 0x5a, 0xf1, 0xf4, 0x4d, 0x7d, 0x6b, 0x2c, 0xed, 0x4f, 0x40, 0x80, 0x39, 0xc4, 0xb1, 0xcb, 0x3b, 0xb4,
0x2a, 0x9d, 0xf5, 0xe5, 0xae, 0x45, 0x35, 0x2, 0x21, 0x0, 0x65, 0x3f, 0x1b, 0xb2, 0xf7, 0x8, 0xd2, 0xc6, 0xd5,
0xda, 0xb2, 0xf1, 0x80, 0x48, 0xba, 0xa6, 0x8d, 0xe7, 0xdf, 0xa8, 0x9f, 0x23, 0x69, 0xb6, 0x3d, 0xf9, 0xac, 0x1c,
0x4, 0xd2, 0xd3, 0x5d, 0x5d, 0x80, 0xe6, 0xe, 0x2d, 0xfa, 0xb3, 0x13, 0x87, 0x64, 0x4, 0x37, 0xdb, 0x40, 0xc8,
0x42, 0xd5, 0xa, 0x9b, 0x4, 0x21, 0x90, 0x32, 0x71, 0x5e, 0x82, 0xc, 0xd0, 0xf8, 0x21, 0x7c, 0xdc, 0xbd, 0x9, 0x4,
0x46, 0xb3, 0x2, 0x21, 0x0, 0xd1, 0xf, 0x2e, 0x66, 0xb1, 0x20, 0x16, 0xb0, 0x97, 0xe2, 0x6d, 0x56, 0x1d, 0xe3,
0xd0, 0xc1, 0x3f, 0x10, 0xef, 0x99, 0x27, 0xbf, 0x53, 0x24, 0xec, 0xf0, 0xfc, 0xe2, 0x56, 0xad, 0xa4, 0x3, 0x70,
0xa3, 0x79, 0xca, 0x21, 0x81, 0x46, 0xcb, 0xf9, 0xca, 0xfc, 0x6d, 0x63, 0x3c, 0x1, 0xbe, 0x3e, 0x28, 0x38, 0x6f,
0x79, 0x52, 0x21, 0xf1, 0x6a, 0x31, 0x17, 0x2, 0x20, 0x21, 0xc0, 0xe6, 0xfd, 0x85, 0xd1, 0x53, 0xa8, 0x9b, 0xcb,
0x2, 0x89, 0x79, 0x37, 0x81, 0x14, 0xca, 0xae, 0x88, 0xf7, 0xd4, 0x4, 0xb1, 0x73, 0xb9, 0x73, 0x32, 0xd6, 0x7a,
0xd, 0x6b, 0x61, 0xd8, 0x4f, 0x30, 0x6a, 0x4b, 0x7e, 0x4e, 0xc6, 0x29, 0x25, 0xa5, 0xda, 0x17, 0x93, 0x7a, 0x10,
0xc0, 0x21, 0x4d, 0xac, 0x9d, 0xf4, 0x49, 0xe8, 0xda, 0xb6, 0xe8, 0x41, 0xfb, 0xa5, 0x17, 0x20, 0xf8, 0x4e, 0xe9,
0x9, 0x2, 0x20, 0x16, 0xb3, 0xec, 0x59, 0x10, 0xa4, 0x57, 0xe3, 0x8f, 0x51, 0x20, 0x13, 0xbb, 0xde, 0xb7, 0x93,
0xe8, 0xe, 0x61, 0xc6, 0xa3, 0xf, 0x5e, 0xeb, 0x12, 0xa9, 0xae, 0x13, 0x8a, 0xf6, 0x9, 0xf4, 0xa6, 0x41, 0xe0,
0xae, 0x2e, 0xb7, 0x48, 0x45, 0xec, 0x69, 0x83, 0xc3, 0x75, 0x2b, 0x51, 0x1a, 0x30, 0x38, 0xd, 0xb1, 0x3b, 0x67,
0xc, 0xe4, 0x97, 0xa0, 0x9f, 0x2, 0x20, 0x69, 0x52, 0xb4, 0x87, 0x64, 0xf5, 0xca, 0x32, 0x67, 0xd1, 0xc8, 0xa5,
0x6, 0xe8, 0x50, 0x60, 0x71, 0x4c, 0x3a, 0xb7, 0x66, 0xba, 0x3d, 0x23, 0x72, 0xc4, 0x6, 0xaf, 0x8f, 0x7b, 0x26,
0xd, 0x8a, 0xc9, 0xb7, 0xd, 0xa3, 0x8, 0x6c, 0xa3, 0xf2, 0xac, 0x3c, 0x75, 0x91, 0xa1, 0x0, 0x13, 0xc6, 0x5c,
0x62, 0xb0, 0x2a, 0x84, 0xaa, 0x2f, 0xd6, 0x1e, 0x55, 0x49, 0xd5, 0x3c, 0xe7, 0xb2, 0xb2, 0x99, 0xe0, 0xd5,
0x25, 0xfa, 0xe2, 0x12, 0x80, 0x37, 0x85, 0xcf, 0x92,
0xca, 0x1b, 0x9f, 0xf3, 0x4e, 0xd8, 0x80, 0xef, 0x3c,
0xce, 0xcd, 0xf5, 0x90, 0x9e, 0xf9, 0xa7, 0xb2, 0xc,
0x49, 0x4, 0xf1, 0x9, 0x8f, 0xea, 0x63, 0xd2, 0x70,
0xbb, 0x86, 0xbf, 0x34, 0xab, 0xb2, 0x3, 0xb1, 0x59,
0x33, 0x16, 0x17, 0xb0, 0xdb, 0x77, 0x38, 0xf4, 0xb4,
0x94, 0xb, 0x25, 0x16, 0x7e, 0x22, 0xd4, 0xf9, 0x22,
0xb9, 0x78, 0xa3, 0x4, 0x84, 0x4, 0xd2, 0xda, 0x84,
0x2d, 0x63, 0xdd, 0xf8, 0x50, 0x6a, 0xf6, 0xe3, 0xf5,
0x65, 0x40, 0x7c, 0xa9,
}, },
} }
var pemPrivateKey = `-----BEGIN RSA PRIVATE KEY----- var pemPrivateKey2 = `-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 Proc-Type: 4,ENCRYPTED
fd7Ai2KW5ToIwzFofvJcS/STa6HA5gQenRUCAwEAAQJBAIq9amn00aS0h/CrjXqu Content-Domain: RFC822
/ThglAXJmZhOMPVn4eiu7/ROixi9sex436MaVeMqSNf7Ex9a8fRNfWss7Sqd9eWu DEK-Info: AES-128-CBC,BFCD243FEDBB40A4AA6DDAA1335473A4
RTUCIQDasvGASLqmjeffBNLTXV2A5g4t+kLVCpsEIZAycV5GswIhANEPLmax0ME/
EO+ZJ79TJKN5yiGBRsv5yvx5UiHxajEXAiAhAol5N4EUyq6I9w1rYdhPMGpLfk7A qDXMK7nLIavAnXZhAPSBrWnSwEJBO+Q8r1lebSo8nKGkXmg3xIxwHKkY5sIripHc
IU2snfRJ6Nq2CQIgFrPsWRCkV+gOYcajD17rEqmuLrdIRexpg8N1DOSXoJ8CIGlS LR8IIznxS4sbL0YLsia6T0CAOcSxyzu0ZT8bsvcI0sbVqJ8jabY9+awcsxOHZAQ3
tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V 20DIggzQ+CF83L0JBCAWsJfibVYd4+zw/OJWraQDcG1jPAG+Pig4b8Dm/YXRU6ib
y9QEsXO5czLWesYpJaXaF5N6EOhB+6UXIPhO6eOPUSATu963k64TivYJ9KZB4CtR
GjA4DbE7Z4dk9coyZ9HIpT0jcsQGr497Jqw8dZGhABPGXEnVPOeyspng1SX64hKA
N4XPksobn/NO2IDvPM7N9ZCe+aeyDEkE8QmP6mPScLuGvzSrsgOxWTMWF7Dbdzj0
tJQLJRZ+ItT5Irl4owSEBNLahC1j3fhQavbj9WVAfKk=
-----END RSA PRIVATE KEY----- -----END RSA PRIVATE KEY-----
` `

View File

@ -687,6 +687,27 @@ var marshalTests = []struct {
Value: &IgnoreTest{}, Value: &IgnoreTest{},
UnmarshalOnly: true, UnmarshalOnly: true,
}, },
// Test escaping.
{
ExpectXML: `<a><nested><value>dquote: &#34;; squote: &#39;; ampersand: &amp;; less: &lt;; greater: &gt;;</value></nested></a>`,
Value: &AnyTest{
Nested: `dquote: "; squote: '; ampersand: &; less: <; greater: >;`,
},
},
{
ExpectXML: `<a><nested><value>newline: &#xA;; cr: &#xD;; tab: &#x9;;</value></nested></a>`,
Value: &AnyTest{
Nested: "newline: \n; cr: \r; tab: \t;",
},
},
{
ExpectXML: "<a><nested><value>1\r2\r\n3\n\r4\n5</value></nested></a>",
Value: &AnyTest{
Nested: "1\n2\n3\n\n4\n5",
},
UnmarshalOnly: true,
},
} }
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {

View File

@ -181,7 +181,6 @@ type Decoder struct {
ns map[string]string ns map[string]string
err error err error
line int line int
tmp [32]byte
} }
// NewDecoder creates a new XML parser reading from r. // NewDecoder creates a new XML parser reading from r.
@ -877,71 +876,66 @@ Input:
// XML in all its glory allows a document to define and use // XML in all its glory allows a document to define and use
// its own character names with <!ENTITY ...> directives. // its own character names with <!ENTITY ...> directives.
// Parsers are required to recognize lt, gt, amp, apos, and quot // Parsers are required to recognize lt, gt, amp, apos, and quot
// even if they have not been declared. That's all we allow. // even if they have not been declared.
var i int before := d.buf.Len()
var semicolon bool
var valid bool
for i = 0; i < len(d.tmp); i++ {
var ok bool
d.tmp[i], ok = d.getc()
if !ok {
if d.err == io.EOF {
d.err = d.syntaxError("unexpected EOF")
}
return nil
}
c := d.tmp[i]
if c == ';' {
semicolon = true
valid = i > 0
break
}
if 'a' <= c && c <= 'z' ||
'A' <= c && c <= 'Z' ||
'0' <= c && c <= '9' ||
c == '_' || c == '#' {
continue
}
d.ungetc(c)
break
}
s := string(d.tmp[0:i])
if !valid {
if !d.Strict {
b0, b1 = 0, 0
d.buf.WriteByte('&') d.buf.WriteByte('&')
d.buf.Write(d.tmp[0:i]) var ok bool
if semicolon { var text string
d.buf.WriteByte(';') var haveText bool
} if b, ok = d.mustgetc(); !ok {
continue Input
}
semi := ";"
if !semicolon {
semi = " (no semicolon)"
}
if i < len(d.tmp) {
d.err = d.syntaxError("invalid character entity &" + s + semi)
} else {
d.err = d.syntaxError("invalid character entity &" + s + "... too long")
}
return nil return nil
} }
var haveText bool if b == '#' {
var text string d.buf.WriteByte(b)
if i >= 2 && s[0] == '#' { if b, ok = d.mustgetc(); !ok {
var n uint64 return nil
var err error
if i >= 3 && s[1] == 'x' {
n, err = strconv.ParseUint(s[2:], 16, 64)
} else {
n, err = strconv.ParseUint(s[1:], 10, 64)
} }
base := 10
if b == 'x' {
base = 16
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
}
start := d.buf.Len()
for '0' <= b && b <= '9' ||
base == 16 && 'a' <= b && b <= 'f' ||
base == 16 && 'A' <= b && b <= 'F' {
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
}
if b != ';' {
d.ungetc(b)
} else {
s := string(d.buf.Bytes()[start:])
d.buf.WriteByte(';')
n, err := strconv.ParseUint(s, base, 64)
if err == nil && n <= unicode.MaxRune { if err == nil && n <= unicode.MaxRune {
text = string(n) text = string(n)
haveText = true haveText = true
} }
}
} else { } else {
d.ungetc(b)
if !d.readName() {
if d.err != nil {
return nil
}
ok = false
}
if b, ok = d.mustgetc(); !ok {
return nil
}
if b != ';' {
d.ungetc(b)
} else {
name := d.buf.Bytes()[before+1:]
d.buf.WriteByte(';')
if isName(name) {
s := string(name)
if r, ok := entity[s]; ok { if r, ok := entity[s]; ok {
text = string(r) text = string(r)
haveText = true haveText = true
@ -949,22 +943,36 @@ Input:
text, haveText = d.Entity[s] text, haveText = d.Entity[s]
} }
} }
if !haveText {
if !d.Strict {
b0, b1 = 0, 0
d.buf.WriteByte('&')
d.buf.Write(d.tmp[0:i])
d.buf.WriteByte(';')
continue Input
} }
d.err = d.syntaxError("invalid character entity &" + s + ";")
return nil
} }
if haveText {
d.buf.Truncate(before)
d.buf.Write([]byte(text)) d.buf.Write([]byte(text))
b0, b1 = 0, 0 b0, b1 = 0, 0
continue Input continue Input
} }
if !d.Strict {
b0, b1 = 0, 0
continue Input
}
ent := string(d.buf.Bytes()[before])
if ent[len(ent)-1] != ';' {
ent += " (no semicolon)"
}
d.err = d.syntaxError("invalid character entity " + ent)
return nil
}
// We must rewrite unescaped \r and \r\n into \n.
if b == '\r' {
d.buf.WriteByte('\n')
} else if b1 == '\r' && b == '\n' {
// Skip \r\n--we already wrote \n.
} else {
d.buf.WriteByte(b) d.buf.WriteByte(b)
}
b0, b1 = b1, b b0, b1 = b1, b
} }
data := d.buf.Bytes() data := d.buf.Bytes()
@ -985,20 +993,7 @@ Input:
} }
} }
// Must rewrite \r and \r\n into \n. return data
w := 0
for r := 0; r < len(data); r++ {
b := data[r]
if b == '\r' {
if r+1 < len(data) && data[r+1] == '\n' {
continue
}
b = '\n'
}
data[w] = b
w++
}
return data[0:w]
} }
// Decide whether the given rune is in the XML Character Range, per // Decide whether the given rune is in the XML Character Range, per
@ -1034,18 +1029,34 @@ func (d *Decoder) nsname() (name Name, ok bool) {
// Do not set d.err if the name is missing (unless unexpected EOF is received): // Do not set d.err if the name is missing (unless unexpected EOF is received):
// let the caller provide better context. // let the caller provide better context.
func (d *Decoder) name() (s string, ok bool) { func (d *Decoder) name() (s string, ok bool) {
d.buf.Reset()
if !d.readName() {
return "", false
}
// Now we check the characters.
s = d.buf.String()
if !isName([]byte(s)) {
d.err = d.syntaxError("invalid XML name: " + s)
return "", false
}
return s, true
}
// Read a name and append its bytes to d.buf.
// The name is delimited by any single-byte character not valid in names.
// All multi-byte characters are accepted; the caller must check their validity.
func (d *Decoder) readName() (ok bool) {
var b byte var b byte
if b, ok = d.mustgetc(); !ok { if b, ok = d.mustgetc(); !ok {
return return
} }
// As a first approximation, we gather the bytes [A-Za-z_:.-\x80-\xFF]*
if b < utf8.RuneSelf && !isNameByte(b) { if b < utf8.RuneSelf && !isNameByte(b) {
d.ungetc(b) d.ungetc(b)
return "", false return false
} }
d.buf.Reset()
d.buf.WriteByte(b) d.buf.WriteByte(b)
for { for {
if b, ok = d.mustgetc(); !ok { if b, ok = d.mustgetc(); !ok {
return return
@ -1056,16 +1067,7 @@ func (d *Decoder) name() (s string, ok bool) {
} }
d.buf.WriteByte(b) d.buf.WriteByte(b)
} }
return true
// Then we check the characters.
s = d.buf.String()
for i, c := range s {
if !unicode.Is(first, c) && (i == 0 || !unicode.Is(second, c)) {
d.err = d.syntaxError("invalid XML name: " + s)
return "", false
}
}
return s, true
} }
func isNameByte(c byte) bool { func isNameByte(c byte) bool {
@ -1075,6 +1077,30 @@ func isNameByte(c byte) bool {
c == '_' || c == ':' || c == '.' || c == '-' c == '_' || c == ':' || c == '.' || c == '-'
} }
func isName(s []byte) bool {
if len(s) == 0 {
return false
}
c, n := utf8.DecodeRune(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) {
return false
}
for n < len(s) {
s = s[n:]
c, n = utf8.DecodeRune(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) && !unicode.Is(second, c) {
return false
}
}
return true
}
// These tables were generated by cut and paste from Appendix B of // These tables were generated by cut and paste from Appendix B of
// the XML spec at http://www.xml.com/axml/testaxml.htm // the XML spec at http://www.xml.com/axml/testaxml.htm
// and then reformatting. First corresponds to (Letter | '_' | ':') // and then reformatting. First corresponds to (Letter | '_' | ':')
@ -1689,6 +1715,9 @@ var (
esc_amp = []byte("&amp;") esc_amp = []byte("&amp;")
esc_lt = []byte("&lt;") esc_lt = []byte("&lt;")
esc_gt = []byte("&gt;") esc_gt = []byte("&gt;")
esc_tab = []byte("&#x9;")
esc_nl = []byte("&#xA;")
esc_cr = []byte("&#xD;")
) )
// Escape writes to w the properly escaped XML equivalent // Escape writes to w the properly escaped XML equivalent
@ -1708,6 +1737,12 @@ func Escape(w io.Writer, s []byte) {
esc = esc_lt esc = esc_lt
case '>': case '>':
esc = esc_gt esc = esc_gt
case '\t':
esc = esc_tab
case '\n':
esc = esc_nl
case '\r':
esc = esc_cr
default: default:
continue continue
} }

View File

@ -19,6 +19,7 @@ const testInput = `
<body xmlns:foo="ns1" xmlns="ns2" xmlns:tag="ns3" ` + <body xmlns:foo="ns1" xmlns="ns2" xmlns:tag="ns3" ` +
"\r\n\t" + ` > "\r\n\t" + ` >
<hello lang="en">World &lt;&gt;&apos;&quot; &#x767d;&#40300;</hello> <hello lang="en">World &lt;&gt;&apos;&quot; &#x767d;&#40300;</hello>
<query>&; &is-it;</query>
<goodbye /> <goodbye />
<outer foo:attr="value" xmlns:tag="ns4"> <outer foo:attr="value" xmlns:tag="ns4">
<inner/> <inner/>
@ -28,6 +29,8 @@ const testInput = `
</tag:name> </tag:name>
</body><!-- missing final newline -->` </body><!-- missing final newline -->`
var testEntity = map[string]string{"何": "What", "is-it": "is it?"}
var rawTokens = []Token{ var rawTokens = []Token{
CharData("\n"), CharData("\n"),
ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)}, ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)},
@ -41,6 +44,10 @@ var rawTokens = []Token{
CharData("World <>'\" 白鵬翔"), CharData("World <>'\" 白鵬翔"),
EndElement{Name{"", "hello"}}, EndElement{Name{"", "hello"}},
CharData("\n "), CharData("\n "),
StartElement{Name{"", "query"}, []Attr{}},
CharData("What is it?"),
EndElement{Name{"", "query"}},
CharData("\n "),
StartElement{Name{"", "goodbye"}, []Attr{}}, StartElement{Name{"", "goodbye"}, []Attr{}},
EndElement{Name{"", "goodbye"}}, EndElement{Name{"", "goodbye"}},
CharData("\n "), CharData("\n "),
@ -74,6 +81,10 @@ var cookedTokens = []Token{
CharData("World <>'\" 白鵬翔"), CharData("World <>'\" 白鵬翔"),
EndElement{Name{"ns2", "hello"}}, EndElement{Name{"ns2", "hello"}},
CharData("\n "), CharData("\n "),
StartElement{Name{"ns2", "query"}, []Attr{}},
CharData("What is it?"),
EndElement{Name{"ns2", "query"}},
CharData("\n "),
StartElement{Name{"ns2", "goodbye"}, []Attr{}}, StartElement{Name{"ns2", "goodbye"}, []Attr{}},
EndElement{Name{"ns2", "goodbye"}}, EndElement{Name{"ns2", "goodbye"}},
CharData("\n "), CharData("\n "),
@ -156,6 +167,7 @@ var xmlInput = []string{
func TestRawToken(t *testing.T) { func TestRawToken(t *testing.T) {
d := NewDecoder(strings.NewReader(testInput)) d := NewDecoder(strings.NewReader(testInput))
d.Entity = testEntity
testRawToken(t, d, rawTokens) testRawToken(t, d, rawTokens)
} }
@ -164,8 +176,14 @@ const nonStrictInput = `
<tag>&unknown;entity</tag> <tag>&unknown;entity</tag>
<tag>&#123</tag> <tag>&#123</tag>
<tag>&#zzz;</tag> <tag>&#zzz;</tag>
<tag>&なまえ3;</tag>
<tag>&lt-gt;</tag>
<tag>&;</tag>
<tag>&0a;</tag>
` `
var nonStringEntity = map[string]string{"": "oops!", "0a": "oops!"}
var nonStrictTokens = []Token{ var nonStrictTokens = []Token{
CharData("\n"), CharData("\n"),
StartElement{Name{"", "tag"}, []Attr{}}, StartElement{Name{"", "tag"}, []Attr{}},
@ -184,6 +202,22 @@ var nonStrictTokens = []Token{
CharData("&#zzz;"), CharData("&#zzz;"),
EndElement{Name{"", "tag"}}, EndElement{Name{"", "tag"}},
CharData("\n"), CharData("\n"),
StartElement{Name{"", "tag"}, []Attr{}},
CharData("&なまえ3;"),
EndElement{Name{"", "tag"}},
CharData("\n"),
StartElement{Name{"", "tag"}, []Attr{}},
CharData("&lt-gt;"),
EndElement{Name{"", "tag"}},
CharData("\n"),
StartElement{Name{"", "tag"}, []Attr{}},
CharData("&;"),
EndElement{Name{"", "tag"}},
CharData("\n"),
StartElement{Name{"", "tag"}, []Attr{}},
CharData("&0a;"),
EndElement{Name{"", "tag"}},
CharData("\n"),
} }
func TestNonStrictRawToken(t *testing.T) { func TestNonStrictRawToken(t *testing.T) {
@ -317,6 +351,7 @@ func TestNestedDirectives(t *testing.T) {
func TestToken(t *testing.T) { func TestToken(t *testing.T) {
d := NewDecoder(strings.NewReader(testInput)) d := NewDecoder(strings.NewReader(testInput))
d.Entity = testEntity
for i, want := range cookedTokens { for i, want := range cookedTokens {
have, err := d.Token() have, err := d.Token()

View File

@ -176,8 +176,7 @@ func processPackage(fset *token.FileSet, files map[string]*ast.File) {
report(err) report(err)
return return
} }
_, err = types.Check(fset, pkg) if err = types.Check(fset, pkg, nil, nil); err != nil {
if err != nil {
report(err) report(err)
} }
} }

View File

@ -35,10 +35,12 @@ var tests = []struct {
// directories // directories
{filepath.Join(runtime.GOROOT(), "src/pkg/go/ast"), "ast"}, {filepath.Join(runtime.GOROOT(), "src/pkg/go/ast"), "ast"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/build"), "build"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/doc"), "doc"}, {filepath.Join(runtime.GOROOT(), "src/pkg/go/doc"), "doc"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/token"), "scanner"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/scanner"), "scanner"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/parser"), "parser"}, {filepath.Join(runtime.GOROOT(), "src/pkg/go/parser"), "parser"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/printer"), "printer"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/scanner"), "scanner"},
{filepath.Join(runtime.GOROOT(), "src/pkg/go/token"), "token"},
{filepath.Join(runtime.GOROOT(), "src/pkg/exp/types"), "types"}, {filepath.Join(runtime.GOROOT(), "src/pkg/exp/types"), "types"},
} }

View File

@ -58,7 +58,9 @@ type Tailoring struct {
id string id string
builder *Builder builder *Builder
index *ordering index *ordering
// TODO: implement.
anchor *entry
before bool
} }
// NewBuilder returns a new Builder. // NewBuilder returns a new Builder.
@ -80,6 +82,7 @@ func (b *Builder) Tailoring(locale string) *Tailoring {
builder: b, builder: b,
index: b.root.clone(), index: b.root.clone(),
} }
t.index.id = t.id
b.locale = append(b.locale, t) b.locale = append(b.locale, t)
return t return t
} }
@ -95,7 +98,6 @@ func (b *Builder) Tailoring(locale string) *Tailoring {
// a value for each colelem that is a variable. (See the reference above.) // a value for each colelem that is a variable. (See the reference above.)
func (b *Builder) Add(runes []rune, colelems [][]int, variables []int) error { func (b *Builder) Add(runes []rune, colelems [][]int, variables []int) error {
str := string(runes) str := string(runes)
elems := make([][]int, len(colelems)) elems := make([][]int, len(colelems))
for i, ce := range colelems { for i, ce := range colelems {
elems[i] = append(elems[i], ce...) elems[i] = append(elems[i], ce...)
@ -127,7 +129,7 @@ func (b *Builder) Add(runes []rune, colelems [][]int, variables []int) error {
if ce[0] > b.varTop { if ce[0] > b.varTop {
b.varTop = ce[0] b.varTop = ce[0]
} }
} else if ce[0] > 0 { } else if ce[0] > 1 { // 1 is a special primary value reserved for FFFE
if ce[0] <= b.varTop { if ce[0] <= b.varTop {
return fmt.Errorf("primary value %X of non-variable is smaller than the highest variable %X", ce[0], b.varTop) return fmt.Errorf("primary value %X of non-variable is smaller than the highest variable %X", ce[0], b.varTop)
} }
@ -144,6 +146,21 @@ func (b *Builder) Add(runes []rune, colelems [][]int, variables []int) error {
return nil return nil
} }
func (t *Tailoring) setAnchor(anchor string) error {
anchor = norm.NFD.String(anchor)
a := t.index.find(anchor)
if a == nil {
a = t.index.newEntry(anchor, nil)
a.implicit = true
for _, r := range []rune(anchor) {
e := t.index.find(string(r))
e.lock = true
}
}
t.anchor = a
return nil
}
// SetAnchor sets the point after which elements passed in subsequent calls to // SetAnchor sets the point after which elements passed in subsequent calls to
// Insert will be inserted. It is equivalent to the reset directive in an LDML // Insert will be inserted. It is equivalent to the reset directive in an LDML
// specification. See Insert for an example. // specification. See Insert for an example.
@ -151,14 +168,20 @@ func (b *Builder) Add(runes []rune, colelems [][]int, variables []int) error {
// <first_tertiary_ignorable/>, <last_teriary_ignorable/>, <first_primary_ignorable/>, // <first_tertiary_ignorable/>, <last_teriary_ignorable/>, <first_primary_ignorable/>,
// and <last_non_ignorable/>. // and <last_non_ignorable/>.
func (t *Tailoring) SetAnchor(anchor string) error { func (t *Tailoring) SetAnchor(anchor string) error {
// TODO: implement. if err := t.setAnchor(anchor); err != nil {
return err
}
t.before = false
return nil return nil
} }
// SetAnchorBefore is similar to SetAnchor, except that subsequent calls to // SetAnchorBefore is similar to SetAnchor, except that subsequent calls to
// Insert will insert entries before the anchor. // Insert will insert entries before the anchor.
func (t *Tailoring) SetAnchorBefore(anchor string) error { func (t *Tailoring) SetAnchorBefore(anchor string) error {
// TODO: implement. if err := t.setAnchor(anchor); err != nil {
return err
}
t.before = true
return nil return nil
} }
@ -195,7 +218,112 @@ func (t *Tailoring) SetAnchorBefore(anchor string) error {
// t.SetAnchor("<last_primary_ignorable/>") // t.SetAnchor("<last_primary_ignorable/>")
// t.Insert(collate.Primary, "0", "") // t.Insert(collate.Primary, "0", "")
func (t *Tailoring) Insert(level collate.Level, str, extend string) error { func (t *Tailoring) Insert(level collate.Level, str, extend string) error {
// TODO: implement. if t.anchor == nil {
return fmt.Errorf("%s:Insert: no anchor point set for tailoring of %s", t.id, str)
}
str = norm.NFD.String(str)
e := t.index.find(str)
if e == nil {
e = t.index.newEntry(str, nil)
} else if e.logical != noAnchor {
return fmt.Errorf("%s:Insert: cannot reinsert logical reset position %q", t.id, e.str)
}
if e.lock {
return fmt.Errorf("%s:Insert: cannot reinsert element %q", t.id, e.str)
}
a := t.anchor
// Find the first element after the anchor which differs at a level smaller or
// equal to the given level. Then insert at this position.
// See http://unicode.org/reports/tr35/#Collation_Elements, Section 5.14.5 for details.
e.before = t.before
if t.before {
t.before = false
if a.prev == nil {
a.insertBefore(e)
} else {
for a = a.prev; a.level > level; a = a.prev {
}
a.insertAfter(e)
}
e.level = level
} else {
for ; a.level > level; a = a.next {
}
e.level = a.level
if a != e {
a.insertAfter(e)
a.level = level
} else {
// We don't set a to prev itself. This has the effect of the entry
// getting new collation elements that are an increment of itself.
// This is intentional.
a.prev.level = level
}
}
e.extend = norm.NFD.String(extend)
e.exclude = false
e.elems = nil
t.anchor = e
return nil
}
func (o *ordering) getWeight(e *entry) [][]int {
if len(e.elems) == 0 && e.logical == noAnchor {
if e.implicit {
for _, r := range e.runes {
e.elems = append(e.elems, o.getWeight(o.find(string(r)))...)
}
} else if e.before {
count := [collate.Identity + 1]int{}
a := e
for ; a.elems == nil && !a.implicit; a = a.next {
count[a.level]++
}
e.elems = append([][]int(nil), make([]int, len(a.elems[0])))
copy(e.elems[0], a.elems[0])
for i := collate.Primary; i < collate.Quaternary; i++ {
if count[i] != 0 {
e.elems[0][i] -= count[i]
break
}
}
if e.prev != nil {
o.verifyWeights(e.prev, e, e.prev.level)
}
} else {
prev := e.prev
e.elems = nextWeight(prev.level, o.getWeight(prev))
o.verifyWeights(e, e.next, e.level)
}
}
return e.elems
}
func (o *ordering) addExtension(e *entry) {
if ex := o.find(e.extend); ex != nil {
e.elems = append(e.elems, ex.elems...)
} else {
for _, r := range []rune(e.extend) {
e.elems = append(e.elems, o.find(string(r)).elems...)
}
}
e.extend = ""
}
func (o *ordering) verifyWeights(a, b *entry, level collate.Level) error {
if level == collate.Identity || b == nil || b.elems == nil || a.elems == nil {
return nil
}
for i := collate.Primary; i < level; i++ {
if a.elems[0][i] < b.elems[0][i] {
return nil
}
}
if a.elems[0][level] >= b.elems[0][level] {
err := fmt.Errorf("%s:overflow: collation elements of %q (%X) overflows those of %q (%X) at level %d (%X >= %X)", o.id, a.str, a.runes, b.str, b.runes, level, a.elems, b.elems)
log.Println(err)
// TODO: return the error instead, or better, fix the conflicting entry by making room.
}
return nil return nil
} }
@ -205,7 +333,19 @@ func (b *Builder) error(e error) {
} }
} }
func (b *Builder) errorID(locale string, e error) {
if e != nil {
b.err = fmt.Errorf("%s:%v", locale, e)
}
}
func (b *Builder) buildOrdering(o *ordering) { func (b *Builder) buildOrdering(o *ordering) {
for _, e := range o.ordered {
o.getWeight(e)
}
for _, e := range o.ordered {
o.addExtension(e)
}
o.sort() o.sort()
simplify(o) simplify(o)
b.processExpansions(o) // requires simplify b.processExpansions(o) // requires simplify
@ -215,7 +355,7 @@ func (b *Builder) buildOrdering(o *ordering) {
for e := o.front(); e != nil; e, _ = e.nextIndexed() { for e := o.front(); e != nil; e, _ = e.nextIndexed() {
if !e.skip() { if !e.skip() {
ce, err := e.encode() ce, err := e.encode()
b.error(err) b.errorID(o.id, err)
t.insert(e.runes[0], ce) t.insert(e.runes[0], ce)
} }
} }
@ -252,7 +392,11 @@ func (b *Builder) Build() (*collate.Collator, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return collate.Init(t), nil c := collate.Init(t)
if c == nil {
panic("generated table of incompatible type")
}
return c, nil
} }
// Build builds a Collator for Tailoring t. // Build builds a Collator for Tailoring t.
@ -308,6 +452,10 @@ func reproducibleFromNFKD(e *entry, exp, nfkd [][]int) bool {
if i >= 2 && ce[2] != maxTertiary { if i >= 2 && ce[2] != maxTertiary {
return false return false
} }
if _, err := makeCE(ce); err != nil {
// Simply return false. The error will be caught elsewhere.
return false
}
} }
return true return true
} }
@ -332,12 +480,11 @@ func simplify(o *ordering) {
e.remove() e.remove()
} }
} }
// Tag entries for which the runes NFKD decompose to identical values. // Tag entries for which the runes NFKD decompose to identical values.
for e := o.front(); e != nil; e, _ = e.nextIndexed() { for e := o.front(); e != nil; e, _ = e.nextIndexed() {
s := e.str s := e.str
nfkd := norm.NFKD.String(s) nfkd := norm.NFKD.String(s)
if len(e.runes) > 1 || keep[e.runes[0]] || nfkd == s { if e.decompose || len(e.runes) > 1 || len(e.elems) == 1 || keep[e.runes[0]] || nfkd == s {
continue continue
} }
if reproducibleFromNFKD(e, e.elems, o.genColElems(nfkd)) { if reproducibleFromNFKD(e, e.elems, o.genColElems(nfkd)) {
@ -459,7 +606,7 @@ func (b *Builder) processContractions(o *ordering) {
elems := []uint32{} elems := []uint32{}
for _, e := range es { for _, e := range es {
ce, err := e.encodeBase() ce, err := e.encodeBase()
b.error(err) b.errorID(o.id, err)
elems = append(elems, ce) elems = append(elems, ce)
} }
key = fmt.Sprintf("%v", elems) key = fmt.Sprintf("%v", elems)

View File

@ -26,7 +26,7 @@ const (
// For normal collation elements, we assume that a collation element either has // For normal collation elements, we assume that a collation element either has
// a primary or non-default secondary value, not both. // a primary or non-default secondary value, not both.
// Collation elements with a primary value are of the form // Collation elements with a primary value are of the form
// 010ppppp pppppppp pppppppp ssssssss // 01pppppp pppppppp ppppppp0 ssssssss
// - p* is primary collation value // - p* is primary collation value
// - s* is the secondary collation value // - s* is the secondary collation value
// or // or
@ -67,10 +67,10 @@ func makeCE(weights []int) (uint32, error) {
if weights[1] >= 1<<maxSecondaryCompactBits { if weights[1] >= 1<<maxSecondaryCompactBits {
return 0, fmt.Errorf("makeCE: secondary weight with non-zero primary out of bounds: %x >= %x", weights[1], 1<<maxSecondaryCompactBits) return 0, fmt.Errorf("makeCE: secondary weight with non-zero primary out of bounds: %x >= %x", weights[1], 1<<maxSecondaryCompactBits)
} }
ce = uint32(weights[0]<<maxSecondaryCompactBits + weights[1]) ce = uint32(weights[0]<<(maxSecondaryCompactBits+1) + weights[1])
ce |= isPrimary ce |= isPrimary
} else { } else {
d := weights[1] - defaultSecondary d := weights[1] - defaultSecondary + 4
if d >= 1<<maxSecondaryDiffBits || d < 0 { if d >= 1<<maxSecondaryDiffBits || d < 0 {
return 0, fmt.Errorf("makeCE: secondary weight diff out of bounds: %x < 0 || %x > %x", d, d, 1<<maxSecondaryDiffBits) return 0, fmt.Errorf("makeCE: secondary weight diff out of bounds: %x < 0 || %x > %x", d, d, 1<<maxSecondaryDiffBits)
} }
@ -258,21 +258,31 @@ func convertLargeWeights(elems [][]int) (res [][]int, err error) {
// nextWeight computes the first possible collation weights following elems // nextWeight computes the first possible collation weights following elems
// for the given level. // for the given level.
func nextWeight(level collate.Level, elems [][]int) [][]int { func nextWeight(level collate.Level, elems [][]int) [][]int {
nce := make([][]int, len(elems)) if level == collate.Identity {
copy(nce, elems) next := make([][]int, len(elems))
copy(next, elems)
if level != collate.Identity { return next
nce[0] = make([]int, len(elems[0])) }
copy(nce[0], elems[0]) next := [][]int{make([]int, len(elems[0]))}
nce[0][level]++ copy(next[0], elems[0])
next[0][level]++
if level < collate.Secondary { if level < collate.Secondary {
nce[0][collate.Secondary] = defaultSecondary next[0][collate.Secondary] = defaultSecondary
} }
if level < collate.Tertiary { if level < collate.Tertiary {
nce[0][collate.Tertiary] = defaultTertiary next[0][collate.Tertiary] = defaultTertiary
}
// Filter entries that cannot influence ordering.
for _, ce := range elems[1:] {
skip := true
for i := collate.Primary; i < level; i++ {
skip = skip && ce[i] == 0
}
if !skip {
next = append(next, ce)
} }
} }
return nce return next
} }
func nextVal(elems [][]int, i int, level collate.Level) (index, value int) { func nextVal(elems [][]int, i int, level collate.Level) (index, value int) {

View File

@ -34,10 +34,10 @@ func decompCE(in []int) (ce uint32, err error) {
var ceTests = []ceTest{ var ceTests = []ceTest{
{normalCE, []int{0, 0, 0}, 0x80000000}, {normalCE, []int{0, 0, 0}, 0x80000000},
{normalCE, []int{0, 0x28, 3}, 0x80002803}, {normalCE, []int{0, 0x28, 3}, 0x80002803},
{normalCE, []int{100, defaultSecondary, 3}, 0x0000C803}, {normalCE, []int{100, defaultSecondary, 3}, 0x0000C883},
// non-ignorable primary with non-default secondary // non-ignorable primary with non-default secondary
{normalCE, []int{100, 0x28, defaultTertiary}, 0x40006428}, {normalCE, []int{100, 0x28, defaultTertiary}, 0x4000C828},
{normalCE, []int{100, defaultSecondary + 8, 3}, 0x0000C903}, {normalCE, []int{100, defaultSecondary + 8, 3}, 0x0000C983},
{normalCE, []int{100, 0, 3}, 0xFFFF}, // non-ignorable primary with non-supported secondary {normalCE, []int{100, 0, 3}, 0xFFFF}, // non-ignorable primary with non-supported secondary
{normalCE, []int{100, 1, 3}, 0xFFFF}, {normalCE, []int{100, 1, 3}, 0xFFFF},
{normalCE, []int{1 << maxPrimaryBits, defaultSecondary, 0}, 0xFFFF}, {normalCE, []int{1 << maxPrimaryBits, defaultSecondary, 0}, 0xFFFF},
@ -114,18 +114,24 @@ var nextWeightTests = []weightsTest{
}, },
} }
var extra = []int{200, 32, 8, 0} var extra = [][]int{{200, 32, 8, 0}, {0, 32, 8, 0}, {0, 0, 8, 0}, {0, 0, 0, 0}}
func TestNextWeight(t *testing.T) { func TestNextWeight(t *testing.T) {
for i, tt := range nextWeightTests { for i, tt := range nextWeightTests {
test := func(tt weightsTest, a, gold [][]int) { test := func(l collate.Level, tt weightsTest, a, gold [][]int) {
res := nextWeight(tt.level, a) res := nextWeight(tt.level, a)
if !equalCEArrays(gold, res) { if !equalCEArrays(gold, res) {
t.Errorf("%d: expected weights %d; found %d", i, tt.b, res) t.Errorf("%d:%d: expected weights %d; found %d", i, l, gold, res)
}
}
test(-1, tt, tt.a, tt.b)
for l := collate.Primary; l <= collate.Tertiary; l++ {
if tt.level <= l {
test(l, tt, append(tt.a, extra[l]), tt.b)
} else {
test(l, tt, append(tt.a, extra[l]), append(tt.b, extra[l]))
} }
} }
test(tt, tt.a, tt.b)
test(tt, append(tt.a, extra), append(tt.b, extra))
} }
} }
@ -137,7 +143,7 @@ var compareTests = []weightsTest{
0, 0,
}, },
{ {
[][]int{{100, 20, 5, 0}, extra}, [][]int{{100, 20, 5, 0}, extra[0]},
[][]int{{100, 20, 5, 1}}, [][]int{{100, 20, 5, 1}},
collate.Primary, collate.Primary,
1, 1,
@ -192,6 +198,6 @@ func TestCompareWeights(t *testing.T) {
} }
} }
test(tt, tt.a, tt.b) test(tt, tt.a, tt.b)
test(tt, append(tt.a, extra), append(tt.b, extra)) test(tt, append(tt.a, extra[0]), append(tt.b, extra[0]))
} }
} }

View File

@ -26,16 +26,21 @@ const (
// Collation Element Table. // Collation Element Table.
// See http://www.unicode.org/Public/UCA/6.0.0/allkeys.txt. // See http://www.unicode.org/Public/UCA/6.0.0/allkeys.txt.
type entry struct { type entry struct {
runes []rune
elems [][]int // the collation elements for runes
str string // same as string(runes) str string // same as string(runes)
runes []rune
elems [][]int // the collation elements
extend string // weights of extend to be appended to elems
before bool // weights relative to next instead of previous.
lock bool // entry is used in extension and can no longer be moved.
// prev, next, and level are used to keep track of tailorings. // prev, next, and level are used to keep track of tailorings.
prev, next *entry prev, next *entry
level collate.Level // next differs at this level level collate.Level // next differs at this level
skipRemove bool // do not unlink when removed
decompose bool // can use NFKD decomposition to generate elems decompose bool // can use NFKD decomposition to generate elems
exclude bool // do not include in table exclude bool // do not include in table
implicit bool // derived, is not included in the list
logical logicalAnchor logical logicalAnchor
expansionIndex int // used to store index into expansion table expansionIndex int // used to store index into expansion table
@ -44,8 +49,8 @@ type entry struct {
} }
func (e *entry) String() string { func (e *entry) String() string {
return fmt.Sprintf("%X -> %X (ch:%x; ci:%d, ei:%d)", return fmt.Sprintf("%X (%q) -> %X (ch:%x; ci:%d, ei:%d)",
e.runes, e.elems, e.contractionHandle, e.contractionIndex, e.expansionIndex) e.runes, e.str, e.elems, e.contractionHandle, e.contractionIndex, e.expansionIndex)
} }
func (e *entry) skip() bool { func (e *entry) skip() bool {
@ -71,7 +76,7 @@ func (e *entry) contractionStarter() bool {
// examples of entries that will not be indexed. // examples of entries that will not be indexed.
func (e *entry) nextIndexed() (*entry, collate.Level) { func (e *entry) nextIndexed() (*entry, collate.Level) {
level := e.level level := e.level
for e = e.next; e != nil && e.exclude; e = e.next { for e = e.next; e != nil && (e.exclude || len(e.elems) == 0); e = e.next {
if e.level < level { if e.level < level {
level = e.level level = e.level
} }
@ -87,16 +92,20 @@ func (e *entry) remove() {
if e.logical != noAnchor { if e.logical != noAnchor {
log.Fatalf("may not remove anchor %q", e.str) log.Fatalf("may not remove anchor %q", e.str)
} }
// TODO: need to set e.prev.level to e.level if e.level is smaller?
e.elems = nil
if !e.skipRemove {
if e.prev != nil { if e.prev != nil {
e.prev.next = e.next e.prev.next = e.next
} }
if e.next != nil { if e.next != nil {
e.next.prev = e.prev e.next.prev = e.prev
} }
e.elems = nil }
e.skipRemove = false
} }
// insertAfter inserts t after e. // insertAfter inserts n after e.
func (e *entry) insertAfter(n *entry) { func (e *entry) insertAfter(n *entry) {
if e == n { if e == n {
panic("e == anchor") panic("e == anchor")
@ -109,10 +118,31 @@ func (e *entry) insertAfter(n *entry) {
n.next = e.next n.next = e.next
n.prev = e n.prev = e
if e.next != nil {
e.next.prev = n e.next.prev = n
}
e.next = n e.next = n
} }
// insertBefore inserts n before e.
func (e *entry) insertBefore(n *entry) {
if e == n {
panic("e == anchor")
}
if e == nil {
panic("unexpected nil anchor")
}
n.remove()
n.decompose = false // redo decomposition test
n.prev = e.prev
n.next = e
if e.prev != nil {
e.prev.next = n
}
e.prev = n
}
func (e *entry) encodeBase() (ce uint32, err error) { func (e *entry) encodeBase() (ce uint32, err error) {
switch { switch {
case e.expansion(): case e.expansion():
@ -178,6 +208,7 @@ func (s sortedEntries) Less(i, j int) bool {
} }
type ordering struct { type ordering struct {
id string
entryMap map[string]*entry entryMap map[string]*entry
ordered []*entry ordered []*entry
handle *trieHandle handle *trieHandle
@ -187,7 +218,14 @@ type ordering struct {
// Note that insert simply appends e to ordered. To reattain a sorted // Note that insert simply appends e to ordered. To reattain a sorted
// order, o.sort() should be called. // order, o.sort() should be called.
func (o *ordering) insert(e *entry) { func (o *ordering) insert(e *entry) {
if e.logical == noAnchor {
o.entryMap[e.str] = e o.entryMap[e.str] = e
} else {
// Use key format as used in UCA rules.
o.entryMap[fmt.Sprintf("[%s]", e.str)] = e
// Also add index entry for XML format.
o.entryMap[fmt.Sprintf("<%s/>", strings.Replace(e.str, " ", "_", -1))] = e
}
o.ordered = append(o.ordered, e) o.ordered = append(o.ordered, e)
} }
@ -236,13 +274,13 @@ func makeRootOrdering() ordering {
entryMap: make(map[string]*entry), entryMap: make(map[string]*entry),
} }
insert := func(typ logicalAnchor, s string, ce []int) { insert := func(typ logicalAnchor, s string, ce []int) {
// Use key format as used in UCA rules. e := &entry{
e := o.newEntry(fmt.Sprintf("[%s]", s), [][]int{ce}) elems: [][]int{ce},
// Also add index entry for XML format. str: s,
o.entryMap[fmt.Sprintf("<%s/>", strings.Replace(s, " ", "_", -1))] = e exclude: true,
e.runes = nil logical: typ,
e.exclude = true }
e.logical = typ o.insert(e)
} }
insert(firstAnchor, "first tertiary ignorable", []int{0, 0, 0, 0}) insert(firstAnchor, "first tertiary ignorable", []int{0, 0, 0, 0})
insert(lastAnchor, "last tertiary ignorable", []int{0, 0, 0, max}) insert(lastAnchor, "last tertiary ignorable", []int{0, 0, 0, max})
@ -252,6 +290,29 @@ func makeRootOrdering() ordering {
return o return o
} }
// patchForInsert eleminates entries from the list with more than one collation element.
// The next and prev fields of the eliminated entries still point to appropriate
// values in the newly created list.
// It requires that sort has been called.
func (o *ordering) patchForInsert() {
for i := 0; i < len(o.ordered)-1; {
e := o.ordered[i]
lev := e.level
n := e.next
for ; n != nil && len(n.elems) > 1; n = n.next {
if n.level < lev {
lev = n.level
}
n.skipRemove = true
}
for ; o.ordered[i] != n; i++ {
o.ordered[i].level = lev
o.ordered[i].next = n
o.ordered[i+1].prev = e
}
}
}
// clone copies all ordering of es into a new ordering value. // clone copies all ordering of es into a new ordering value.
func (o *ordering) clone() *ordering { func (o *ordering) clone() *ordering {
o.sort() o.sort()
@ -270,6 +331,7 @@ func (o *ordering) clone() *ordering {
oo.insert(ne) oo.insert(ne)
} }
oo.sort() // link all ordering. oo.sort() // link all ordering.
oo.patchForInsert()
return &oo return &oo
} }

View File

@ -128,6 +128,9 @@ func TestInsertAfter(t *testing.T) {
last.insertAfter(es[i]) last.insertAfter(es[i])
last = es[i] last = es[i]
} }
for _, e := range es {
e.elems = es[0].elems
}
e := es[0] e := es[0]
for _, i := range perm { for _, i := range perm {
e, _ = e.nextIndexed() e, _ = e.nextIndexed()
@ -139,6 +142,34 @@ func TestInsertAfter(t *testing.T) {
} }
} }
func TestInsertBefore(t *testing.T) {
const n = 5
orig := makeList(n)
perm := make([]int, n)
for i := range perm {
perm[i] = i + 1
}
for ok := true; ok; ok = nextPerm(perm) {
es := makeList(n)
last := es[len(es)-1]
for _, i := range perm {
last.insertBefore(es[i])
last = es[i]
}
for _, e := range es {
e.elems = es[0].elems
}
e := es[0]
for i := n - 1; i >= 0; i-- {
e, _ = e.nextIndexed()
if e.runes[0] != rune(perm[i]) {
t.Errorf("%d:%d: expected entry %X; found %X", perm, i, orig[i].runes, e.runes)
break
}
}
}
}
type entryLessTest struct { type entryLessTest struct {
a, b *entry a, b *entry
res bool res bool

View File

@ -38,7 +38,8 @@ type trieNode struct {
index []*trieNode index []*trieNode
value []uint32 value []uint32
b byte b byte
ref uint16 refValue uint16
refIndex uint16
} }
func newNode() *trieNode { func newNode() *trieNode {
@ -108,18 +109,20 @@ func (b *trieBuilder) computeOffsets(n *trieNode) *trieNode {
hasher := fnv.New32() hasher := fnv.New32()
if n.index != nil { if n.index != nil {
for i, nn := range n.index { for i, nn := range n.index {
v := uint16(0) var vi, vv uint16
if nn != nil { if nn != nil {
nn = b.computeOffsets(nn) nn = b.computeOffsets(nn)
n.index[i] = nn n.index[i] = nn
v = nn.ref vi = nn.refIndex
vv = nn.refValue
} }
hasher.Write([]byte{byte(v >> 8), byte(v)}) hasher.Write([]byte{byte(vi >> 8), byte(vi)})
hasher.Write([]byte{byte(vv >> 8), byte(vv)})
} }
h := hasher.Sum32() h := hasher.Sum32()
nn, ok := b.lookupBlockIdx[h] nn, ok := b.lookupBlockIdx[h]
if !ok { if !ok {
n.ref = uint16(len(b.lookupBlocks)) - blockOffset n.refIndex = uint16(len(b.lookupBlocks)) - blockOffset
b.lookupBlocks = append(b.lookupBlocks, n) b.lookupBlocks = append(b.lookupBlocks, n)
b.lookupBlockIdx[h] = n b.lookupBlockIdx[h] = n
} else { } else {
@ -132,7 +135,8 @@ func (b *trieBuilder) computeOffsets(n *trieNode) *trieNode {
h := hasher.Sum32() h := hasher.Sum32()
nn, ok := b.valueBlockIdx[h] nn, ok := b.valueBlockIdx[h]
if !ok { if !ok {
n.ref = uint16(len(b.valueBlocks)) - blockOffset n.refValue = uint16(len(b.valueBlocks)) - blockOffset
n.refIndex = n.refValue
b.valueBlocks = append(b.valueBlocks, n) b.valueBlocks = append(b.valueBlocks, n)
b.valueBlockIdx[h] = n b.valueBlockIdx[h] = n
} else { } else {
@ -150,7 +154,8 @@ func (b *trieBuilder) addStartValueBlock(n *trieNode) uint16 {
h := hasher.Sum32() h := hasher.Sum32()
nn, ok := b.valueBlockIdx[h] nn, ok := b.valueBlockIdx[h]
if !ok { if !ok {
n.ref = uint16(len(b.valueBlocks)) n.refValue = uint16(len(b.valueBlocks))
n.refIndex = n.refValue
b.valueBlocks = append(b.valueBlocks, n) b.valueBlocks = append(b.valueBlocks, n)
// Add a dummy block to accommodate the double block size. // Add a dummy block to accommodate the double block size.
b.valueBlocks = append(b.valueBlocks, nil) b.valueBlocks = append(b.valueBlocks, nil)
@ -158,7 +163,7 @@ func (b *trieBuilder) addStartValueBlock(n *trieNode) uint16 {
} else { } else {
n = nn n = nn
} }
return n.ref return n.refValue
} }
func genValueBlock(t *trie, n *trieNode) { func genValueBlock(t *trie, n *trieNode) {
@ -173,7 +178,11 @@ func genLookupBlock(t *trie, n *trieNode) {
for _, nn := range n.index { for _, nn := range n.index {
v := uint16(0) v := uint16(0)
if nn != nil { if nn != nil {
v = nn.ref if n.index != nil {
v = nn.refIndex
} else {
v = nn.refValue
}
} }
t.index = append(t.index, v) t.index = append(t.index, v)
} }
@ -192,7 +201,7 @@ func (b *trieBuilder) addTrie(n *trieNode) *trieHandle {
} }
n = b.computeOffsets(n) n = b.computeOffsets(n)
// Offset by one extra block as the first byte starts at 0xC0 instead of 0x80. // Offset by one extra block as the first byte starts at 0xC0 instead of 0x80.
h.lookupStart = n.ref - 1 h.lookupStart = n.refIndex - 1
return h return h
} }

View File

@ -8,16 +8,6 @@ import (
"unicode" "unicode"
) )
// weights holds the decoded weights per collation level.
type weights struct {
primary uint32
secondary uint16
tertiary uint8
// TODO: compute quaternary on the fly or compress this value into 8 bits
// such that weights fit within 64bit.
quaternary uint32
}
const ( const (
defaultSecondary = 0x20 defaultSecondary = 0x20
defaultTertiary = 0x2 defaultTertiary = 0x2
@ -69,7 +59,7 @@ func (ce colElem) ctype() ceType {
// For normal collation elements, we assume that a collation element either has // For normal collation elements, we assume that a collation element either has
// a primary or non-default secondary value, not both. // a primary or non-default secondary value, not both.
// Collation elements with a primary value are of the form // Collation elements with a primary value are of the form
// 010ppppp pppppppp pppppppp ssssssss // 01pppppp pppppppp ppppppp0 ssssssss
// - p* is primary collation value // - p* is primary collation value
// - s* is the secondary collation value // - s* is the secondary collation value
// or // or
@ -82,25 +72,87 @@ func (ce colElem) ctype() ceType {
// - 16 BMP implicit -> weight // - 16 BMP implicit -> weight
// - 8 bit s // - 8 bit s
// - default tertiary // - default tertiary
func splitCE(ce colElem) weights { // 11qqqqqq qqqqqqqq qqqqqqq0 00000000
const primaryMask = 0x40000000 // - q* quaternary value
const secondaryMask = 0x80000000 const (
w := weights{} ceTypeMask = 0xC0000000
if ce&primaryMask != 0 { ceType1 = 0x40000000
w.tertiary = defaultTertiary ceType2 = 0x00000000
w.secondary = uint16(uint8(ce)) ceType3 = 0x80000000
w.primary = uint32((ce >> 8) & 0x1FFFFF) ceTypeQ = 0xC0000000
} else if ce&secondaryMask == 0 { ceIgnore = ceType3
w.tertiary = uint8(ce & 0x1F) firstNonPrimary = 0x80000000
ce >>= 5 secondaryMask = 0x80000000
w.secondary = defaultSecondary + uint16(ce&0xF) hasTertiaryMask = 0x40000000
ce >>= 4 primaryValueMask = 0x3FFFFE00
w.primary = uint32(ce) primaryShift = 9
} else { compactSecondaryShift = 5
w.tertiary = uint8(ce) minCompactSecondary = defaultSecondary - 4
w.secondary = uint16(ce >> 8) )
func makeImplicitCE(primary int) colElem {
return ceType1 | colElem(primary<<primaryShift) | defaultSecondary
} }
return w
func makeQuaternary(primary int) colElem {
return ceTypeQ | colElem(primary<<primaryShift)
}
func (ce colElem) primary() int {
if ce >= firstNonPrimary {
return 0
}
return int(ce&primaryValueMask) >> primaryShift
}
func (ce colElem) secondary() int {
switch ce & ceTypeMask {
case ceType1:
return int(uint8(ce))
case ceType2:
return minCompactSecondary + int((ce>>compactSecondaryShift)&0xF)
case ceType3:
return int(uint16(ce >> 8))
case ceTypeQ:
return 0
}
panic("should not reach here")
}
func (ce colElem) tertiary() uint8 {
if ce&hasTertiaryMask == 0 {
if ce&ceType3 == 0 {
return uint8(ce & 0x1F)
}
return uint8(ce)
} else if ce&ceTypeMask == ceType1 {
return defaultTertiary
}
// ce is a quaternary value.
return 0
}
func (ce colElem) updateTertiary(t uint8) colElem {
if ce&ceTypeMask == ceType1 {
nce := ce & primaryValueMask
nce |= colElem(uint8(ce)-minCompactSecondary) << compactSecondaryShift
ce = nce
} else {
ce &= ^colElem(maxTertiary)
}
return ce | colElem(t)
}
// quaternary returns the quaternary value if explicitly specified,
// 0 if ce == ceIgnore, or maxQuaternary otherwise.
// Quaternary values are used only for shifted variants.
func (ce colElem) quaternary() int {
if ce&ceTypeMask == ceTypeQ {
return int(ce&primaryValueMask) >> primaryShift
} else if ce == ceIgnore {
return 0
}
return maxQuaternary
} }
// For contractions, collation elements are of the form // For contractions, collation elements are of the form

View File

@ -29,10 +29,10 @@ func makeCE(weights []int) colElem {
var ce colElem var ce colElem
if weights[0] != 0 { if weights[0] != 0 {
if weights[2] == defaultTertiary { if weights[2] == defaultTertiary {
ce = colElem(weights[0]<<maxSecondaryCompactBits + weights[1]) ce = colElem(weights[0]<<(maxSecondaryCompactBits+1) + weights[1])
ce |= isPrimary ce |= isPrimary
} else { } else {
d := weights[1] - defaultSecondary d := weights[1] - defaultSecondary + 4
ce = colElem(weights[0]<<maxSecondaryDiffBits + d) ce = colElem(weights[0]<<maxSecondaryDiffBits + d)
ce = ce<<maxTertiaryCompactBits + colElem(weights[2]) ce = ce<<maxTertiaryCompactBits + colElem(weights[2])
} }
@ -68,10 +68,10 @@ func makeDecompose(t1, t2 int) colElem {
} }
func normalCE(inout []int) (ce colElem, t ceType) { func normalCE(inout []int) (ce colElem, t ceType) {
w := splitCE(makeCE(inout)) w := makeCE(inout)
inout[0] = int(w.primary) inout[0] = w.primary()
inout[1] = int(w.secondary) inout[1] = w.secondary()
inout[2] = int(w.tertiary) inout[2] = int(w.tertiary())
return ce, ceNormal return ce, ceNormal
} }
@ -167,3 +167,20 @@ func TestImplicit(t *testing.T) {
} }
} }
} }
func TestUpdateTertiary(t *testing.T) {
tests := []struct {
in, out colElem
t uint8
}{
{0x4000FE20, 0x0000FE8A, 0x0A},
{0x4000FE21, 0x0000FEAA, 0x0A},
{0x0000FE8B, 0x0000FE83, 0x03},
{0x8000CC02, 0x8000CC1B, 0x1B},
}
for i, tt := range tests {
if out := tt.in.updateTertiary(tt.t); out != tt.out {
t.Errorf("%d: was %X; want %X", i, out, tt.out)
}
}
}

View File

@ -83,9 +83,17 @@ type Collator struct {
f norm.Form f norm.Form
t *table t *table
_iter [2]iter
}
func (c *Collator) iter(i int) *iter {
// TODO: evaluate performance for making the second iterator optional.
return &c._iter[i]
} }
// Locales returns the list of locales for which collating differs from its parent locale. // Locales returns the list of locales for which collating differs from its parent locale.
// The returned value should not be modified.
func Locales() []string { func Locales() []string {
return availableLocales return availableLocales
} }
@ -99,11 +107,18 @@ func New(loc string) *Collator {
t = mainTable.indexedTable(idx) t = mainTable.indexedTable(idx)
} }
} }
return &Collator{ return newCollator(t)
}
func newCollator(t *table) *Collator {
c := &Collator{
Strength: Quaternary, Strength: Quaternary,
f: norm.NFD, f: norm.NFD,
t: t, t: t,
} }
c._iter[0].init(c)
c._iter[1].init(c)
return c
} }
// SetVariableTop sets all runes with primary strength less than the primary // SetVariableTop sets all runes with primary strength less than the primary
@ -112,63 +127,114 @@ func (c *Collator) SetVariableTop(r rune) {
// TODO: implement // TODO: implement
} }
// Buffer holds reusable buffers that can be used during collation. // Buffer holds keys generated by Key and KeyString.
// Reusing a Buffer for the various calls that accept it may avoid
// unnecessary memory allocations.
type Buffer struct { type Buffer struct {
// TODO: try various parameters and techniques, such as using buf [4096]byte
// a chan of buffers for a pool.
ba [4096]byte
wa [512]weights
key []byte key []byte
ce []weights
} }
func (b *Buffer) init() { func (b *Buffer) init() {
if b.ce == nil { if b.key == nil {
b.ce = b.wa[:0] b.key = b.buf[:0]
b.key = b.ba[:0]
} else {
b.ce = b.ce[:0]
} }
} }
// ResetKeys clears the buffer used for generated keys. Calling ResetKeys // Reset clears the buffer from previous results generated by Key and KeyString.
// invalidates keys previously obtained from Key or KeyFromString. func (b *Buffer) Reset() {
func (b *Buffer) ResetKeys() {
b.ce = b.ce[:0]
b.key = b.key[:0] b.key = b.key[:0]
} }
// Compare returns an integer comparing the two byte slices. // Compare returns an integer comparing the two byte slices.
// The result will be 0 if a==b, -1 if a < b, and +1 if a > b. // The result will be 0 if a==b, -1 if a < b, and +1 if a > b.
// Compare calls ResetKeys, thereby invalidating keys func (c *Collator) Compare(a, b []byte) int {
// previously generated using Key or KeyFromString using buf. // TODO: skip identical prefixes once we have a fast way to detect if a rune is
func (c *Collator) Compare(buf *Buffer, a, b []byte) int { // part of a contraction. This would lead to roughly a 10% speedup for the colcmp regtest.
// TODO: for now we simply compute keys and compare. Once we c.iter(0).setInput(c, a)
// have good benchmarks, move to an implementation that works c.iter(1).setInput(c, b)
// incrementally for the majority of cases. if res := c.compare(); res != 0 {
// - Benchmark with long strings that only vary in modifiers. return res
buf.ResetKeys() }
ka := c.Key(buf, a) if Identity == c.Strength {
kb := c.Key(buf, b) return bytes.Compare(a, b)
defer buf.ResetKeys() }
return bytes.Compare(ka, kb) return 0
} }
// CompareString returns an integer comparing the two strings. // CompareString returns an integer comparing the two strings.
// The result will be 0 if a==b, -1 if a < b, and +1 if a > b. // The result will be 0 if a==b, -1 if a < b, and +1 if a > b.
// CompareString calls ResetKeys, thereby invalidating keys func (c *Collator) CompareString(a, b string) int {
// previously generated using Key or KeyFromString using buf. // TODO: skip identical prefixes once we have a fast way to detect if a rune is
func (c *Collator) CompareString(buf *Buffer, a, b string) int { // part of a contraction. This would lead to roughly a 10% speedup for the colcmp regtest.
buf.ResetKeys() c.iter(0).setInputString(c, a)
ka := c.KeyFromString(buf, a) c.iter(1).setInputString(c, b)
kb := c.KeyFromString(buf, b) if res := c.compare(); res != 0 {
defer buf.ResetKeys() return res
return bytes.Compare(ka, kb) }
if Identity == c.Strength {
if a < b {
return -1
} else if a > b {
return 1
}
}
return 0
} }
func (c *Collator) Prefix(buf *Buffer, s, prefix []byte) int { func compareLevel(f func(i *iter) int, a, b *iter) int {
a.pce = 0
b.pce = 0
for {
va := f(a)
vb := f(b)
if va != vb {
if va < vb {
return -1
}
return 1
} else if va == 0 {
break
}
}
return 0
}
func (c *Collator) compare() int {
ia, ib := c.iter(0), c.iter(1)
// Process primary level
if c.Alternate != AltShifted {
// TODO: implement script reordering
// TODO: special hiragana handling
if res := compareLevel((*iter).nextPrimary, ia, ib); res != 0 {
return res
}
} else {
// TODO: handle shifted
}
if Secondary <= c.Strength {
f := (*iter).nextSecondary
if c.Backwards {
f = (*iter).prevSecondary
}
if res := compareLevel(f, ia, ib); res != 0 {
return res
}
}
// TODO: special case handling (Danish?)
if Tertiary <= c.Strength || c.CaseLevel {
if res := compareLevel((*iter).nextTertiary, ia, ib); res != 0 {
return res
}
// TODO: Not needed for the default value of AltNonIgnorable?
if Quaternary <= c.Strength {
if res := compareLevel((*iter).nextQuaternary, ia, ib); res != 0 {
return res
}
}
}
return 0
}
func (c *Collator) Prefix(s, prefix []byte) int {
// iterate over s, track bytes consumed. // iterate over s, track bytes consumed.
return 0 return 0
} }
@ -176,12 +242,11 @@ func (c *Collator) Prefix(buf *Buffer, s, prefix []byte) int {
// Key returns the collation key for str. // Key returns the collation key for str.
// Passing the buffer buf may avoid memory allocations. // Passing the buffer buf may avoid memory allocations.
// The returned slice will point to an allocation in Buffer and will remain // The returned slice will point to an allocation in Buffer and will remain
// valid until the next call to buf.ResetKeys(). // valid until the next call to buf.Reset().
func (c *Collator) Key(buf *Buffer, str []byte) []byte { func (c *Collator) Key(buf *Buffer, str []byte) []byte {
// See http://www.unicode.org/reports/tr10/#Main_Algorithm for more details. // See http://www.unicode.org/reports/tr10/#Main_Algorithm for more details.
buf.init() buf.init()
c.getColElems(buf, str) return c.key(buf, c.getColElems(str))
return c.key(buf, buf.ce)
} }
// KeyFromString returns the collation key for str. // KeyFromString returns the collation key for str.
@ -191,46 +256,73 @@ func (c *Collator) Key(buf *Buffer, str []byte) []byte {
func (c *Collator) KeyFromString(buf *Buffer, str string) []byte { func (c *Collator) KeyFromString(buf *Buffer, str string) []byte {
// See http://www.unicode.org/reports/tr10/#Main_Algorithm for more details. // See http://www.unicode.org/reports/tr10/#Main_Algorithm for more details.
buf.init() buf.init()
c.getColElemsString(buf, str) return c.key(buf, c.getColElemsString(str))
return c.key(buf, buf.ce)
} }
func (c *Collator) key(buf *Buffer, w []weights) []byte { func (c *Collator) key(buf *Buffer, w []colElem) []byte {
processWeights(c.Alternate, c.t.variableTop, w) processWeights(c.Alternate, c.t.variableTop, w)
kn := len(buf.key) kn := len(buf.key)
c.keyFromElems(buf, w) c.keyFromElems(buf, w)
return buf.key[kn:] return buf.key[kn:]
} }
func (c *Collator) getColElems(buf *Buffer, str []byte) { func (c *Collator) getColElems(str []byte) []colElem {
i := c.iter() i := c.iter(0)
i.src.SetInput(c.f, str) i.setInput(c, str)
for !i.done() { for !i.done() {
buf.ce = i.next(buf.ce) i.next()
} }
return i.ce
} }
func (c *Collator) getColElemsString(buf *Buffer, str string) { func (c *Collator) getColElemsString(str string) []colElem {
i := c.iter() i := c.iter(0)
i.src.SetInputString(c.f, str) i.setInputString(c, str)
for !i.done() { for !i.done() {
buf.ce = i.next(buf.ce) i.next()
} }
return i.ce
} }
type iter struct { type iter struct {
src norm.Iter src norm.Iter
ba [1024]byte norm [1024]byte
buf []byte buf []byte
t *table
p int p int
minBufSize int minBufSize int
wa [512]colElem
ce []colElem
pce int
t *table
_done, eof bool _done, eof bool
} }
func (c *Collator) iter() iter { func (i *iter) init(c *Collator) {
i := iter{t: c.t, minBufSize: c.t.maxContractLen} i.t = c.t
i.buf = i.ba[:0] i.minBufSize = c.t.maxContractLen
i.ce = i.wa[:0]
i.buf = i.norm[:0]
}
func (i *iter) reset() {
i.ce = i.ce[:0]
i.buf = i.buf[:0]
i.p = 0
i.eof = i.src.Done()
i._done = i.eof
}
func (i *iter) setInput(c *Collator, s []byte) *iter {
i.src.SetInput(c.f, s)
i.reset()
return i
}
func (i *iter) setInputString(c *Collator, s string) *iter {
i.src.SetInputString(c.f, s)
i.reset()
return i return i
} }
@ -238,7 +330,7 @@ func (i *iter) done() bool {
return i._done return i._done
} }
func (i *iter) next(ce []weights) []weights { func (i *iter) next() {
if !i.eof && len(i.buf)-i.p < i.minBufSize { if !i.eof && len(i.buf)-i.p < i.minBufSize {
// replenish buffer // replenish buffer
n := copy(i.buf, i.buf[i.p:]) n := copy(i.buf, i.buf[i.p:])
@ -249,14 +341,70 @@ func (i *iter) next(ce []weights) []weights {
} }
if i.p == len(i.buf) { if i.p == len(i.buf) {
i._done = true i._done = true
return ce return
} }
ce, sz := i.t.appendNext(ce, i.buf[i.p:]) sz := 0
i.ce, sz = i.t.appendNext(i.ce, i.buf[i.p:])
i.p += sz i.p += sz
return ce
} }
func appendPrimary(key []byte, p uint32) []byte { func (i *iter) nextPrimary() int {
for {
for ; i.pce < len(i.ce); i.pce++ {
if v := i.ce[i.pce].primary(); v != 0 {
i.pce++
return v
}
}
if i.done() {
return 0
}
i.next()
}
panic("should not reach here")
}
func (i *iter) nextSecondary() int {
for ; i.pce < len(i.ce); i.pce++ {
if v := i.ce[i.pce].secondary(); v != 0 {
i.pce++
return v
}
}
return 0
}
func (i *iter) prevSecondary() int {
for ; i.pce < len(i.ce); i.pce++ {
if v := i.ce[len(i.ce)-i.pce-1].secondary(); v != 0 {
i.pce++
return v
}
}
return 0
}
func (i *iter) nextTertiary() int {
for ; i.pce < len(i.ce); i.pce++ {
if v := i.ce[i.pce].tertiary(); v != 0 {
i.pce++
return int(v)
}
}
return 0
}
func (i *iter) nextQuaternary() int {
for ; i.pce < len(i.ce); i.pce++ {
if v := i.ce[i.pce].quaternary(); v != 0 {
i.pce++
return v
}
}
return 0
}
func appendPrimary(key []byte, p int) []byte {
// Convert to variable length encoding; supports up to 23 bits. // Convert to variable length encoding; supports up to 23 bits.
if p <= 0x7FFF { if p <= 0x7FFF {
key = append(key, uint8(p>>8), uint8(p)) key = append(key, uint8(p>>8), uint8(p))
@ -268,9 +416,9 @@ func appendPrimary(key []byte, p uint32) []byte {
// keyFromElems converts the weights ws to a compact sequence of bytes. // keyFromElems converts the weights ws to a compact sequence of bytes.
// The result will be appended to the byte buffer in buf. // The result will be appended to the byte buffer in buf.
func (c *Collator) keyFromElems(buf *Buffer, ws []weights) { func (c *Collator) keyFromElems(buf *Buffer, ws []colElem) {
for _, v := range ws { for _, v := range ws {
if w := v.primary; w > 0 { if w := v.primary(); w > 0 {
buf.key = appendPrimary(buf.key, w) buf.key = appendPrimary(buf.key, w)
} }
} }
@ -279,13 +427,13 @@ func (c *Collator) keyFromElems(buf *Buffer, ws []weights) {
// TODO: we can use one 0 if we can guarantee that all non-zero weights are > 0xFF. // TODO: we can use one 0 if we can guarantee that all non-zero weights are > 0xFF.
if !c.Backwards { if !c.Backwards {
for _, v := range ws { for _, v := range ws {
if w := v.secondary; w > 0 { if w := v.secondary(); w > 0 {
buf.key = append(buf.key, uint8(w>>8), uint8(w)) buf.key = append(buf.key, uint8(w>>8), uint8(w))
} }
} }
} else { } else {
for i := len(ws) - 1; i >= 0; i-- { for i := len(ws) - 1; i >= 0; i-- {
if w := ws[i].secondary; w > 0 { if w := ws[i].secondary(); w > 0 {
buf.key = append(buf.key, uint8(w>>8), uint8(w)) buf.key = append(buf.key, uint8(w>>8), uint8(w))
} }
} }
@ -296,20 +444,20 @@ func (c *Collator) keyFromElems(buf *Buffer, ws []weights) {
if Tertiary <= c.Strength || c.CaseLevel { if Tertiary <= c.Strength || c.CaseLevel {
buf.key = append(buf.key, 0, 0) buf.key = append(buf.key, 0, 0)
for _, v := range ws { for _, v := range ws {
if w := v.tertiary; w > 0 { if w := v.tertiary(); w > 0 {
buf.key = append(buf.key, w) buf.key = append(buf.key, uint8(w))
} }
} }
// Derive the quaternary weights from the options and other levels. // Derive the quaternary weights from the options and other levels.
// Note that we represent maxQuaternary as 0xFF. The first byte of the // Note that we represent maxQuaternary as 0xFF. The first byte of the
// representation of a a primary weight is always smaller than 0xFF, // representation of a a primary weight is always smaller than 0xFF,
// so using this single byte value will compare correctly. // so using this single byte value will compare correctly.
if Quaternary <= c.Strength { if Quaternary <= c.Strength && c.Alternate >= AltShifted {
if c.Alternate == AltShiftTrimmed { if c.Alternate == AltShiftTrimmed {
lastNonFFFF := len(buf.key) lastNonFFFF := len(buf.key)
buf.key = append(buf.key, 0) buf.key = append(buf.key, 0)
for _, v := range ws { for _, v := range ws {
if w := v.quaternary; w == maxQuaternary { if w := v.quaternary(); w == maxQuaternary {
buf.key = append(buf.key, 0xFF) buf.key = append(buf.key, 0xFF)
} else if w > 0 { } else if w > 0 {
buf.key = appendPrimary(buf.key, w) buf.key = appendPrimary(buf.key, w)
@ -320,7 +468,7 @@ func (c *Collator) keyFromElems(buf *Buffer, ws []weights) {
} else { } else {
buf.key = append(buf.key, 0) buf.key = append(buf.key, 0)
for _, v := range ws { for _, v := range ws {
if w := v.quaternary; w == maxQuaternary { if w := v.quaternary(); w == maxQuaternary {
buf.key = append(buf.key, 0xFF) buf.key = append(buf.key, 0xFF)
} else if w > 0 { } else if w > 0 {
buf.key = appendPrimary(buf.key, w) buf.key = appendPrimary(buf.key, w)
@ -331,29 +479,27 @@ func (c *Collator) keyFromElems(buf *Buffer, ws []weights) {
} }
} }
func processWeights(vw AlternateHandling, top uint32, wa []weights) { func processWeights(vw AlternateHandling, top uint32, wa []colElem) {
ignore := false ignore := false
vtop := int(top)
switch vw { switch vw {
case AltShifted, AltShiftTrimmed: case AltShifted, AltShiftTrimmed:
for i := range wa { for i := range wa {
if p := wa[i].primary; p <= top && p != 0 { if p := wa[i].primary(); p <= vtop && p != 0 {
wa[i] = weights{quaternary: p} wa[i] = makeQuaternary(p)
ignore = true ignore = true
} else if p == 0 { } else if p == 0 {
if ignore { if ignore {
wa[i] = weights{} wa[i] = ceIgnore
} else if wa[i].tertiary != 0 {
wa[i].quaternary = maxQuaternary
} }
} else { } else {
wa[i].quaternary = maxQuaternary
ignore = false ignore = false
} }
} }
case AltBlanked: case AltBlanked:
for i := range wa { for i := range wa {
if p := wa[i].primary; p <= top && (ignore || p != 0) { if p := wa[i].primary(); p <= vtop && (ignore || p != 0) {
wa[i] = weights{} wa[i] = ceIgnore
ignore = true ignore = true
} else { } else {
ignore = false ignore = false

View File

@ -4,8 +4,6 @@
package collate package collate
import "exp/norm"
// Init is used by type Builder in exp/locale/collate/build/ // Init is used by type Builder in exp/locale/collate/build/
// to create Collator instances. It is for internal use only. // to create Collator instances. It is for internal use only.
func Init(data interface{}) *Collator { func Init(data interface{}) *Collator {
@ -24,11 +22,7 @@ func Init(data interface{}) *Collator {
t.contractElem = init.ContractElems() t.contractElem = init.ContractElems()
t.maxContractLen = init.MaxContractLen() t.maxContractLen = init.MaxContractLen()
t.variableTop = init.VariableTop() t.variableTop = init.VariableTop()
return &Collator{ return newCollator(t)
Strength: Quaternary,
f: norm.NFD,
t: t,
}
} }
type tableInitializer interface { type tableInitializer interface {

View File

@ -24,6 +24,8 @@ func W(ce ...int) Weights {
} }
if len(ce) > 3 { if len(ce) > 3 {
w.Quaternary = ce[3] w.Quaternary = ce[3]
} else if w.Tertiary != 0 {
w.Quaternary = maxQuaternary
} }
return w return w
} }
@ -33,25 +35,27 @@ func (w Weights) String() string {
type Table struct { type Table struct {
t *table t *table
w []weights
} }
func GetTable(c *Collator) *Table { func GetTable(c *Collator) *Table {
return &Table{c.t, nil} return &Table{c.t}
} }
func convertToWeights(ws []weights) []Weights { func convertToWeights(ws []colElem) []Weights {
out := make([]Weights, len(ws)) out := make([]Weights, len(ws))
for i, w := range ws { for i, w := range ws {
out[i] = Weights{int(w.primary), int(w.secondary), int(w.tertiary), int(w.quaternary)} out[i] = Weights{int(w.primary()), int(w.secondary()), int(w.tertiary()), int(w.quaternary())}
} }
return out return out
} }
func convertFromWeights(ws []Weights) []weights { func convertFromWeights(ws []Weights) []colElem {
out := make([]weights, len(ws)) out := make([]colElem, len(ws))
for i, w := range ws { for i, w := range ws {
out[i] = weights{uint32(w.Primary), uint16(w.Secondary), uint8(w.Tertiary), uint32(w.Quaternary)} out[i] = makeCE([]int{w.Primary, w.Secondary, w.Tertiary})
if out[i] == ceIgnore && w.Quaternary > 0 {
out[i] = makeQuaternary(w.Quaternary)
}
} }
return out return out
} }
@ -68,10 +72,9 @@ func SetTop(c *Collator, top int) {
c.t.variableTop = uint32(top) c.t.variableTop = uint32(top)
} }
func GetColElems(c *Collator, buf *Buffer, str []byte) []Weights { func GetColElems(c *Collator, str []byte) []Weights {
buf.ResetKeys() ce := c.getColElems(str)
c.getColElems(buf, str) return convertToWeights(ce)
return convertToWeights(buf.ce)
} }
func ProcessWeights(h AlternateHandling, top int, w []Weights) []Weights { func ProcessWeights(h AlternateHandling, top int, w []Weights) []Weights {

View File

@ -38,7 +38,7 @@ var (
`URL of the Default Unicode Collation Element Table (DUCET). This can be a zip `URL of the Default Unicode Collation Element Table (DUCET). This can be a zip
file containing the file allkeys_CLDR.txt or an allkeys.txt file.`) file containing the file allkeys_CLDR.txt or an allkeys.txt file.`)
cldr = flag.String("cldr", cldr = flag.String("cldr",
"http://www.unicode.org/Public/cldr/2.0.1/core.zip", "http://www.unicode.org/Public/cldr/22/core.zip",
"URL of CLDR archive.") "URL of CLDR archive.")
test = flag.Bool("test", false, test = flag.Bool("test", false,
"test existing tables; can be used to compare web data with package data.") "test existing tables; can be used to compare web data with package data.")
@ -180,7 +180,7 @@ func skipAlt(a string) bool {
func failOnError(e error) { func failOnError(e error) {
if e != nil { if e != nil {
log.Fatal(e) log.Panic(e)
} }
} }
@ -607,7 +607,7 @@ func insertTailoring(t *build.Tailoring, r RuleElem, context, extend string) {
if *test { if *test {
testInput.add(str) testInput.add(str)
} }
err := t.Insert(lmap[l[0]], str, extend) err := t.Insert(lmap[l[0]], str, context+extend)
failOnError(err) failOnError(err)
} }
case "pc", "sc", "tc", "ic": case "pc", "sc", "tc", "ic":
@ -617,7 +617,7 @@ func insertTailoring(t *build.Tailoring, r RuleElem, context, extend string) {
if *test { if *test {
testInput.add(str) testInput.add(str)
} }
err := t.Insert(level, str, extend) err := t.Insert(level, str, context+extend)
failOnError(err) failOnError(err)
} }
default: default:
@ -677,7 +677,7 @@ func testCollator(c *collate.Collator) {
if bytes.Compare(k0, k) != 0 { if bytes.Compare(k0, k) != 0 {
failOnError(fmt.Errorf("test:%U: keys differ (%x vs %x)", []rune(str), k0, k)) failOnError(fmt.Errorf("test:%U: keys differ (%x vs %x)", []rune(str), k0, k))
} }
buf.ResetKeys() buf.Reset()
} }
fmt.Println("PASS") fmt.Println("PASS")
} }

View File

@ -236,9 +236,9 @@ func doTest(t Test) {
if strings.Contains(t.name, "NON_IGNOR") { if strings.Contains(t.name, "NON_IGNOR") {
c.Alternate = collate.AltNonIgnorable c.Alternate = collate.AltNonIgnorable
} }
prev := t.str[0] prev := t.str[0]
for i := 1; i < len(t.str); i++ { for i := 1; i < len(t.str); i++ {
b.Reset()
s := t.str[i] s := t.str[i]
ka := c.Key(b, prev) ka := c.Key(b, prev)
kb := c.Key(b, s) kb := c.Key(b, s)
@ -247,10 +247,10 @@ func doTest(t Test) {
prev = s prev = s
continue continue
} }
if r := c.Compare(b, prev, s); r == 1 { if r := c.Compare(prev, s); r == 1 {
fail(t, "%d: Compare(%.4X, %.4X) == %d; want -1 or 0", i, runes(prev), runes(s), r) fail(t, "%d: Compare(%.4X, %.4X) == %d; want -1 or 0", i, runes(prev), runes(s), r)
} }
if r := c.Compare(b, s, prev); r == -1 { if r := c.Compare(s, prev); r == -1 {
fail(t, "%d: Compare(%.4X, %.4X) == %d; want 1 or 0", i, runes(s), runes(prev), r) fail(t, "%d: Compare(%.4X, %.4X) == %d; want 1 or 0", i, runes(s), runes(prev), r)
} }
prev = s prev = s

View File

@ -42,12 +42,16 @@ func (t *table) indexedTable(idx tableIndex) *table {
// sequence of runes, the weights for the interstitial runes are // sequence of runes, the weights for the interstitial runes are
// appended as well. It returns a new slice that includes the appended // appended as well. It returns a new slice that includes the appended
// weights and the number of bytes consumed from s. // weights and the number of bytes consumed from s.
func (t *table) appendNext(w []weights, s []byte) ([]weights, int) { func (t *table) appendNext(w []colElem, s []byte) ([]colElem, int) {
v, sz := t.index.lookup(s) v, sz := t.index.lookup(s)
ce := colElem(v) ce := colElem(v)
tp := ce.ctype() tp := ce.ctype()
if tp == ceNormal { if tp == ceNormal {
w = append(w, getWeights(ce, s)) if ce == 0 {
r, _ := utf8.DecodeRune(s)
ce = makeImplicitCE(implicitPrimary(r))
}
w = append(w, ce)
} else if tp == ceExpansionIndex { } else if tp == ceExpansionIndex {
w = t.appendExpansion(w, ce) w = t.appendExpansion(w, ce)
} else if tp == ceContractionIndex { } else if tp == ceContractionIndex {
@ -62,40 +66,28 @@ func (t *table) appendNext(w []weights, s []byte) ([]weights, int) {
for p := 0; len(nfkd) > 0; nfkd = nfkd[p:] { for p := 0; len(nfkd) > 0; nfkd = nfkd[p:] {
w, p = t.appendNext(w, nfkd) w, p = t.appendNext(w, nfkd)
} }
w[i].tertiary = t1 w[i] = w[i].updateTertiary(t1)
if i++; i < len(w) { if i++; i < len(w) {
w[i].tertiary = t2 w[i] = w[i].updateTertiary(t2)
for i++; i < len(w); i++ { for i++; i < len(w); i++ {
w[i].tertiary = maxTertiary w[i] = w[i].updateTertiary(maxTertiary)
} }
} }
} }
return w, sz return w, sz
} }
func getWeights(ce colElem, s []byte) weights { func (t *table) appendExpansion(w []colElem, ce colElem) []colElem {
if ce == 0 { // implicit
r, _ := utf8.DecodeRune(s)
return weights{
primary: uint32(implicitPrimary(r)),
secondary: defaultSecondary,
tertiary: defaultTertiary,
}
}
return splitCE(ce)
}
func (t *table) appendExpansion(w []weights, ce colElem) []weights {
i := splitExpandIndex(ce) i := splitExpandIndex(ce)
n := int(t.expandElem[i]) n := int(t.expandElem[i])
i++ i++
for _, ce := range t.expandElem[i : i+n] { for _, ce := range t.expandElem[i : i+n] {
w = append(w, splitCE(colElem(ce))) w = append(w, colElem(ce))
} }
return w return w
} }
func (t *table) matchContraction(w []weights, ce colElem, suffix []byte) ([]weights, int) { func (t *table) matchContraction(w []colElem, ce colElem, suffix []byte) ([]colElem, int) {
index, n, offset := splitContractIndex(ce) index, n, offset := splitContractIndex(ce)
scan := t.contractTries.scanner(index, n, suffix) scan := t.contractTries.scanner(index, n, suffix)
@ -138,7 +130,7 @@ func (t *table) matchContraction(w []weights, ce colElem, suffix []byte) ([]weig
i, n := scan.result() i, n := scan.result()
ce = colElem(t.contractElem[i+offset]) ce = colElem(t.contractElem[i+offset])
if ce.ctype() == ceNormal { if ce.ctype() == ceNormal {
w = append(w, splitCE(ce)) w = append(w, ce)
} else { } else {
w = t.appendExpansion(w, ce) w = t.appendExpansion(w, ce)
} }

File diff suppressed because it is too large Load Diff

View File

@ -91,5 +91,5 @@ func (c *goCollator) Key(b Input) []byte {
} }
func (c *goCollator) Compare(a, b Input) int { func (c *goCollator) Compare(a, b Input) int {
return c.c.Compare(&c.buf, a.UTF8, b.UTF8) return c.c.Compare(a.UTF8, b.UTF8)
} }

View File

@ -399,7 +399,7 @@ var cmdRegress = &Command{
} }
const failedKeyCompare = ` const failedKeyCompare = `
%d: incorrect comparison result for input: %s:%d: incorrect comparison result for input:
a: %q (%.4X) a: %q (%.4X)
key: %s key: %s
b: %q (%.4X) b: %q (%.4X)
@ -412,7 +412,7 @@ const failedKeyCompare = `
` `
const failedCompare = ` const failedCompare = `
%d: incorrect comparison result for input: %s:%d: incorrect comparison result for input:
a: %q (%.4X) a: %q (%.4X)
b: %q (%.4X) b: %q (%.4X)
Compare(a, b) = %d; want %d. Compare(a, b) = %d; want %d.
@ -453,12 +453,12 @@ func runRegress(ctxt *Context, args []string) {
count++ count++
a := string(ia.UTF8) a := string(ia.UTF8)
b := string(ib.UTF8) b := string(ib.UTF8)
fmt.Printf(failedKeyCompare, i-1, a, []rune(a), keyStr(ia.key), b, []rune(b), keyStr(ib.key), cmp, goldCmp, keyStr(gold.Key(ia)), keyStr(gold.Key(ib))) fmt.Printf(failedKeyCompare, t.Locale, i-1, a, []rune(a), keyStr(ia.key), b, []rune(b), keyStr(ib.key), cmp, goldCmp, keyStr(gold.Key(ia)), keyStr(gold.Key(ib)))
} else if cmp := t.Col.Compare(ia, ib); cmp != goldCmp { } else if cmp := t.Col.Compare(ia, ib); cmp != goldCmp {
count++ count++
a := string(ia.UTF8) a := string(ia.UTF8)
b := string(ib.UTF8) b := string(ib.UTF8)
fmt.Printf(failedKeyCompare, i-1, a, []rune(a), b, []rune(b), cmp, goldCmp) fmt.Printf(failedCompare, t.Locale, i-1, a, []rune(a), b, []rune(b), cmp, goldCmp)
} }
} }
if count > 0 { if count > 0 {

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,7 @@ func (check *checker) builtin(x *operand, call *ast.CallExpr, bin *builtin, iota
switch id { switch id {
case _Make, _New: case _Make, _New:
// argument must be a type // argument must be a type
typ0 = underlying(check.typ(arg0, false)) typ0 = check.typ(arg0, false)
if typ0 == Typ[Invalid] { if typ0 == Typ[Invalid] {
goto Error goto Error
} }
@ -191,7 +191,7 @@ func (check *checker) builtin(x *operand, call *ast.CallExpr, bin *builtin, iota
case _Make: case _Make:
var min int // minimum number of arguments var min int // minimum number of arguments
switch typ0.(type) { switch underlying(typ0).(type) {
case *Slice: case *Slice:
min = 2 min = 2
case *Map, *Chan: case *Map, *Chan:
@ -204,13 +204,27 @@ func (check *checker) builtin(x *operand, call *ast.CallExpr, bin *builtin, iota
check.errorf(call.Pos(), "%s expects %d or %d arguments; found %d", call, min, min+1, n) check.errorf(call.Pos(), "%s expects %d or %d arguments; found %d", call, min, min+1, n)
goto Error goto Error
} }
var sizes []interface{} // constant integer arguments, if any
for _, arg := range args[1:] { for _, arg := range args[1:] {
check.expr(x, arg, nil, iota) check.expr(x, arg, nil, iota)
if !x.isInteger() { if x.isInteger() {
if x.mode == constant {
if isNegConst(x.val) {
check.invalidArg(x.pos(), "%s must not be negative", x)
// safe to continue
} else {
sizes = append(sizes, x.val) // x.val >= 0
}
}
} else {
check.invalidArg(x.pos(), "%s must be an integer", x) check.invalidArg(x.pos(), "%s must be an integer", x)
// safe to continue // safe to continue
} }
} }
if len(sizes) == 2 && compareConst(sizes[0], sizes[1], token.GTR) {
check.invalidArg(args[1].Pos(), "length and capacity swapped")
// safe to continue
}
x.mode = variable x.mode = variable
x.typ = typ0 x.typ = typ0
@ -287,7 +301,7 @@ func (check *checker) builtin(x *operand, call *ast.CallExpr, bin *builtin, iota
var t operand var t operand
x1 := x x1 := x
for _, arg := range args { for _, arg := range args {
check.exprOrType(x1, arg, nil, iota, true) // permit trace for types, e.g.: new(trace(T)) check.rawExpr(x1, arg, nil, iota, true) // permit trace for types, e.g.: new(trace(T))
check.dump("%s: %s", x1.pos(), x1) check.dump("%s: %s", x1.pos(), x1)
x1 = &t // use incoming x only for first argument x1 = &t // use incoming x only for first argument
} }

View File

@ -9,244 +9,366 @@ package types
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/scanner"
"go/token" "go/token"
"strconv" "sort"
) )
const debug = false // enable for debugging
const trace = false
type checker struct { type checker struct {
fset *token.FileSet fset *token.FileSet
errors scanner.ErrorList pkg *ast.Package
types map[ast.Expr]Type errh func(token.Pos, string)
mapf func(ast.Expr, Type)
// lazily initialized
firsterr error
filenames []string // sorted list of package file names for reproducible iteration order
initexprs map[*ast.ValueSpec][]ast.Expr // "inherited" initialization expressions for constant declarations
functypes []*Signature // stack of function signatures; actively typechecked function on top
pos []token.Pos // stack of expr positions; debugging support, used if trace is set
} }
func (c *checker) errorf(pos token.Pos, format string, args ...interface{}) string { // declare declares an object of the given kind and name (ident) in scope;
msg := fmt.Sprintf(format, args...) // decl is the corresponding declaration in the AST. An error is reported
c.errors.Add(c.fset.Position(pos), msg) // if the object was declared before.
return msg
}
// collectFields collects struct fields tok = token.STRUCT), interface methods
// (tok = token.INTERFACE), and function arguments/results (tok = token.FUNC).
// //
func (c *checker) collectFields(tok token.Token, list *ast.FieldList, cycleOk bool) (fields ObjList, tags []string, isVariadic bool) { // TODO(gri) This is very similar to the declare function in go/parser; it
if list != nil { // is only used to associate methods with their respective receiver base types.
for _, field := range list.List { // In a future version, it might be simpler and cleaner do to all the resolution
ftype := field.Type // in the type-checking phase. It would simplify the parser, AST, and also
if t, ok := ftype.(*ast.Ellipsis); ok { // reduce some amount of code duplication.
ftype = t.Elt //
isVariadic = true func (check *checker) declare(scope *ast.Scope, kind ast.ObjKind, ident *ast.Ident, decl ast.Decl) {
} assert(ident.Obj == nil) // identifier already declared or resolved
typ := c.makeType(ftype, cycleOk) obj := ast.NewObj(kind, ident.Name)
tag := "" obj.Decl = decl
if field.Tag != nil { ident.Obj = obj
assert(field.Tag.Kind == token.STRING) if ident.Name != "_" {
tag, _ = strconv.Unquote(field.Tag.Value) if alt := scope.Insert(obj); alt != nil {
} prevDecl := ""
if len(field.Names) > 0 { if pos := alt.Pos(); pos.IsValid() {
// named fields prevDecl = fmt.Sprintf("\n\tprevious declaration at %s", check.fset.Position(pos))
for _, name := range field.Names {
obj := name.Obj
obj.Type = typ
fields = append(fields, obj)
if tok == token.STRUCT {
tags = append(tags, tag)
}
}
} else {
// anonymous field
switch tok {
case token.STRUCT:
tags = append(tags, tag)
fallthrough
case token.FUNC:
obj := ast.NewObj(ast.Var, "")
obj.Type = typ
fields = append(fields, obj)
case token.INTERFACE:
utyp := Underlying(typ)
if typ, ok := utyp.(*Interface); ok {
// TODO(gri) This is not good enough. Check for double declarations!
fields = append(fields, typ.Methods...)
} else if _, ok := utyp.(*Bad); !ok {
// if utyp is Bad, don't complain (the root cause was reported before)
c.errorf(ftype.Pos(), "interface contains embedded non-interface type")
}
default:
panic("unreachable")
} }
check.errorf(ident.Pos(), fmt.Sprintf("%s redeclared in this block%s", ident.Name, prevDecl))
} }
} }
} }
func (check *checker) valueSpec(pos token.Pos, obj *ast.Object, lhs []*ast.Ident, typ ast.Expr, rhs []ast.Expr, iota int) {
if len(lhs) == 0 {
check.invalidAST(pos, "missing lhs in declaration")
return return
} }
// makeType makes a new type for an AST type specification x or returns // determine type for all of lhs, if any
// the type referred to by a type name x. If cycleOk is set, a type may // (but only set it for the object we typecheck!)
// refer to itself directly or indirectly; otherwise cycles are errors. var t Type
// if typ != nil {
func (c *checker) makeType(x ast.Expr, cycleOk bool) (typ Type) { t = check.typ(typ, false)
if debug {
fmt.Printf("makeType (cycleOk = %v)\n", cycleOk)
ast.Print(c.fset, x)
defer func() {
fmt.Printf("-> %T %v\n\n", typ, typ)
}()
} }
switch t := x.(type) { // len(lhs) > 0
case *ast.BadExpr: if len(lhs) == len(rhs) {
return &Bad{} // check only lhs and rhs corresponding to obj
var l, r ast.Expr
case *ast.Ident: for i, name := range lhs {
// type name if name.Obj == obj {
obj := t.Obj l = lhs[i]
if obj == nil { r = rhs[i]
// unresolved identifier (error has been reported before) break
return &Bad{Msg: fmt.Sprintf("%s is unresolved", t.Name)}
}
if obj.Kind != ast.Typ {
msg := c.errorf(t.Pos(), "%s is not a type", t.Name)
return &Bad{Msg: msg}
}
c.checkObj(obj, cycleOk)
if !cycleOk && obj.Type.(*Name).Underlying == nil {
msg := c.errorf(obj.Pos(), "illegal cycle in declaration of %s", obj.Name)
return &Bad{Msg: msg}
}
return obj.Type.(Type)
case *ast.ParenExpr:
return c.makeType(t.X, cycleOk)
case *ast.SelectorExpr:
// qualified identifier
// TODO (gri) eventually, this code belongs to expression
// type checking - here for the time being
if ident, ok := t.X.(*ast.Ident); ok {
if obj := ident.Obj; obj != nil {
if obj.Kind != ast.Pkg {
msg := c.errorf(ident.Pos(), "%s is not a package", obj.Name)
return &Bad{Msg: msg}
}
// TODO(gri) we have a package name but don't
// have the mapping from package name to package
// scope anymore (created in ast.NewPackage).
return &Bad{} // for now
} }
} }
// TODO(gri) can this really happen (the parser should have excluded this)? assert(l != nil)
msg := c.errorf(t.Pos(), "expected qualified identifier") obj.Type = t
return &Bad{Msg: msg} check.assign1to1(l, r, nil, true, iota)
case *ast.StarExpr:
return &Pointer{Base: c.makeType(t.X, true)}
case *ast.ArrayType:
if t.Len != nil {
// TODO(gri) compute length
return &Array{Elt: c.makeType(t.Elt, cycleOk)}
}
return &Slice{Elt: c.makeType(t.Elt, true)}
case *ast.StructType:
fields, tags, _ := c.collectFields(token.STRUCT, t.Fields, cycleOk)
return &Struct{Fields: fields, Tags: tags}
case *ast.FuncType:
params, _, isVariadic := c.collectFields(token.FUNC, t.Params, true)
results, _, _ := c.collectFields(token.FUNC, t.Results, true)
return &Func{Recv: nil, Params: params, Results: results, IsVariadic: isVariadic}
case *ast.InterfaceType:
methods, _, _ := c.collectFields(token.INTERFACE, t.Methods, cycleOk)
methods.Sort()
return &Interface{Methods: methods}
case *ast.MapType:
return &Map{Key: c.makeType(t.Key, true), Elt: c.makeType(t.Value, true)}
case *ast.ChanType:
return &Chan{Dir: t.Dir, Elt: c.makeType(t.Value, true)}
}
panic(fmt.Sprintf("unreachable (%T)", x))
}
// checkObj type checks an object.
func (c *checker) checkObj(obj *ast.Object, ref bool) {
if obj.Type != nil {
// object has already been type checked
return return
} }
// there must be a type or initialization expressions
if t == nil && len(rhs) == 0 {
check.invalidAST(pos, "missing type or initialization expression")
t = Typ[Invalid]
}
// if we have a type, mark all of lhs
if t != nil {
for _, name := range lhs {
name.Obj.Type = t
}
}
// check initial values, if any
if len(rhs) > 0 {
// TODO(gri) should try to avoid this conversion
lhx := make([]ast.Expr, len(lhs))
for i, e := range lhs {
lhx[i] = e
}
check.assignNtoM(lhx, rhs, true, iota)
}
}
func (check *checker) function(typ *Signature, body *ast.BlockStmt) {
check.functypes = append(check.functypes, typ)
check.stmt(body)
check.functypes = check.functypes[0 : len(check.functypes)-1]
}
// object typechecks an object by assigning it a type; obj.Type must be nil.
// Callers must check obj.Type before calling object; this eliminates a call
// for each identifier that has been typechecked already, a common scenario.
//
func (check *checker) object(obj *ast.Object, cycleOk bool) {
assert(obj.Type == nil)
switch obj.Kind { switch obj.Kind {
case ast.Bad: case ast.Bad, ast.Pkg:
// ignore // nothing to do
case ast.Con: case ast.Con, ast.Var:
// TODO(gri) complete this // The obj.Data field for constants and variables is initialized
// to the respective (hypothetical, for variables) iota value by
// the parser. The object's fields can be in one of the following
// states:
// Type != nil => the constant value is Data
// Type == nil => the object is not typechecked yet, and Data can be:
// Data is int => Data is the value of iota for this declaration
// Data == nil => the object's expression is being evaluated
if obj.Data == nil {
check.errorf(obj.Pos(), "illegal cycle in initialization of %s", obj.Name)
obj.Type = Typ[Invalid]
return
}
spec := obj.Decl.(*ast.ValueSpec)
iota := obj.Data.(int)
obj.Data = nil
// determine initialization expressions
values := spec.Values
if len(values) == 0 && obj.Kind == ast.Con {
values = check.initexprs[spec]
}
check.valueSpec(spec.Pos(), obj, spec.Names, spec.Type, values, iota)
case ast.Typ: case ast.Typ:
typ := &Name{Obj: obj} typ := &NamedType{Obj: obj}
obj.Type = typ // "mark" object so recursion terminates obj.Type = typ // "mark" object so recursion terminates
typ.Underlying = Underlying(c.makeType(obj.Decl.(*ast.TypeSpec).Type, ref)) typ.Underlying = underlying(check.typ(obj.Decl.(*ast.TypeSpec).Type, cycleOk))
// typecheck associated method signatures
case ast.Var: if obj.Data != nil {
// TODO(gri) complete this scope := obj.Data.(*ast.Scope)
switch t := typ.Underlying.(type) {
case *Struct:
// struct fields must not conflict with methods
for _, f := range t.Fields {
if m := scope.Lookup(f.Name); m != nil {
check.errorf(m.Pos(), "type %s has both field and method named %s", obj.Name, f.Name)
}
}
// ok to continue
case *Interface:
// methods cannot be associated with an interface type
for _, m := range scope.Objects {
recv := m.Decl.(*ast.FuncDecl).Recv.List[0].Type
check.errorf(recv.Pos(), "invalid receiver type %s (%s is an interface type)", obj.Name, obj.Name)
}
// ok to continue
}
// typecheck method signatures
for _, m := range scope.Objects {
mdecl := m.Decl.(*ast.FuncDecl)
// TODO(gri) At the moment, the receiver is type-checked when checking
// the method body. Also, we don't properly track if the receiver is
// a pointer (i.e., currently, method sets are too large). FIX THIS.
mtyp := check.typ(mdecl.Type, cycleOk).(*Signature)
m.Type = mtyp
}
}
case ast.Fun: case ast.Fun:
fdecl := obj.Decl.(*ast.FuncDecl) fdecl := obj.Decl.(*ast.FuncDecl)
ftyp := c.makeType(fdecl.Type, ref).(*Func)
obj.Type = ftyp
if fdecl.Recv != nil { if fdecl.Recv != nil {
recvField := fdecl.Recv.List[0] // This will ensure that the method base type is
if len(recvField.Names) > 0 { // type-checked
ftyp.Recv = recvField.Names[0].Obj check.collectFields(token.FUNC, fdecl.Recv, true)
} else {
ftyp.Recv = ast.NewObj(ast.Var, "_")
ftyp.Recv.Decl = recvField
} }
c.checkObj(ftyp.Recv, ref) ftyp := check.typ(fdecl.Type, cycleOk).(*Signature)
// TODO(axw) add method to a list in the receiver type. obj.Type = ftyp
} check.function(ftyp, fdecl.Body)
// TODO(axw) check function body, if non-nil.
default: default:
panic("unreachable") panic("unreachable")
} }
} }
// Check typechecks a package. // assocInitvals associates "inherited" initialization expressions
// It augments the AST by assigning types to all ast.Objects and returns a map // with the corresponding *ast.ValueSpec in the check.initexprs map
// of types for all expression nodes in statements, and a scanner.ErrorList if // for constant declarations without explicit initialization expressions.
// there are errors.
// //
func Check(fset *token.FileSet, pkg *ast.Package) (types map[ast.Expr]Type, err error) { func (check *checker) assocInitvals(decl *ast.GenDecl) {
// Sort objects so that we get reproducible error var values []ast.Expr
// positions (this is only needed for testing). for _, s := range decl.Specs {
// TODO(gri): Consider ast.Scope implementation that if s, ok := s.(*ast.ValueSpec); ok {
// provides both a list and a map for fast lookup. if len(s.Values) > 0 {
// Would permit the use of scopes instead of ObjMaps values = s.Values
// elsewhere. } else {
list := make(ObjList, len(pkg.Scope.Objects)) check.initexprs[s] = values
i := 0 }
for _, obj := range pkg.Scope.Objects { }
list[i] = obj }
i++ if len(values) == 0 {
check.invalidAST(decl.Pos(), "no initialization values provided")
} }
list.Sort()
var c checker
c.fset = fset
c.types = make(map[ast.Expr]Type)
for _, obj := range list {
c.checkObj(obj, false)
} }
c.errors.RemoveMultiples() // assocMethod associates a method declaration with the respective
return c.types, c.errors.Err() // receiver base type. meth.Recv must exist.
//
func (check *checker) assocMethod(meth *ast.FuncDecl) {
// The receiver type is one of the following (enforced by parser):
// - *ast.Ident
// - *ast.StarExpr{*ast.Ident}
// - *ast.BadExpr (parser error)
typ := meth.Recv.List[0].Type
if ptr, ok := typ.(*ast.StarExpr); ok {
typ = ptr.X
}
// determine receiver base type object (or nil if error)
var obj *ast.Object
if ident, ok := typ.(*ast.Ident); ok && ident.Obj != nil {
obj = ident.Obj
if obj.Kind != ast.Typ {
check.errorf(ident.Pos(), "%s is not a type", ident.Name)
obj = nil
}
// TODO(gri) determine if obj was defined in this package
/*
if check.notLocal(obj) {
check.errorf(ident.Pos(), "cannot define methods on non-local type %s", ident.Name)
obj = nil
}
*/
} else {
// If it's not an identifier or the identifier wasn't declared/resolved,
// the parser/resolver already reported an error. Nothing to do here.
}
// determine base type scope (or nil if error)
var scope *ast.Scope
if obj != nil {
if obj.Data != nil {
scope = obj.Data.(*ast.Scope)
} else {
scope = ast.NewScope(nil)
obj.Data = scope
}
} else {
// use a dummy scope so that meth can be declared in
// presence of an error and get an associated object
// (always use a new scope so that we don't get double
// declaration errors)
scope = ast.NewScope(nil)
}
check.declare(scope, ast.Fun, meth.Name, meth)
}
func (check *checker) assocInitvalsOrMethod(decl ast.Decl) {
switch d := decl.(type) {
case *ast.GenDecl:
if d.Tok == token.CONST {
check.assocInitvals(d)
}
case *ast.FuncDecl:
if d.Recv != nil {
check.assocMethod(d)
}
}
}
func (check *checker) decl(decl ast.Decl) {
switch d := decl.(type) {
case *ast.BadDecl:
// ignore
case *ast.GenDecl:
for _, spec := range d.Specs {
switch s := spec.(type) {
case *ast.ImportSpec:
// nothing to do (handled by ast.NewPackage)
case *ast.ValueSpec:
for _, name := range s.Names {
if obj := name.Obj; obj.Type == nil {
check.object(obj, false)
}
}
case *ast.TypeSpec:
if obj := s.Name.Obj; obj.Type == nil {
check.object(obj, false)
}
default:
check.invalidAST(s.Pos(), "unknown ast.Spec node %T", s)
}
}
case *ast.FuncDecl:
if d.Name.Name == "init" {
// initialization function
// TODO(gri) ignore for now (has no object associated with it)
// (should probably collect in a first phase and properly initialize)
return
}
if obj := d.Name.Obj; obj.Type == nil {
check.object(obj, false)
}
default:
check.invalidAST(d.Pos(), "unknown ast.Decl node %T", d)
}
}
// iterate calls f for each package-level declaration.
func (check *checker) iterate(f func(*checker, ast.Decl)) {
list := check.filenames
if list == nil {
// initialize lazily
for filename := range check.pkg.Files {
list = append(list, filename)
}
sort.Strings(list)
check.filenames = list
}
for _, filename := range list {
for _, decl := range check.pkg.Files[filename].Decls {
f(check, decl)
}
}
}
// A bailout panic is raised to indicate early termination.
type bailout struct{}
func check(fset *token.FileSet, pkg *ast.Package, errh func(token.Pos, string), f func(ast.Expr, Type)) (err error) {
// initialize checker
var check checker
check.fset = fset
check.pkg = pkg
check.errh = errh
check.mapf = f
check.initexprs = make(map[*ast.ValueSpec][]ast.Expr)
// handle bailouts
defer func() {
if p := recover(); p != nil {
_ = p.(bailout) // re-panic if not a bailout
}
err = check.firsterr
}()
// determine missing constant initialization expressions
// and associate methods with types
check.iterate((*checker).assocInitvalsOrMethod)
// typecheck all declarations
check.iterate((*checker).decl)
return
} }

View File

@ -23,17 +23,21 @@
package types package types
import ( import (
// "fmt" "flag"
"fmt"
"go/ast" "go/ast"
"go/parser" "go/parser"
"go/scanner" "go/scanner"
"go/token" "go/token"
"io/ioutil" "io/ioutil"
// "os" "os"
"regexp" "regexp"
"runtime"
"testing" "testing"
) )
var listErrors = flag.Bool("list", false, "list errors")
// The test filenames do not end in .go so that they are invisible // The test filenames do not end in .go so that they are invisible
// to gofmt since they contain comments that must not change their // to gofmt since they contain comments that must not change their
// positions relative to surrounding tokens. // positions relative to surrounding tokens.
@ -42,7 +46,17 @@ var tests = []struct {
name string name string
files []string files []string
}{ }{
{"test0", []string{"testdata/test0.src"}}, {"decls0", []string{"testdata/decls0.src"}},
{"decls1", []string{"testdata/decls1.src"}},
{"decls2", []string{"testdata/decls2a.src", "testdata/decls2b.src"}},
{"const0", []string{"testdata/const0.src"}},
{"expr0", []string{"testdata/expr0.src"}},
{"expr1", []string{"testdata/expr1.src"}},
{"expr2", []string{"testdata/expr2.src"}},
{"expr3", []string{"testdata/expr3.src"}},
{"builtins", []string{"testdata/builtins.src"}},
{"conversions", []string{"testdata/conversions.src"}},
{"stmt0", []string{"testdata/stmt0.src"}},
} }
var fset = token.NewFileSet() var fset = token.NewFileSet()
@ -96,8 +110,9 @@ var errRx = regexp.MustCompile(`^/\* *ERROR *"([^"]*)" *\*/$`)
// expectedErrors collects the regular expressions of ERROR comments found // expectedErrors collects the regular expressions of ERROR comments found
// in files and returns them as a map of error positions to error messages. // in files and returns them as a map of error positions to error messages.
// //
func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) map[token.Pos]string { func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) map[token.Pos][]string {
errors := make(map[token.Pos]string) errors := make(map[token.Pos][]string)
for filename := range files { for filename := range files {
src, err := ioutil.ReadFile(filename) src, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
@ -120,7 +135,8 @@ func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) m
case token.COMMENT: case token.COMMENT:
s := errRx.FindStringSubmatch(lit) s := errRx.FindStringSubmatch(lit)
if len(s) == 2 { if len(s) == 2 {
errors[prev] = string(s[1]) list := errors[prev]
errors[prev] = append(list, string(s[1]))
} }
case token.SEMICOLON: case token.SEMICOLON:
// ignore automatically inserted semicolon // ignore automatically inserted semicolon
@ -133,45 +149,51 @@ func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) m
} }
} }
} }
return errors return errors
} }
func eliminate(t *testing.T, expected map[token.Pos]string, errors error) { func eliminate(t *testing.T, expected map[token.Pos][]string, errors error) {
if errors == nil { if *listErrors || errors == nil {
return return
} }
for _, error := range errors.(scanner.ErrorList) { for _, error := range errors.(scanner.ErrorList) {
// error.Pos is a token.Position, but we want // error.Pos is a token.Position, but we want
// a token.Pos so we can do a map lookup // a token.Pos so we can do a map lookup
pos := getPos(error.Pos.Filename, error.Pos.Offset) pos := getPos(error.Pos.Filename, error.Pos.Offset)
if msg, found := expected[pos]; found { list := expected[pos]
// we expect a message at pos; check if it matches index := -1 // list index of matching message, if any
// we expect one of the messages in list to match the error at pos
for i, msg := range list {
rx, err := regexp.Compile(msg) rx, err := regexp.Compile(msg)
if err != nil { if err != nil {
t.Errorf("%s: %v", error.Pos, err) t.Errorf("%s: %v", error.Pos, err)
continue continue
} }
if match := rx.MatchString(error.Msg); !match { if match := rx.MatchString(error.Msg); match {
t.Errorf("%s: %q does not match %q", error.Pos, error.Msg, msg) index = i
break
}
}
if index >= 0 {
// eliminate from list
n := len(list) - 1
if n > 0 {
// not the last entry - swap in last element and shorten list by 1
list[index] = list[n]
expected[pos] = list[:n]
} else {
// last entry - remove list from map
delete(expected, pos)
}
} else {
t.Errorf("%s: no error expected: %q", error.Pos, error.Msg)
continue continue
} }
// we have a match - eliminate this error
delete(expected, pos)
} else {
// To keep in mind when analyzing failed test output:
// If the same error position occurs multiple times in errors,
// this message will be triggered (because the first error at
// the position removes this position from the expected errors).
t.Errorf("%s: no (multiple?) error expected, but found: %s", error.Pos, error.Msg)
}
} }
} }
/* func checkFiles(t *testing.T, testname string, testfiles []string) {
This test doesn't work with gccgo--it can't read gccgo imports.
func check(t *testing.T, testname string, testfiles []string) {
// TODO(gri) Eventually all these different phases should be // TODO(gri) Eventually all these different phases should be
// subsumed into a single function call that takes // subsumed into a single function call that takes
// a set of files and creates a fully resolved and // a set of files and creates a fully resolved and
@ -192,8 +214,17 @@ func check(t *testing.T, testname string, testfiles []string) {
eliminate(t, errors, err) eliminate(t, errors, err)
// verify errors returned by the typechecker // verify errors returned by the typechecker
_, err = Check(fset, pkg) var list scanner.ErrorList
eliminate(t, errors, err) errh := func(pos token.Pos, msg string) {
list.Add(fset.Position(pos), msg)
}
err = Check(fset, pkg, errh, nil)
eliminate(t, errors, list)
if *listErrors {
scanner.PrintError(os.Stdout, err)
return
}
// there should be no expected errors left // there should be no expected errors left
if len(errors) > 0 { if len(errors) > 0 {
@ -205,19 +236,28 @@ func check(t *testing.T, testname string, testfiles []string) {
} }
func TestCheck(t *testing.T) { func TestCheck(t *testing.T) {
// This package does not yet know how to read gccgo export data.
if runtime.Compiler == "gccgo" {
return
}
// Declare builtins for testing.
// Not done in an init func to avoid an init race with
// the construction of the Universe var.
def(ast.Fun, "assert").Type = &builtin{aType, _Assert, "assert", 1, false, true}
def(ast.Fun, "trace").Type = &builtin{aType, _Trace, "trace", 0, true, true}
// For easy debugging w/o changing the testing code, // For easy debugging w/o changing the testing code,
// if there is a local test file, only test that file. // if there is a local test file, only test that file.
const testfile = "testdata/test.go" const testfile = "testdata/test.go"
if fi, err := os.Stat(testfile); err == nil && !fi.IsDir() { if fi, err := os.Stat(testfile); err == nil && !fi.IsDir() {
fmt.Printf("WARNING: Testing only %s (remove it to run all tests)\n", testfile) fmt.Printf("WARNING: Testing only %s (remove it to run all tests)\n", testfile)
check(t, testfile, []string{testfile}) checkFiles(t, testfile, []string{testfile})
return return
} }
// Otherwise, run all the tests. // Otherwise, run all the tests.
for _, test := range tests { for _, test := range tests {
check(t, test.name, test.files) checkFiles(t, test.name, test.files)
} }
} }
*/

View File

@ -2,282 +2,497 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This file implements operations on ideal constants. // This file implements operations on constant values.
package types package types
import ( import (
"fmt"
"go/token" "go/token"
"math/big" "math/big"
"strconv" "strconv"
) )
// TODO(gri) Consider changing the API so Const is an interface // TODO(gri) At the moment, constants are different types
// and operations on consts don't have to type switch. // passed around as interface{} values. Consider introducing
// a Const type and use methods instead of xConst functions.
// A Const implements an ideal constant Value. // Representation of constant values.
// The zero value z for a Const is not a valid constant value. //
type Const struct { // bool -> bool (true, false)
// representation of constant values: // numeric -> int64, *big.Int, *big.Rat, complex (ordered by increasing data structure "size")
// ideal bool -> bool // string -> string
// ideal int -> *big.Int // nil -> nilType (nilConst)
// ideal float -> *big.Rat //
// ideal complex -> cmplx // Numeric constants are normalized after each operation such
// ideal string -> string // that they are represented by the "smallest" data structure
val interface{} // required to represent the constant, independent of actual
} // type. Non-numeric constants are always normalized.
// Representation of complex values. // Representation of complex numbers.
type cmplx struct { type complex struct {
re, im *big.Rat re, im *big.Rat
} }
func assert(cond bool) { func (c complex) String() string {
if !cond { if c.re.Sign() == 0 {
panic("go/types internal error: assertion failed") return fmt.Sprintf("%si", c.im)
} }
// normalized complex values always have an imaginary part
return fmt.Sprintf("(%s + %si)", c.re, c.im)
} }
// MakeConst makes an ideal constant from a literal // Representation of nil.
// token and the corresponding literal string. type nilType struct{}
func MakeConst(tok token.Token, lit string) Const {
switch tok { func (nilType) String() string {
case token.INT: return "nil"
var x big.Int
_, ok := x.SetString(lit, 0)
assert(ok)
return Const{&x}
case token.FLOAT:
var y big.Rat
_, ok := y.SetString(lit)
assert(ok)
return Const{&y}
case token.IMAG:
assert(lit[len(lit)-1] == 'i')
var im big.Rat
_, ok := im.SetString(lit[0 : len(lit)-1])
assert(ok)
return Const{cmplx{big.NewRat(0, 1), &im}}
case token.CHAR:
assert(lit[0] == '\'' && lit[len(lit)-1] == '\'')
code, _, _, err := strconv.UnquoteChar(lit[1:len(lit)-1], '\'')
assert(err == nil)
return Const{big.NewInt(int64(code))}
case token.STRING:
s, err := strconv.Unquote(lit)
assert(err == nil)
return Const{s}
}
panic("unreachable")
} }
// MakeZero returns the zero constant for the given type. // Frequently used constants.
func MakeZero(typ *Type) Const { var (
// TODO(gri) fix this zeroConst = int64(0)
return Const{0} oneConst = int64(1)
} minusOneConst = int64(-1)
nilConst = nilType{}
)
// Match attempts to match the internal constant representations of x and y. // int64 bounds
// If the attempt is successful, the result is the values of x and y, var (
// if necessary converted to have the same internal representation; otherwise minInt64 = big.NewInt(-1 << 63)
// the results are invalid. maxInt64 = big.NewInt(1<<63 - 1)
func (x Const) Match(y Const) (u, v Const) { )
switch a := x.val.(type) {
case bool:
if _, ok := y.val.(bool); ok {
u, v = x, y
}
case *big.Int:
switch y.val.(type) {
case *big.Int:
u, v = x, y
case *big.Rat:
var z big.Rat
z.SetInt(a)
u, v = Const{&z}, y
case cmplx:
var z big.Rat
z.SetInt(a)
u, v = Const{cmplx{&z, big.NewRat(0, 1)}}, y
}
case *big.Rat:
switch y.val.(type) {
case *big.Int:
v, u = y.Match(x)
case *big.Rat:
u, v = x, y
case cmplx:
u, v = Const{cmplx{a, big.NewRat(0, 0)}}, y
}
case cmplx:
switch y.val.(type) {
case *big.Int, *big.Rat:
v, u = y.Match(x)
case cmplx:
u, v = x, y
}
case string:
if _, ok := y.val.(string); ok {
u, v = x, y
}
default:
panic("unreachable")
}
return
}
// Convert attempts to convert the constant x to a given type. // normalizeIntConst returns the smallest constant representation
// If the attempt is successful, the result is the new constant; // for the specific value of x; either an int64 or a *big.Int value.
// otherwise the result is invalid. //
func (x Const) Convert(typ *Type) Const { func normalizeIntConst(x *big.Int) interface{} {
// TODO(gri) implement this if minInt64.Cmp(x) <= 0 && x.Cmp(maxInt64) <= 0 {
switch x.val.(type) { return x.Int64()
case bool:
case *big.Int:
case *big.Rat:
case cmplx:
case string:
} }
return x return x
} }
func (x Const) String() string { // normalizeRatConst returns the smallest constant representation
switch x := x.val.(type) { // for the specific value of x; either an int64, *big.Int value,
case bool: // or *big.Rat value.
if x { //
return "true" func normalizeRatConst(x *big.Rat) interface{} {
if x.IsInt() {
return normalizeIntConst(x.Num())
} }
return "false"
case *big.Int:
return x.String()
case *big.Rat:
return x.FloatString(10) // 10 digits of precision after decimal point seems fine
case cmplx:
// TODO(gri) don't print 0 components
return x.re.FloatString(10) + " + " + x.im.FloatString(10) + "i"
case string:
return x return x
} }
panic("unreachable")
// normalizeComplexConst returns the smallest constant representation
// for the specific value of x; either an int64, *big.Int value, *big.Rat,
// or complex value.
//
func normalizeComplexConst(x complex) interface{} {
if x.im.Sign() == 0 {
return normalizeRatConst(x.re)
}
return x
} }
func (x Const) UnaryOp(op token.Token) Const { // makeRuneConst returns the int64 code point for the rune literal
panic("unimplemented") // lit. The result is nil if lit is not a correct rune literal.
//
func makeRuneConst(lit string) interface{} {
if n := len(lit); n >= 2 {
if code, _, _, err := strconv.UnquoteChar(lit[1:n-1], '\''); err == nil {
return int64(code)
}
}
return nil
} }
func (x Const) BinaryOp(op token.Token, y Const) Const { // makeRuneConst returns the smallest integer constant representation
var z interface{} // (int64, *big.Int) for the integer literal lit. The result is nil if
switch x := x.val.(type) { // lit is not a correct integer literal.
case bool: //
z = binaryBoolOp(x, op, y.val.(bool)) func makeIntConst(lit string) interface{} {
if x, err := strconv.ParseInt(lit, 0, 64); err == nil {
return x
}
if x, ok := new(big.Int).SetString(lit, 0); ok {
return x
}
return nil
}
// makeFloatConst returns the smallest floating-point constant representation
// (int64, *big.Int, *big.Rat) for the floating-point literal lit. The result
// is nil if lit is not a correct floating-point literal.
//
func makeFloatConst(lit string) interface{} {
if x, ok := new(big.Rat).SetString(lit); ok {
return normalizeRatConst(x)
}
return nil
}
// makeComplexConst returns the complex constant representation (complex) for
// the imaginary literal lit. The result is nil if lit is not a correct imaginary
// literal.
//
func makeComplexConst(lit string) interface{} {
n := len(lit)
if n > 0 && lit[n-1] == 'i' {
if im, ok := new(big.Rat).SetString(lit[0 : n-1]); ok {
return normalizeComplexConst(complex{big.NewRat(0, 1), im})
}
}
return nil
}
// makeStringConst returns the string constant representation (string) for
// the string literal lit. The result is nil if lit is not a correct string
// literal.
//
func makeStringConst(lit string) interface{} {
if s, err := strconv.Unquote(lit); err == nil {
return s
}
return nil
}
// isZeroConst reports whether the value of constant x is 0.
// x must be normalized.
//
func isZeroConst(x interface{}) bool {
i, ok := x.(int64) // good enough since constants are normalized
return ok && i == 0
}
// isNegConst reports whether the value of constant x is < 0.
// x must be a non-complex numeric value.
//
func isNegConst(x interface{}) bool {
switch x := x.(type) {
case int64:
return x < 0
case *big.Int: case *big.Int:
z = binaryIntOp(x, op, y.val.(*big.Int)) return x.Sign() < 0
case *big.Rat: case *big.Rat:
z = binaryFloatOp(x, op, y.val.(*big.Rat)) return x.Sign() < 0
case cmplx:
z = binaryCmplxOp(x, op, y.val.(cmplx))
case string:
z = binaryStringOp(x, op, y.val.(string))
default:
panic("unreachable")
} }
return Const{z} unreachable()
return false
} }
func binaryBoolOp(x bool, op token.Token, y bool) interface{} { // isRepresentableConst reports whether the value of constant x can
// be represented as a value of the basic type Typ[as] without loss
// of precision.
//
func isRepresentableConst(x interface{}, as BasicKind) bool {
const intBits = 32 // TODO(gri) implementation-specific constant
const ptrBits = 64 // TODO(gri) implementation-specific constant
switch x := x.(type) {
case bool:
return as == Bool || as == UntypedBool
case int64:
switch as {
case Int:
return -1<<(intBits-1) <= x && x <= 1<<(intBits-1)-1
case Int8:
return -1<<(8-1) <= x && x <= 1<<(8-1)-1
case Int16:
return -1<<(16-1) <= x && x <= 1<<(16-1)-1
case Int32, UntypedRune:
return -1<<(32-1) <= x && x <= 1<<(32-1)-1
case Int64:
return true
case Uint:
return 0 <= x && x <= 1<<intBits-1
case Uint8:
return 0 <= x && x <= 1<<8-1
case Uint16:
return 0 <= x && x <= 1<<16-1
case Uint32:
return 0 <= x && x <= 1<<32-1
case Uint64:
return 0 <= x
case Uintptr:
assert(ptrBits == 64)
return 0 <= x
case Float32:
return true // TODO(gri) fix this
case Float64:
return true // TODO(gri) fix this
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedInt, UntypedFloat, UntypedComplex:
return true
}
case *big.Int:
switch as {
case Uint:
return x.Sign() >= 0 && x.BitLen() <= intBits
case Uint64:
return x.Sign() >= 0 && x.BitLen() <= 64
case Uintptr:
return x.Sign() >= 0 && x.BitLen() <= ptrBits
case Float32:
return true // TODO(gri) fix this
case Float64:
return true // TODO(gri) fix this
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedInt, UntypedFloat, UntypedComplex:
return true
}
case *big.Rat:
switch as {
case Float32:
return true // TODO(gri) fix this
case Float64:
return true // TODO(gri) fix this
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedFloat, UntypedComplex:
return true
}
case complex:
switch as {
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedComplex:
return true
}
case string:
return as == String || as == UntypedString
case nilType:
return as == UntypedNil
default:
unreachable()
}
return false
}
var (
int1 = big.NewInt(1)
rat0 = big.NewRat(0, 1)
)
// complexity returns a measure of representation complexity for constant x.
func complexity(x interface{}) int {
switch x.(type) {
case bool, string, nilType:
return 1
case int64:
return 2
case *big.Int:
return 3
case *big.Rat:
return 4
case complex:
return 5
}
unreachable()
return 0
}
// matchConst returns the matching representation (same type) with the
// smallest complexity for two constant values x and y. They must be
// of the same "kind" (boolean, numeric, string, or nilType).
//
func matchConst(x, y interface{}) (_, _ interface{}) {
if complexity(x) > complexity(y) {
y, x = matchConst(y, x)
return x, y
}
// complexity(x) <= complexity(y)
switch x := x.(type) {
case bool, complex, string, nilType:
return x, y
case int64:
switch y := y.(type) {
case int64:
return x, y
case *big.Int:
return big.NewInt(x), y
case *big.Rat:
return big.NewRat(x, 1), y
case complex:
return complex{big.NewRat(x, 1), rat0}, y
}
case *big.Int:
switch y := y.(type) {
case *big.Int:
return x, y
case *big.Rat:
return new(big.Rat).SetFrac(x, int1), y
case complex:
return complex{new(big.Rat).SetFrac(x, int1), rat0}, y
}
case *big.Rat:
switch y := y.(type) {
case *big.Rat:
return x, y
case complex:
return complex{x, rat0}, y
}
}
unreachable()
return nil, nil
}
// is32bit reports whether x can be represented using 32 bits.
func is32bit(x int64) bool {
return -1<<31 <= x && x <= 1<<31-1
}
// is63bit reports whether x can be represented using 63 bits.
func is63bit(x int64) bool {
return -1<<62 <= x && x <= 1<<62-1
}
// binaryOpConst returns the result of the constant evaluation x op y;
// both operands must be of the same "kind" (boolean, numeric, or string).
// If intDiv is true, division (op == token.QUO) is using integer division
// (and the result is guaranteed to be integer) rather than floating-point
// division. Division by zero leads to a run-time panic.
//
func binaryOpConst(x, y interface{}, op token.Token, intDiv bool) interface{} {
x, y = matchConst(x, y)
switch x := x.(type) {
case bool:
y := y.(bool)
switch op { switch op {
case token.EQL: case token.LAND:
return x == y return x && y
case token.NEQ: case token.LOR:
return x != y return x || y
} default:
panic("unreachable") unreachable()
} }
func binaryIntOp(x *big.Int, op token.Token, y *big.Int) interface{} { case int64:
y := y.(int64)
switch op {
case token.ADD:
// TODO(gri) can do better than this
if is63bit(x) && is63bit(y) {
return x + y
}
return normalizeIntConst(new(big.Int).Add(big.NewInt(x), big.NewInt(y)))
case token.SUB:
// TODO(gri) can do better than this
if is63bit(x) && is63bit(y) {
return x - y
}
return normalizeIntConst(new(big.Int).Sub(big.NewInt(x), big.NewInt(y)))
case token.MUL:
// TODO(gri) can do better than this
if is32bit(x) && is32bit(y) {
return x * y
}
return normalizeIntConst(new(big.Int).Mul(big.NewInt(x), big.NewInt(y)))
case token.REM:
return x % y
case token.QUO:
if intDiv {
return x / y
}
return normalizeRatConst(new(big.Rat).SetFrac(big.NewInt(x), big.NewInt(y)))
case token.AND:
return x & y
case token.OR:
return x | y
case token.XOR:
return x ^ y
case token.AND_NOT:
return x &^ y
default:
unreachable()
}
case *big.Int:
y := y.(*big.Int)
var z big.Int var z big.Int
switch op { switch op {
case token.ADD: case token.ADD:
return z.Add(x, y) z.Add(x, y)
case token.SUB: case token.SUB:
return z.Sub(x, y) z.Sub(x, y)
case token.MUL: case token.MUL:
return z.Mul(x, y) z.Mul(x, y)
case token.QUO:
return z.Quo(x, y)
case token.REM: case token.REM:
return z.Rem(x, y) z.Rem(x, y)
case token.QUO:
if intDiv {
z.Quo(x, y)
} else {
return normalizeRatConst(new(big.Rat).SetFrac(x, y))
}
case token.AND: case token.AND:
return z.And(x, y) z.And(x, y)
case token.OR: case token.OR:
return z.Or(x, y) z.Or(x, y)
case token.XOR: case token.XOR:
return z.Xor(x, y) z.Xor(x, y)
case token.AND_NOT: case token.AND_NOT:
return z.AndNot(x, y) z.AndNot(x, y)
case token.SHL: default:
panic("unimplemented") unreachable()
case token.SHR:
panic("unimplemented")
case token.EQL:
return x.Cmp(y) == 0
case token.NEQ:
return x.Cmp(y) != 0
case token.LSS:
return x.Cmp(y) < 0
case token.LEQ:
return x.Cmp(y) <= 0
case token.GTR:
return x.Cmp(y) > 0
case token.GEQ:
return x.Cmp(y) >= 0
}
panic("unreachable")
} }
return normalizeIntConst(&z)
func binaryFloatOp(x *big.Rat, op token.Token, y *big.Rat) interface{} { case *big.Rat:
y := y.(*big.Rat)
var z big.Rat var z big.Rat
switch op { switch op {
case token.ADD: case token.ADD:
return z.Add(x, y) z.Add(x, y)
case token.SUB: case token.SUB:
return z.Sub(x, y) z.Sub(x, y)
case token.MUL: case token.MUL:
return z.Mul(x, y) z.Mul(x, y)
case token.QUO: case token.QUO:
return z.Quo(x, y) z.Quo(x, y)
case token.EQL: default:
return x.Cmp(y) == 0 unreachable()
case token.NEQ:
return x.Cmp(y) != 0
case token.LSS:
return x.Cmp(y) < 0
case token.LEQ:
return x.Cmp(y) <= 0
case token.GTR:
return x.Cmp(y) > 0
case token.GEQ:
return x.Cmp(y) >= 0
}
panic("unreachable")
} }
return normalizeRatConst(&z)
func binaryCmplxOp(x cmplx, op token.Token, y cmplx) interface{} { case complex:
y := y.(complex)
a, b := x.re, x.im a, b := x.re, x.im
c, d := y.re, y.im c, d := y.re, y.im
var re, im big.Rat
switch op { switch op {
case token.ADD: case token.ADD:
// (a+c) + i(b+d) // (a+c) + i(b+d)
var re, im big.Rat
re.Add(a, c) re.Add(a, c)
im.Add(b, d) im.Add(b, d)
return cmplx{&re, &im}
case token.SUB: case token.SUB:
// (a-c) + i(b-d) // (a-c) + i(b-d)
var re, im big.Rat
re.Sub(a, c) re.Sub(a, c)
im.Sub(b, d) im.Sub(b, d)
return cmplx{&re, &im}
case token.MUL: case token.MUL:
// (ac-bd) + i(bc+ad) // (ac-bd) + i(bc+ad)
var ac, bd, bc, ad big.Rat var ac, bd, bc, ad big.Rat
@ -285,10 +500,8 @@ func binaryCmplxOp(x cmplx, op token.Token, y cmplx) interface{} {
bd.Mul(b, d) bd.Mul(b, d)
bc.Mul(b, c) bc.Mul(b, c)
ad.Mul(a, d) ad.Mul(a, d)
var re, im big.Rat
re.Sub(&ac, &bd) re.Sub(&ac, &bd)
im.Add(&bc, &ad) im.Add(&bc, &ad)
return cmplx{&re, &im}
case token.QUO: case token.QUO:
// (ac+bd)/s + i(bc-ad)/s, with s = cc + dd // (ac+bd)/s + i(bc-ad)/s, with s = cc + dd
var ac, bd, bc, ad, s big.Rat var ac, bd, bc, ad, s big.Rat
@ -297,36 +510,153 @@ func binaryCmplxOp(x cmplx, op token.Token, y cmplx) interface{} {
bc.Mul(b, c) bc.Mul(b, c)
ad.Mul(a, d) ad.Mul(a, d)
s.Add(c.Mul(c, c), d.Mul(d, d)) s.Add(c.Mul(c, c), d.Mul(d, d))
var re, im big.Rat
re.Add(&ac, &bd) re.Add(&ac, &bd)
re.Quo(&re, &s) re.Quo(&re, &s)
im.Sub(&bc, &ad) im.Sub(&bc, &ad)
im.Quo(&im, &s) im.Quo(&im, &s)
return cmplx{&re, &im} default:
case token.EQL: unreachable()
return a.Cmp(c) == 0 && b.Cmp(d) == 0 }
case token.NEQ: return normalizeComplexConst(complex{&re, &im})
return a.Cmp(c) != 0 || b.Cmp(d) != 0
case string:
if op == token.ADD {
return x + y.(string)
} }
panic("unreachable")
} }
func binaryStringOp(x string, op token.Token, y string) interface{} { unreachable()
return nil
}
// shiftConst returns the result of the constant evaluation x op s
// where op is token.SHL or token.SHR (<< or >>). x must be an
// integer constant.
//
func shiftConst(x interface{}, s uint, op token.Token) interface{} {
switch x := x.(type) {
case int64:
switch op {
case token.SHL:
z := big.NewInt(x)
return normalizeIntConst(z.Lsh(z, s))
case token.SHR:
return x >> s
}
case *big.Int:
var z big.Int
switch op {
case token.SHL:
return normalizeIntConst(z.Lsh(x, s))
case token.SHR:
return normalizeIntConst(z.Rsh(x, s))
}
}
unreachable()
return nil
}
// compareConst returns the result of the constant comparison x op y;
// both operands must be of the same "kind" (boolean, numeric, string,
// or nilType).
//
func compareConst(x, y interface{}, op token.Token) (z bool) {
x, y = matchConst(x, y)
// x == y => x == y
// x != y => x != y
// x > y => y < x
// x >= y => u <= x
swap := false
switch op {
case token.GTR:
swap = true
op = token.LSS
case token.GEQ:
swap = true
op = token.LEQ
}
// x == y => x == y
// x != y => !(x == y)
// x < y => x < y
// x <= y => !(y < x)
negate := false
switch op {
case token.NEQ:
negate = true
op = token.EQL
case token.LEQ:
swap = !swap
negate = true
op = token.LSS
}
if negate {
defer func() { z = !z }()
}
if swap {
x, y = y, x
}
switch x := x.(type) {
case bool:
if op == token.EQL {
return x == y.(bool)
}
case int64:
y := y.(int64)
switch op { switch op {
case token.ADD:
return x + y
case token.EQL: case token.EQL:
return x == y return x == y
case token.NEQ:
return x != y
case token.LSS: case token.LSS:
return x < y return x < y
case token.LEQ:
return x <= y
case token.GTR:
return x > y
case token.GEQ:
return x >= y
} }
panic("unreachable")
case *big.Int:
s := x.Cmp(y.(*big.Int))
switch op {
case token.EQL:
return s == 0
case token.LSS:
return s < 0
}
case *big.Rat:
s := x.Cmp(y.(*big.Rat))
switch op {
case token.EQL:
return s == 0
case token.LSS:
return s < 0
}
case complex:
y := y.(complex)
if op == token.EQL {
return x.re.Cmp(y.re) == 0 && x.im.Cmp(y.im) == 0
}
case string:
y := y.(string)
switch op {
case token.EQL:
return x == y
case token.LSS:
return x < y
}
case nilType:
if op == token.EQL {
return x == y.(nilType)
}
}
fmt.Printf("x = %s (%T), y = %s (%T)\n", x, x, y, y)
unreachable()
return
} }

View File

@ -29,7 +29,9 @@ func (check *checker) conversion(x *operand, conv *ast.CallExpr, typ Type, iota
} }
// TODO(gri) fix this - implement all checks and constant evaluation // TODO(gri) fix this - implement all checks and constant evaluation
if x.mode != constant {
x.mode = value x.mode = value
}
x.expr = conv x.expr = conv
x.typ = typ x.typ = typ
return return

View File

@ -13,10 +13,6 @@ import (
"go/token" "go/token"
) )
// debugging flags
const debug = false
const trace = false
// TODO(gri) eventually assert and unimplemented should disappear. // TODO(gri) eventually assert and unimplemented should disappear.
func assert(p bool) { func assert(p bool) {
if !p { if !p {
@ -25,15 +21,40 @@ func assert(p bool) {
} }
func unimplemented() { func unimplemented() {
if debug { // enable for debugging
panic("unimplemented") // panic("unimplemented")
}
} }
func unreachable() { func unreachable() {
panic("unreachable") panic("unreachable")
} }
func (check *checker) printTrace(format string, args []interface{}) {
const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
n := len(check.pos) - 1
i := 2 * n
for i > len(dots) {
fmt.Print(dots)
i -= len(dots)
}
// i <= len(dots)
fmt.Printf("%s: ", check.fset.Position(check.pos[n]))
fmt.Print(dots[0:i])
fmt.Println(check.formatMsg(format, args))
}
func (check *checker) trace(pos token.Pos, format string, args ...interface{}) {
check.pos = append(check.pos, pos)
check.printTrace(format, args)
}
func (check *checker) untrace(format string, args ...interface{}) {
if len(format) > 0 {
check.printTrace(format, args)
}
check.pos = check.pos[:len(check.pos)-1]
}
func (check *checker) formatMsg(format string, args []interface{}) string { func (check *checker) formatMsg(format string, args []interface{}) string {
for i, arg := range args { for i, arg := range args {
switch a := arg.(type) { switch a := arg.(type) {

View File

@ -22,7 +22,8 @@ func readGopackHeader(r *bufio.Reader) (name string, size int, err error) {
if err != nil { if err != nil {
return return
} }
if trace { // leave for debugging
if false {
fmt.Printf("header: %s", hdr) fmt.Printf("header: %s", hdr)
} }
s := strings.TrimSpace(string(hdr[16+12+6+6+8:][:10])) s := strings.TrimSpace(string(hdr[16+12+6+6+8:][:10]))

View File

@ -15,6 +15,7 @@ import (
// TODO(gri) // TODO(gri)
// - don't print error messages referring to invalid types (they are likely spurious errors) // - don't print error messages referring to invalid types (they are likely spurious errors)
// - simplify invalid handling: maybe just use Typ[Invalid] as marker, get rid of invalid Mode for values? // - simplify invalid handling: maybe just use Typ[Invalid] as marker, get rid of invalid Mode for values?
// - rethink error handling: should all callers check if x.mode == valid after making a call?
func (check *checker) tag(field *ast.Field) string { func (check *checker) tag(field *ast.Field) string {
if t := field.Tag; t != nil { if t := field.Tag; t != nil {
@ -94,7 +95,7 @@ func (check *checker) collectStructFields(list *ast.FieldList, cycleOk bool) (fi
fields = append(fields, &StructField{t.Obj.Name, t, tag, true}) fields = append(fields, &StructField{t.Obj.Name, t, tag, true})
default: default:
if typ != Typ[Invalid] { if typ != Typ[Invalid] {
check.errorf(f.Type.Pos(), "invalid anonymous field type %s", typ) check.invalidAST(f.Type.Pos(), "anonymous field type %s must be named", typ)
} }
} }
} }
@ -109,7 +110,6 @@ var unaryOpPredicates = opPredicates{
token.SUB: isNumeric, token.SUB: isNumeric,
token.XOR: isInteger, token.XOR: isInteger,
token.NOT: isBoolean, token.NOT: isBoolean,
token.ARROW: func(typ Type) bool { t, ok := underlying(typ).(*Chan); return ok && t.Dir&ast.RECV != 0 },
} }
func (check *checker) op(m opPredicates, x *operand, op token.Token) bool { func (check *checker) op(m opPredicates, x *operand, op token.Token) bool {
@ -129,20 +129,33 @@ func (check *checker) op(m opPredicates, x *operand, op token.Token) bool {
} }
func (check *checker) unary(x *operand, op token.Token) { func (check *checker) unary(x *operand, op token.Token) {
if op == token.AND { switch op {
case token.AND:
// TODO(gri) need to check for composite literals, somehow (they are not variables, in general) // TODO(gri) need to check for composite literals, somehow (they are not variables, in general)
if x.mode != variable { if x.mode != variable {
check.invalidOp(x.pos(), "cannot take address of %s", x) check.invalidOp(x.pos(), "cannot take address of %s", x)
x.mode = invalid goto Error
return
} }
x.typ = &Pointer{Base: x.typ} x.typ = &Pointer{Base: x.typ}
return return
case token.ARROW:
typ, ok := underlying(x.typ).(*Chan)
if !ok {
check.invalidOp(x.pos(), "cannot receive from non-channel %s", x)
goto Error
}
if typ.Dir&ast.RECV == 0 {
check.invalidOp(x.pos(), "cannot receive from send-only channel %s", x)
goto Error
}
x.mode = valueok
x.typ = typ.Elt
return
} }
if !check.op(unaryOpPredicates, x, op) { if !check.op(unaryOpPredicates, x, op) {
x.mode = invalid goto Error
return
} }
if x.mode == constant { if x.mode == constant {
@ -156,7 +169,7 @@ func (check *checker) unary(x *operand, op token.Token) {
case token.NOT: case token.NOT:
x.val = !x.val.(bool) x.val = !x.val.(bool)
default: default:
unreachable() unreachable() // operators where checked by check.op
} }
// Typed constants must be representable in // Typed constants must be representable in
// their type after each constant operation. // their type after each constant operation.
@ -165,6 +178,11 @@ func (check *checker) unary(x *operand, op token.Token) {
} }
x.mode = value x.mode = value
// x.typ remains unchanged
return
Error:
x.mode = invalid
} }
func isShift(op token.Token) bool { func isShift(op token.Token) bool {
@ -216,8 +234,7 @@ func (check *checker) convertUntyped(x *operand, target Type) {
x.typ = target x.typ = target
} }
} else if xkind != tkind { } else if xkind != tkind {
check.errorf(x.pos(), "cannot convert %s to %s", x, target) goto Error
x.mode = invalid // avoid spurious errors
} }
return return
} }
@ -226,15 +243,22 @@ func (check *checker) convertUntyped(x *operand, target Type) {
switch t := underlying(target).(type) { switch t := underlying(target).(type) {
case *Basic: case *Basic:
check.isRepresentable(x, t) check.isRepresentable(x, t)
case *Interface:
case *Pointer, *Signature, *Interface, *Slice, *Map, *Chan: if !x.isNil() && len(t.Methods) > 0 /* empty interfaces are ok */ {
if x.typ != Typ[UntypedNil] { goto Error
check.errorf(x.pos(), "cannot convert %s to %s", x, target) }
x.mode = invalid case *Pointer, *Signature, *Slice, *Map, *Chan:
if !x.isNil() {
goto Error
} }
} }
x.typ = target x.typ = target
return
Error:
check.errorf(x.pos(), "cannot convert %s to %s", x, target)
x.mode = invalid
} }
func (check *checker) comparison(x, y *operand, op token.Token) { func (check *checker) comparison(x, y *operand, op token.Token) {
@ -244,9 +268,11 @@ func (check *checker) comparison(x, y *operand, op token.Token) {
if x.isAssignable(y.typ) || y.isAssignable(x.typ) { if x.isAssignable(y.typ) || y.isAssignable(x.typ) {
switch op { switch op {
case token.EQL, token.NEQ: case token.EQL, token.NEQ:
valid = isComparable(x.typ) valid = isComparable(x.typ) ||
x.isNil() && hasNil(y.typ) ||
y.isNil() && hasNil(x.typ)
case token.LSS, token.LEQ, token.GTR, token.GEQ: case token.LSS, token.LEQ, token.GTR, token.GEQ:
valid = isOrdered(y.typ) valid = isOrdered(x.typ)
default: default:
unreachable() unreachable()
} }
@ -389,7 +415,7 @@ func (check *checker) binary(x, y *operand, op token.Token, hint Type) {
x.val = binaryOpConst(x.val, y.val, op, isInteger(x.typ)) x.val = binaryOpConst(x.val, y.val, op, isInteger(x.typ))
// Typed constants must be representable in // Typed constants must be representable in
// their type after each constant operation. // their type after each constant operation.
check.isRepresentable(x, x.typ.(*Basic)) check.isRepresentable(x, underlying(x.typ).(*Basic))
return return
} }
@ -398,38 +424,31 @@ func (check *checker) binary(x, y *operand, op token.Token, hint Type) {
} }
// index checks an index expression for validity. If length >= 0, it is the upper // index checks an index expression for validity. If length >= 0, it is the upper
// bound for the index. The result is a valid constant index >= 0, or a negative // bound for the index. The result is a valid integer constant, or nil.
// value.
// //
func (check *checker) index(index ast.Expr, length int64, iota int) int64 { func (check *checker) index(index ast.Expr, length int64, iota int) interface{} {
var x operand var x operand
var i int64 // index value, valid if >= 0
check.expr(&x, index, nil, iota) check.expr(&x, index, nil, iota)
if !x.isInteger() { if !x.isInteger() {
check.errorf(x.pos(), "index %s must be integer", &x) check.errorf(x.pos(), "index %s must be integer", &x)
return -1 return nil
} }
if x.mode != constant { if x.mode != constant {
return -1 // we cannot check more return nil // we cannot check more
} }
// x.mode == constant and the index value must be >= 0 // x.mode == constant and the index value must be >= 0
if isNegConst(x.val) { if isNegConst(x.val) {
check.errorf(x.pos(), "index %s must not be negative", &x) check.errorf(x.pos(), "index %s must not be negative", &x)
return -1 return nil
} }
var ok bool // x.val >= 0
if i, ok = x.val.(int64); !ok { if length >= 0 && compareConst(x.val, length, token.GEQ) {
// index value doesn't fit into an int64
i = length // trigger out of bounds check below if we know length (>= 0)
}
if length >= 0 && i >= length {
check.errorf(x.pos(), "index %s is out of bounds (>= %d)", &x, length) check.errorf(x.pos(), "index %s is out of bounds (>= %d)", &x, length)
return -1 return nil
} }
return i return x.val
} }
func (check *checker) callRecord(x *operand) { func (check *checker) callRecord(x *operand) {
@ -438,20 +457,25 @@ func (check *checker) callRecord(x *operand) {
} }
} }
// expr typechecks expression e and initializes x with the expression // rawExpr typechecks expression e and initializes x with the expression
// value or type. If an error occured, x.mode is set to invalid. // value or type. If an error occured, x.mode is set to invalid.
// A hint != nil is used as operand type for untyped shifted operands; // A hint != nil is used as operand type for untyped shifted operands;
// iota >= 0 indicates that the expression is part of a constant declaration. // iota >= 0 indicates that the expression is part of a constant declaration.
// cycleOk indicates whether it is ok for a type expression to refer to itself. // cycleOk indicates whether it is ok for a type expression to refer to itself.
// //
func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cycleOk bool) { func (check *checker) rawExpr(x *operand, e ast.Expr, hint Type, iota int, cycleOk bool) {
if trace {
check.trace(e.Pos(), "expr(%s, iota = %d, cycleOk = %v)", e, iota, cycleOk)
defer check.untrace("=> %s", x)
}
if check.mapf != nil { if check.mapf != nil {
defer check.callRecord(x) defer check.callRecord(x)
} }
switch e := e.(type) { switch e := e.(type) {
case *ast.BadExpr: case *ast.BadExpr:
x.mode = invalid goto Error // error was reported before
case *ast.Ident: case *ast.Ident:
if e.Name == "_" { if e.Name == "_" {
@ -460,13 +484,14 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
} }
obj := e.Obj obj := e.Obj
if obj == nil { if obj == nil {
// unresolved identifier (error has been reported before) goto Error // error was reported before
goto Error }
if obj.Type == nil {
check.object(obj, cycleOk)
} }
check.ident(e, cycleOk)
switch obj.Kind { switch obj.Kind {
case ast.Bad: case ast.Bad:
goto Error goto Error // error was reported before
case ast.Pkg: case ast.Pkg:
check.errorf(e.Pos(), "use of package %s not in selector", obj.Name) check.errorf(e.Pos(), "use of package %s not in selector", obj.Name)
goto Error goto Error
@ -501,6 +526,9 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
} }
x.typ = obj.Type.(Type) x.typ = obj.Type.(Type)
case *ast.Ellipsis:
unimplemented()
case *ast.BasicLit: case *ast.BasicLit:
x.setConst(e.Kind, e.Value) x.setConst(e.Kind, e.Value)
if x.mode == invalid { if x.mode == invalid {
@ -511,32 +539,41 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
case *ast.FuncLit: case *ast.FuncLit:
x.mode = value x.mode = value
x.typ = check.typ(e.Type, false) x.typ = check.typ(e.Type, false)
check.stmt(e.Body) // TODO(gri) handle errors (e.g. x.typ is not a *Signature)
check.function(x.typ.(*Signature), e.Body)
case *ast.CompositeLit: case *ast.CompositeLit:
// TODO(gri) // TODO(gri)
// - determine element type if nil // - determine element type if nil
// - deal with map elements // - deal with map elements
var typ Type
if e.Type != nil {
// TODO(gri) Fix this - just to get going for now
typ = check.typ(e.Type, false)
}
for _, e := range e.Elts { for _, e := range e.Elts {
var x operand var x operand
check.expr(&x, e, hint, iota) check.expr(&x, e, hint, iota)
// TODO(gri) check assignment compatibility to element type // TODO(gri) check assignment compatibility to element type
} }
x.mode = value // TODO(gri) composite literals are addressable // TODO(gri) this is not correct - leave for now to get going
x.mode = variable
x.typ = typ
case *ast.ParenExpr: case *ast.ParenExpr:
check.exprOrType(x, e.X, hint, iota, cycleOk) check.rawExpr(x, e.X, hint, iota, cycleOk)
case *ast.SelectorExpr: case *ast.SelectorExpr:
sel := e.Sel.Name
// If the identifier refers to a package, handle everything here // If the identifier refers to a package, handle everything here
// so we don't need a "package" mode for operands: package names // so we don't need a "package" mode for operands: package names
// can only appear in qualified identifiers which are mapped to // can only appear in qualified identifiers which are mapped to
// selector expressions. // selector expressions.
if ident, ok := e.X.(*ast.Ident); ok { if ident, ok := e.X.(*ast.Ident); ok {
if obj := ident.Obj; obj != nil && obj.Kind == ast.Pkg { if obj := ident.Obj; obj != nil && obj.Kind == ast.Pkg {
exp := obj.Data.(*ast.Scope).Lookup(e.Sel.Name) exp := obj.Data.(*ast.Scope).Lookup(sel)
if exp == nil { if exp == nil {
check.errorf(e.Sel.Pos(), "cannot refer to unexported %s", e.Sel.Name) check.errorf(e.Sel.Pos(), "cannot refer to unexported %s", sel)
goto Error goto Error
} }
// simplified version of the code for *ast.Idents: // simplified version of the code for *ast.Idents:
@ -561,24 +598,39 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
} }
} }
// TODO(gri) lots of checks missing below - just raw outline check.exprOrType(x, e.X, nil, iota, false)
check.expr(x, e.X, hint, iota) if x.mode == invalid {
switch typ := x.typ.(type) {
case *Struct:
if fld := lookupField(typ, e.Sel.Name); fld != nil {
// TODO(gri) only variable if struct is variable
x.mode = variable
x.expr = e
x.typ = fld.Type
return
}
case *Interface:
unimplemented()
case *NamedType:
unimplemented()
}
check.invalidOp(e.Pos(), "%s has no field or method %s", x.typ, e.Sel.Name)
goto Error goto Error
}
mode, typ := lookupField(x.typ, sel)
if mode == invalid {
check.invalidOp(e.Pos(), "%s has no field or method %s", x, sel)
goto Error
}
if x.mode == typexpr {
// method expression
sig, ok := typ.(*Signature)
if !ok {
check.invalidOp(e.Pos(), "%s has no method %s", x, sel)
goto Error
}
// the receiver type becomes the type of the first function
// argument of the method expression's function type
// TODO(gri) at the moment, method sets don't correctly track
// pointer vs non-pointer receivers -> typechecker is too lenient
arg := ast.NewObj(ast.Var, "")
arg.Type = x.typ
x.mode = value
x.typ = &Signature{
Params: append(ObjList{arg}, sig.Params...),
Results: sig.Results,
IsVariadic: sig.IsVariadic,
}
} else {
// regular selector
x.mode = mode
x.typ = typ
}
case *ast.IndexExpr: case *ast.IndexExpr:
check.expr(x, e.X, hint, iota) check.expr(x, e.X, hint, iota)
@ -614,7 +666,7 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
case *Map: case *Map:
// TODO(gri) check index type // TODO(gri) check index type
x.mode = variable x.mode = valueok
x.typ = typ.Elt x.typ = typ.Elt
return return
} }
@ -672,24 +724,26 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
goto Error goto Error
} }
var lo int64 var lo interface{} = zeroConst
if e.Low != nil { if e.Low != nil {
lo = check.index(e.Low, length, iota) lo = check.index(e.Low, length, iota)
} }
var hi int64 = length var hi interface{}
if e.High != nil { if e.High != nil {
hi = check.index(e.High, length, iota) hi = check.index(e.High, length, iota)
} else if length >= 0 {
hi = length
} }
if hi >= 0 && lo > hi { if lo != nil && hi != nil && compareConst(lo, hi, token.GTR) {
check.errorf(e.Low.Pos(), "inverted slice range: %d > %d", lo, hi) check.errorf(e.Low.Pos(), "inverted slice range: %v > %v", lo, hi)
// ok to continue // ok to continue
} }
case *ast.TypeAssertExpr: case *ast.TypeAssertExpr:
check.expr(x, e.X, hint, iota) check.expr(x, e.X, hint, iota)
if _, ok := x.typ.(*Interface); !ok { if _, ok := underlying(x.typ).(*Interface); !ok {
check.invalidOp(e.X.Pos(), "non-interface type %s in type assertion", x.typ) check.invalidOp(e.X.Pos(), "non-interface type %s in type assertion", x.typ)
// ok to continue // ok to continue
} }
@ -700,9 +754,10 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
case *ast.CallExpr: case *ast.CallExpr:
check.exprOrType(x, e.Fun, nil, iota, false) check.exprOrType(x, e.Fun, nil, iota, false)
if x.mode == typexpr { if x.mode == invalid {
goto Error
} else if x.mode == typexpr {
check.conversion(x, e, x.typ, iota) check.conversion(x, e, x.typ, iota)
} else if sig, ok := underlying(x.typ).(*Signature); ok { } else if sig, ok := underlying(x.typ).(*Signature); ok {
// check parameters // check parameters
// TODO(gri) complete this // TODO(gri) complete this
@ -747,8 +802,7 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
case *ast.StarExpr: case *ast.StarExpr:
check.exprOrType(x, e.X, hint, iota, true) check.exprOrType(x, e.X, hint, iota, true)
switch x.mode { switch x.mode {
case novalue: case invalid:
check.errorf(x.pos(), "%s used as value or type", x)
goto Error goto Error
case typexpr: case typexpr:
x.typ = &Pointer{Base: x.typ} x.typ = &Pointer{Base: x.typ}
@ -777,11 +831,18 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
case *ast.ArrayType: case *ast.ArrayType:
if e.Len != nil { if e.Len != nil {
var n int64 = -1
if ellip, ok := e.Len.(*ast.Ellipsis); ok {
// TODO(gri) need to check somewhere that [...]T types are only used with composite literals
if ellip.Elt != nil {
check.invalidAST(ellip.Pos(), "ellipsis only expected")
// ok to continue
}
} else {
check.expr(x, e.Len, nil, 0) check.expr(x, e.Len, nil, 0)
if x.mode == invalid { if x.mode == invalid {
goto Error goto Error
} }
var n int64 = -1
if x.mode == constant { if x.mode == constant {
if i, ok := x.val.(int64); ok && i == int64(int(i)) { if i, ok := x.val.(int64); ok && i == int64(int(i)) {
n = i n = i
@ -792,6 +853,7 @@ func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cy
// ok to continue // ok to continue
n = 0 n = 0
} }
}
x.typ = &Array{Len: n, Elt: check.typ(e.Elt, cycleOk)} x.typ = &Array{Len: n, Elt: check.typ(e.Elt, cycleOk)}
} else { } else {
x.typ = &Slice{Elt: check.typ(e.Elt, true)} x.typ = &Slice{Elt: check.typ(e.Elt, true)}
@ -836,9 +898,18 @@ Error:
x.expr = e x.expr = e
} }
// expr is like exprOrType but also checks that e represents a value (rather than a type). // exprOrType is like rawExpr but reports an error if e doesn't represents a value or type.
func (check *checker) exprOrType(x *operand, e ast.Expr, hint Type, iota int, cycleOk bool) {
check.rawExpr(x, e, hint, iota, cycleOk)
if x.mode == novalue {
check.errorf(x.pos(), "%s used as value or type", x)
x.mode = invalid
}
}
// expr is like rawExpr but reports an error if e doesn't represents a value.
func (check *checker) expr(x *operand, e ast.Expr, hint Type, iota int) { func (check *checker) expr(x *operand, e ast.Expr, hint Type, iota int) {
check.exprOrType(x, e, hint, iota, false) check.rawExpr(x, e, hint, iota, false)
switch x.mode { switch x.mode {
case novalue: case novalue:
check.errorf(x.pos(), "%s used as value", x) check.errorf(x.pos(), "%s used as value", x)
@ -849,19 +920,21 @@ func (check *checker) expr(x *operand, e ast.Expr, hint Type, iota int) {
} }
} }
// typ is like exprOrType but also checks that e represents a type (rather than a value). // expr is like rawExpr but reports an error if e doesn't represents a type.
// If an error occured, the result is Typ[Invalid]. // It returns e's type, or Typ[Invalid] if an error occured.
// //
func (check *checker) typ(e ast.Expr, cycleOk bool) Type { func (check *checker) typ(e ast.Expr, cycleOk bool) Type {
var x operand var x operand
check.exprOrType(&x, e, nil, -1, cycleOk) check.rawExpr(&x, e, nil, -1, cycleOk)
switch { switch x.mode {
case x.mode == novalue: case invalid:
// ignore - error reported before
case novalue:
check.errorf(x.pos(), "%s used as type", &x) check.errorf(x.pos(), "%s used as type", &x)
x.typ = Typ[Invalid] case typexpr:
case x.mode != typexpr:
check.errorf(x.pos(), "%s is not a type", &x)
x.typ = Typ[Invalid]
}
return x.typ return x.typ
default:
check.errorf(x.pos(), "%s is not a type", &x)
}
return Typ[Invalid]
} }

View File

@ -23,8 +23,6 @@ import (
"text/scanner" "text/scanner"
) )
const trace = false // set to true for debugging
var pkgExts = [...]string{".a", ".5", ".6", ".8"} var pkgExts = [...]string{".a", ".5", ".6", ".8"}
// FindPkg returns the filename and unique package id for an import // FindPkg returns the filename and unique package id for an import
@ -86,10 +84,6 @@ func FindPkg(path, srcDir string) (filename, id string) {
// in error messages. // in error messages.
// //
func GcImportData(imports map[string]*ast.Object, filename, id string, data *bufio.Reader) (pkg *ast.Object, err error) { func GcImportData(imports map[string]*ast.Object, filename, id string, data *bufio.Reader) (pkg *ast.Object, err error) {
if trace {
fmt.Printf("importing %s (%s)\n", id, filename)
}
// support for gcParser error handling // support for gcParser error handling
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -182,12 +176,13 @@ func (p *gcParser) init(filename, id string, src io.Reader, imports map[string]*
func (p *gcParser) next() { func (p *gcParser) next() {
p.tok = p.scanner.Scan() p.tok = p.scanner.Scan()
switch p.tok { switch p.tok {
case scanner.Ident, scanner.Int, scanner.String, '·': case scanner.Ident, scanner.Int, scanner.Char, scanner.String, '·':
p.lit = p.scanner.TokenText() p.lit = p.scanner.TokenText()
default: default:
p.lit = "" p.lit = ""
} }
if trace { // leave for debugging
if false {
fmt.Printf("%s: %q -> %q\n", scanner.TokenString(p.tok), p.scanner.TokenText(), p.lit) fmt.Printf("%s: %q -> %q\n", scanner.TokenString(p.tok), p.scanner.TokenText(), p.lit)
} }
} }
@ -204,13 +199,13 @@ func (p *gcParser) declare(scope *ast.Scope, kind ast.ObjKind, name string) *ast
// otherwise create a new object and insert it into the package scope // otherwise create a new object and insert it into the package scope
obj := ast.NewObj(kind, name) obj := ast.NewObj(kind, name)
if scope.Insert(obj) != nil { if scope.Insert(obj) != nil {
p.errorf("already declared: %v %s", kind, obj.Name) unreachable() // Lookup should have found it
} }
// a new type object is a named type and may be referred // if the new type object is a named type it may be referred
// to before the underlying type is known - set it up // to before the underlying type is known - set it up
if kind == ast.Typ { if kind == ast.Typ {
obj.Type = &Name{Obj: obj} obj.Type = &NamedType{Obj: obj}
} }
return obj return obj
@ -244,7 +239,6 @@ func (p *gcParser) errorf(format string, args ...interface{}) {
func (p *gcParser) expect(tok rune) string { func (p *gcParser) expect(tok rune) string {
lit := p.lit lit := p.lit
if p.tok != tok { if p.tok != tok {
panic(1)
p.errorf("expected %s, got %s (%s)", scanner.TokenString(tok), scanner.TokenString(p.tok), lit) p.errorf("expected %s, got %s (%s)", scanner.TokenString(tok), scanner.TokenString(p.tok), lit)
} }
p.next() p.next()
@ -350,7 +344,7 @@ func (p *gcParser) parseArrayType() Type {
lit := p.expect(scanner.Int) lit := p.expect(scanner.Int)
p.expect(']') p.expect(']')
elt := p.parseType() elt := p.parseType()
n, err := strconv.ParseUint(lit, 10, 64) n, err := strconv.ParseInt(lit, 10, 64)
if err != nil { if err != nil {
p.error(err) p.error(err)
} }
@ -389,34 +383,33 @@ func (p *gcParser) parseName() (name string) {
// Field = Name Type [ string_lit ] . // Field = Name Type [ string_lit ] .
// //
func (p *gcParser) parseField() (fld *ast.Object, tag string) { func (p *gcParser) parseField() *StructField {
name := p.parseName() var f StructField
ftyp := p.parseType() f.Name = p.parseName()
if name == "" { f.Type = p.parseType()
// anonymous field - ftyp must be T or *T and T must be a type name if p.tok == scanner.String {
if _, ok := Deref(ftyp).(*Name); !ok { f.Tag = p.expect(scanner.String)
}
if f.Name == "" {
// anonymous field - typ must be T or *T and T must be a type name
if typ, ok := deref(f.Type).(*NamedType); ok && typ.Obj != nil {
f.Name = typ.Obj.Name
f.IsAnonymous = true
} else {
p.errorf("anonymous field expected") p.errorf("anonymous field expected")
} }
} }
if p.tok == scanner.String { return &f
tag = p.expect(scanner.String)
}
fld = ast.NewObj(ast.Var, name)
fld.Type = ftyp
return
} }
// StructType = "struct" "{" [ FieldList ] "}" . // StructType = "struct" "{" [ FieldList ] "}" .
// FieldList = Field { ";" Field } . // FieldList = Field { ";" Field } .
// //
func (p *gcParser) parseStructType() Type { func (p *gcParser) parseStructType() Type {
var fields []*ast.Object var fields []*StructField
var tags []string
parseField := func() { parseField := func() {
fld, tag := p.parseField() fields = append(fields, p.parseField())
fields = append(fields, fld)
tags = append(tags, tag)
} }
p.expectKeyword("struct") p.expectKeyword("struct")
@ -430,7 +423,7 @@ func (p *gcParser) parseStructType() Type {
} }
p.expect('}') p.expect('}')
return &Struct{Fields: fields, Tags: tags} return &Struct{Fields: fields}
} }
// Parameter = ( identifier | "?" ) [ "..." ] Type [ string_lit ] . // Parameter = ( identifier | "?" ) [ "..." ] Type [ string_lit ] .
@ -445,9 +438,9 @@ func (p *gcParser) parseParameter() (par *ast.Object, isVariadic bool) {
isVariadic = true isVariadic = true
} }
ptyp := p.parseType() ptyp := p.parseType()
// ignore argument tag // ignore argument tag (e.g. "noescape")
if p.tok == scanner.String { if p.tok == scanner.String {
p.expect(scanner.String) p.next()
} }
par = ast.NewObj(ast.Var, name) par = ast.NewObj(ast.Var, name)
par.Type = ptyp par.Type = ptyp
@ -485,7 +478,7 @@ func (p *gcParser) parseParameters() (list []*ast.Object, isVariadic bool) {
// Signature = Parameters [ Result ] . // Signature = Parameters [ Result ] .
// Result = Type | Parameters . // Result = Type | Parameters .
// //
func (p *gcParser) parseSignature() *Func { func (p *gcParser) parseSignature() *Signature {
params, isVariadic := p.parseParameters() params, isVariadic := p.parseParameters()
// optional result type // optional result type
@ -505,16 +498,16 @@ func (p *gcParser) parseSignature() *Func {
} }
} }
return &Func{Params: params, Results: results, IsVariadic: isVariadic} return &Signature{Params: params, Results: results, IsVariadic: isVariadic}
} }
// InterfaceType = "interface" "{" [ MethodList ] "}" . // InterfaceType = "interface" "{" [ MethodList ] "}" .
// MethodList = Method { ";" Method } . // MethodList = Method { ";" Method } .
// Method = Name Signature . // Method = Name Signature .
// //
// (The methods of embedded interfaces are always "inlined" // The methods of embedded interfaces are always "inlined"
// by the compiler and thus embedded interfaces are never // by the compiler and thus embedded interfaces are never
// visible in the export data.) // visible in the export data.
// //
func (p *gcParser) parseInterfaceType() Type { func (p *gcParser) parseInterfaceType() Type {
var methods ObjList var methods ObjList
@ -563,6 +556,7 @@ func (p *gcParser) parseChanType() Type {
// BasicType | TypeName | ArrayType | SliceType | StructType | // BasicType | TypeName | ArrayType | SliceType | StructType |
// PointerType | FuncType | InterfaceType | MapType | ChanType | // PointerType | FuncType | InterfaceType | MapType | ChanType |
// "(" Type ")" . // "(" Type ")" .
//
// BasicType = ident . // BasicType = ident .
// TypeName = ExportedName . // TypeName = ExportedName .
// SliceType = "[" "]" Type . // SliceType = "[" "]" Type .
@ -635,11 +629,11 @@ func (p *gcParser) parseImportDecl() {
// int_lit = [ "+" | "-" ] { "0" ... "9" } . // int_lit = [ "+" | "-" ] { "0" ... "9" } .
// //
func (p *gcParser) parseInt() (sign, val string) { func (p *gcParser) parseInt() (neg bool, val string) {
switch p.tok { switch p.tok {
case '-': case '-':
p.next() neg = true
sign = "-" fallthrough
case '+': case '+':
p.next() p.next()
} }
@ -649,37 +643,48 @@ func (p *gcParser) parseInt() (sign, val string) {
// number = int_lit [ "p" int_lit ] . // number = int_lit [ "p" int_lit ] .
// //
func (p *gcParser) parseNumber() Const { func (p *gcParser) parseNumber() (x operand) {
x.mode = constant
// mantissa // mantissa
sign, val := p.parseInt() neg, val := p.parseInt()
mant, ok := new(big.Int).SetString(sign+val, 0) mant, ok := new(big.Int).SetString(val, 0)
assert(ok) assert(ok)
if neg {
mant.Neg(mant)
}
if p.lit == "p" { if p.lit == "p" {
// exponent (base 2) // exponent (base 2)
p.next() p.next()
sign, val = p.parseInt() neg, val = p.parseInt()
exp64, err := strconv.ParseUint(val, 10, 0) exp64, err := strconv.ParseUint(val, 10, 0)
if err != nil { if err != nil {
p.error(err) p.error(err)
} }
exp := uint(exp64) exp := uint(exp64)
if sign == "-" { if neg {
denom := big.NewInt(1) denom := big.NewInt(1)
denom.Lsh(denom, exp) denom.Lsh(denom, exp)
return Const{new(big.Rat).SetFrac(mant, denom)} x.typ = Typ[UntypedFloat]
x.val = normalizeRatConst(new(big.Rat).SetFrac(mant, denom))
return
} }
if exp > 0 { if exp > 0 {
mant.Lsh(mant, exp) mant.Lsh(mant, exp)
} }
return Const{new(big.Rat).SetInt(mant)} x.typ = Typ[UntypedFloat]
x.val = normalizeIntConst(mant)
return
} }
return Const{mant} x.typ = Typ[UntypedInt]
x.val = normalizeIntConst(mant)
return
} }
// ConstDecl = "const" ExportedName [ Type ] "=" Literal . // ConstDecl = "const" ExportedName [ Type ] "=" Literal .
// Literal = bool_lit | int_lit | float_lit | complex_lit | string_lit . // Literal = bool_lit | int_lit | float_lit | complex_lit | rune_lit | string_lit .
// bool_lit = "true" | "false" . // bool_lit = "true" | "false" .
// complex_lit = "(" float_lit "+" float_lit "i" ")" . // complex_lit = "(" float_lit "+" float_lit "i" ")" .
// rune_lit = "(" int_lit "+" int_lit ")" . // rune_lit = "(" int_lit "+" int_lit ")" .
@ -689,8 +694,7 @@ func (p *gcParser) parseConstDecl() {
p.expectKeyword("const") p.expectKeyword("const")
pkg, name := p.parseExportedName() pkg, name := p.parseExportedName()
obj := p.declare(pkg.Data.(*ast.Scope), ast.Con, name) obj := p.declare(pkg.Data.(*ast.Scope), ast.Con, name)
var x Const var x operand
var typ Type
if p.tok != '=' { if p.tok != '=' {
obj.Type = p.parseType() obj.Type = p.parseType()
} }
@ -701,25 +705,23 @@ func (p *gcParser) parseConstDecl() {
if p.lit != "true" && p.lit != "false" { if p.lit != "true" && p.lit != "false" {
p.error("expected true or false") p.error("expected true or false")
} }
x = Const{p.lit == "true"} x.typ = Typ[UntypedBool]
typ = Bool.Underlying x.val = p.lit == "true"
p.next() p.next()
case '-', scanner.Int: case '-', scanner.Int:
// int_lit // int_lit
x = p.parseNumber() x = p.parseNumber()
typ = Int.Underlying
if _, ok := x.val.(*big.Rat); ok {
typ = Float64.Underlying
}
case '(': case '(':
// complex_lit or rune_lit // complex_lit or rune_lit
p.next() p.next()
if p.tok == scanner.Char { if p.tok == scanner.Char {
p.next() p.next()
p.expect('+') p.expect('+')
p.parseNumber() x = p.parseNumber()
x.typ = Typ[UntypedRune]
p.expect(')') p.expect(')')
// TODO: x = ...
break break
} }
re := p.parseNumber() re := p.parseNumber()
@ -727,23 +729,29 @@ func (p *gcParser) parseConstDecl() {
im := p.parseNumber() im := p.parseNumber()
p.expectKeyword("i") p.expectKeyword("i")
p.expect(')') p.expect(')')
x = Const{cmplx{re.val.(*big.Rat), im.val.(*big.Rat)}} x.typ = Typ[UntypedComplex]
typ = Complex128.Underlying // TODO(gri) fix this
_, _ = re, im
x.val = zeroConst
case scanner.Char: case scanner.Char:
// TODO: x = ... // rune_lit
x.setConst(token.CHAR, p.lit)
p.next() p.next()
case scanner.String: case scanner.String:
// string_lit // string_lit
x = MakeConst(token.STRING, p.lit) x.setConst(token.STRING, p.lit)
p.next() p.next()
typ = String.Underlying
default: default:
p.errorf("expected literal got %s", scanner.TokenString(p.tok)) p.errorf("expected literal got %s", scanner.TokenString(p.tok))
} }
if obj.Type == nil { if obj.Type == nil {
obj.Type = typ obj.Type = x.typ
} }
obj.Data = x assert(x.val != nil)
obj.Data = x.val
} }
// TypeDecl = "type" ExportedName Type . // TypeDecl = "type" ExportedName Type .
@ -760,8 +768,7 @@ func (p *gcParser) parseTypeDecl() {
// a given type declaration. // a given type declaration.
typ := p.parseType() typ := p.parseType()
if name := obj.Type.(*Name); name.Underlying == nil { if name := obj.Type.(*NamedType); name.Underlying == nil {
assert(Underlying(typ) == typ)
name.Underlying = typ name.Underlying = typ
} }
} }
@ -775,10 +782,14 @@ func (p *gcParser) parseVarDecl() {
obj.Type = p.parseType() obj.Type = p.parseType()
} }
// FuncBody = "{" ... "}" . // Func = Signature [ Body ] .
// Body = "{" ... "}" .
// //
func (p *gcParser) parseFuncBody() { func (p *gcParser) parseFunc(scope *ast.Scope, name string) {
p.expect('{') obj := p.declare(scope, ast.Fun, name)
obj.Type = p.parseSignature()
if p.tok == '{' {
p.next()
for i := 1; i > 0; p.next() { for i := 1; i > 0; p.next() {
switch p.tok { switch p.tok {
case '{': case '{':
@ -788,32 +799,44 @@ func (p *gcParser) parseFuncBody() {
} }
} }
} }
// FuncDecl = "func" ExportedName Signature [ FuncBody ] .
//
func (p *gcParser) parseFuncDecl() {
// "func" already consumed
pkg, name := p.parseExportedName()
obj := p.declare(pkg.Data.(*ast.Scope), ast.Fun, name)
obj.Type = p.parseSignature()
if p.tok == '{' {
p.parseFuncBody()
}
} }
// MethodDecl = "func" Receiver Name Signature . // MethodDecl = "func" Receiver Name Func .
// Receiver = "(" ( identifier | "?" ) [ "*" ] ExportedName ")" [ FuncBody ]. // Receiver = "(" ( identifier | "?" ) [ "*" ] ExportedName ")" .
// //
func (p *gcParser) parseMethodDecl() { func (p *gcParser) parseMethodDecl() {
// "func" already consumed // "func" already consumed
p.expect('(') p.expect('(')
p.parseParameter() // receiver recv, _ := p.parseParameter() // receiver
p.expect(')') p.expect(')')
p.parseName() // unexported method names in imports are qualified with their package.
p.parseSignature() // determine receiver base type object
if p.tok == '{' { typ := recv.Type.(Type)
p.parseFuncBody() if ptr, ok := typ.(*Pointer); ok {
typ = ptr.Base
} }
obj := typ.(*NamedType).Obj
// determine base type scope
var scope *ast.Scope
if obj.Data != nil {
scope = obj.Data.(*ast.Scope)
} else {
scope = ast.NewScope(nil)
obj.Data = scope
}
// declare method in base type scope
name := p.parseName() // unexported method names in imports are qualified with their package.
p.parseFunc(scope, name)
}
// FuncDecl = "func" ExportedName Func .
//
func (p *gcParser) parseFuncDecl() {
// "func" already consumed
pkg, name := p.parseExportedName()
p.parseFunc(pkg.Data.(*ast.Scope), name)
} }
// Decl = [ ImportDecl | ConstDecl | TypeDecl | VarDecl | FuncDecl | MethodDecl ] "\n" . // Decl = [ ImportDecl | ConstDecl | TypeDecl | VarDecl | FuncDecl | MethodDecl ] "\n" .

View File

@ -113,24 +113,24 @@ func TestGcImport(t *testing.T) {
t.Logf("tested %d imports", nimports) t.Logf("tested %d imports", nimports)
} }
/*
Does not work with gccgo.
var importedObjectTests = []struct { var importedObjectTests = []struct {
name string name string
kind ast.ObjKind kind ast.ObjKind
typ string typ string
}{ }{
{"unsafe.Pointer", ast.Typ, "Pointer"}, {"unsafe.Pointer", ast.Typ, "Pointer"},
{"math.Pi", ast.Con, "basicType"}, // TODO(gri) need to complete BasicType {"math.Pi", ast.Con, "untyped float"},
{"io.Reader", ast.Typ, "interface{Read(p []byte) (n int, err error)}"}, {"io.Reader", ast.Typ, "interface{Read(p []byte) (n int, err error)}"},
{"io.ReadWriter", ast.Typ, "interface{Read(p []byte) (n int, err error); Write(p []byte) (n int, err error)}"}, {"io.ReadWriter", ast.Typ, "interface{Read(p []byte) (n int, err error); Write(p []byte) (n int, err error)}"},
{"math.Sin", ast.Fun, "func(x float64) (_ float64)"}, {"math.Sin", ast.Fun, "func(x·2 float64) (_ float64)"},
// TODO(gri) add more tests // TODO(gri) add more tests
} }
func TestGcImportedTypes(t *testing.T) { func TestGcImportedTypes(t *testing.T) {
// This package does not yet know how to read gccgo export data.
if runtime.Compiler == "gccgo" {
return
}
for _, test := range importedObjectTests { for _, test := range importedObjectTests {
s := strings.Split(test.name, ".") s := strings.Split(test.name, ".")
if len(s) != 2 { if len(s) != 2 {
@ -149,11 +149,9 @@ func TestGcImportedTypes(t *testing.T) {
if obj.Kind != test.kind { if obj.Kind != test.kind {
t.Errorf("%s: got kind = %q; want %q", test.name, obj.Kind, test.kind) t.Errorf("%s: got kind = %q; want %q", test.name, obj.Kind, test.kind)
} }
typ := TypeString(Underlying(obj.Type.(Type))) typ := typeString(underlying(obj.Type.(Type)))
if typ != test.typ { if typ != test.typ {
t.Errorf("%s: got type = %q; want %q", test.name, typ, test.typ) t.Errorf("%s: got type = %q; want %q", test.name, typ, test.typ)
} }
} }
} }
*/

View File

@ -69,7 +69,11 @@ func (x *operand) String() string {
} }
buf.WriteString(operandModeString[x.mode]) buf.WriteString(operandModeString[x.mode])
if x.mode == constant { if x.mode == constant {
fmt.Fprintf(&buf, " %v", x.val) format := " %v"
if isString(x.typ) {
format = " %q"
}
fmt.Fprintf(&buf, format, x.val)
} }
if x.mode != novalue && (x.mode != constant || !isUntyped(x.typ)) { if x.mode != novalue && (x.mode != constant || !isUntyped(x.typ)) {
fmt.Fprintf(&buf, " of type %s", typeString(x.typ)) fmt.Fprintf(&buf, " of type %s", typeString(x.typ))
@ -125,6 +129,11 @@ func (x *operand) implements(T *Interface) bool {
return true return true
} }
// isNil reports whether x is the predeclared nil constant.
func (x *operand) isNil() bool {
return x.mode == constant && x.val == nilConst
}
// isAssignable reports whether x is assignable to a variable of type T. // isAssignable reports whether x is assignable to a variable of type T.
func (x *operand) isAssignable(T Type) bool { func (x *operand) isAssignable(T Type) bool {
if x.mode == invalid || T == Typ[Invalid] { if x.mode == invalid || T == Typ[Invalid] {
@ -163,7 +172,7 @@ func (x *operand) isAssignable(T Type) bool {
// x is the predeclared identifier nil and T is a pointer, // x is the predeclared identifier nil and T is a pointer,
// function, slice, map, channel, or interface type // function, slice, map, channel, or interface type
if x.typ == Typ[UntypedNil] { if x.isNil() {
switch Tu.(type) { switch Tu.(type) {
case *Pointer, *Signature, *Slice, *Map, *Chan, *Interface: case *Pointer, *Signature, *Slice, *Map, *Chan, *Interface:
return true return true
@ -185,17 +194,135 @@ func (x *operand) isInteger() bool {
x.mode == constant && isRepresentableConst(x.val, UntypedInt) x.mode == constant && isRepresentableConst(x.val, UntypedInt)
} }
// lookupField returns the struct field with the given name in typ. type lookupResult struct {
// If no such field exists, the result is nil. mode operandMode
// TODO(gri) should this be a method of Struct? typ Type
// }
func lookupField(typ *Struct, name string) *StructField {
// TODO(gri) deal with embedding and conflicts - this is // lookupFieldRecursive is similar to FieldByNameFunc in reflect/type.go
// a very basic version to get going for now. // TODO(gri): FieldByNameFunc seems more complex - what are we missing?
func lookupFieldRecursive(list []*NamedType, name string) (res lookupResult) {
// visited records the types that have been searched already
visited := make(map[Type]bool)
// embedded types of the next lower level
var next []*NamedType
potentialMatch := func(mode operandMode, typ Type) bool {
if res.mode != invalid {
// name appeared multiple times at this level - annihilate
res.mode = invalid
return false
}
res.mode = mode
res.typ = typ
return true
}
// look for name in all types of this level
for len(list) > 0 {
assert(res.mode == invalid)
for _, typ := range list {
if visited[typ] {
// We have seen this type before, at a higher level.
// That higher level shadows the lower level we are
// at now, and either we would have found or not
// found the field before. Ignore this type now.
continue
}
visited[typ] = true
// look for a matching attached method
if data := typ.Obj.Data; data != nil {
if obj := data.(*ast.Scope).Lookup(name); obj != nil {
assert(obj.Type != nil)
if !potentialMatch(value, obj.Type.(Type)) {
return // name collision
}
}
}
switch typ := underlying(typ).(type) {
case *Struct:
// look for a matching fieldm and collect embedded types
for _, f := range typ.Fields { for _, f := range typ.Fields {
if f.Name == name { if f.Name == name {
return f assert(f.Type != nil)
if !potentialMatch(variable, f.Type) {
return // name collision
}
continue
}
// Collect embedded struct fields for searching the next
// lower level, but only if we have not seen a match yet.
// Embedded fields are always of the form T or *T where
// T is a named type.
if f.IsAnonymous && res.mode == invalid {
next = append(next, deref(f.Type).(*NamedType))
} }
} }
return nil
case *Interface:
// look for a matching method
for _, obj := range typ.Methods {
if obj.Name == name {
assert(obj.Type != nil)
if !potentialMatch(value, obj.Type.(Type)) {
return // name collision
}
}
}
}
}
if res.mode != invalid {
// we found a match on this level
return
}
// search the next level
list = append(list[:0], next...) // don't waste underlying arrays
next = next[:0]
}
return
}
func lookupField(typ Type, name string) (operandMode, Type) {
typ = deref(typ)
if typ, ok := typ.(*NamedType); ok {
if data := typ.Obj.Data; data != nil {
if obj := data.(*ast.Scope).Lookup(name); obj != nil {
assert(obj.Type != nil)
return value, obj.Type.(Type)
}
}
}
switch typ := underlying(typ).(type) {
case *Struct:
var list []*NamedType
for _, f := range typ.Fields {
if f.Name == name {
return variable, f.Type
}
if f.IsAnonymous {
list = append(list, deref(f.Type).(*NamedType))
}
}
if len(list) > 0 {
res := lookupFieldRecursive(list, name)
return res.mode, res.typ
}
case *Interface:
for _, obj := range typ.Methods {
if obj.Name == name {
return value, obj.Type.(Type)
}
}
}
// not found
return invalid, nil
} }

View File

@ -59,11 +59,16 @@ func isOrdered(typ Type) bool {
return ok && t.Info&IsOrdered != 0 return ok && t.Info&IsOrdered != 0
} }
func isConstType(typ Type) bool {
t, ok := underlying(typ).(*Basic)
return ok && t.Info&IsConstType != 0
}
func isComparable(typ Type) bool { func isComparable(typ Type) bool {
switch t := underlying(typ).(type) { switch t := underlying(typ).(type) {
case *Basic: case *Basic:
return t.Kind != Invalid return t.Kind != Invalid && t.Kind != UntypedNil
case *Pointer, *Chan, *Interface: case *Pointer, *Interface, *Chan:
// assumes types are equal for pointers and channels // assumes types are equal for pointers and channels
return true return true
case *Struct: case *Struct:
@ -79,6 +84,14 @@ func isComparable(typ Type) bool {
return false return false
} }
func hasNil(typ Type) bool {
switch underlying(typ).(type) {
case *Slice, *Pointer, *Signature, *Interface, *Map, *Chan:
return true
}
return false
}
// identical returns true if x and y are identical. // identical returns true if x and y are identical.
func isIdentical(x, y Type) bool { func isIdentical(x, y Type) bool {
if x == y { if x == y {

View File

@ -1,352 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements the Check function, which typechecks a package.
package types
import (
"fmt"
"go/ast"
"go/token"
"sort"
)
type checker struct {
fset *token.FileSet
pkg *ast.Package
errh func(token.Pos, string)
mapf func(ast.Expr, Type)
// lazily initialized
firsterr error
filenames []string // sorted list of package file names for reproducible iteration order
initexprs map[*ast.ValueSpec][]ast.Expr // "inherited" initialization expressions for constant declarations
}
// declare declares an object of the given kind and name (ident) in scope;
// decl is the corresponding declaration in the AST. An error is reported
// if the object was declared before.
//
// TODO(gri) This is very similar to the declare function in go/parser; it
// is only used to associate methods with their respective receiver base types.
// In a future version, it might be simpler and cleaner do to all the resolution
// in the type-checking phase. It would simplify the parser, AST, and also
// reduce some amount of code duplication.
//
func (check *checker) declare(scope *ast.Scope, kind ast.ObjKind, ident *ast.Ident, decl ast.Decl) {
assert(ident.Obj == nil) // identifier already declared or resolved
obj := ast.NewObj(kind, ident.Name)
obj.Decl = decl
ident.Obj = obj
if ident.Name != "_" {
if alt := scope.Insert(obj); alt != nil {
prevDecl := ""
if pos := alt.Pos(); pos.IsValid() {
prevDecl = fmt.Sprintf("\n\tprevious declaration at %s", check.fset.Position(pos))
}
check.errorf(ident.Pos(), fmt.Sprintf("%s redeclared in this block%s", ident.Name, prevDecl))
}
}
}
func (check *checker) valueSpec(pos token.Pos, obj *ast.Object, lhs []*ast.Ident, typ ast.Expr, rhs []ast.Expr, iota int) {
if len(lhs) == 0 {
check.invalidAST(pos, "missing lhs in declaration")
return
}
var t Type
if typ != nil {
t = check.typ(typ, false)
}
// len(lhs) >= 1
if len(lhs) == len(rhs) {
// check only corresponding lhs and rhs
var l, r ast.Expr
for i, ident := range lhs {
if ident.Obj == obj {
l = lhs[i]
r = rhs[i]
break
}
}
assert(l != nil)
obj.Type = t
// check rhs
var x operand
check.expr(&x, r, t, iota)
// assign to lhs
check.assignment(l, &x, true)
return
}
if t != nil {
for _, name := range lhs {
name.Obj.Type = t
}
}
// check initial values, if any
if len(rhs) > 0 {
// TODO(gri) should try to avoid this conversion
lhx := make([]ast.Expr, len(lhs))
for i, e := range lhs {
lhx[i] = e
}
check.assignNtoM(lhx, rhs, true, iota)
}
}
// ident type checks an identifier.
func (check *checker) ident(name *ast.Ident, cycleOk bool) {
obj := name.Obj
if obj == nil {
check.invalidAST(name.Pos(), "missing object for %s", name.Name)
return
}
if obj.Type != nil {
// object has already been type checked
return
}
switch obj.Kind {
case ast.Bad, ast.Pkg:
// nothing to do
case ast.Con, ast.Var:
// The obj.Data field for constants and variables is initialized
// to the respective (hypothetical, for variables) iota value by
// the parser. The object's fields can be in one of the following
// states:
// Type != nil => the constant value is Data
// Type == nil => the object is not typechecked yet, and Data can be:
// Data is int => Data is the value of iota for this declaration
// Data == nil => the object's expression is being evaluated
if obj.Data == nil {
check.errorf(obj.Pos(), "illegal cycle in initialization of %s", obj.Name)
return
}
spec := obj.Decl.(*ast.ValueSpec)
iota := obj.Data.(int)
obj.Data = nil
// determine initialization expressions
values := spec.Values
if len(values) == 0 && obj.Kind == ast.Con {
values = check.initexprs[spec]
}
check.valueSpec(spec.Pos(), obj, spec.Names, spec.Type, values, iota)
case ast.Typ:
typ := &NamedType{Obj: obj}
obj.Type = typ // "mark" object so recursion terminates
typ.Underlying = underlying(check.typ(obj.Decl.(*ast.TypeSpec).Type, cycleOk))
// collect associated methods, if any
if obj.Data != nil {
scope := obj.Data.(*ast.Scope)
// struct fields must not conflict with methods
if t, ok := typ.Underlying.(*Struct); ok {
for _, f := range t.Fields {
if m := scope.Lookup(f.Name); m != nil {
check.errorf(m.Pos(), "type %s has both field and method named %s", obj.Name, f.Name)
}
}
}
// collect methods
methods := make(ObjList, len(scope.Objects))
i := 0
for _, m := range scope.Objects {
methods[i] = m
i++
}
methods.Sort()
typ.Methods = methods
// methods cannot be associated with an interface type
// (do this check after sorting for reproducible error positions - needed for testing)
if _, ok := typ.Underlying.(*Interface); ok {
for _, m := range methods {
recv := m.Decl.(*ast.FuncDecl).Recv.List[0].Type
check.errorf(recv.Pos(), "invalid receiver type %s (%s is an interface type)", obj.Name, obj.Name)
}
}
}
case ast.Fun:
fdecl := obj.Decl.(*ast.FuncDecl)
ftyp := check.typ(fdecl.Type, cycleOk).(*Signature)
obj.Type = ftyp
if fdecl.Recv != nil {
// TODO(gri) handle method receiver
}
check.stmt(fdecl.Body)
default:
panic("unreachable")
}
}
// assocInitvals associates "inherited" initialization expressions
// with the corresponding *ast.ValueSpec in the check.initexprs map
// for constant declarations without explicit initialization expressions.
//
func (check *checker) assocInitvals(decl *ast.GenDecl) {
var values []ast.Expr
for _, s := range decl.Specs {
if s, ok := s.(*ast.ValueSpec); ok {
if len(s.Values) > 0 {
values = s.Values
} else {
check.initexprs[s] = values
}
}
}
if len(values) == 0 {
check.invalidAST(decl.Pos(), "no initialization values provided")
}
}
// assocMethod associates a method declaration with the respective
// receiver base type. meth.Recv must exist.
//
func (check *checker) assocMethod(meth *ast.FuncDecl) {
// The receiver type is one of the following (enforced by parser):
// - *ast.Ident
// - *ast.StarExpr{*ast.Ident}
// - *ast.BadExpr (parser error)
typ := meth.Recv.List[0].Type
if ptr, ok := typ.(*ast.StarExpr); ok {
typ = ptr.X
}
// determine receiver base type object (or nil if error)
var obj *ast.Object
if ident, ok := typ.(*ast.Ident); ok && ident.Obj != nil {
obj = ident.Obj
if obj.Kind != ast.Typ {
check.errorf(ident.Pos(), "%s is not a type", ident.Name)
obj = nil
}
// TODO(gri) determine if obj was defined in this package
/*
if check.notLocal(obj) {
check.errorf(ident.Pos(), "cannot define methods on non-local type %s", ident.Name)
obj = nil
}
*/
} else {
// If it's not an identifier or the identifier wasn't declared/resolved,
// the parser/resolver already reported an error. Nothing to do here.
}
// determine base type scope (or nil if error)
var scope *ast.Scope
if obj != nil {
if obj.Data != nil {
scope = obj.Data.(*ast.Scope)
} else {
scope = ast.NewScope(nil)
obj.Data = scope
}
} else {
// use a dummy scope so that meth can be declared in
// presence of an error and get an associated object
// (always use a new scope so that we don't get double
// declaration errors)
scope = ast.NewScope(nil)
}
check.declare(scope, ast.Fun, meth.Name, meth)
}
func (check *checker) assocInitvalsOrMethod(decl ast.Decl) {
switch d := decl.(type) {
case *ast.GenDecl:
if d.Tok == token.CONST {
check.assocInitvals(d)
}
case *ast.FuncDecl:
if d.Recv != nil {
check.assocMethod(d)
}
}
}
func (check *checker) decl(decl ast.Decl) {
switch d := decl.(type) {
case *ast.BadDecl:
// ignore
case *ast.GenDecl:
for _, spec := range d.Specs {
switch s := spec.(type) {
case *ast.ImportSpec:
// nothing to do (handled by ast.NewPackage)
case *ast.ValueSpec:
for _, name := range s.Names {
if name.Name == "_" {
// TODO(gri) why is _ special here?
} else {
check.ident(name, false)
}
}
case *ast.TypeSpec:
check.ident(s.Name, false)
default:
check.invalidAST(s.Pos(), "unknown ast.Spec node %T", s)
}
}
case *ast.FuncDecl:
check.ident(d.Name, false)
default:
check.invalidAST(d.Pos(), "unknown ast.Decl node %T", d)
}
}
// iterate calls f for each package-level declaration.
func (check *checker) iterate(f func(*checker, ast.Decl)) {
list := check.filenames
if list == nil {
// initialize lazily
for filename := range check.pkg.Files {
list = append(list, filename)
}
sort.Strings(list)
check.filenames = list
}
for _, filename := range list {
for _, decl := range check.pkg.Files[filename].Decls {
f(check, decl)
}
}
}
// A bailout panic is raised to indicate early termination.
type bailout struct{}
func check(fset *token.FileSet, pkg *ast.Package, errh func(token.Pos, string), f func(ast.Expr, Type)) (err error) {
// initialize checker
var check checker
check.fset = fset
check.pkg = pkg
check.errh = errh
check.mapf = f
check.initexprs = make(map[*ast.ValueSpec][]ast.Expr)
// handle bailouts
defer func() {
if p := recover(); p != nil {
_ = p.(bailout) // re-panic if not a bailout
}
err = check.firsterr
}()
// determine missing constant initialization expressions
// and associate methods with types
check.iterate((*checker).assocInitvalsOrMethod)
// typecheck all declarations
check.iterate((*checker).decl)
return
}

View File

@ -1,257 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements a typechecker test harness. The packages specified
// in tests are typechecked. Error messages reported by the typechecker are
// compared against the error messages expected in the test files.
//
// Expected errors are indicated in the test files by putting a comment
// of the form /* ERROR "rx" */ immediately following an offending token.
// The harness will verify that an error matching the regular expression
// rx is reported at that source position. Consecutive comments may be
// used to indicate multiple errors for the same token position.
//
// For instance, the following test file indicates that a "not declared"
// error should be reported for the undeclared variable x:
//
// package p
// func f() {
// _ = x /* ERROR "not declared" */ + 1
// }
package types
import (
"flag"
"fmt"
"go/ast"
"go/parser"
"go/scanner"
"go/token"
"io/ioutil"
"os"
"regexp"
"testing"
)
var listErrors = flag.Bool("list", false, "list errors")
func init() {
// declare builtins for testing
def(ast.Fun, "assert").Type = &builtin{aType, _Assert, "assert", 1, false, true}
def(ast.Fun, "trace").Type = &builtin{aType, _Trace, "trace", 0, true, true}
}
// The test filenames do not end in .go so that they are invisible
// to gofmt since they contain comments that must not change their
// positions relative to surrounding tokens.
var tests = []struct {
name string
files []string
}{
{"decls0", []string{"testdata/decls0.src"}},
{"decls1", []string{"testdata/decls1.src"}},
{"decls2", []string{"testdata/decls2a.src", "testdata/decls2b.src"}},
{"const0", []string{"testdata/const0.src"}},
{"expr0", []string{"testdata/expr0.src"}},
{"expr1", []string{"testdata/expr1.src"}},
{"expr2", []string{"testdata/expr2.src"}},
{"expr3", []string{"testdata/expr3.src"}},
{"builtins", []string{"testdata/builtins.src"}},
{"conversions", []string{"testdata/conversions.src"}},
{"stmt0", []string{"testdata/stmt0.src"}},
}
var fset = token.NewFileSet()
func getFile(filename string) (file *token.File) {
fset.Iterate(func(f *token.File) bool {
if f.Name() == filename {
file = f
return false // end iteration
}
return true
})
return file
}
func getPos(filename string, offset int) token.Pos {
if f := getFile(filename); f != nil {
return f.Pos(offset)
}
return token.NoPos
}
func parseFiles(t *testing.T, testname string, filenames []string) (map[string]*ast.File, error) {
files := make(map[string]*ast.File)
var errors scanner.ErrorList
for _, filename := range filenames {
if _, exists := files[filename]; exists {
t.Fatalf("%s: duplicate file %s", testname, filename)
}
file, err := parser.ParseFile(fset, filename, nil, parser.DeclarationErrors)
if file == nil {
t.Fatalf("%s: could not parse file %s", testname, filename)
}
files[filename] = file
if err != nil {
// if the parser returns a non-scanner.ErrorList error
// the file couldn't be read in the first place and
// file == nil; in that case we shouldn't reach here
errors = append(errors, err.(scanner.ErrorList)...)
}
}
return files, errors
}
// ERROR comments must be of the form /* ERROR "rx" */ and rx is
// a regular expression that matches the expected error message.
//
var errRx = regexp.MustCompile(`^/\* *ERROR *"([^"]*)" *\*/$`)
// expectedErrors collects the regular expressions of ERROR comments found
// in files and returns them as a map of error positions to error messages.
//
func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) map[token.Pos][]string {
errors := make(map[token.Pos][]string)
for filename := range files {
src, err := ioutil.ReadFile(filename)
if err != nil {
t.Fatalf("%s: could not read %s", testname, filename)
}
var s scanner.Scanner
// file was parsed already - do not add it again to the file
// set otherwise the position information returned here will
// not match the position information collected by the parser
s.Init(getFile(filename), src, nil, scanner.ScanComments)
var prev token.Pos // position of last non-comment, non-semicolon token
scanFile:
for {
pos, tok, lit := s.Scan()
switch tok {
case token.EOF:
break scanFile
case token.COMMENT:
s := errRx.FindStringSubmatch(lit)
if len(s) == 2 {
list := errors[prev]
errors[prev] = append(list, string(s[1]))
}
case token.SEMICOLON:
// ignore automatically inserted semicolon
if lit == "\n" {
break
}
fallthrough
default:
prev = pos
}
}
}
return errors
}
func eliminate(t *testing.T, expected map[token.Pos][]string, errors error) {
if *listErrors || errors == nil {
return
}
for _, error := range errors.(scanner.ErrorList) {
// error.Pos is a token.Position, but we want
// a token.Pos so we can do a map lookup
pos := getPos(error.Pos.Filename, error.Pos.Offset)
list := expected[pos]
index := -1 // list index of matching message, if any
// we expect one of the messages in list to match the error at pos
for i, msg := range list {
rx, err := regexp.Compile(msg)
if err != nil {
t.Errorf("%s: %v", error.Pos, err)
continue
}
if match := rx.MatchString(error.Msg); match {
index = i
break
}
}
if index >= 0 {
// eliminate from list
n := len(list) - 1
if n > 0 {
// not the last entry - swap in last element and shorten list by 1
list[index] = list[n]
expected[pos] = list[:n]
} else {
// last entry - remove list from map
delete(expected, pos)
}
} else {
t.Errorf("%s: no error expected: %q", error.Pos, error.Msg)
continue
}
}
}
func checkFiles(t *testing.T, testname string, testfiles []string) {
// TODO(gri) Eventually all these different phases should be
// subsumed into a single function call that takes
// a set of files and creates a fully resolved and
// type-checked AST.
files, err := parseFiles(t, testname, testfiles)
// we are expecting the following errors
// (collect these after parsing the files so that
// they are found in the file set)
errors := expectedErrors(t, testname, files)
// verify errors returned by the parser
eliminate(t, errors, err)
// verify errors returned after resolving identifiers
pkg, err := ast.NewPackage(fset, files, GcImport, Universe)
eliminate(t, errors, err)
// verify errors returned by the typechecker
var list scanner.ErrorList
errh := func(pos token.Pos, msg string) {
list.Add(fset.Position(pos), msg)
}
err = Check(fset, pkg, errh, nil)
eliminate(t, errors, list)
if *listErrors {
scanner.PrintError(os.Stdout, err)
return
}
// there should be no expected errors left
if len(errors) > 0 {
t.Errorf("%s: %d errors not reported:", testname, len(errors))
for pos, msg := range errors {
t.Errorf("%s: %s\n", fset.Position(pos), msg)
}
}
}
func TestCheck(t *testing.T) {
// For easy debugging w/o changing the testing code,
// if there is a local test file, only test that file.
const testfile = "testdata/test.go"
if fi, err := os.Stat(testfile); err == nil && !fi.IsDir() {
fmt.Printf("WARNING: Testing only %s (remove it to run all tests)\n", testfile)
checkFiles(t, testfile, []string{testfile})
return
}
// Otherwise, run all the tests.
for _, test := range tests {
checkFiles(t, test.name, test.files)
}
}

View File

@ -1,662 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements operations on constant values.
package types
import (
"fmt"
"go/token"
"math/big"
"strconv"
)
// TODO(gri) At the moment, constants are different types
// passed around as interface{} values. Consider introducing
// a Const type and use methods instead of xConst functions.
// Representation of constant values.
//
// bool -> bool (true, false)
// numeric -> int64, *big.Int, *big.Rat, complex (ordered by increasing data structure "size")
// string -> string
// nil -> nilType (nilConst)
//
// Numeric constants are normalized after each operation such
// that they are represented by the "smallest" data structure
// required to represent the constant, independent of actual
// type. Non-numeric constants are always normalized.
// Representation of complex numbers.
type complex struct {
re, im *big.Rat
}
func (c complex) String() string {
if c.re.Sign() == 0 {
return fmt.Sprintf("%si", c.im)
}
// normalized complex values always have an imaginary part
return fmt.Sprintf("(%s + %si)", c.re, c.im)
}
// Representation of nil.
type nilType struct{}
func (nilType) String() string {
return "nil"
}
// Frequently used constants.
var (
zeroConst = int64(0)
oneConst = int64(1)
minusOneConst = int64(-1)
nilConst = new(nilType)
)
// int64 bounds
var (
minInt64 = big.NewInt(-1 << 63)
maxInt64 = big.NewInt(1<<63 - 1)
)
// normalizeIntConst returns the smallest constant representation
// for the specific value of x; either an int64 or a *big.Int value.
//
func normalizeIntConst(x *big.Int) interface{} {
if minInt64.Cmp(x) <= 0 && x.Cmp(maxInt64) <= 0 {
return x.Int64()
}
return x
}
// normalizeRatConst returns the smallest constant representation
// for the specific value of x; either an int64, *big.Int value,
// or *big.Rat value.
//
func normalizeRatConst(x *big.Rat) interface{} {
if x.IsInt() {
return normalizeIntConst(x.Num())
}
return x
}
// normalizeComplexConst returns the smallest constant representation
// for the specific value of x; either an int64, *big.Int value, *big.Rat,
// or complex value.
//
func normalizeComplexConst(x complex) interface{} {
if x.im.Sign() == 0 {
return normalizeRatConst(x.re)
}
return x
}
// makeRuneConst returns the int64 code point for the rune literal
// lit. The result is nil if lit is not a correct rune literal.
//
func makeRuneConst(lit string) interface{} {
if n := len(lit); n >= 2 {
if code, _, _, err := strconv.UnquoteChar(lit[1:n-1], '\''); err == nil {
return int64(code)
}
}
return nil
}
// makeRuneConst returns the smallest integer constant representation
// (int64, *big.Int) for the integer literal lit. The result is nil if
// lit is not a correct integer literal.
//
func makeIntConst(lit string) interface{} {
if x, err := strconv.ParseInt(lit, 0, 64); err == nil {
return x
}
if x, ok := new(big.Int).SetString(lit, 0); ok {
return x
}
return nil
}
// makeFloatConst returns the smallest floating-point constant representation
// (int64, *big.Int, *big.Rat) for the floating-point literal lit. The result
// is nil if lit is not a correct floating-point literal.
//
func makeFloatConst(lit string) interface{} {
if x, ok := new(big.Rat).SetString(lit); ok {
return normalizeRatConst(x)
}
return nil
}
// makeComplexConst returns the complex constant representation (complex) for
// the imaginary literal lit. The result is nil if lit is not a correct imaginary
// literal.
//
func makeComplexConst(lit string) interface{} {
n := len(lit)
if n > 0 && lit[n-1] == 'i' {
if im, ok := new(big.Rat).SetString(lit[0 : n-1]); ok {
return normalizeComplexConst(complex{big.NewRat(0, 1), im})
}
}
return nil
}
// makeStringConst returns the string constant representation (string) for
// the string literal lit. The result is nil if lit is not a correct string
// literal.
//
func makeStringConst(lit string) interface{} {
if s, err := strconv.Unquote(lit); err == nil {
return s
}
return nil
}
// isZeroConst reports whether the value of constant x is 0.
// x must be normalized.
//
func isZeroConst(x interface{}) bool {
i, ok := x.(int64) // good enough since constants are normalized
return ok && i == 0
}
// isNegConst reports whether the value of constant x is < 0.
// x must be a non-complex numeric value.
//
func isNegConst(x interface{}) bool {
switch x := x.(type) {
case int64:
return x < 0
case *big.Int:
return x.Sign() < 0
case *big.Rat:
return x.Sign() < 0
}
unreachable()
return false
}
// isRepresentableConst reports whether the value of constant x can
// be represented as a value of the basic type Typ[as] without loss
// of precision.
//
func isRepresentableConst(x interface{}, as BasicKind) bool {
const intBits = 32 // TODO(gri) implementation-specific constant
const ptrBits = 64 // TODO(gri) implementation-specific constant
switch x := x.(type) {
case bool:
return as == Bool || as == UntypedBool
case int64:
switch as {
case Int:
return -1<<(intBits-1) <= x && x <= 1<<(intBits-1)-1
case Int8:
return -1<<(8-1) <= x && x <= 1<<(8-1)-1
case Int16:
return -1<<(16-1) <= x && x <= 1<<(16-1)-1
case Int32, UntypedRune:
return -1<<(32-1) <= x && x <= 1<<(32-1)-1
case Int64:
return true
case Uint:
return 0 <= x && x <= 1<<intBits-1
case Uint8:
return 0 <= x && x <= 1<<8-1
case Uint16:
return 0 <= x && x <= 1<<16-1
case Uint32:
return 0 <= x && x <= 1<<32-1
case Uint64:
return 0 <= x
case Uintptr:
assert(ptrBits == 64)
return 0 <= x
case Float32:
return true // TODO(gri) fix this
case Float64:
return true // TODO(gri) fix this
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedInt, UntypedFloat, UntypedComplex:
return true
}
case *big.Int:
switch as {
case Uint:
return x.Sign() >= 0 && x.BitLen() <= intBits
case Uint64:
return x.Sign() >= 0 && x.BitLen() <= 64
case Uintptr:
return x.Sign() >= 0 && x.BitLen() <= ptrBits
case Float32:
return true // TODO(gri) fix this
case Float64:
return true // TODO(gri) fix this
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedInt, UntypedFloat, UntypedComplex:
return true
}
case *big.Rat:
switch as {
case Float32:
return true // TODO(gri) fix this
case Float64:
return true // TODO(gri) fix this
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedFloat, UntypedComplex:
return true
}
case complex:
switch as {
case Complex64:
return true // TODO(gri) fix this
case Complex128:
return true // TODO(gri) fix this
case UntypedComplex:
return true
}
case string:
return as == String || as == UntypedString
case nilType:
return as == UntypedNil
default:
unreachable()
}
return false
}
var (
int1 = big.NewInt(1)
rat0 = big.NewRat(0, 1)
)
// complexity returns a measure of representation complexity for constant x.
func complexity(x interface{}) int {
switch x.(type) {
case bool, string, nilType:
return 1
case int64:
return 2
case *big.Int:
return 3
case *big.Rat:
return 4
case complex:
return 5
}
unreachable()
return 0
}
// matchConst returns the matching representation (same type) with the
// smallest complexity for two constant values x and y. They must be
// of the same "kind" (boolean, numeric, string, or nilType).
//
func matchConst(x, y interface{}) (_, _ interface{}) {
if complexity(x) > complexity(y) {
y, x = matchConst(y, x)
return x, y
}
// complexity(x) <= complexity(y)
switch x := x.(type) {
case bool, complex, string, nilType:
return x, y
case int64:
switch y := y.(type) {
case int64:
return x, y
case *big.Int:
return big.NewInt(x), y
case *big.Rat:
return big.NewRat(x, 1), y
case complex:
return complex{big.NewRat(x, 1), rat0}, y
}
case *big.Int:
switch y := y.(type) {
case *big.Int:
return x, y
case *big.Rat:
return new(big.Rat).SetFrac(x, int1), y
case complex:
return complex{new(big.Rat).SetFrac(x, int1), rat0}, y
}
case *big.Rat:
switch y := y.(type) {
case *big.Rat:
return x, y
case complex:
return complex{x, rat0}, y
}
}
unreachable()
return nil, nil
}
// is32bit reports whether x can be represented using 32 bits.
func is32bit(x int64) bool {
return -1<<31 <= x && x <= 1<<31-1
}
// is63bit reports whether x can be represented using 63 bits.
func is63bit(x int64) bool {
return -1<<62 <= x && x <= 1<<62-1
}
// binaryOpConst returns the result of the constant evaluation x op y;
// both operands must be of the same "kind" (boolean, numeric, or string).
// If intDiv is true, division (op == token.QUO) is using integer division
// (and the result is guaranteed to be integer) rather than floating-point
// division. Division by zero leads to a run-time panic.
//
func binaryOpConst(x, y interface{}, op token.Token, intDiv bool) interface{} {
x, y = matchConst(x, y)
switch x := x.(type) {
case bool:
y := y.(bool)
switch op {
case token.LAND:
return x && y
case token.LOR:
return x || y
default:
unreachable()
}
case int64:
y := y.(int64)
switch op {
case token.ADD:
// TODO(gri) can do better than this
if is63bit(x) && is63bit(y) {
return x + y
}
return normalizeIntConst(new(big.Int).Add(big.NewInt(x), big.NewInt(y)))
case token.SUB:
// TODO(gri) can do better than this
if is63bit(x) && is63bit(y) {
return x - y
}
return normalizeIntConst(new(big.Int).Sub(big.NewInt(x), big.NewInt(y)))
case token.MUL:
// TODO(gri) can do better than this
if is32bit(x) && is32bit(y) {
return x * y
}
return normalizeIntConst(new(big.Int).Mul(big.NewInt(x), big.NewInt(y)))
case token.REM:
return x % y
case token.QUO:
if intDiv {
return x / y
}
return normalizeRatConst(new(big.Rat).SetFrac(big.NewInt(x), big.NewInt(y)))
case token.AND:
return x & y
case token.OR:
return x | y
case token.XOR:
return x ^ y
case token.AND_NOT:
return x &^ y
default:
unreachable()
}
case *big.Int:
y := y.(*big.Int)
var z big.Int
switch op {
case token.ADD:
z.Add(x, y)
case token.SUB:
z.Sub(x, y)
case token.MUL:
z.Mul(x, y)
case token.REM:
z.Rem(x, y)
case token.QUO:
if intDiv {
z.Quo(x, y)
} else {
return normalizeRatConst(new(big.Rat).SetFrac(x, y))
}
case token.AND:
z.And(x, y)
case token.OR:
z.Or(x, y)
case token.XOR:
z.Xor(x, y)
case token.AND_NOT:
z.AndNot(x, y)
default:
unreachable()
}
return normalizeIntConst(&z)
case *big.Rat:
y := y.(*big.Rat)
var z big.Rat
switch op {
case token.ADD:
z.Add(x, y)
case token.SUB:
z.Sub(x, y)
case token.MUL:
z.Mul(x, y)
case token.QUO:
z.Quo(x, y)
default:
unreachable()
}
return normalizeRatConst(&z)
case complex:
y := y.(complex)
a, b := x.re, x.im
c, d := y.re, y.im
var re, im big.Rat
switch op {
case token.ADD:
// (a+c) + i(b+d)
re.Add(a, c)
im.Add(b, d)
case token.SUB:
// (a-c) + i(b-d)
re.Sub(a, c)
im.Sub(b, d)
case token.MUL:
// (ac-bd) + i(bc+ad)
var ac, bd, bc, ad big.Rat
ac.Mul(a, c)
bd.Mul(b, d)
bc.Mul(b, c)
ad.Mul(a, d)
re.Sub(&ac, &bd)
im.Add(&bc, &ad)
case token.QUO:
// (ac+bd)/s + i(bc-ad)/s, with s = cc + dd
var ac, bd, bc, ad, s big.Rat
ac.Mul(a, c)
bd.Mul(b, d)
bc.Mul(b, c)
ad.Mul(a, d)
s.Add(c.Mul(c, c), d.Mul(d, d))
re.Add(&ac, &bd)
re.Quo(&re, &s)
im.Sub(&bc, &ad)
im.Quo(&im, &s)
default:
unreachable()
}
return normalizeComplexConst(complex{&re, &im})
case string:
if op == token.ADD {
return x + y.(string)
}
}
unreachable()
return nil
}
// shiftConst returns the result of the constant evaluation x op s
// where op is token.SHL or token.SHR (<< or >>). x must be an
// integer constant.
//
func shiftConst(x interface{}, s uint, op token.Token) interface{} {
switch x := x.(type) {
case int64:
switch op {
case token.SHL:
z := big.NewInt(x)
return normalizeIntConst(z.Lsh(z, s))
case token.SHR:
return x >> s
}
case *big.Int:
var z big.Int
switch op {
case token.SHL:
return normalizeIntConst(z.Lsh(x, s))
case token.SHR:
return normalizeIntConst(z.Rsh(x, s))
}
}
unreachable()
return nil
}
// compareConst returns the result of the constant comparison x op y;
// both operands must be of the same "kind" (boolean, numeric, string,
// or nilType).
//
func compareConst(x, y interface{}, op token.Token) (z bool) {
x, y = matchConst(x, y)
// x == y => x == y
// x != y => x != y
// x > y => y < x
// x >= y => u <= x
swap := false
switch op {
case token.GTR:
swap = true
op = token.LSS
case token.GEQ:
swap = true
op = token.LEQ
}
// x == y => x == y
// x != y => !(x == y)
// x < y => x < y
// x <= y => !(y < x)
negate := false
switch op {
case token.NEQ:
negate = true
op = token.EQL
case token.LEQ:
swap = !swap
negate = true
op = token.LSS
}
if negate {
defer func() { z = !z }()
}
if swap {
x, y = y, x
}
switch x := x.(type) {
case bool:
if op == token.EQL {
return x == y.(bool)
}
case int64:
y := y.(int64)
switch op {
case token.EQL:
return x == y
case token.LSS:
return x < y
}
case *big.Int:
s := x.Cmp(y.(*big.Int))
switch op {
case token.EQL:
return s == 0
case token.LSS:
return s < 0
}
case *big.Rat:
s := x.Cmp(y.(*big.Rat))
switch op {
case token.EQL:
return s == 0
case token.LSS:
return s < 0
}
case complex:
y := y.(complex)
if op == token.EQL {
return x.re.Cmp(y.re) == 0 && x.im.Cmp(y.im) == 0
}
case string:
y := y.(string)
switch op {
case token.EQL:
return x == y
case token.LSS:
return x < y
}
case nilType:
if op == token.EQL {
return x == y.(nilType)
}
}
fmt.Printf("x = %s (%T), y = %s (%T)\n", x, x, y, y)
unreachable()
return
}

View File

@ -1,110 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements FindGcExportData.
package types
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
)
func readGopackHeader(r *bufio.Reader) (name string, size int, err error) {
// See $GOROOT/include/ar.h.
hdr := make([]byte, 16+12+6+6+8+10+2)
_, err = io.ReadFull(r, hdr)
if err != nil {
return
}
if trace {
fmt.Printf("header: %s", hdr)
}
s := strings.TrimSpace(string(hdr[16+12+6+6+8:][:10]))
size, err = strconv.Atoi(s)
if err != nil || hdr[len(hdr)-2] != '`' || hdr[len(hdr)-1] != '\n' {
err = errors.New("invalid archive header")
return
}
name = strings.TrimSpace(string(hdr[:16]))
return
}
// FindGcExportData positions the reader r at the beginning of the
// export data section of an underlying GC-created object/archive
// file by reading from it. The reader must be positioned at the
// start of the file before calling this function.
//
func FindGcExportData(r *bufio.Reader) (err error) {
// Read first line to make sure this is an object file.
line, err := r.ReadSlice('\n')
if err != nil {
return
}
if string(line) == "!<arch>\n" {
// Archive file. Scan to __.PKGDEF, which should
// be second archive entry.
var name string
var size int
// First entry should be __.GOSYMDEF.
// Older archives used __.SYMDEF, so allow that too.
// Read and discard.
if name, size, err = readGopackHeader(r); err != nil {
return
}
if name != "__.SYMDEF" && name != "__.GOSYMDEF" {
err = errors.New("go archive does not begin with __.SYMDEF or __.GOSYMDEF")
return
}
const block = 4096
tmp := make([]byte, block)
for size > 0 {
n := size
if n > block {
n = block
}
if _, err = io.ReadFull(r, tmp[:n]); err != nil {
return
}
size -= n
}
// Second entry should be __.PKGDEF.
if name, size, err = readGopackHeader(r); err != nil {
return
}
if name != "__.PKGDEF" {
err = errors.New("go archive is missing __.PKGDEF")
return
}
// Read first line of __.PKGDEF data, so that line
// is once again the first line of the input.
if line, err = r.ReadSlice('\n'); err != nil {
return
}
}
// Now at __.PKGDEF in archive or still at beginning of file.
// Either way, line should begin with "go object ".
if !strings.HasPrefix(string(line), "go object ") {
err = errors.New("not a go object file")
return
}
// Skip over object header to export data.
// Begins after first line with $$.
for line[0] != '$' {
if line, err = r.ReadSlice('\n'); err != nil {
return
}
}
return
}

View File

@ -1,889 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements an ast.Importer for gc-generated object files.
// TODO(gri) Eventually move this into a separate package outside types.
package types
import (
"bufio"
"errors"
"fmt"
"go/ast"
"go/build"
"go/token"
"io"
"math/big"
"os"
"path/filepath"
"strconv"
"strings"
"text/scanner"
)
var pkgExts = [...]string{".a", ".5", ".6", ".8"}
// FindPkg returns the filename and unique package id for an import
// path based on package information provided by build.Import (using
// the build.Default build.Context).
// If no file was found, an empty filename is returned.
//
func FindPkg(path, srcDir string) (filename, id string) {
if len(path) == 0 {
return
}
id = path
var noext string
switch {
default:
// "x" -> "$GOPATH/pkg/$GOOS_$GOARCH/x.ext", "x"
// Don't require the source files to be present.
bp, _ := build.Import(path, srcDir, build.FindOnly|build.AllowBinary)
if bp.PkgObj == "" {
return
}
noext = bp.PkgObj
if strings.HasSuffix(noext, ".a") {
noext = noext[:len(noext)-len(".a")]
}
case build.IsLocalImport(path):
// "./x" -> "/this/directory/x.ext", "/this/directory/x"
noext = filepath.Join(srcDir, path)
id = noext
case filepath.IsAbs(path):
// for completeness only - go/build.Import
// does not support absolute imports
// "/x" -> "/x.ext", "/x"
noext = path
}
// try extensions
for _, ext := range pkgExts {
filename = noext + ext
if f, err := os.Stat(filename); err == nil && !f.IsDir() {
return
}
}
filename = "" // not found
return
}
// GcImportData imports a package by reading the gc-generated export data,
// adds the corresponding package object to the imports map indexed by id,
// and returns the object.
//
// The imports map must contains all packages already imported, and no map
// entry with id as the key must be present. The data reader position must
// be the beginning of the export data section. The filename is only used
// in error messages.
//
func GcImportData(imports map[string]*ast.Object, filename, id string, data *bufio.Reader) (pkg *ast.Object, err error) {
if trace {
fmt.Printf("importing %s (%s)\n", id, filename)
}
// support for gcParser error handling
defer func() {
if r := recover(); r != nil {
err = r.(importError) // will re-panic if r is not an importError
}
}()
var p gcParser
p.init(filename, id, data, imports)
pkg = p.parseExport()
return
}
// GcImport imports a gc-generated package given its import path, adds the
// corresponding package object to the imports map, and returns the object.
// Local import paths are interpreted relative to the current working directory.
// The imports map must contains all packages already imported.
// GcImport satisfies the ast.Importer signature.
//
func GcImport(imports map[string]*ast.Object, path string) (pkg *ast.Object, err error) {
if path == "unsafe" {
return Unsafe, nil
}
srcDir, err := os.Getwd()
if err != nil {
return
}
filename, id := FindPkg(path, srcDir)
if filename == "" {
err = errors.New("can't find import: " + id)
return
}
// Note: imports[id] may already contain a partially imported package.
// We must continue doing the full import here since we don't
// know if something is missing.
// TODO: There's no need to re-import a package if we know that we
// have done a full import before. At the moment we cannot
// tell from the available information in this function alone.
// open file
f, err := os.Open(filename)
if err != nil {
return
}
defer func() {
f.Close()
if err != nil {
// Add file name to error.
err = fmt.Errorf("reading export data: %s: %v", filename, err)
}
}()
buf := bufio.NewReader(f)
if err = FindGcExportData(buf); err != nil {
return
}
pkg, err = GcImportData(imports, filename, id, buf)
return
}
// ----------------------------------------------------------------------------
// gcParser
// gcParser parses the exports inside a gc compiler-produced
// object/archive file and populates its scope with the results.
type gcParser struct {
scanner scanner.Scanner
tok rune // current token
lit string // literal string; only valid for Ident, Int, String tokens
id string // package id of imported package
imports map[string]*ast.Object // package id -> package object
}
func (p *gcParser) init(filename, id string, src io.Reader, imports map[string]*ast.Object) {
p.scanner.Init(src)
p.scanner.Error = func(_ *scanner.Scanner, msg string) { p.error(msg) }
p.scanner.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanChars | scanner.ScanStrings | scanner.ScanComments | scanner.SkipComments
p.scanner.Whitespace = 1<<'\t' | 1<<' '
p.scanner.Filename = filename // for good error messages
p.next()
p.id = id
p.imports = imports
}
func (p *gcParser) next() {
p.tok = p.scanner.Scan()
switch p.tok {
case scanner.Ident, scanner.Int, scanner.Char, scanner.String, '·':
p.lit = p.scanner.TokenText()
default:
p.lit = ""
}
if trace {
fmt.Printf("%s: %q -> %q\n", scanner.TokenString(p.tok), p.scanner.TokenText(), p.lit)
}
}
// Declare inserts a named object of the given kind in scope.
func (p *gcParser) declare(scope *ast.Scope, kind ast.ObjKind, name string) *ast.Object {
// the object may have been imported before - if it exists
// already in the respective package scope, return that object
if obj := scope.Lookup(name); obj != nil {
assert(obj.Kind == kind)
return obj
}
// otherwise create a new object and insert it into the package scope
obj := ast.NewObj(kind, name)
if scope.Insert(obj) != nil {
p.errorf("already declared: %v %s", kind, obj.Name)
}
// if the new type object is a named type it may be referred
// to before the underlying type is known - set it up
if kind == ast.Typ {
obj.Type = &NamedType{Obj: obj}
}
return obj
}
// ----------------------------------------------------------------------------
// Error handling
// Internal errors are boxed as importErrors.
type importError struct {
pos scanner.Position
err error
}
func (e importError) Error() string {
return fmt.Sprintf("import error %s (byte offset = %d): %s", e.pos, e.pos.Offset, e.err)
}
func (p *gcParser) error(err interface{}) {
if s, ok := err.(string); ok {
err = errors.New(s)
}
// panic with a runtime.Error if err is not an error
panic(importError{p.scanner.Pos(), err.(error)})
}
func (p *gcParser) errorf(format string, args ...interface{}) {
p.error(fmt.Sprintf(format, args...))
}
func (p *gcParser) expect(tok rune) string {
lit := p.lit
if p.tok != tok {
p.errorf("expected %s, got %s (%s)", scanner.TokenString(tok), scanner.TokenString(p.tok), lit)
}
p.next()
return lit
}
func (p *gcParser) expectSpecial(tok string) {
sep := 'x' // not white space
i := 0
for i < len(tok) && p.tok == rune(tok[i]) && sep > ' ' {
sep = p.scanner.Peek() // if sep <= ' ', there is white space before the next token
p.next()
i++
}
if i < len(tok) {
p.errorf("expected %q, got %q", tok, tok[0:i])
}
}
func (p *gcParser) expectKeyword(keyword string) {
lit := p.expect(scanner.Ident)
if lit != keyword {
p.errorf("expected keyword %s, got %q", keyword, lit)
}
}
// ----------------------------------------------------------------------------
// Import declarations
// ImportPath = string_lit .
//
func (p *gcParser) parsePkgId() *ast.Object {
id, err := strconv.Unquote(p.expect(scanner.String))
if err != nil {
p.error(err)
}
switch id {
case "":
// id == "" stands for the imported package id
// (only known at time of package installation)
id = p.id
case "unsafe":
// package unsafe is not in the imports map - handle explicitly
return Unsafe
}
pkg := p.imports[id]
if pkg == nil {
pkg = ast.NewObj(ast.Pkg, "")
pkg.Data = ast.NewScope(nil)
p.imports[id] = pkg
}
return pkg
}
// dotIdentifier = ( ident | '·' ) { ident | int | '·' } .
func (p *gcParser) parseDotIdent() string {
ident := ""
if p.tok != scanner.Int {
sep := 'x' // not white space
for (p.tok == scanner.Ident || p.tok == scanner.Int || p.tok == '·') && sep > ' ' {
ident += p.lit
sep = p.scanner.Peek() // if sep <= ' ', there is white space before the next token
p.next()
}
}
if ident == "" {
p.expect(scanner.Ident) // use expect() for error handling
}
return ident
}
// ExportedName = "@" ImportPath "." dotIdentifier .
//
func (p *gcParser) parseExportedName() (*ast.Object, string) {
p.expect('@')
pkg := p.parsePkgId()
p.expect('.')
name := p.parseDotIdent()
return pkg, name
}
// ----------------------------------------------------------------------------
// Types
// BasicType = identifier .
//
func (p *gcParser) parseBasicType() Type {
id := p.expect(scanner.Ident)
obj := Universe.Lookup(id)
if obj == nil || obj.Kind != ast.Typ {
p.errorf("not a basic type: %s", id)
}
return obj.Type.(Type)
}
// ArrayType = "[" int_lit "]" Type .
//
func (p *gcParser) parseArrayType() Type {
// "[" already consumed and lookahead known not to be "]"
lit := p.expect(scanner.Int)
p.expect(']')
elt := p.parseType()
n, err := strconv.ParseInt(lit, 10, 64)
if err != nil {
p.error(err)
}
return &Array{Len: n, Elt: elt}
}
// MapType = "map" "[" Type "]" Type .
//
func (p *gcParser) parseMapType() Type {
p.expectKeyword("map")
p.expect('[')
key := p.parseType()
p.expect(']')
elt := p.parseType()
return &Map{Key: key, Elt: elt}
}
// Name = identifier | "?" | ExportedName .
//
func (p *gcParser) parseName() (name string) {
switch p.tok {
case scanner.Ident:
name = p.lit
p.next()
case '?':
// anonymous
p.next()
case '@':
// exported name prefixed with package path
_, name = p.parseExportedName()
default:
p.error("name expected")
}
return
}
// Field = Name Type [ string_lit ] .
//
func (p *gcParser) parseField() *StructField {
var f StructField
f.Name = p.parseName()
f.Type = p.parseType()
if p.tok == scanner.String {
f.Tag = p.expect(scanner.String)
}
if f.Name == "" {
// anonymous field - typ must be T or *T and T must be a type name
if typ, ok := deref(f.Type).(*NamedType); ok && typ.Obj != nil {
f.Name = typ.Obj.Name
} else {
p.errorf("anonymous field expected")
}
}
return &f
}
// StructType = "struct" "{" [ FieldList ] "}" .
// FieldList = Field { ";" Field } .
//
func (p *gcParser) parseStructType() Type {
var fields []*StructField
parseField := func() {
fields = append(fields, p.parseField())
}
p.expectKeyword("struct")
p.expect('{')
if p.tok != '}' {
parseField()
for p.tok == ';' {
p.next()
parseField()
}
}
p.expect('}')
return &Struct{Fields: fields}
}
// Parameter = ( identifier | "?" ) [ "..." ] Type [ string_lit ] .
//
func (p *gcParser) parseParameter() (par *ast.Object, isVariadic bool) {
name := p.parseName()
if name == "" {
name = "_" // cannot access unnamed identifiers
}
if p.tok == '.' {
p.expectSpecial("...")
isVariadic = true
}
ptyp := p.parseType()
// ignore argument tag (e.g. "noescape")
if p.tok == scanner.String {
p.expect(scanner.String)
}
par = ast.NewObj(ast.Var, name)
par.Type = ptyp
return
}
// Parameters = "(" [ ParameterList ] ")" .
// ParameterList = { Parameter "," } Parameter .
//
func (p *gcParser) parseParameters() (list []*ast.Object, isVariadic bool) {
parseParameter := func() {
par, variadic := p.parseParameter()
list = append(list, par)
if variadic {
if isVariadic {
p.error("... not on final argument")
}
isVariadic = true
}
}
p.expect('(')
if p.tok != ')' {
parseParameter()
for p.tok == ',' {
p.next()
parseParameter()
}
}
p.expect(')')
return
}
// Signature = Parameters [ Result ] .
// Result = Type | Parameters .
//
func (p *gcParser) parseSignature() *Signature {
params, isVariadic := p.parseParameters()
// optional result type
var results []*ast.Object
switch p.tok {
case scanner.Ident, '[', '*', '<', '@':
// single, unnamed result
result := ast.NewObj(ast.Var, "_")
result.Type = p.parseType()
results = []*ast.Object{result}
case '(':
// named or multiple result(s)
var variadic bool
results, variadic = p.parseParameters()
if variadic {
p.error("... not permitted on result type")
}
}
return &Signature{Params: params, Results: results, IsVariadic: isVariadic}
}
// InterfaceType = "interface" "{" [ MethodList ] "}" .
// MethodList = Method { ";" Method } .
// Method = Name Signature .
//
// (The methods of embedded interfaces are always "inlined"
// by the compiler and thus embedded interfaces are never
// visible in the export data.)
//
func (p *gcParser) parseInterfaceType() Type {
var methods ObjList
parseMethod := func() {
obj := ast.NewObj(ast.Fun, p.parseName())
obj.Type = p.parseSignature()
methods = append(methods, obj)
}
p.expectKeyword("interface")
p.expect('{')
if p.tok != '}' {
parseMethod()
for p.tok == ';' {
p.next()
parseMethod()
}
}
p.expect('}')
methods.Sort()
return &Interface{Methods: methods}
}
// ChanType = ( "chan" [ "<-" ] | "<-" "chan" ) Type .
//
func (p *gcParser) parseChanType() Type {
dir := ast.SEND | ast.RECV
if p.tok == scanner.Ident {
p.expectKeyword("chan")
if p.tok == '<' {
p.expectSpecial("<-")
dir = ast.SEND
}
} else {
p.expectSpecial("<-")
p.expectKeyword("chan")
dir = ast.RECV
}
elt := p.parseType()
return &Chan{Dir: dir, Elt: elt}
}
// Type =
// BasicType | TypeName | ArrayType | SliceType | StructType |
// PointerType | FuncType | InterfaceType | MapType | ChanType |
// "(" Type ")" .
// BasicType = ident .
// TypeName = ExportedName .
// SliceType = "[" "]" Type .
// PointerType = "*" Type .
// FuncType = "func" Signature .
//
func (p *gcParser) parseType() Type {
switch p.tok {
case scanner.Ident:
switch p.lit {
default:
return p.parseBasicType()
case "struct":
return p.parseStructType()
case "func":
// FuncType
p.next()
return p.parseSignature()
case "interface":
return p.parseInterfaceType()
case "map":
return p.parseMapType()
case "chan":
return p.parseChanType()
}
case '@':
// TypeName
pkg, name := p.parseExportedName()
return p.declare(pkg.Data.(*ast.Scope), ast.Typ, name).Type.(Type)
case '[':
p.next() // look ahead
if p.tok == ']' {
// SliceType
p.next()
return &Slice{Elt: p.parseType()}
}
return p.parseArrayType()
case '*':
// PointerType
p.next()
return &Pointer{Base: p.parseType()}
case '<':
return p.parseChanType()
case '(':
// "(" Type ")"
p.next()
typ := p.parseType()
p.expect(')')
return typ
}
p.errorf("expected type, got %s (%q)", scanner.TokenString(p.tok), p.lit)
return nil
}
// ----------------------------------------------------------------------------
// Declarations
// ImportDecl = "import" identifier string_lit .
//
func (p *gcParser) parseImportDecl() {
p.expectKeyword("import")
// The identifier has no semantic meaning in the import data.
// It exists so that error messages can print the real package
// name: binary.ByteOrder instead of "encoding/binary".ByteOrder.
name := p.expect(scanner.Ident)
pkg := p.parsePkgId()
assert(pkg.Name == "" || pkg.Name == name)
pkg.Name = name
}
// int_lit = [ "+" | "-" ] { "0" ... "9" } .
//
func (p *gcParser) parseInt() (neg bool, val string) {
switch p.tok {
case '-':
neg = true
fallthrough
case '+':
p.next()
}
val = p.expect(scanner.Int)
return
}
// number = int_lit [ "p" int_lit ] .
//
func (p *gcParser) parseNumber() (x operand) {
x.mode = constant
// mantissa
neg, val := p.parseInt()
mant, ok := new(big.Int).SetString(val, 0)
assert(ok)
if neg {
mant.Neg(mant)
}
if p.lit == "p" {
// exponent (base 2)
p.next()
neg, val = p.parseInt()
exp64, err := strconv.ParseUint(val, 10, 0)
if err != nil {
p.error(err)
}
exp := uint(exp64)
if neg {
denom := big.NewInt(1)
denom.Lsh(denom, exp)
x.typ = Typ[UntypedFloat]
x.val = normalizeRatConst(new(big.Rat).SetFrac(mant, denom))
return
}
if exp > 0 {
mant.Lsh(mant, exp)
}
x.typ = Typ[UntypedFloat]
x.val = normalizeIntConst(mant)
return
}
x.typ = Typ[UntypedInt]
x.val = normalizeIntConst(mant)
return
}
// ConstDecl = "const" ExportedName [ Type ] "=" Literal .
// Literal = bool_lit | int_lit | float_lit | complex_lit | rune_lit | string_lit .
// bool_lit = "true" | "false" .
// complex_lit = "(" float_lit "+" float_lit "i" ")" .
// rune_lit = "(" int_lit "+" int_lit ")" .
// string_lit = `"` { unicode_char } `"` .
//
func (p *gcParser) parseConstDecl() {
p.expectKeyword("const")
pkg, name := p.parseExportedName()
obj := p.declare(pkg.Data.(*ast.Scope), ast.Con, name)
var x operand
if p.tok != '=' {
obj.Type = p.parseType()
}
p.expect('=')
switch p.tok {
case scanner.Ident:
// bool_lit
if p.lit != "true" && p.lit != "false" {
p.error("expected true or false")
}
x.typ = Typ[UntypedBool]
x.val = p.lit == "true"
p.next()
case '-', scanner.Int:
// int_lit
x = p.parseNumber()
case '(':
// complex_lit or rune_lit
p.next()
if p.tok == scanner.Char {
p.next()
p.expect('+')
x = p.parseNumber()
x.typ = Typ[UntypedRune]
p.expect(')')
break
}
re := p.parseNumber()
p.expect('+')
im := p.parseNumber()
p.expectKeyword("i")
p.expect(')')
x.typ = Typ[UntypedComplex]
// TODO(gri) fix this
_, _ = re, im
x.val = zeroConst
case scanner.Char:
// rune_lit
x.setConst(token.CHAR, p.lit)
p.next()
case scanner.String:
// string_lit
x.setConst(token.STRING, p.lit)
p.next()
default:
p.errorf("expected literal got %s", scanner.TokenString(p.tok))
}
if obj.Type == nil {
obj.Type = x.typ
}
assert(x.val != nil)
obj.Data = x.val
}
// TypeDecl = "type" ExportedName Type .
//
func (p *gcParser) parseTypeDecl() {
p.expectKeyword("type")
pkg, name := p.parseExportedName()
obj := p.declare(pkg.Data.(*ast.Scope), ast.Typ, name)
// The type object may have been imported before and thus already
// have a type associated with it. We still need to parse the type
// structure, but throw it away if the object already has a type.
// This ensures that all imports refer to the same type object for
// a given type declaration.
typ := p.parseType()
if name := obj.Type.(*NamedType); name.Underlying == nil {
name.Underlying = typ
}
}
// VarDecl = "var" ExportedName Type .
//
func (p *gcParser) parseVarDecl() {
p.expectKeyword("var")
pkg, name := p.parseExportedName()
obj := p.declare(pkg.Data.(*ast.Scope), ast.Var, name)
obj.Type = p.parseType()
}
// FuncBody = "{" ... "}" .
//
func (p *gcParser) parseFuncBody() {
p.expect('{')
for i := 1; i > 0; p.next() {
switch p.tok {
case '{':
i++
case '}':
i--
}
}
}
// FuncDecl = "func" ExportedName Signature [ FuncBody ] .
//
func (p *gcParser) parseFuncDecl() {
// "func" already consumed
pkg, name := p.parseExportedName()
obj := p.declare(pkg.Data.(*ast.Scope), ast.Fun, name)
obj.Type = p.parseSignature()
if p.tok == '{' {
p.parseFuncBody()
}
}
// MethodDecl = "func" Receiver Name Signature .
// Receiver = "(" ( identifier | "?" ) [ "*" ] ExportedName ")" [ FuncBody ].
//
func (p *gcParser) parseMethodDecl() {
// "func" already consumed
p.expect('(')
p.parseParameter() // receiver
p.expect(')')
p.parseName() // unexported method names in imports are qualified with their package.
p.parseSignature()
if p.tok == '{' {
p.parseFuncBody()
}
}
// Decl = [ ImportDecl | ConstDecl | TypeDecl | VarDecl | FuncDecl | MethodDecl ] "\n" .
//
func (p *gcParser) parseDecl() {
switch p.lit {
case "import":
p.parseImportDecl()
case "const":
p.parseConstDecl()
case "type":
p.parseTypeDecl()
case "var":
p.parseVarDecl()
case "func":
p.next() // look ahead
if p.tok == '(' {
p.parseMethodDecl()
} else {
p.parseFuncDecl()
}
}
p.expect('\n')
}
// ----------------------------------------------------------------------------
// Export
// Export = "PackageClause { Decl } "$$" .
// PackageClause = "package" identifier [ "safe" ] "\n" .
//
func (p *gcParser) parseExport() *ast.Object {
p.expectKeyword("package")
name := p.expect(scanner.Ident)
if p.tok != '\n' {
// A package is safe if it was compiled with the -u flag,
// which disables the unsafe package.
// TODO(gri) remember "safe" package
p.expectKeyword("safe")
}
p.expect('\n')
pkg := p.imports[p.id]
if pkg == nil {
pkg = ast.NewObj(ast.Pkg, name)
pkg.Data = ast.NewScope(nil)
p.imports[p.id] = pkg
}
for p.tok != '$' && p.tok != scanner.EOF {
p.parseDecl()
}
if ch := p.scanner.Peek(); p.tok != '$' || ch != '$' {
// don't call next()/expect() since reading past the
// export data may cause scanner errors (e.g. NUL chars)
p.errorf("expected '$$', got %s %c", scanner.TokenString(p.tok), ch)
}
if n := p.scanner.ErrorCount; n != 0 {
p.errorf("expected no scanner errors, got %d", n)
}
return pkg
}

View File

@ -1,153 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/ast"
"go/build"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
)
var gcPath string // Go compiler path
func init() {
// determine compiler
var gc string
switch runtime.GOARCH {
case "386":
gc = "8g"
case "amd64":
gc = "6g"
case "arm":
gc = "5g"
default:
gcPath = "unknown-GOARCH-compiler"
return
}
gcPath = filepath.Join(build.ToolDir, gc)
}
func compile(t *testing.T, dirname, filename string) string {
cmd := exec.Command(gcPath, filename)
cmd.Dir = dirname
out, err := cmd.CombinedOutput()
if err != nil {
t.Logf("%s", out)
t.Fatalf("%s %s failed: %s", gcPath, filename, err)
}
archCh, _ := build.ArchChar(runtime.GOARCH)
// filename should end with ".go"
return filepath.Join(dirname, filename[:len(filename)-2]+archCh)
}
// Use the same global imports map for all tests. The effect is
// as if all tested packages were imported into a single package.
var imports = make(map[string]*ast.Object)
func testPath(t *testing.T, path string) bool {
_, err := GcImport(imports, path)
if err != nil {
t.Errorf("testPath(%s): %s", path, err)
return false
}
return true
}
const maxTime = 3 * time.Second
func testDir(t *testing.T, dir string, endTime time.Time) (nimports int) {
dirname := filepath.Join(runtime.GOROOT(), "pkg", runtime.GOOS+"_"+runtime.GOARCH, dir)
list, err := ioutil.ReadDir(dirname)
if err != nil {
t.Errorf("testDir(%s): %s", dirname, err)
}
for _, f := range list {
if time.Now().After(endTime) {
t.Log("testing time used up")
return
}
switch {
case !f.IsDir():
// try extensions
for _, ext := range pkgExts {
if strings.HasSuffix(f.Name(), ext) {
name := f.Name()[0 : len(f.Name())-len(ext)] // remove extension
if testPath(t, filepath.Join(dir, name)) {
nimports++
}
}
}
case f.IsDir():
nimports += testDir(t, filepath.Join(dir, f.Name()), endTime)
}
}
return
}
func TestGcImport(t *testing.T) {
// On cross-compile builds, the path will not exist.
// Need to use GOHOSTOS, which is not available.
if _, err := os.Stat(gcPath); err != nil {
t.Logf("skipping test: %v", err)
return
}
if outFn := compile(t, "testdata", "exports.go"); outFn != "" {
defer os.Remove(outFn)
}
nimports := 0
if testPath(t, "./testdata/exports") {
nimports++
}
nimports += testDir(t, "", time.Now().Add(maxTime)) // installed packages
t.Logf("tested %d imports", nimports)
}
var importedObjectTests = []struct {
name string
kind ast.ObjKind
typ string
}{
{"unsafe.Pointer", ast.Typ, "Pointer"},
{"math.Pi", ast.Con, "untyped float"},
{"io.Reader", ast.Typ, "interface{Read(p []byte) (n int, err error)}"},
{"io.ReadWriter", ast.Typ, "interface{Read(p []byte) (n int, err error); Write(p []byte) (n int, err error)}"},
{"math.Sin", ast.Fun, "func(x float64) (_ float64)"},
// TODO(gri) add more tests
}
func TestGcImportedTypes(t *testing.T) {
for _, test := range importedObjectTests {
s := strings.Split(test.name, ".")
if len(s) != 2 {
t.Fatal("inconsistent test data")
}
importPath := s[0]
objName := s[1]
pkg, err := GcImport(imports, importPath)
if err != nil {
t.Error(err)
continue
}
obj := pkg.Data.(*ast.Scope).Lookup(objName)
if obj.Kind != test.kind {
t.Errorf("%s: got kind = %q; want %q", test.name, obj.Kind, test.kind)
}
typ := typeString(underlying(obj.Type.(Type)))
if typ != test.typ {
t.Errorf("%s: got type = %q; want %q", test.name, typ, test.typ)
}
}
}

View File

@ -1,130 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/ast"
"go/parser"
"go/scanner"
"go/token"
"testing"
)
var sources = []string{
`package p
import "fmt"
import "math"
const pi = math.Pi
func sin(x float64) float64 {
return math.Sin(x)
}
var Println = fmt.Println
`,
`package p
import "fmt"
func f() string {
return fmt.Sprintf("%d", g())
}
`,
`package p
import . "go/parser"
func g() Mode { return ImportsOnly }`,
}
var pkgnames = []string{
"fmt",
"go/parser",
"math",
}
// ResolveQualifiedIdents resolves the selectors of qualified
// identifiers by associating the correct ast.Object with them.
// TODO(gri): Eventually, this functionality should be subsumed
// by Check.
//
func ResolveQualifiedIdents(fset *token.FileSet, pkg *ast.Package) error {
var errors scanner.ErrorList
findObj := func(pkg *ast.Object, name *ast.Ident) *ast.Object {
scope := pkg.Data.(*ast.Scope)
obj := scope.Lookup(name.Name)
if obj == nil {
errors.Add(fset.Position(name.Pos()), fmt.Sprintf("no %s in package %s", name.Name, pkg.Name))
}
return obj
}
ast.Inspect(pkg, func(n ast.Node) bool {
if s, ok := n.(*ast.SelectorExpr); ok {
if x, ok := s.X.(*ast.Ident); ok && x.Obj != nil && x.Obj.Kind == ast.Pkg {
// find selector in respective package
s.Sel.Obj = findObj(x.Obj, s.Sel)
}
return false
}
return true
})
return errors.Err()
}
func TestResolveQualifiedIdents(t *testing.T) {
// parse package files
fset := token.NewFileSet()
files := make(map[string]*ast.File)
for i, src := range sources {
filename := fmt.Sprintf("file%d", i)
f, err := parser.ParseFile(fset, filename, src, parser.DeclarationErrors)
if err != nil {
t.Fatal(err)
}
files[filename] = f
}
// resolve package AST
pkg, err := ast.NewPackage(fset, files, GcImport, Universe)
if err != nil {
t.Fatal(err)
}
// check that all packages were imported
for _, name := range pkgnames {
if pkg.Imports[name] == nil {
t.Errorf("package %s not imported", name)
}
}
// check that there are no top-level unresolved identifiers
for _, f := range pkg.Files {
for _, x := range f.Unresolved {
t.Errorf("%s: unresolved global identifier %s", fset.Position(x.Pos()), x.Name)
}
}
// resolve qualified identifiers
if err := ResolveQualifiedIdents(fset, pkg); err != nil {
t.Error(err)
}
// check that qualified identifiers are resolved
ast.Inspect(pkg, func(n ast.Node) bool {
if s, ok := n.(*ast.SelectorExpr); ok {
if x, ok := s.X.(*ast.Ident); ok {
if x.Obj == nil {
t.Errorf("%s: unresolved qualified identifier %s", fset.Position(x.Pos()), x.Name)
return false
}
if x.Obj.Kind == ast.Pkg && s.Sel != nil && s.Sel.Obj == nil {
t.Errorf("%s: unresolved selector %s", fset.Position(s.Sel.Pos()), s.Sel.Name)
return false
}
return false
}
return false
}
return true
})
}

View File

@ -1,89 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is used to generate an object file which
// serves as test file for gcimporter_test.go.
package exports
import (
"go/ast"
)
// Issue 3682: Correctly read dotted identifiers from export data.
const init1 = 0
func init() {}
const (
C0 int = 0
C1 = 3.14159265
C2 = 2.718281828i
C3 = -123.456e-789
C4 = +123.456E+789
C5 = 1234i
C6 = "foo\n"
C7 = `bar\n`
)
type (
T1 int
T2 [10]int
T3 []int
T4 *int
T5 chan int
T6a chan<- int
T6b chan (<-chan int)
T6c chan<- (chan int)
T7 <-chan *ast.File
T8 struct{}
T9 struct {
a int
b, c float32
d []string `go:"tag"`
}
T10 struct {
T8
T9
_ *T10
}
T11 map[int]string
T12 interface{}
T13 interface {
m1()
m2(int) float32
}
T14 interface {
T12
T13
m3(x ...struct{}) []T9
}
T15 func()
T16 func(int)
T17 func(x int)
T18 func() float32
T19 func() (x float32)
T20 func(...interface{})
T21 struct{ next *T21 }
T22 struct{ link *T23 }
T23 struct{ link *T22 }
T24 *T24
T25 *T26
T26 *T27
T27 *T25
T28 func(T28) T28
)
var (
V0 int
V1 = -991.0
)
func F1() {}
func F2(x int) {}
func F3() int { return 0 }
func F4() float32 { return 0 }
func F5(a, b, c int, u, v, w struct{ x, y T1 }, more ...interface{}) (p, q, r chan<- T10)
func (p *T1) M1()

View File

@ -1,235 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package types declares the data structures for representing
// Go types and implements typechecking of an *ast.Package.
//
// PACKAGE UNDER CONSTRUCTION. ANY AND ALL PARTS MAY CHANGE.
//
package types
import (
"go/ast"
"go/token"
"sort"
)
// Check typechecks a package pkg. It returns the first error, or nil.
//
// Check augments the AST by assigning types to ast.Objects. It
// calls err with the error position and message for each error.
// It calls f with each valid AST expression and corresponding
// type. If err == nil, Check terminates as soon as the first error
// is found. If f is nil, it is not invoked.
//
func Check(fset *token.FileSet, pkg *ast.Package, err func(token.Pos, string), f func(ast.Expr, Type)) error {
return check(fset, pkg, err, f)
}
// All types implement the Type interface.
// TODO(gri) Eventually determine what common Type functionality should be exported.
type Type interface {
aType()
}
// BasicKind describes the kind of basic type.
type BasicKind int
const (
Invalid BasicKind = iota // type is invalid
// predeclared types
Bool
Int
Int8
Int16
Int32
Int64
Uint
Uint8
Uint16
Uint32
Uint64
Uintptr
Float32
Float64
Complex64
Complex128
String
UnsafePointer
// types for untyped values
UntypedBool
UntypedInt
UntypedRune
UntypedFloat
UntypedComplex
UntypedString
UntypedNil
// aliases
Byte = Uint8
Rune = Int32
)
// BasicInfo is a set of flags describing properties of a basic type.
type BasicInfo int
// Properties of basic types.
const (
IsBoolean BasicInfo = 1 << iota
IsInteger
IsUnsigned
IsFloat
IsComplex
IsString
IsUntyped
IsOrdered = IsInteger | IsFloat | IsString
IsNumeric = IsInteger | IsFloat | IsComplex
)
// A Basic represents a basic type.
type Basic struct {
implementsType
Kind BasicKind
Info BasicInfo
Size int64 // > 0 if valid
Name string
}
// An Array represents an array type [Len]Elt.
type Array struct {
implementsType
Len int64
Elt Type
}
// A Slice represents a slice type []Elt.
type Slice struct {
implementsType
Elt Type
}
type StructField struct {
Name string // unqualified type name for anonymous fields
Type Type
Tag string
IsAnonymous bool
}
// A Struct represents a struct type struct{...}.
type Struct struct {
implementsType
Fields []*StructField
}
// A Pointer represents a pointer type *Base.
type Pointer struct {
implementsType
Base Type
}
// A tuple represents a multi-value function return.
// TODO(gri) use better name to avoid confusion (Go doesn't have tuples).
type tuple struct {
implementsType
list []Type
}
// A Signature represents a user-defined function type func(...) (...).
// TODO(gri) consider using "tuples" to represent parameters and results (see comment on tuples).
type Signature struct {
implementsType
Recv *ast.Object // nil if not a method
Params ObjList // (incoming) parameters from left to right; or nil
Results ObjList // (outgoing) results from left to right; or nil
IsVariadic bool // true if the last parameter's type is of the form ...T
}
// builtinId is an id of a builtin function.
type builtinId int
// Predeclared builtin functions.
const (
// Universe scope
_Append builtinId = iota
_Cap
_Close
_Complex
_Copy
_Delete
_Imag
_Len
_Make
_New
_Panic
_Print
_Println
_Real
_Recover
// Unsafe package
_Alignof
_Offsetof
_Sizeof
// Testing support
_Assert
_Trace
)
// A builtin represents the type of a built-in function.
type builtin struct {
implementsType
id builtinId
name string
nargs int // number of arguments (minimum if variadic)
isVariadic bool
isStatement bool // true if the built-in is valid as an expression statement
}
// An Interface represents an interface type interface{...}.
type Interface struct {
implementsType
Methods ObjList // interface methods sorted by name; or nil
}
// A Map represents a map type map[Key]Elt.
type Map struct {
implementsType
Key, Elt Type
}
// A Chan represents a channel type chan Elt, <-chan Elt, or chan<-Elt.
type Chan struct {
implementsType
Dir ast.ChanDir
Elt Type
}
// A NamedType represents a named type as declared in a type declaration.
type NamedType struct {
implementsType
Obj *ast.Object // corresponding declared object
Underlying Type // nil if not fully declared yet, never a *NamedType
Methods ObjList // associated methods; or nil
}
// An ObjList represents an ordered (in some fashion) list of objects.
type ObjList []*ast.Object
// ObjList implements sort.Interface.
func (list ObjList) Len() int { return len(list) }
func (list ObjList) Less(i, j int) bool { return list[i].Name < list[j].Name }
func (list ObjList) Swap(i, j int) { list[i], list[j] = list[j], list[i] }
// Sort sorts an object list by object name.
func (list ObjList) Sort() { sort.Sort(list) }
// All concrete types embed implementsType which
// ensures that all types implement the Type interface.
type implementsType struct{}
func (*implementsType) aType() {}

View File

@ -1,181 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains tests verifying the types associated with an AST after
// type checking.
package types
import (
"go/ast"
"go/parser"
"testing"
)
const filename = "<src>"
func makePkg(t *testing.T, src string) (*ast.Package, error) {
file, err := parser.ParseFile(fset, filename, src, parser.DeclarationErrors)
if err != nil {
return nil, err
}
files := map[string]*ast.File{filename: file}
pkg, err := ast.NewPackage(fset, files, GcImport, Universe)
if err != nil {
return nil, err
}
if err := Check(fset, pkg, nil, nil); err != nil {
return nil, err
}
return pkg, nil
}
type testEntry struct {
src, str string
}
// dup returns a testEntry where both src and str are the same.
func dup(s string) testEntry {
return testEntry{s, s}
}
var testTypes = []testEntry{
// basic types
dup("int"),
dup("float32"),
dup("string"),
// arrays
dup("[10]int"),
// slices
dup("[]int"),
dup("[][]int"),
// structs
dup("struct{}"),
dup("struct{x int}"),
{`struct {
x, y int
z float32 "foo"
}`, `struct{x int; y int; z float32 "foo"}`},
{`struct {
string
elems []T
}`, `struct{string; elems []T}`},
// pointers
dup("*int"),
dup("***struct{}"),
dup("*struct{a int; b float32}"),
// functions
dup("func()"),
dup("func(x int)"),
{"func(x, y int)", "func(x int, y int)"},
{"func(x, y int, z string)", "func(x int, y int, z string)"},
dup("func(int)"),
{"func(int, string, byte)", "func(int, string, byte)"},
dup("func() int"),
{"func() (string)", "func() string"},
dup("func() (u int)"),
{"func() (u, v int, w string)", "func() (u int, v int, w string)"},
dup("func(int) string"),
dup("func(x int) string"),
dup("func(x int) (u string)"),
{"func(x, y int) (u string)", "func(x int, y int) (u string)"},
dup("func(...int) string"),
dup("func(x ...int) string"),
dup("func(x ...int) (u string)"),
{"func(x, y ...int) (u string)", "func(x int, y ...int) (u string)"},
// interfaces
dup("interface{}"),
dup("interface{m()}"),
{`interface{
m(int) float32
String() string
}`, `interface{String() string; m(int) float32}`}, // methods are sorted
// TODO(gri) add test for interface w/ anonymous field
// maps
dup("map[string]int"),
{"map[struct{x, y int}][]byte", "map[struct{x int; y int}][]byte"},
// channels
dup("chan int"),
dup("chan<- func()"),
dup("<-chan []func() int"),
}
func TestTypes(t *testing.T) {
for _, test := range testTypes {
src := "package p; type T " + test.src
pkg, err := makePkg(t, src)
if err != nil {
t.Errorf("%s: %s", src, err)
continue
}
typ := underlying(pkg.Scope.Lookup("T").Type.(Type))
str := typeString(typ)
if str != test.str {
t.Errorf("%s: got %s, want %s", test.src, str, test.str)
}
}
}
var testExprs = []testEntry{
// basic type literals
dup("x"),
dup("true"),
dup("42"),
dup("3.1415"),
dup("2.71828i"),
dup(`'a'`),
dup(`"foo"`),
dup("`bar`"),
// arbitrary expressions
dup("&x"),
dup("*x"),
dup("(x)"),
dup("x + y"),
dup("x + y * 10"),
dup("s.foo"),
dup("s[0]"),
dup("s[x:y]"),
dup("s[:y]"),
dup("s[x:]"),
dup("s[:]"),
dup("f(1, 2.3)"),
dup("-f(10, 20)"),
dup("f(x + y, +3.1415)"),
{"func(a, b int) {}", "(func literal)"},
{"func(a, b int) []int {}()[x]", "(func literal)()[x]"},
{"[]int{1, 2, 3}", "(composite literal)"},
{"[]int{1, 2, 3}[x:]", "(composite literal)[x:]"},
{"x.([]string)", "x.(...)"},
}
func TestExprs(t *testing.T) {
for _, test := range testExprs {
src := "package p; var _ = " + test.src + "; var (x, y int; s []string; f func(int, float32))"
pkg, err := makePkg(t, src)
if err != nil {
t.Errorf("%s: %s", src, err)
continue
}
// TODO(gri) writing the code below w/o the decl variable will
// cause a 386 compiler error (out of fixed registers)
decl := pkg.Files[filename].Decls[0].(*ast.GenDecl)
expr := decl.Specs[0].(*ast.ValueSpec).Values[0]
str := exprString(expr)
if str != test.str {
t.Errorf("%s: got %s, want %s", test.src, str, test.str)
}
}
}

Some files were not shown because too many files have changed in this diff Show More