libgo: Update to weekly.2011-11-18.

From-SVN: r182266
This commit is contained in:
Ian Lance Taylor 2011-12-12 23:40:51 +00:00
parent 6e456f4cf4
commit ab61e9c4da
223 changed files with 6373 additions and 3999 deletions

View File

@ -1,4 +1,4 @@
2f4482b89a6b b4a91b693374
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.

View File

@ -648,7 +648,8 @@ go_math_files = \
go_mime_files = \ go_mime_files = \
go/mime/grammar.go \ go/mime/grammar.go \
go/mime/mediatype.go \ go/mime/mediatype.go \
go/mime/type.go go/mime/type.go \
go/mime/type_unix.go
if LIBGO_IS_RTEMS if LIBGO_IS_RTEMS
go_net_fd_os_file = go/net/fd_select.go go_net_fd_os_file = go/net/fd_select.go
@ -770,7 +771,6 @@ go_os_files = \
$(go_os_dir_file) \ $(go_os_dir_file) \
go/os/dir.go \ go/os/dir.go \
go/os/env.go \ go/os/env.go \
go/os/env_unix.go \
go/os/error.go \ go/os/error.go \
go/os/error_posix.go \ go/os/error_posix.go \
go/os/exec.go \ go/os/exec.go \
@ -1156,6 +1156,7 @@ go_exp_sql_files = \
go/exp/sql/sql.go go/exp/sql/sql.go
go_exp_ssh_files = \ go_exp_ssh_files = \
go/exp/ssh/channel.go \ go/exp/ssh/channel.go \
go/exp/ssh/cipher.go \
go/exp/ssh/client.go \ go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \ go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \ go/exp/ssh/common.go \
@ -1164,10 +1165,11 @@ go_exp_ssh_files = \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/shell.go \ go/exp/terminal/terminal.go \
go/exp/terminal/terminal.go go/exp/terminal/util.go
go_exp_types_files = \ go_exp_types_files = \
go/exp/types/check.go \ go/exp/types/check.go \
go/exp/types/const.go \ go/exp/types/const.go \
@ -1546,6 +1548,7 @@ syscall_netlink_file =
endif endif
go_base_syscall_files = \ go_base_syscall_files = \
go/syscall/env_unix.go \
go/syscall/libcall_support.go \ go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \ go/syscall/libcall_posix.go \
go/syscall/socket.go \ go/syscall/socket.go \

View File

@ -1032,7 +1032,8 @@ go_math_files = \
go_mime_files = \ go_mime_files = \
go/mime/grammar.go \ go/mime/grammar.go \
go/mime/mediatype.go \ go/mime/mediatype.go \
go/mime/type.go go/mime/type.go \
go/mime/type_unix.go
# By default use select with pipes. Most systems should have # By default use select with pipes. Most systems should have
# something better. # something better.
@ -1103,7 +1104,6 @@ go_os_files = \
$(go_os_dir_file) \ $(go_os_dir_file) \
go/os/dir.go \ go/os/dir.go \
go/os/env.go \ go/os/env.go \
go/os/env_unix.go \
go/os/error.go \ go/os/error.go \
go/os/error_posix.go \ go/os/error_posix.go \
go/os/exec.go \ go/os/exec.go \
@ -1521,6 +1521,7 @@ go_exp_sql_files = \
go_exp_ssh_files = \ go_exp_ssh_files = \
go/exp/ssh/channel.go \ go/exp/ssh/channel.go \
go/exp/ssh/cipher.go \
go/exp/ssh/client.go \ go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \ go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \ go/exp/ssh/common.go \
@ -1529,11 +1530,12 @@ go_exp_ssh_files = \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/shell.go \ go/exp/terminal/terminal.go \
go/exp/terminal/terminal.go go/exp/terminal/util.go
go_exp_types_files = \ go_exp_types_files = \
go/exp/types/check.go \ go/exp/types/check.go \
@ -1890,6 +1892,7 @@ go_unicode_utf8_files = \
# Support for netlink sockets and messages. # Support for netlink sockets and messages.
@LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go @LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go
go_base_syscall_files = \ go_base_syscall_files = \
go/syscall/env_unix.go \
go/syscall/libcall_support.go \ go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \ go/syscall/libcall_posix.go \
go/syscall/socket.go \ go/syscall/socket.go \

View File

@ -10,7 +10,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os"
"strings" "strings"
"testing" "testing"
"testing/iotest" "testing/iotest"
@ -425,9 +424,9 @@ var errorWriterTests = []errorWriterTest{
{0, 1, nil, io.ErrShortWrite}, {0, 1, nil, io.ErrShortWrite},
{1, 2, nil, io.ErrShortWrite}, {1, 2, nil, io.ErrShortWrite},
{1, 1, nil, nil}, {1, 1, nil, nil},
{0, 1, os.EPIPE, os.EPIPE}, {0, 1, io.ErrClosedPipe, io.ErrClosedPipe},
{1, 2, os.EPIPE, os.EPIPE}, {1, 2, io.ErrClosedPipe, io.ErrClosedPipe},
{1, 1, os.EPIPE, os.EPIPE}, {1, 1, io.ErrClosedPipe, io.ErrClosedPipe},
} }
func TestWriteErrors(t *testing.T) { func TestWriteErrors(t *testing.T) {

View File

@ -91,6 +91,11 @@ type rune rune
// invocation. // invocation.
type Type int type Type int
// Type1 is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
type Type1 int
// IntegerType is here for the purposes of documentation only. It is a stand-in // IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc. // for any integer type: int, uint, int8 etc.
type IntegerType int type IntegerType int
@ -119,6 +124,11 @@ func append(slice []Type, elems ...Type) []Type
// len(src) and len(dst). // len(src) and len(dst).
func copy(dst, src []Type) int func copy(dst, src []Type) int
// The delete built-in function deletes the element with the specified key
// (m[key]) from the map. If there is no such element, delete is a no-op.
// If m is nil, delete panics.
func delete(m map[Type]Type1, key Type)
// The len built-in function returns the length of v, according to its type: // The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v. // Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil). // Pointer to array: the number of elements in *v (even if v is nil).
@ -171,7 +181,7 @@ func complex(r, i FloatType) ComplexType
// The return value will be floating point type corresponding to the type of c. // The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType func real(c ComplexType) FloatType
// The imaginary built-in function returns the imaginary part of the complex // The imag built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to // number c. The return value will be floating point type corresponding to
// the type of c. // the type of c.
func imag(c ComplexType) FloatType func imag(c ComplexType) FloatType

View File

@ -662,48 +662,49 @@ func TestRunes(t *testing.T) {
} }
type TrimTest struct { type TrimTest struct {
f func([]byte, string) []byte f string
in, cutset, out string in, cutset, out string
} }
var trimTests = []TrimTest{ var trimTests = []TrimTest{
{Trim, "abba", "a", "bb"}, {"Trim", "abba", "a", "bb"},
{Trim, "abba", "ab", ""}, {"Trim", "abba", "ab", ""},
{TrimLeft, "abba", "ab", ""}, {"TrimLeft", "abba", "ab", ""},
{TrimRight, "abba", "ab", ""}, {"TrimRight", "abba", "ab", ""},
{TrimLeft, "abba", "a", "bba"}, {"TrimLeft", "abba", "a", "bba"},
{TrimRight, "abba", "a", "abb"}, {"TrimRight", "abba", "a", "abb"},
{Trim, "<tag>", "<>", "tag"}, {"Trim", "<tag>", "<>", "tag"},
{Trim, "* listitem", " *", "listitem"}, {"Trim", "* listitem", " *", "listitem"},
{Trim, `"quote"`, `"`, "quote"}, {"Trim", `"quote"`, `"`, "quote"},
{Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"}, {"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
//empty string tests //empty string tests
{Trim, "abba", "", "abba"}, {"Trim", "abba", "", "abba"},
{Trim, "", "123", ""}, {"Trim", "", "123", ""},
{Trim, "", "", ""}, {"Trim", "", "", ""},
{TrimLeft, "abba", "", "abba"}, {"TrimLeft", "abba", "", "abba"},
{TrimLeft, "", "123", ""}, {"TrimLeft", "", "123", ""},
{TrimLeft, "", "", ""}, {"TrimLeft", "", "", ""},
{TrimRight, "abba", "", "abba"}, {"TrimRight", "abba", "", "abba"},
{TrimRight, "", "123", ""}, {"TrimRight", "", "123", ""},
{TrimRight, "", "", ""}, {"TrimRight", "", "", ""},
{TrimRight, "☺\xc0", "☺", "☺\xc0"}, {"TrimRight", "☺\xc0", "☺", "☺\xc0"},
} }
func TestTrim(t *testing.T) { func TestTrim(t *testing.T) {
for _, tc := range trimTests { for _, tc := range trimTests {
actual := string(tc.f([]byte(tc.in), tc.cutset)) name := tc.f
var name string var f func([]byte, string) []byte
switch tc.f { switch name {
case Trim: case "Trim":
name = "Trim" f = Trim
case TrimLeft: case "TrimLeft":
name = "TrimLeft" f = TrimLeft
case TrimRight: case "TrimRight":
name = "TrimRight" f = TrimRight
default: default:
t.Error("Undefined trim function") t.Error("Undefined trim function %s", name)
} }
actual := string(f([]byte(tc.in), tc.cutset))
if actual != tc.out { if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out) t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
} }

View File

@ -19,7 +19,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
) )
// Order specifies the bit ordering in an LZW data stream. // Order specifies the bit ordering in an LZW data stream.
@ -212,8 +211,10 @@ func (d *decoder) flush() {
d.o = 0 d.o = 0
} }
var errClosed = errors.New("compress/lzw: reader/writer is closed")
func (d *decoder) Close() error { func (d *decoder) Close() error {
d.err = os.EINVAL // in case any Reads come along d.err = errClosed // in case any Reads come along
return nil return nil
} }

View File

@ -9,7 +9,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
) )
// A writer is a buffered, flushable writer. // A writer is a buffered, flushable writer.
@ -49,8 +48,9 @@ const (
type encoder struct { type encoder struct {
// w is the writer that compressed bytes are written to. // w is the writer that compressed bytes are written to.
w writer w writer
// write, bits, nBits and width are the state for converting a code stream // order, write, bits, nBits and width are the state for
// into a byte stream. // converting a code stream into a byte stream.
order Order
write func(*encoder, uint32) error write func(*encoder, uint32) error
bits uint32 bits uint32
nBits uint nBits uint
@ -64,7 +64,7 @@ type encoder struct {
// call. It is equal to invalidCode if there was no such call. // call. It is equal to invalidCode if there was no such call.
savedCode uint32 savedCode uint32
// err is the first error encountered during writing. Closing the encoder // err is the first error encountered during writing. Closing the encoder
// will make any future Write calls return os.EINVAL. // will make any future Write calls return errClosed
err error err error
// table is the hash table from 20-bit keys to 12-bit values. Each table // table is the hash table from 20-bit keys to 12-bit values. Each table
// entry contains key<<12|val and collisions resolve by linear probing. // entry contains key<<12|val and collisions resolve by linear probing.
@ -191,13 +191,13 @@ loop:
// flush e's underlying writer. // flush e's underlying writer.
func (e *encoder) Close() error { func (e *encoder) Close() error {
if e.err != nil { if e.err != nil {
if e.err == os.EINVAL { if e.err == errClosed {
return nil return nil
} }
return e.err return e.err
} }
// Make any future calls to Write return os.EINVAL. // Make any future calls to Write return errClosed.
e.err = os.EINVAL e.err = errClosed
// Write the savedCode if valid. // Write the savedCode if valid.
if e.savedCode != invalidCode { if e.savedCode != invalidCode {
if err := e.write(e, e.savedCode); err != nil { if err := e.write(e, e.savedCode); err != nil {
@ -214,7 +214,7 @@ func (e *encoder) Close() error {
} }
// Write the final bits. // Write the final bits.
if e.nBits > 0 { if e.nBits > 0 {
if e.write == (*encoder).writeMSB { if e.order == MSB {
e.bits >>= 24 e.bits >>= 24
} }
if err := e.w.WriteByte(uint8(e.bits)); err != nil { if err := e.w.WriteByte(uint8(e.bits)); err != nil {
@ -250,6 +250,7 @@ func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
lw := uint(litWidth) lw := uint(litWidth)
return &encoder{ return &encoder{
w: bw, w: bw,
order: order,
write: write, write: write,
width: 1 + lw, width: 1 + lw,
litWidth: lw, litWidth: lw,

View File

@ -50,10 +50,6 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
return return
} }
_, err1 := lzww.Write(b[:n]) _, err1 := lzww.Write(b[:n])
if err1 == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
return
}
if err1 != nil { if err1 != nil {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1) t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return return

View File

@ -59,10 +59,6 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
} }
defer zlibw.Close() defer zlibw.Close()
_, err = zlibw.Write(b0) _, err = zlibw.Write(b0)
if err == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
return
}
if err != nil { if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return

View File

@ -41,7 +41,7 @@ func NewCipher(key []byte) (*Cipher, error) {
} }
// BlockSize returns the AES block size, 16 bytes. // BlockSize returns the AES block size, 16 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Block interface in the
// package "crypto/cipher". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }

View File

@ -54,7 +54,7 @@ func NewSaltedCipher(key, salt []byte) (*Cipher, error) {
} }
// BlockSize returns the Blowfish block size, 8 bytes. // BlockSize returns the Blowfish block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Block interface in the
// package "crypto/cipher". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }

View File

@ -28,16 +28,16 @@ func (r *rngReader) Read(b []byte) (n int, err error) {
if r.prov == 0 { if r.prov == 0 {
const provType = syscall.PROV_RSA_FULL const provType = syscall.PROV_RSA_FULL
const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT
errno := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags) err := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
if errno != 0 { if err != nil {
r.mu.Unlock() r.mu.Unlock()
return 0, os.NewSyscallError("CryptAcquireContext", errno) return 0, os.NewSyscallError("CryptAcquireContext", err)
} }
} }
r.mu.Unlock() r.mu.Unlock()
errno := syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0]) err = syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
if errno != 0 { if err != nil {
return 0, os.NewSyscallError("CryptGenRandom", errno) return 0, os.NewSyscallError("CryptGenRandom", err)
} }
return len(b), nil return len(b), nil
} }

View File

@ -5,16 +5,16 @@
package rand package rand
import ( import (
"errors"
"io" "io"
"math/big" "math/big"
"os"
) )
// Prime returns a number, p, of the given size, such that p is prime // Prime returns a number, p, of the given size, such that p is prime
// with high probability. // with high probability.
func Prime(rand io.Reader, bits int) (p *big.Int, err error) { func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
if bits < 1 { if bits < 1 {
err = os.EINVAL err = errors.New("crypto/rand: prime size must be positive")
} }
b := uint(bits % 8) b := uint(bits % 8)

View File

@ -93,7 +93,8 @@ func (c *Conn) SetTimeout(nsec int64) error {
} }
// SetReadTimeout sets the time (in nanoseconds) that // SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning os.EAGAIN. // Read will wait for data before returning a net.Error
// with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline. // Setting nsec == 0 (the default) disables the deadline.
func (c *Conn) SetReadTimeout(nsec int64) error { func (c *Conn) SetReadTimeout(nsec int64) error {
return c.conn.SetReadTimeout(nsec) return c.conn.SetReadTimeout(nsec)
@ -737,7 +738,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return c.writeRecord(recordTypeApplicationData, b) return c.writeRecord(recordTypeApplicationData, b)
} }
// Read can be made to time out and return err == os.EAGAIN // Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetTimeout and SetReadTimeout. // after a fixed time limit; see SetTimeout and SetReadTimeout.
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil { if err = c.Handshake(); err != nil {

View File

@ -4,6 +4,8 @@
package tls package tls
import "bytes"
type clientHelloMsg struct { type clientHelloMsg struct {
raw []byte raw []byte
vers uint16 vers uint16
@ -18,6 +20,25 @@ type clientHelloMsg struct {
supportedPoints []uint8 supportedPoints []uint8
} }
func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
eqUint16s(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints)
}
func (m *clientHelloMsg) marshal() []byte { func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -309,6 +330,23 @@ type serverHelloMsg struct {
ocspStapling bool ocspStapling bool
} }
func (m *serverHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
m.cipherSuite == m1.cipherSuite &&
m.compressionMethod == m1.compressionMethod &&
m.nextProtoNeg == m1.nextProtoNeg &&
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling
}
func (m *serverHelloMsg) marshal() []byte { func (m *serverHelloMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -463,6 +501,16 @@ type certificateMsg struct {
certificates [][]byte certificates [][]byte
} }
func (m *certificateMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
eqByteSlices(m.certificates, m1.certificates)
}
func (m *certificateMsg) marshal() (x []byte) { func (m *certificateMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
key []byte key []byte
} }
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*serverKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.key, m1.key)
}
func (m *serverKeyExchangeMsg) marshal() []byte { func (m *serverKeyExchangeMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -571,6 +629,17 @@ type certificateStatusMsg struct {
response []byte response []byte
} }
func (m *certificateStatusMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateStatusMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.statusType == m1.statusType &&
bytes.Equal(m.response, m1.response)
}
func (m *certificateStatusMsg) marshal() []byte { func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{} type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
_, ok := i.(*serverHelloDoneMsg)
return ok
}
func (m *serverHelloDoneMsg) marshal() []byte { func (m *serverHelloDoneMsg) marshal() []byte {
x := make([]byte, 4) x := make([]byte, 4)
x[0] = typeServerHelloDone x[0] = typeServerHelloDone
@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
ciphertext []byte ciphertext []byte
} }
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*clientKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.ciphertext, m1.ciphertext)
}
func (m *clientKeyExchangeMsg) marshal() []byte { func (m *clientKeyExchangeMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -671,6 +755,16 @@ type finishedMsg struct {
verifyData []byte verifyData []byte
} }
func (m *finishedMsg) equal(i interface{}) bool {
m1, ok := i.(*finishedMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.verifyData, m1.verifyData)
}
func (m *finishedMsg) marshal() (x []byte) { func (m *finishedMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -698,6 +792,16 @@ type nextProtoMsg struct {
proto string proto string
} }
func (m *nextProtoMsg) equal(i interface{}) bool {
m1, ok := i.(*nextProtoMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.proto == m1.proto
}
func (m *nextProtoMsg) marshal() []byte { func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -759,6 +863,17 @@ type certificateRequestMsg struct {
certificateAuthorities [][]byte certificateAuthorities [][]byte
} }
func (m *certificateRequestMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateRequestMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
}
func (m *certificateRequestMsg) marshal() (x []byte) { func (m *certificateRequestMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
signature []byte signature []byte
} }
func (m *certificateVerifyMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateVerifyMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.signature, m1.signature)
}
func (m *certificateVerifyMsg) marshal() (x []byte) { func (m *certificateVerifyMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
return true return true
} }
func eqUint16s(x, y []uint16) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqStrings(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqByteSlices(x, y [][]byte) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if !bytes.Equal(v, y[i]) {
return false
}
}
return true
}

View File

@ -27,10 +27,12 @@ var tests = []interface{}{
type testMessage interface { type testMessage interface {
marshal() []byte marshal() []byte
unmarshal([]byte) bool unmarshal([]byte) bool
equal(interface{}) bool
} }
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0)) rand := rand.New(rand.NewSource(0))
for i, iface := range tests { for i, iface := range tests {
ty := reflect.ValueOf(iface).Type() ty := reflect.ValueOf(iface).Type()
@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
} }
m2.marshal() // to fill any marshal cache in the message m2.marshal() // to fill any marshal cache in the message
if !reflect.DeepEqual(m1, m2) { if !m1.equal(m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break break
} }

View File

@ -12,8 +12,8 @@ import (
) )
func loadStore(roots *x509.CertPool, name string) { func loadStore(roots *x509.CertPool, name string) {
store, errno := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name)) store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
if errno != 0 { if err != nil {
return return
} }

View File

@ -44,7 +44,7 @@ func NewCipher(key []byte) (*Cipher, error) {
} }
// BlockSize returns the XTEA block size, 8 bytes. // BlockSize returns the XTEA block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Block interface in the
// package "crypto/cipher". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }

View File

@ -0,0 +1,157 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Large data benchmark.
// The JSON data is a summary of agl's changes in the
// go, webkit, and chromium open source projects.
// We benchmark converting between the JSON form
// and in-memory data structures.
package json
import (
"bytes"
"compress/gzip"
"io/ioutil"
"os"
"testing"
)
type codeResponse struct {
Tree *codeNode `json:"tree"`
Username string `json:"username"`
}
type codeNode struct {
Name string `json:"name"`
Kids []*codeNode `json:"kids"`
CLWeight float64 `json:"cl_weight"`
Touches int `json:"touches"`
MinT int64 `json:"min_t"`
MaxT int64 `json:"max_t"`
MeanT int64 `json:"mean_t"`
}
var codeJSON []byte
var codeStruct codeResponse
func codeInit() {
f, err := os.Open("testdata/code.json.gz")
if err != nil {
panic(err)
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
panic(err)
}
data, err := ioutil.ReadAll(gz)
if err != nil {
panic(err)
}
codeJSON = data
if err := Unmarshal(codeJSON, &codeStruct); err != nil {
panic("unmarshal code.json: " + err.Error())
}
if data, err = Marshal(&codeStruct); err != nil {
panic("marshal code.json: " + err.Error())
}
if !bytes.Equal(data, codeJSON) {
println("different lengths", len(data), len(codeJSON))
for i := 0; i < len(data) && i < len(codeJSON); i++ {
if data[i] != codeJSON[i] {
println("re-marshal: changed at byte", i)
println("orig: ", string(codeJSON[i-10:i+10]))
println("new: ", string(data[i-10:i+10]))
break
}
}
panic("re-marshal code.json: different result")
}
}
func BenchmarkCodeEncoder(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
enc := NewEncoder(ioutil.Discard)
for i := 0; i < b.N; i++ {
if err := enc.Encode(&codeStruct); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeMarshal(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
for i := 0; i < b.N; i++ {
if _, err := Marshal(&codeStruct); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeDecoder(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
var buf bytes.Buffer
dec := NewDecoder(&buf)
var r codeResponse
for i := 0; i < b.N; i++ {
buf.Write(codeJSON)
// hide EOF
buf.WriteByte('\n')
buf.WriteByte('\n')
buf.WriteByte('\n')
if err := dec.Decode(&r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeUnmarshal(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
for i := 0; i < b.N; i++ {
var r codeResponse
if err := Unmarshal(codeJSON, &r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeUnmarshalReuse(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
var r codeResponse
for i := 0; i < b.N; i++ {
if err := Unmarshal(codeJSON, &r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}

View File

@ -227,7 +227,7 @@ func (d *decodeState) value(v reflect.Value) {
// d.scan thinks we're still at the beginning of the item. // d.scan thinks we're still at the beginning of the item.
// Feed in an empty string - the shortest, simplest value - // Feed in an empty string - the shortest, simplest value -
// so that it knows we got to the end of the value. // so that it knows we got to the end of the value.
if d.scan.step == stateRedo { if d.scan.redo {
panic("redo") panic("redo")
} }
d.scan.step(&d.scan, '"') d.scan.step(&d.scan, '"')
@ -381,6 +381,7 @@ func (d *decodeState) array(v reflect.Value) {
d.error(errPhase) d.error(errPhase)
} }
} }
if i < av.Len() { if i < av.Len() {
if !sv.IsValid() { if !sv.IsValid() {
// Array. Zero the rest. // Array. Zero the rest.
@ -392,6 +393,9 @@ func (d *decodeState) array(v reflect.Value) {
sv.SetLen(i) sv.SetLen(i)
} }
} }
if i == 0 && av.Kind() == reflect.Slice && sv.IsNil() {
sv.Set(reflect.MakeSlice(sv.Type(), 0, 0))
}
} }
// object consumes an object from d.data[d.off-1:], decoding into the value v. // object consumes an object from d.data[d.off-1:], decoding into the value v.

View File

@ -80,6 +80,9 @@ type scanner struct {
// on a 64-bit Mac Mini, and it's nicer to read. // on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, int) int step func(*scanner, int) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values. // Stack of what we're in the middle of - array values, object keys, object values.
parseState []int parseState []int
@ -87,6 +90,7 @@ type scanner struct {
err error err error
// 1-byte redo (see undo method) // 1-byte redo (see undo method)
redo bool
redoCode int redoCode int
redoState func(*scanner, int) int redoState func(*scanner, int) int
@ -135,6 +139,8 @@ func (s *scanner) reset() {
s.step = stateBeginValue s.step = stateBeginValue
s.parseState = s.parseState[0:0] s.parseState = s.parseState[0:0]
s.err = nil s.err = nil
s.redo = false
s.endTop = false
} }
// eof tells the scanner that the end of input has been reached. // eof tells the scanner that the end of input has been reached.
@ -143,11 +149,11 @@ func (s *scanner) eof() int {
if s.err != nil { if s.err != nil {
return scanError return scanError
} }
if s.step == stateEndTop { if s.endTop {
return scanEnd return scanEnd
} }
s.step(s, ' ') s.step(s, ' ')
if s.step == stateEndTop { if s.endTop {
return scanEnd return scanEnd
} }
if s.err == nil { if s.err == nil {
@ -166,8 +172,10 @@ func (s *scanner) pushParseState(p int) {
func (s *scanner) popParseState() { func (s *scanner) popParseState() {
n := len(s.parseState) - 1 n := len(s.parseState) - 1
s.parseState = s.parseState[0:n] s.parseState = s.parseState[0:n]
s.redo = false
if n == 0 { if n == 0 {
s.step = stateEndTop s.step = stateEndTop
s.endTop = true
} else { } else {
s.step = stateEndValue s.step = stateEndValue
} }
@ -269,6 +277,7 @@ func stateEndValue(s *scanner, c int) int {
if n == 0 { if n == 0 {
// Completed top-level before the current byte. // Completed top-level before the current byte.
s.step = stateEndTop s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c) return stateEndTop(s, c)
} }
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
@ -606,16 +615,18 @@ func quoteChar(c int) string {
// undo causes the scanner to return scanCode from the next state transition. // undo causes the scanner to return scanCode from the next state transition.
// This gives callers a simple 1-byte undo mechanism. // This gives callers a simple 1-byte undo mechanism.
func (s *scanner) undo(scanCode int) { func (s *scanner) undo(scanCode int) {
if s.step == stateRedo { if s.redo {
panic("invalid use of scanner") panic("json: invalid use of scanner")
} }
s.redoCode = scanCode s.redoCode = scanCode
s.redoState = s.step s.redoState = s.step
s.step = stateRedo s.step = stateRedo
s.redo = true
} }
// stateRedo helps implement the scanner's 1-byte undo. // stateRedo helps implement the scanner's 1-byte undo.
func stateRedo(s *scanner, c int) int { func stateRedo(s *scanner, c int) int {
s.redo = false
s.step = s.redoState s.step = s.redoState
return s.redoCode return s.redoCode
} }

View File

@ -186,11 +186,12 @@ func TestNextValueBig(t *testing.T) {
} }
} }
var benchScan scanner
func BenchmarkSkipValue(b *testing.B) { func BenchmarkSkipValue(b *testing.B) {
initBig() initBig()
var scan scanner
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
nextValue(jsonBig, &scan) nextValue(jsonBig, &benchScan)
} }
b.SetBytes(int64(len(jsonBig))) b.SetBytes(int64(len(jsonBig)))
} }

View File

@ -7,7 +7,6 @@ package xml
import ( import (
"bytes" "bytes"
"io" "io"
"os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -43,17 +42,17 @@ var rawTokens = []Token{
CharData([]byte("World <>'\" 白鵬翔")), CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"", "hello"}}, EndElement{Name{"", "hello"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"", "goodbye"}, nil}, StartElement{Name{"", "goodbye"}, []Attr{}},
EndElement{Name{"", "goodbye"}}, EndElement{Name{"", "goodbye"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"", "inner"}, nil}, StartElement{Name{"", "inner"}, []Attr{}},
EndElement{Name{"", "inner"}}, EndElement{Name{"", "inner"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
EndElement{Name{"", "outer"}}, EndElement{Name{"", "outer"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"tag", "name"}, nil}, StartElement{Name{"tag", "name"}, []Attr{}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
CharData([]byte("Some text here.")), CharData([]byte("Some text here.")),
CharData([]byte("\n ")), CharData([]byte("\n ")),
@ -77,17 +76,17 @@ var cookedTokens = []Token{
CharData([]byte("World <>'\" 白鵬翔")), CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"ns2", "hello"}}, EndElement{Name{"ns2", "hello"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns2", "goodbye"}, nil}, StartElement{Name{"ns2", "goodbye"}, []Attr{}},
EndElement{Name{"ns2", "goodbye"}}, EndElement{Name{"ns2", "goodbye"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns2", "inner"}, nil}, StartElement{Name{"ns2", "inner"}, []Attr{}},
EndElement{Name{"ns2", "inner"}}, EndElement{Name{"ns2", "inner"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
EndElement{Name{"ns2", "outer"}}, EndElement{Name{"ns2", "outer"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns3", "name"}, nil}, StartElement{Name{"ns3", "name"}, []Attr{}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
CharData([]byte("Some text here.")), CharData([]byte("Some text here.")),
CharData([]byte("\n ")), CharData([]byte("\n ")),
@ -105,7 +104,7 @@ var rawTokensAltEncoding = []Token{
CharData([]byte("\n")), CharData([]byte("\n")),
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
CharData([]byte("\n")), CharData([]byte("\n")),
StartElement{Name{"", "tag"}, nil}, StartElement{Name{"", "tag"}, []Attr{}},
CharData([]byte("value")), CharData([]byte("value")),
EndElement{Name{"", "tag"}}, EndElement{Name{"", "tag"}},
} }
@ -205,7 +204,7 @@ func (d *downCaser) ReadByte() (c byte, err error) {
func (d *downCaser) Read(p []byte) (int, error) { func (d *downCaser) Read(p []byte) (int, error) {
d.t.Fatalf("unexpected Read call on downCaser reader") d.t.Fatalf("unexpected Read call on downCaser reader")
return 0, os.EINVAL panic("unreachable")
} }
func TestRawTokenAltEncoding(t *testing.T) { func TestRawTokenAltEncoding(t *testing.T) {

View File

@ -105,9 +105,9 @@ func (w *Watcher) AddWatch(path string, flags uint32) error {
watchEntry.flags |= flags watchEntry.flags |= flags
flags |= syscall.IN_MASK_ADD flags |= syscall.IN_MASK_ADD
} }
wd, errno := syscall.InotifyAddWatch(w.fd, path, flags) wd, err := syscall.InotifyAddWatch(w.fd, path, flags)
if wd == -1 { if err != nil {
return &os.PathError{"inotify_add_watch", path, os.Errno(errno)} return &os.PathError{"inotify_add_watch", path, err}
} }
if !found { if !found {
@ -139,14 +139,10 @@ func (w *Watcher) RemoveWatch(path string) error {
// readEvents reads from the inotify file descriptor, converts the // readEvents reads from the inotify file descriptor, converts the
// received events into Event objects and sends them via the Event channel // received events into Event objects and sends them via the Event channel
func (w *Watcher) readEvents() { func (w *Watcher) readEvents() {
var ( var buf [syscall.SizeofInotifyEvent * 4096]byte
buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
n int // Number of bytes read with read()
errno int // Syscall errno
)
for { for {
n, errno = syscall.Read(w.fd, buf[0:]) n, err := syscall.Read(w.fd, buf[0:])
// See if there is a message on the "done" channel // See if there is a message on the "done" channel
var done bool var done bool
select { select {
@ -156,16 +152,16 @@ func (w *Watcher) readEvents() {
// If EOF or a "done" message is received // If EOF or a "done" message is received
if n == 0 || done { if n == 0 || done {
errno := syscall.Close(w.fd) err := syscall.Close(w.fd)
if errno == -1 { if err != nil {
w.Error <- os.NewSyscallError("close", errno) w.Error <- os.NewSyscallError("close", err)
} }
close(w.Event) close(w.Event)
close(w.Error) close(w.Error)
return return
} }
if n < 0 { if n < 0 {
w.Error <- os.NewSyscallError("read", errno) w.Error <- os.NewSyscallError("read", err)
continue continue
} }
if n < syscall.SizeofInotifyEvent { if n < syscall.SizeofInotifyEvent {

View File

@ -14,6 +14,21 @@ import (
"strconv" "strconv"
) )
// subsetTypeArgs takes a slice of arguments from callers of the sql
// package and converts them into a slice of the driver package's
// "subset types".
func subsetTypeArgs(args []interface{}) ([]interface{}, error) {
out := make([]interface{}, len(args))
for n, arg := range args {
var err error
out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
}
}
return out, nil
}
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information. // An error is returned if the copy would result in loss of information.
// dest should be a pointer type. // dest should be a pointer type.

View File

@ -36,19 +36,22 @@ type Driver interface {
Open(name string) (Conn, error) Open(name string) (Conn, error)
} }
// Execer is an optional interface that may be implemented by a Driver // ErrSkip may be returned by some optional interfaces' methods to
// or a Conn. // indicate at runtime that the fast path is unavailable and the sql
// // package should continue as if the optional interface was not
// If a Driver does not implement Execer, the sql package's DB.Exec // implemented. ErrSkip is only supported where explicitly
// method first obtains a free connection from its free pool or from // documented.
// the driver's Open method. Execer should only be implemented by var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// drivers that can provide a more efficient implementation.
// Execer is an optional interface that may be implemented by a Conn.
// //
// If a Conn does not implement Execer, the db package's DB.Exec will // If a Conn does not implement Execer, the db package's DB.Exec will
// first prepare a query, execute the statement, and then close the // first prepare a query, execute the statement, and then close the
// statement. // statement.
// //
// All arguments are of a subset type as defined in the package docs. // All arguments are of a subset type as defined in the package docs.
//
// Exec may return ErrSkip.
type Execer interface { type Execer interface {
Exec(query string, args []interface{}) (Result, error) Exec(query string, args []interface{}) (Result, error)
} }
@ -94,6 +97,9 @@ type Stmt interface {
Close() error Close() error
// NumInput returns the number of placeholder parameters. // NumInput returns the number of placeholder parameters.
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
NumInput() int NumInput() int
// Exec executes a query that doesn't return rows, such // Exec executes a query that doesn't return rows, such
@ -135,6 +141,8 @@ type Rows interface {
// The dest slice may be populated with only with values // The dest slice may be populated with only with values
// of subset types defined above, but excluding string. // of subset types defined above, but excluding string.
// All string values must be converted to []byte. // All string values must be converted to []byte.
//
// Next should return io.EOF when there are no more rows.
Next(dest []interface{}) error Next(dest []interface{}) error
} }

View File

@ -195,6 +195,29 @@ func (c *fakeConn) Close() error {
return nil return nil
} }
func checkSubsetTypes(args []interface{}) error {
for n, arg := range args {
switch arg.(type) {
case int64, float64, bool, nil, []byte, string:
default:
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
}
}
return nil
}
func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args of of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
// implement this at all.
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
return nil, driver.ErrSkip
}
func errf(msg string, args ...interface{}) error { func errf(msg string, args ...interface{}) error {
return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
} }
@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error {
} }
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
db := s.c.db db := s.c.db
switch s.cmd { switch s.cmd {
case "WIPE": case "WIPE":
@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
} }
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) { func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
db := s.c.db db := s.c.db
if len(args) != s.placeholders { if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct") panic("error in pkg db; should only get here if size is correct")

View File

@ -88,8 +88,9 @@ type DB struct {
driver driver.Driver driver driver.Driver
dsn string dsn string
mu sync.Mutex mu sync.Mutex // protects freeConn and closed
freeConn []driver.Conn freeConn []driver.Conn
closed bool
} }
// Open opens a database specified by its database driver name and a // Open opens a database specified by its database driver name and a
@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return &DB{driver: driver, dsn: dataSourceName}, nil return &DB{driver: driver, dsn: dataSourceName}, nil
} }
// Close closes the database, releasing any open resources.
func (db *DB) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
var err error
for _, c := range db.freeConn {
err1 := c.Close()
if err1 != nil {
err = err1
}
}
db.freeConn = nil
db.closed = true
return err
}
func (db *DB) maxIdleConns() int { func (db *DB) maxIdleConns() int {
const defaultMaxIdleConns = 2 const defaultMaxIdleConns = 2
// TODO(bradfitz): ask driver, if supported, for its default preference // TODO(bradfitz): ask driver, if supported, for its default preference
@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int {
// conn returns a newly-opened or cached driver.Conn // conn returns a newly-opened or cached driver.Conn
func (db *DB) conn() (driver.Conn, error) { func (db *DB) conn() (driver.Conn, error) {
db.mu.Lock() db.mu.Lock()
if db.closed {
return nil, errors.New("sql: database is closed")
}
if n := len(db.freeConn); n > 0 { if n := len(db.freeConn); n > 0 {
conn := db.freeConn[n-1] conn := db.freeConn[n-1]
db.freeConn = db.freeConn[:n-1] db.freeConn = db.freeConn[:n-1]
@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
} }
func (db *DB) putConn(c driver.Conn) { func (db *DB) putConn(c driver.Conn) {
if n := len(db.freeConn); n < db.maxIdleConns() { db.mu.Lock()
defer db.mu.Unlock()
if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
db.freeConn = append(db.freeConn, c) db.freeConn = append(db.freeConn, c)
return return
} }
db.closeConn(c) db.closeConn(c) // TODO(bradfitz): release lock before calling this?
} }
func (db *DB) closeConn(c driver.Conn) { func (db *DB) closeConn(c driver.Conn) {
@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
// Exec executes a query without returning any rows. // Exec executes a query without returning any rows.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) { func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
// Optional fast path, if the driver implements driver.Execer. sargs, err := subsetTypeArgs(args)
if execer, ok := db.driver.(driver.Execer); ok {
resi, err := execer.Exec(query, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return result{resi}, nil
}
// If the driver does not implement driver.Execer, we need
// a connection.
ci, err := db.conn() ci, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
@ -198,19 +214,22 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
defer db.putConn(ci) defer db.putConn(ci)
if execer, ok := ci.(driver.Execer); ok { if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, args) resi, err := execer.Exec(query, sargs)
if err != driver.ErrSkip {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return result{resi}, nil return result{resi}, nil
} }
}
sti, err := ci.Prepare(query) sti, err := ci.Prepare(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer sti.Close() defer sti.Close()
resi, err := sti.Exec(args)
resi, err := sti.Exec(sargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
return nil, err return nil, err
} }
defer sti.Close() defer sti.Close()
resi, err := sti.Exec(args)
sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
resi, err := sti.Exec(sargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -449,7 +474,10 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
} }
defer releaseConn() defer releaseConn()
if want := si.NumInput(); len(args) != want { // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args)) return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
} }
@ -545,10 +573,18 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(args) != si.NumInput() {
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args)) return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
} }
rowsi, err := si.Query(args) sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
rowsi, err := si.Query(sargs)
if err != nil { if err != nil {
s.db.putConn(ci) s.db.putConn(ci)
return nil, err return nil, err

View File

@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
} }
} }
func closeDB(t *testing.T, db *DB) {
err := db.Close()
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
}
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db)
var name string var name string
var age int var age int
@ -69,6 +77,7 @@ func TestQuery(t *testing.T) {
func TestStatementQueryRow(t *testing.T) { func TestStatementQueryRow(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db)
stmt, err := db.Prepare("SELECT|people|age|name=?") stmt, err := db.Prepare("SELECT|people|age|name=?")
if err != nil { if err != nil {
t.Fatalf("Prepare: %v", err) t.Fatalf("Prepare: %v", err)
@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) {
// just a test of fakedb itself // just a test of fakedb itself
func TestBogusPreboundParameters(t *testing.T) { func TestBogusPreboundParameters(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
if err == nil { if err == nil {
@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) {
func TestDb(t *testing.T) { func TestDb(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
stmt, err := db.Prepare("INSERT|t1|name=?,age=?") stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil { if err != nil {

View File

@ -0,0 +1,88 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/rc4"
)
// streamDump is used to dump the initial keystream for stream ciphers. It is a
// a write-only buffer, and not intended for reading so do not require a mutex.
var streamDump [512]byte
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
type noneCipher struct{}
func (c noneCipher) XORKeyStream(dst, src []byte) {
copy(dst, src)
}
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(c, iv), nil
}
func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type cipherMode struct {
keySize int
ivSize int
skip int
createFn func(key, iv []byte) (cipher.Stream, error)
}
func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
}
// Specifies a default set of ciphers and a preference order. This is based on
// OpenSSH's default client preference order, minus algorithms that are not
// implemented.
var DefaultCipherOrder = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"arcfour256", "arcfour128",
}
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": &cipherMode{16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": &cipherMode{24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": &cipherMode{32, aes.BlockSize, 0, newAESCTR},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": &cipherMode{16, 0, 1536, newRC4},
"arcfour256": &cipherMode{32, 0, 1536, newRC4},
}

View File

@ -0,0 +1,62 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"testing"
)
// TestCipherReversal tests that each cipher factory produces ciphers that can
// encrypt and decrypt some data successfully.
func TestCipherReversal(t *testing.T) {
testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
cryptBuffer := make([]byte, 32)
for name, cipherMode := range cipherModes {
encrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create encrypter for %q: %s", name, err)
continue
}
decrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create decrypter for %q: %s", name, err)
continue
}
copy(cryptBuffer, testData)
encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if name == "none" {
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made change with 'none' cipher")
continue
}
} else {
if bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made no change with %q", name)
continue
}
}
decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("decrypted bytes not equal to input with %q", name)
continue
}
}
}
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range DefaultCipherOrder {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}

View File

@ -35,10 +35,6 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
conn.Close() conn.Close()
return nil, err return nil, err
} }
if err := conn.authenticate(); err != nil {
conn.Close()
return nil, err
}
go conn.mainLoop() go conn.mainLoop()
return conn, nil return conn, nil
} }
@ -64,8 +60,8 @@ func (c *ClientConn) handshake() error {
clientKexInit := kexInitMsg{ clientKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos, KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos, ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers, CiphersClientServer: c.config.Crypto.ciphers(),
CiphersServerClient: supportedCiphers, CiphersServerClient: c.config.Crypto.ciphers(),
MACsClientServer: supportedMACs, MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs, MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions, CompressionClientServer: supportedCompressions,
@ -128,7 +124,10 @@ func (c *ClientConn) handshake() error {
if packet[0] != msgNewKeys { if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]} return UnexpectedMessageError{msgNewKeys, packet[0]}
} }
return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc) if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err
}
return c.authenticate(H)
} }
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
@ -195,6 +194,7 @@ func (c *ClientConn) openChan(typ string) (*clientChan, error) {
switch msg := (<-ch.msg).(type) { switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg: case *channelOpenConfirmMsg:
ch.peersId = msg.MyId ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg: case *channelOpenFailureMsg:
c.chanlist.remove(ch.id) c.chanlist.remove(ch.id)
return nil, errors.New(msg.Message) return nil, errors.New(msg.Message)
@ -301,6 +301,9 @@ type ClientConfig struct {
// A slice of ClientAuth methods. Only the first instance // A slice of ClientAuth methods. Only the first instance
// of a particular RFC 4252 method will be used during authentication. // of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth Auth []ClientAuth
// Cryptographic-related configuration.
Crypto CryptoConfig
} }
func (c *ClientConfig) rand() io.Reader { func (c *ClientConfig) rand() io.Reader {

View File

@ -6,10 +6,11 @@ package ssh
import ( import (
"errors" "errors"
"io"
) )
// authenticate authenticates with the remote server. See RFC 4252. // authenticate authenticates with the remote server. See RFC 4252.
func (c *ClientConn) authenticate() error { func (c *ClientConn) authenticate(session []byte) error {
// initiate user auth session // initiate user auth session
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil { if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err return err
@ -26,7 +27,7 @@ func (c *ClientConn) authenticate() error {
// then any untried methods suggested by the server. // then any untried methods suggested by the server.
tried, remain := make(map[string]bool), make(map[string]bool) tried, remain := make(map[string]bool), make(map[string]bool)
for auth := ClientAuth(new(noneAuth)); auth != nil; { for auth := ClientAuth(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.config.User, c.transport) ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand())
if err != nil { if err != nil {
return err return err
} }
@ -60,7 +61,7 @@ type ClientAuth interface {
// Returns true if authentication is successful. // Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative // If authentication is not successful, a []string of alternative
// method names is returned. // method names is returned.
auth(user string, t *transport) (bool, []string, error) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name. // method returns the RFC 4252 method name.
method() string method() string
@ -69,7 +70,7 @@ type ClientAuth interface {
// "none" authentication, RFC 4252 section 5.2. // "none" authentication, RFC 4252 section 5.2.
type noneAuth int type noneAuth int
func (n *noneAuth) auth(user string, t *transport) (bool, []string, error) { func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{ if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: user, User: user,
Service: serviceSSH, Service: serviceSSH,
@ -102,7 +103,7 @@ type passwordAuth struct {
ClientPassword ClientPassword
} }
func (p *passwordAuth) auth(user string, t *transport) (bool, []string, error) { func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct { type passwordAuthMsg struct {
User string User string
Service string Service string
@ -155,3 +156,140 @@ type ClientPassword interface {
func ClientAuthPassword(impl ClientPassword) ClientAuth { func ClientAuthPassword(impl ClientPassword) ClientAuth {
return &passwordAuth{impl} return &passwordAuth{impl}
} }
// ClientKeyring implements access to a client key ring.
type ClientKeyring interface {
// Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if
// no key exists at i.
Key(i int) (key interface{}, err error)
// Sign returns a signature of the given data using the i'th key
// and the supplied random source.
Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
}
// "publickey" authentication, RFC 4252 Section 7.
type publickeyAuth struct {
ClientKeyring
}
func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type publickeyAuthMsg struct {
User string
Service string
Method string
// HasSig indicates to the reciver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
Pubkey string
// Sig is defined as []byte so marshal will exclude it during the query phase
Sig []byte `ssh:"rest"`
}
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
var index int
// a map of public keys to their index in the keyring
validKeys := make(map[int]interface{})
for {
key, err := p.Key(index)
if err != nil {
return false, nil, err
}
if key == nil {
// no more keys in the keyring
break
}
pubkey := serializePublickey(key)
algoname := algoName(key)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: false,
Algoname: algoname,
Pubkey: string(pubkey),
}
if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
return false, nil, err
}
packet, err := t.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthPubKeyOk:
msg := decode(packet).(*userAuthPubKeyOkMsg)
if msg.Algo != algoname || msg.PubKey != string(pubkey) {
continue
}
validKeys[index] = key
case msgUserAuthFailure:
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
index++
}
// methods that may continue if this auth is not successful.
var methods []string
for i, key := range validKeys {
pubkey := serializePublickey(key)
algoname := algoName(key)
sign, err := p.Sign(i, rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
}, []byte(algoname), pubkey))
if err != nil {
return false, nil, err
}
// manually wrap the serialized signature in a string
s := serializeSignature(algoname, sign)
sig := make([]byte, stringLength(s))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: true,
Algoname: algoname,
Pubkey: string(pubkey),
Sig: sig,
}
p := marshal(msgUserAuthRequest, msg)
if err := t.writePacket(p); err != nil {
return false, nil, err
}
packet, err := t.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthSuccess:
return true, nil, nil
case msgUserAuthFailure:
msg := decode(packet).(*userAuthFailureMsg)
methods = msg.Methods
continue
case msgDisconnect:
return false, nil, io.EOF
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
}
return false, methods, nil
}
func (p *publickeyAuth) method() string {
return "publickey"
}
// ClientAuthPublickey returns a ClientAuth using public key authentication.
func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
return &publickeyAuth{impl}
}

View File

@ -0,0 +1,248 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"io/ioutil"
"testing"
)
const _pem = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
-----END RSA PRIVATE KEY-----`
// reused internally by tests
var serverConfig = new(ServerConfig)
func init() {
if err := serverConfig.SetRSAPrivateKey([]byte(_pem)); err != nil {
panic("unable to set private key: " + err.Error())
}
}
// keychain implements the ClientPublickey interface
type keychain struct {
keys []*rsa.PrivateKey
}
func (k *keychain) Key(i int) (interface{}, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
return k.keys[i].PublicKey, nil
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(data)
digest := h.Sum()
return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
}
func (k *keychain) loadPEM(file string) error {
buf, err := ioutil.ReadFile(file)
if err != nil {
return err
}
block, _ := pem.Decode(buf)
if block == nil {
return errors.New("ssh: no key found")
}
r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
k.keys = append(k.keys, r)
return nil
}
var pkey *rsa.PrivateKey
func init() {
var err error
pkey, err = rsa.GenerateKey(rand.Reader, 512)
if err != nil {
panic("unable to generate public key")
}
}
func TestClientAuthPublickey(t *testing.T) {
k := new(keychain)
k.keys = append(k.keys, pkey)
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
expected := []byte(serializePublickey(k.keys[0].PublicKey))
algoname := algoName(k.keys[0].PublicKey)
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
}
serverConfig.PasswordCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool, 1)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Error(err)
}
done <- true
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPublickey(k),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
// password implements the ClientPassword interface
type password string
func (p password) Password(user string) (string, error) {
return string(p), nil
}
func TestClientAuthPassword(t *testing.T) {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
serverConfig.PubKeyCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
if err := c.Handshake(); err != nil {
t.Error(err)
}
defer c.Close()
done <- true
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(pw),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
func TestClientAuthPasswordAndPublickey(t *testing.T) {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
k := new(keychain)
k.keys = append(k.keys, pkey)
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
expected := []byte(serializePublickey(k.keys[0].PublicKey))
algoname := algoName(k.keys[0].PublicKey)
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
}
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
if err := c.Handshake(); err != nil {
t.Error(err)
}
defer c.Close()
done <- true
}()
wrongPw := password("wrong")
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(wrongPw),
ClientAuthPublickey(k),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}

View File

@ -0,0 +1,61 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// ClientConn functional tests.
// These tests require a running ssh server listening on port 22
// on the local host. Functional tests will be skipped unless
// -ssh.user and -ssh.pass must be passed to gotest.
import (
"flag"
"testing"
)
var (
sshuser = flag.String("ssh.user", "", "ssh username")
sshpass = flag.String("ssh.pass", "", "ssh password")
sshprivkey = flag.String("ssh.privkey", "", "ssh privkey file")
)
func TestFuncPasswordAuth(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPassword(password(*sshpass)),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("Unable to connect: %s", err)
}
defer conn.Close()
}
func TestFuncPublickeyAuth(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
kc := new(keychain)
if err := kc.loadPEM(*sshprivkey); err != nil {
t.Fatalf("unable to load private key: %s", err)
}
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPublickey(kc),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("unable to connect: %s", err)
}
defer conn.Close()
}

View File

@ -5,6 +5,8 @@
package ssh package ssh
import ( import (
"crypto/dsa"
"crypto/rsa"
"math/big" "math/big"
"strconv" "strconv"
"sync" "sync"
@ -14,7 +16,6 @@ import (
const ( const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa" hostAlgoRSA = "ssh-rsa"
cipherAES128CTR = "aes128-ctr"
macSHA196 = "hmac-sha1-96" macSHA196 = "hmac-sha1-96"
compressionNone = "none" compressionNone = "none"
serviceUserAuth = "ssh-userauth" serviceUserAuth = "ssh-userauth"
@ -23,7 +24,6 @@ const (
var supportedKexAlgos = []string{kexAlgoDH14SHA1} var supportedKexAlgos = []string{kexAlgoDH14SHA1}
var supportedHostKeyAlgos = []string{hostAlgoRSA} var supportedHostKeyAlgos = []string{hostAlgoRSA}
var supportedCiphers = []string{cipherAES128CTR}
var supportedMACs = []string{macSHA196} var supportedMACs = []string{macSHA196}
var supportedCompressions = []string{compressionNone} var supportedCompressions = []string{compressionNone}
@ -127,3 +127,100 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke
ok = true ok = true
return return
} }
// Cryptographic configuration common to both ServerConfig and ClientConfig.
type CryptoConfig struct {
// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
// used.
Ciphers []string
}
func (c *CryptoConfig) ciphers() []string {
if c.Ciphers == nil {
return DefaultCipherOrder
}
return c.Ciphers
}
// serialize a signed slice according to RFC 4254 6.6.
func serializeSignature(algoname string, sig []byte) []byte {
length := stringLength([]byte(algoname))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalString(r, sig)
return ret
}
// serialize an rsa.PublicKey or dsa.PublicKey according to RFC 4253 6.6.
func serializePublickey(key interface{}) []byte {
algoname := algoName(key)
switch key := key.(type) {
case rsa.PublicKey:
e := new(big.Int).SetInt64(int64(key.E))
length := stringLength([]byte(algoname))
length += intLength(e)
length += intLength(key.N)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalInt(r, e)
marshalInt(r, key.N)
return ret
case dsa.PublicKey:
length := stringLength([]byte(algoname))
length += intLength(key.P)
length += intLength(key.Q)
length += intLength(key.G)
length += intLength(key.Y)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalInt(r, key.P)
r = marshalInt(r, key.Q)
r = marshalInt(r, key.G)
marshalInt(r, key.Y)
return ret
}
panic("unexpected key type")
}
func algoName(key interface{}) string {
switch key.(type) {
case rsa.PublicKey:
return "ssh-rsa"
case dsa.PublicKey:
return "ssh-dss"
}
panic("unexpected key type")
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}

View File

@ -392,7 +392,10 @@ func parseString(in []byte) (out, rest []byte, ok bool) {
return return
} }
var comma = []byte{','} var (
comma = []byte{','}
emptyNameList = []string{}
)
func parseNameList(in []byte) (out []string, rest []byte, ok bool) { func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in) contents, rest, ok := parseString(in)
@ -400,6 +403,7 @@ func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
return return
} }
if len(contents) == 0 { if len(contents) == 0 {
out = emptyNameList
return return
} }
parts := bytes.Split(contents, comma) parts := bytes.Split(contents, comma)
@ -444,8 +448,6 @@ func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
return return
} }
const maxPacketSize = 36000
func nameListLength(namelist []string) int { func nameListLength(namelist []string) int {
length := 4 /* uint32 length prefix */ length := 4 /* uint32 length prefix */
for i, name := range namelist { for i, name := range namelist {

View File

@ -40,6 +40,9 @@ type ServerConfig struct {
// key authentication. It must return true iff the given public key is // key authentication. It must return true iff the given public key is
// valid for the given user. // valid for the given user.
PubKeyCallback func(user, algo string, pubkey []byte) bool PubKeyCallback func(user, algo string, pubkey []byte) bool
// Cryptographic-related configuration.
Crypto CryptoConfig
} }
func (c *ServerConfig) rand() io.Reader { func (c *ServerConfig) rand() io.Reader {
@ -221,7 +224,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return nil, nil, errors.New("internal error") return nil, nil, errors.New("internal error")
} }
serializedSig := serializeRSASignature(sig) serializedSig := serializeSignature(hostAlgoRSA, sig)
kexDHReply := kexDHReplyMsg{ kexDHReply := kexDHReplyMsg{
HostKey: serializedHostKey, HostKey: serializedHostKey,
@ -234,50 +237,9 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return return
} }
func serializeRSASignature(sig []byte) []byte {
length := stringLength([]byte(hostAlgoRSA))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(hostAlgoRSA))
r = marshalString(r, sig)
return ret
}
// serverVersion is the fixed identification string that Server will use. // serverVersion is the fixed identification string that Server will use.
var serverVersion = []byte("SSH-2.0-Go\r\n") var serverVersion = []byte("SSH-2.0-Go\r\n")
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}
// Handshake performs an SSH transport and client authentication on the given ServerConn. // Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error { func (s *ServerConn) Handshake() error {
var magics handshakeMagics var magics handshakeMagics
@ -298,8 +260,8 @@ func (s *ServerConn) Handshake() error {
serverKexInit := kexInitMsg{ serverKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos, KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos, ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers, CiphersClientServer: s.config.Crypto.ciphers(),
CiphersServerClient: supportedCiphers, CiphersServerClient: s.config.Crypto.ciphers(),
MACsClientServer: supportedMACs, MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs, MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions, CompressionClientServer: supportedCompressions,
@ -364,7 +326,9 @@ func (s *ServerConn) Handshake() error {
if packet[0] != msgNewKeys { if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]} return UnexpectedMessageError{msgNewKeys, packet[0]}
} }
s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
return err
}
if packet, err = s.readPacket(); err != nil { if packet, err = s.readPacket(); err != nil {
return err return err
} }

146
libgo/go/exp/ssh/tcpip.go Normal file
View File

@ -0,0 +1,146 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"io"
"net"
)
// Dial initiates a connection to the addr from the remote host.
// addr is resolved using net.ResolveTCPAddr before connection.
// This could allow an observer to observe the DNS name of the
// remote host. Consider using ssh.DialTCP to avoid this.
func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
raddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.DialTCP(n, nil, raddr)
}
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
if laddr == nil {
laddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
if err != nil {
return nil, err
}
return &tcpchanconn{
tcpchan: ch,
laddr: laddr,
raddr: raddr,
}, nil
}
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
// strings and are expected to be resolveable at the remote end.
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
// RFC 4254 7.2
type channelOpenDirectMsg struct {
ChanType string
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
raddr string
rport uint32
laddr string
lport uint32
}
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip",
PeersId: ch.id,
PeersWindow: 1 << 14,
MaxPacketSize: 1 << 15, // RFC 4253 6.1
raddr: raddr,
rport: uint32(rport),
laddr: laddr,
lport: uint32(lport),
})); err != nil {
c.chanlist.remove(ch.id)
return nil, err
}
// wait for response
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
default:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: unexpected packet")
}
return &tcpchan{
clientChan: ch,
Reader: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.data,
},
Writer: &chanWriter{
packetWriter: ch,
id: ch.id,
win: ch.win,
},
}, nil
}
type tcpchan struct {
*clientChan // the backing channel
io.Reader
io.Writer
}
// tcpchanconn fulfills the net.Conn interface without
// the tcpchan having to hold laddr or raddr directly.
type tcpchanconn struct {
*tcpchan
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *tcpchanconn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *tcpchanconn) RemoteAddr() net.Addr {
return t.raddr
}
// SetTimeout sets the read and write deadlines associated
// with the connection.
func (t *tcpchanconn) SetTimeout(nsec int64) error {
if err := t.SetReadTimeout(nsec); err != nil {
return err
}
return t.SetWriteTimeout(nsec)
}
// SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning an error with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
func (t *tcpchanconn) SetReadTimeout(nsec int64) error {
return errors.New("ssh: tcpchan: timeout not supported")
}
// SetWriteTimeout sets the time (in nanoseconds) that
// Write will wait to send its data before returning an error with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
func (t *tcpchanconn) SetWriteTimeout(nsec int64) error {
return errors.New("ssh: tcpchan: timeout not supported")
}

View File

@ -7,7 +7,6 @@ package ssh
import ( import (
"bufio" "bufio"
"crypto" "crypto"
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/hmac" "crypto/hmac"
"crypto/subtle" "crypto/subtle"
@ -19,7 +18,10 @@ import (
) )
const ( const (
paddingMultiple = 16 // TODO(dfc) does this need to be configurable? packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
minPacketSize = 16
maxPacketSize = 36000
minPaddingSize = 4 // TODO(huin) should this be configurable?
) )
// filteredConn reduces the set of methods exposed when embeddeding // filteredConn reduces the set of methods exposed when embeddeding
@ -61,7 +63,6 @@ type reader struct {
type writer struct { type writer struct {
*sync.Mutex // protects writer.Writer from concurrent writes *sync.Mutex // protects writer.Writer from concurrent writes
*bufio.Writer *bufio.Writer
paddingMultiple int
rand io.Reader rand io.Reader
common common
} }
@ -82,14 +83,11 @@ type common struct {
func (r *reader) readOnePacket() ([]byte, error) { func (r *reader) readOnePacket() ([]byte, error) {
var lengthBytes = make([]byte, 5) var lengthBytes = make([]byte, 5)
var macSize uint32 var macSize uint32
if _, err := io.ReadFull(r, lengthBytes); err != nil { if _, err := io.ReadFull(r, lengthBytes); err != nil {
return nil, err return nil, err
} }
if r.cipher != nil {
r.cipher.XORKeyStream(lengthBytes, lengthBytes) r.cipher.XORKeyStream(lengthBytes, lengthBytes)
}
if r.mac != nil { if r.mac != nil {
r.mac.Reset() r.mac.Reset()
@ -153,9 +151,9 @@ func (w *writer) writePacket(packet []byte) error {
w.Mutex.Lock() w.Mutex.Lock()
defer w.Mutex.Unlock() defer w.Mutex.Unlock()
paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
if paddingLength < 4 { if paddingLength < 4 {
paddingLength += paddingMultiple paddingLength += packetSizeMultiple
} }
length := len(packet) + 1 + paddingLength length := len(packet) + 1 + paddingLength
@ -188,11 +186,9 @@ func (w *writer) writePacket(packet []byte) error {
// TODO(dfc) lengthBytes, packet and padding should be // TODO(dfc) lengthBytes, packet and padding should be
// subslices of a single buffer // subslices of a single buffer
if w.cipher != nil {
w.cipher.XORKeyStream(lengthBytes, lengthBytes) w.cipher.XORKeyStream(lengthBytes, lengthBytes)
w.cipher.XORKeyStream(packet, packet) w.cipher.XORKeyStream(packet, packet)
w.cipher.XORKeyStream(padding, padding) w.cipher.XORKeyStream(padding, padding)
}
if _, err := w.Write(lengthBytes); err != nil { if _, err := w.Write(lengthBytes); err != nil {
return err return err
@ -227,11 +223,17 @@ func newTransport(conn net.Conn, rand io.Reader) *transport {
return &transport{ return &transport{
reader: reader{ reader: reader{
Reader: bufio.NewReader(conn), Reader: bufio.NewReader(conn),
common: common{
cipher: noneCipher{},
},
}, },
writer: writer{ writer: writer{
Writer: bufio.NewWriter(conn), Writer: bufio.NewWriter(conn),
rand: rand, rand: rand,
Mutex: new(sync.Mutex), Mutex: new(sync.Mutex),
common: common{
cipher: noneCipher{},
},
}, },
filteredConn: conn, filteredConn: conn,
} }
@ -249,29 +251,32 @@ var (
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
) )
// setupKeys sets the cipher and MAC keys from K, H and sessionId, as // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys // described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys). // (to setup server->client keys) or clientKeys (for client->server keys).
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error { func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
h := hashFunc.New() cipherMode := cipherModes[c.cipherAlgo]
blockSize := 16
keySize := 16
macKeySize := 20 macKeySize := 20
iv := make([]byte, blockSize) iv := make([]byte, cipherMode.ivSize)
key := make([]byte, keySize) key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macKeySize) macKey := make([]byte, macKeySize)
h := hashFunc.New()
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h) generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h) generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h) generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)} c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
aes, err := aes.NewCipher(key)
cipher, err := cipherMode.createCipher(key, iv)
if err != nil { if err != nil {
return err return err
} }
c.cipher = cipher.NewCTR(aes, iv)
c.cipher = cipher
return nil return nil
} }

View File

@ -1,356 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import "io"
// Shell contains the state for running a VT100 terminal that is capable of
// reading lines of input.
type Shell struct {
c io.ReadWriter
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewShell runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewShell(c io.ReadWriter, prompt string) *Shell {
return &Shell{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of ss.outBuf
func (ss *Shell) queue(data []byte) {
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
copy(newOutBuf, ss.outBuf)
ss.outBuf = newOutBuf
}
oldLen := len(ss.outBuf)
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
copy(ss.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
// given, logical position in the text.
func (ss *Shell) moveCursorToPos(pos int) {
x := len(ss.prompt) + pos
y := x / ss.termWidth
x = x % ss.termWidth
up := 0
if y < ss.cursorY {
up = ss.cursorY - y
}
down := 0
if y > ss.cursorY {
down = y - ss.cursorY
}
left := 0
if x < ss.cursorX {
left = ss.cursorX - x
}
right := 0
if x > ss.cursorX {
right = x - ss.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
ss.cursorX = x
ss.cursorY = y
ss.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (ss *Shell) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if ss.pos == 0 {
return
}
ss.pos--
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
ss.line = ss.line[:len(ss.line)-1]
ss.writeLine(ss.line[ss.pos:])
ss.moveCursorToPos(ss.pos)
ss.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if ss.pos == 0 {
return
}
ss.pos--
for ss.pos > 0 {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos--
}
for ss.pos > 0 {
if ss.line[ss.pos] == ' ' {
ss.pos++
break
}
ss.pos--
}
ss.moveCursorToPos(ss.pos)
case keyAltRight:
// move right by a word.
for ss.pos < len(ss.line) {
if ss.line[ss.pos] == ' ' {
break
}
ss.pos++
}
for ss.pos < len(ss.line) {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos++
}
ss.moveCursorToPos(ss.pos)
case keyLeft:
if ss.pos == 0 {
return
}
ss.pos--
ss.moveCursorToPos(ss.pos)
case keyRight:
if ss.pos == len(ss.line) {
return
}
ss.pos++
ss.moveCursorToPos(ss.pos)
case keyEnter:
ss.moveCursorToPos(len(ss.line))
ss.queue([]byte("\r\n"))
line = string(ss.line)
ok = true
ss.line = ss.line[:0]
ss.pos = 0
ss.cursorX = 0
ss.cursorY = 0
ss.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(ss.line) == maxLineLength {
return
}
if len(ss.line) == cap(ss.line) {
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
copy(newLine, ss.line)
ss.line = newLine
}
ss.line = ss.line[:len(ss.line)+1]
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
ss.line[ss.pos] = byte(key)
ss.writeLine(ss.line[ss.pos:])
ss.pos++
ss.moveCursorToPos(ss.pos)
}
return
}
func (ss *Shell) writeLine(line []byte) {
for len(line) != 0 {
if ss.cursorX == ss.termWidth {
ss.queue([]byte("\r\n"))
ss.cursorX = 0
ss.cursorY++
if ss.cursorY > ss.maxLine {
ss.maxLine = ss.cursorY
}
}
remainingOnLine := ss.termWidth - ss.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
ss.queue(line[:todo])
ss.cursorX += todo
line = line[todo:]
}
}
func (ss *Shell) Write(buf []byte) (n int, err error) {
return ss.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *Shell) ReadLine() (line string, err error) {
ss.writeLine([]byte(ss.prompt))
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
for {
// ss.remainder is a slice at the beginning of ss.inBuf
// containing a partial key sequence
readBuf := ss.inBuf[len(ss.remainder):]
var n int
n, err = ss.c.Read(readBuf)
if err != nil {
return
}
if err == nil {
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
rest := ss.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = ss.handleKey(key)
}
if len(rest) > 0 {
n := copy(ss.inBuf[:], rest)
ss.remainder = ss.inBuf[:n]
} else {
ss.remainder = nil
}
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
if lineOk {
return
}
continue
}
}
panic("unreachable")
}

View File

@ -2,102 +2,361 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err.String())
// }
// defer terminal.Restore(0, oldState)
package terminal package terminal
import ( import "io"
"io"
"os" // Terminal contains the state for running a VT100 terminal that is capable of
"syscall" // reading lines of input.
type Terminal struct {
c io.ReadWriter
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
return &Terminal{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
) )
// State contains the state of a terminal. // bytesToKey tries to parse a key sequence from b. If successful, it returns
type State struct { // the key and the remainder of the input. Otherwise it returns -1.
termios syscall.Termios func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
} }
// IsTerminal returns true if the given file descriptor is a terminal. if b[0] != keyEscape {
func IsTerminal(fd int) bool { return int(b[0]), b[1:]
var termios syscall.Termios
e := syscall.Tcgetattr(fd, &termios)
return e == 0
} }
// MakeRaw put the terminal connected to the given file descriptor into raw if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
// mode and returns the previous state of the terminal so that it can be switch b[2] {
// restored. case 'A':
func MakeRaw(fd int) (*State, error) { return keyUp, b[3:]
var oldState State case 'B':
if e := syscall.Tcgetattr(fd, &oldState.termios); e != 0 { return keyDown, b[3:]
return nil, os.Errno(e) case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
} }
newState := oldState.termios if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF switch b[5] {
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG case 'C':
if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 { return keyAltRight, b[6:]
return nil, os.Errno(e) case 'D':
return keyAltLeft, b[6:]
}
} }
return &oldState, nil // If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
} }
// Restore restores the terminal connected to the given file descriptor to a return -1, b
// previous state.
func Restore(fd int, state *State) error {
e := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
return os.Errno(e)
} }
// ReadPassword reads a line of input from a terminal without local echo. This // queue appends data to the end of t.outBuf
// is commonly used for inputting passwords and other sensitive data. The slice func (t *Terminal) queue(data []byte) {
// returned does not include the \n. if len(t.outBuf)+len(data) > cap(t.outBuf) {
func ReadPassword(fd int) ([]byte, error) { newOutBuf := make([]byte, len(t.outBuf), 2*(len(t.outBuf)+len(data)))
var oldState syscall.Termios copy(newOutBuf, t.outBuf)
if e := syscall.Tcgetattr(fd, &oldState); e != 0 { t.outBuf = newOutBuf
return nil, os.Errno(e)
} }
newState := oldState oldLen := len(t.outBuf)
newState.Lflag &^= syscall.ECHO t.outBuf = t.outBuf[:len(t.outBuf)+len(data)]
if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 { copy(t.outBuf[oldLen:], data)
return nil, os.Errno(e)
} }
defer func() { var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
}() func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
// given, logical position in the text.
func (t *Terminal) moveCursorToPos(pos int) {
x := len(t.prompt) + pos
y := x / t.termWidth
x = x % t.termWidth
up := 0
if y < t.cursorY {
up = t.cursorY - y
}
down := 0
if y > t.cursorY {
down = y - t.cursorY
}
left := 0
if x < t.cursorX {
left = t.cursorX - x
}
right := 0
if x > t.cursorX {
right = x - t.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
t.cursorX = x
t.cursorY = y
t.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (t *Terminal) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if t.pos == 0 {
return
}
t.pos--
copy(t.line[t.pos:], t.line[1+t.pos:])
t.line = t.line[:len(t.line)-1]
t.writeLine(t.line[t.pos:])
t.moveCursorToPos(t.pos)
t.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if t.pos == 0 {
return
}
t.pos--
for t.pos > 0 {
if t.line[t.pos] != ' ' {
break
}
t.pos--
}
for t.pos > 0 {
if t.line[t.pos] == ' ' {
t.pos++
break
}
t.pos--
}
t.moveCursorToPos(t.pos)
case keyAltRight:
// move right by a word.
for t.pos < len(t.line) {
if t.line[t.pos] == ' ' {
break
}
t.pos++
}
for t.pos < len(t.line) {
if t.line[t.pos] != ' ' {
break
}
t.pos++
}
t.moveCursorToPos(t.pos)
case keyLeft:
if t.pos == 0 {
return
}
t.pos--
t.moveCursorToPos(t.pos)
case keyRight:
if t.pos == len(t.line) {
return
}
t.pos++
t.moveCursorToPos(t.pos)
case keyEnter:
t.moveCursorToPos(len(t.line))
t.queue([]byte("\r\n"))
line = string(t.line)
ok = true
t.line = t.line[:0]
t.pos = 0
t.cursorX = 0
t.cursorY = 0
t.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(t.line) == maxLineLength {
return
}
if len(t.line) == cap(t.line) {
newLine := make([]byte, len(t.line), 2*(1+len(t.line)))
copy(newLine, t.line)
t.line = newLine
}
t.line = t.line[:len(t.line)+1]
copy(t.line[t.pos+1:], t.line[t.pos:])
t.line[t.pos] = byte(key)
t.writeLine(t.line[t.pos:])
t.pos++
t.moveCursorToPos(t.pos)
}
return
}
func (t *Terminal) writeLine(line []byte) {
for len(line) != 0 {
if t.cursorX == t.termWidth {
t.queue([]byte("\r\n"))
t.cursorX = 0
t.cursorY++
if t.cursorY > t.maxLine {
t.maxLine = t.cursorY
}
}
remainingOnLine := t.termWidth - t.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
t.queue(line[:todo])
t.cursorX += todo
line = line[todo:]
}
}
func (t *Terminal) Write(buf []byte) (n int, err error) {
return t.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (t *Terminal) ReadLine() (line string, err error) {
if t.cursorX == 0 {
t.writeLine([]byte(t.prompt))
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
}
var buf [16]byte
var ret []byte
for { for {
n, errno := syscall.Read(fd, buf[:]) // t.remainder is a slice at the beginning of t.inBuf
if errno != 0 { // containing a partial key sequence
return nil, os.Errno(errno) readBuf := t.inBuf[len(t.remainder):]
} var n int
if n == 0 { n, err = t.c.Read(readBuf)
if len(ret) == 0 { if err != nil {
return nil, io.EOF return
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
} }
return ret, nil if err == nil {
t.remainder = t.inBuf[:n+len(t.remainder)]
rest := t.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = t.handleKey(key)
}
if len(rest) > 0 {
n := copy(t.inBuf[:], rest)
t.remainder = t.inBuf[:n]
} else {
t.remainder = nil
}
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
if lineOk {
return
}
continue
}
}
panic("unreachable")
}
func (t *Terminal) SetSize(width, height int) {
t.termWidth, t.termHeight = width, height
} }

View File

@ -41,7 +41,7 @@ func (c *MockTerminal) Write(data []byte) (n int, err error) {
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
c := &MockTerminal{} c := &MockTerminal{}
ss := NewShell(c, "> ") ss := NewTerminal(c, "> ")
line, err := ss.ReadLine() line, err := ss.ReadLine()
if line != "" { if line != "" {
t.Errorf("Expected empty line but got: %s", line) t.Errorf("Expected empty line but got: %s", line)
@ -95,7 +95,7 @@ func TestKeyPresses(t *testing.T) {
toSend: []byte(test.in), toSend: []byte(test.in),
bytesPerRead: j, bytesPerRead: j,
} }
ss := NewShell(c, "> ") ss := NewTerminal(c, "> ")
line, err := ss.ReadLine() line, err := ss.ReadLine()
if line != test.line { if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)

View File

@ -0,0 +1,102 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err.String())
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"io"
"syscall"
)
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var termios syscall.Termios
err := syscall.Tcgetattr(fd, &termios)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var oldState State
if err := syscall.Tcgetattr(fd, &oldState.termios); err != nil {
return nil, err
}
newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
return nil, err
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
err := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
return err
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios
if err := syscall.Tcgetattr(fd, &oldState); err != nil {
return nil, err
}
newState := oldState
newState.Lflag &^= syscall.ECHO
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
return nil, err
}
defer func() {
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
}()
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}

View File

@ -357,6 +357,10 @@ var fmttests = []struct {
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`}, {"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`}, {"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`}, {"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
{"%#v", []int(nil), `[]int(nil)`},
{"%#v", []int{}, `[]int{}`},
{"%#v", map[int]byte(nil), `map[int] uint8(nil)`},
{"%#v", map[int]byte{}, `map[int] uint8{}`},
// slices with other formats // slices with other formats
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`}, {"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},

View File

@ -795,6 +795,10 @@ BigSwitch:
case reflect.Map: case reflect.Map:
if goSyntax { if goSyntax {
p.buf.WriteString(f.Type().String()) p.buf.WriteString(f.Type().String())
if f.IsNil() {
p.buf.WriteString("(nil)")
break
}
p.buf.WriteByte('{') p.buf.WriteByte('{')
} else { } else {
p.buf.Write(mapBytes) p.buf.Write(mapBytes)
@ -873,6 +877,10 @@ BigSwitch:
} }
if goSyntax { if goSyntax {
p.buf.WriteString(value.Type().String()) p.buf.WriteString(value.Type().String())
if f.IsNil() {
p.buf.WriteString("(nil)")
break
}
p.buf.WriteByte('{') p.buf.WriteByte('{')
} else { } else {
p.buf.WriteByte('[') p.buf.WriteByte('[')

View File

@ -324,7 +324,7 @@ var x, y Xs
var z IntString var z IntString
var multiTests = []ScanfMultiTest{ var multiTests = []ScanfMultiTest{
{"", "", nil, nil, ""}, {"", "", []interface{}{}, []interface{}{}, ""},
{"%d", "23", args(&i), args(23), ""}, {"%d", "23", args(&i), args(23), ""},
{"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""}, {"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""},
{"%2d%3d", "44555", args(&i, &j), args(44, 555), ""}, {"%2d%3d", "44555", args(&i, &j), args(44, 555), ""},
@ -378,7 +378,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{}
} }
val := v.Interface() val := v.Interface()
if !reflect.DeepEqual(val, test.out) { if !reflect.DeepEqual(val, test.out) {
t.Errorf("%s scanning %q: expected %v got %v, type %T", name, test.text, test.out, val, val) t.Errorf("%s scanning %q: expected %#v got %#v, type %T", name, test.text, test.out, val, val)
} }
} }
} }
@ -417,7 +417,7 @@ func TestScanf(t *testing.T) {
} }
val := v.Interface() val := v.Interface()
if !reflect.DeepEqual(val, test.out) { if !reflect.DeepEqual(val, test.out) {
t.Errorf("scanning (%q, %q): expected %v got %v, type %T", test.format, test.text, test.out, val, val) t.Errorf("scanning (%q, %q): expected %#v got %#v, type %T", test.format, test.text, test.out, val, val)
} }
} }
} }
@ -520,7 +520,7 @@ func testScanfMulti(name string, t *testing.T) {
} }
result := resultVal.Interface() result := resultVal.Interface()
if !reflect.DeepEqual(result, test.out) { if !reflect.DeepEqual(result, test.out) {
t.Errorf("scanning (%q, %q): expected %v got %v", test.format, test.text, test.out, result) t.Errorf("scanning (%q, %q): expected %#v got %#v", test.format, test.text, test.out, result)
} }
} }
} }

View File

@ -412,29 +412,29 @@ func (x *ChanType) End() token.Pos { return x.Value.End() }
// exprNode() ensures that only expression/type nodes can be // exprNode() ensures that only expression/type nodes can be
// assigned to an ExprNode. // assigned to an ExprNode.
// //
func (x *BadExpr) exprNode() {} func (*BadExpr) exprNode() {}
func (x *Ident) exprNode() {} func (*Ident) exprNode() {}
func (x *Ellipsis) exprNode() {} func (*Ellipsis) exprNode() {}
func (x *BasicLit) exprNode() {} func (*BasicLit) exprNode() {}
func (x *FuncLit) exprNode() {} func (*FuncLit) exprNode() {}
func (x *CompositeLit) exprNode() {} func (*CompositeLit) exprNode() {}
func (x *ParenExpr) exprNode() {} func (*ParenExpr) exprNode() {}
func (x *SelectorExpr) exprNode() {} func (*SelectorExpr) exprNode() {}
func (x *IndexExpr) exprNode() {} func (*IndexExpr) exprNode() {}
func (x *SliceExpr) exprNode() {} func (*SliceExpr) exprNode() {}
func (x *TypeAssertExpr) exprNode() {} func (*TypeAssertExpr) exprNode() {}
func (x *CallExpr) exprNode() {} func (*CallExpr) exprNode() {}
func (x *StarExpr) exprNode() {} func (*StarExpr) exprNode() {}
func (x *UnaryExpr) exprNode() {} func (*UnaryExpr) exprNode() {}
func (x *BinaryExpr) exprNode() {} func (*BinaryExpr) exprNode() {}
func (x *KeyValueExpr) exprNode() {} func (*KeyValueExpr) exprNode() {}
func (x *ArrayType) exprNode() {} func (*ArrayType) exprNode() {}
func (x *StructType) exprNode() {} func (*StructType) exprNode() {}
func (x *FuncType) exprNode() {} func (*FuncType) exprNode() {}
func (x *InterfaceType) exprNode() {} func (*InterfaceType) exprNode() {}
func (x *MapType) exprNode() {} func (*MapType) exprNode() {}
func (x *ChanType) exprNode() {} func (*ChanType) exprNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Convenience functions for Idents // Convenience functions for Idents
@ -711,27 +711,27 @@ func (s *RangeStmt) End() token.Pos { return s.Body.End() }
// stmtNode() ensures that only statement nodes can be // stmtNode() ensures that only statement nodes can be
// assigned to a StmtNode. // assigned to a StmtNode.
// //
func (s *BadStmt) stmtNode() {} func (*BadStmt) stmtNode() {}
func (s *DeclStmt) stmtNode() {} func (*DeclStmt) stmtNode() {}
func (s *EmptyStmt) stmtNode() {} func (*EmptyStmt) stmtNode() {}
func (s *LabeledStmt) stmtNode() {} func (*LabeledStmt) stmtNode() {}
func (s *ExprStmt) stmtNode() {} func (*ExprStmt) stmtNode() {}
func (s *SendStmt) stmtNode() {} func (*SendStmt) stmtNode() {}
func (s *IncDecStmt) stmtNode() {} func (*IncDecStmt) stmtNode() {}
func (s *AssignStmt) stmtNode() {} func (*AssignStmt) stmtNode() {}
func (s *GoStmt) stmtNode() {} func (*GoStmt) stmtNode() {}
func (s *DeferStmt) stmtNode() {} func (*DeferStmt) stmtNode() {}
func (s *ReturnStmt) stmtNode() {} func (*ReturnStmt) stmtNode() {}
func (s *BranchStmt) stmtNode() {} func (*BranchStmt) stmtNode() {}
func (s *BlockStmt) stmtNode() {} func (*BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {} func (*IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {} func (*CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {} func (*SwitchStmt) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {} func (*TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {} func (*CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {} func (*SelectStmt) stmtNode() {}
func (s *ForStmt) stmtNode() {} func (*ForStmt) stmtNode() {}
func (s *RangeStmt) stmtNode() {} func (*RangeStmt) stmtNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Declarations // Declarations
@ -807,9 +807,9 @@ func (s *TypeSpec) End() token.Pos { return s.Type.End() }
// specNode() ensures that only spec nodes can be // specNode() ensures that only spec nodes can be
// assigned to a Spec. // assigned to a Spec.
// //
func (s *ImportSpec) specNode() {} func (*ImportSpec) specNode() {}
func (s *ValueSpec) specNode() {} func (*ValueSpec) specNode() {}
func (s *TypeSpec) specNode() {} func (*TypeSpec) specNode() {}
// A declaration is represented by one of the following declaration nodes. // A declaration is represented by one of the following declaration nodes.
// //
@ -875,9 +875,9 @@ func (d *FuncDecl) End() token.Pos {
// declNode() ensures that only declaration nodes can be // declNode() ensures that only declaration nodes can be
// assigned to a DeclNode. // assigned to a DeclNode.
// //
func (d *BadDecl) declNode() {} func (*BadDecl) declNode() {}
func (d *GenDecl) declNode() {} func (*GenDecl) declNode() {}
func (d *FuncDecl) declNode() {} func (*FuncDecl) declNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Files and packages // Files and packages

View File

@ -24,7 +24,7 @@ func exportFilter(name string) bool {
// it returns false otherwise. // it returns false otherwise.
// //
func FileExports(src *File) bool { func FileExports(src *File) bool {
return FilterFile(src, exportFilter) return filterFile(src, exportFilter, true)
} }
// PackageExports trims the AST for a Go package in place such that // PackageExports trims the AST for a Go package in place such that
@ -35,7 +35,7 @@ func FileExports(src *File) bool {
// it returns false otherwise. // it returns false otherwise.
// //
func PackageExports(pkg *Package) bool { func PackageExports(pkg *Package) bool {
return FilterPackage(pkg, exportFilter) return filterPackage(pkg, exportFilter, true)
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -72,7 +72,7 @@ func fieldName(x Expr) *Ident {
return nil return nil
} }
func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
if fields == nil { if fields == nil {
return false return false
} }
@ -93,8 +93,8 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
keepField = len(f.Names) > 0 keepField = len(f.Names) > 0
} }
if keepField { if keepField {
if filter == exportFilter { if export {
filterType(f.Type, filter) filterType(f.Type, filter, export)
} }
list[j] = f list[j] = f
j++ j++
@ -107,84 +107,84 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
return return
} }
func filterParamList(fields *FieldList, filter Filter) bool { func filterParamList(fields *FieldList, filter Filter, export bool) bool {
if fields == nil { if fields == nil {
return false return false
} }
var b bool var b bool
for _, f := range fields.List { for _, f := range fields.List {
if filterType(f.Type, filter) { if filterType(f.Type, filter, export) {
b = true b = true
} }
} }
return b return b
} }
func filterType(typ Expr, f Filter) bool { func filterType(typ Expr, f Filter, export bool) bool {
switch t := typ.(type) { switch t := typ.(type) {
case *Ident: case *Ident:
return f(t.Name) return f(t.Name)
case *ParenExpr: case *ParenExpr:
return filterType(t.X, f) return filterType(t.X, f, export)
case *ArrayType: case *ArrayType:
return filterType(t.Elt, f) return filterType(t.Elt, f, export)
case *StructType: case *StructType:
if filterFieldList(t.Fields, f) { if filterFieldList(t.Fields, f, export) {
t.Incomplete = true t.Incomplete = true
} }
return len(t.Fields.List) > 0 return len(t.Fields.List) > 0
case *FuncType: case *FuncType:
b1 := filterParamList(t.Params, f) b1 := filterParamList(t.Params, f, export)
b2 := filterParamList(t.Results, f) b2 := filterParamList(t.Results, f, export)
return b1 || b2 return b1 || b2
case *InterfaceType: case *InterfaceType:
if filterFieldList(t.Methods, f) { if filterFieldList(t.Methods, f, export) {
t.Incomplete = true t.Incomplete = true
} }
return len(t.Methods.List) > 0 return len(t.Methods.List) > 0
case *MapType: case *MapType:
b1 := filterType(t.Key, f) b1 := filterType(t.Key, f, export)
b2 := filterType(t.Value, f) b2 := filterType(t.Value, f, export)
return b1 || b2 return b1 || b2
case *ChanType: case *ChanType:
return filterType(t.Value, f) return filterType(t.Value, f, export)
} }
return false return false
} }
func filterSpec(spec Spec, f Filter) bool { func filterSpec(spec Spec, f Filter, export bool) bool {
switch s := spec.(type) { switch s := spec.(type) {
case *ValueSpec: case *ValueSpec:
s.Names = filterIdentList(s.Names, f) s.Names = filterIdentList(s.Names, f)
if len(s.Names) > 0 { if len(s.Names) > 0 {
if f == exportFilter { if export {
filterType(s.Type, f) filterType(s.Type, f, export)
} }
return true return true
} }
case *TypeSpec: case *TypeSpec:
if f(s.Name.Name) { if f(s.Name.Name) {
if f == exportFilter { if export {
filterType(s.Type, f) filterType(s.Type, f, export)
} }
return true return true
} }
if f != exportFilter { if !export {
// For general filtering (not just exports), // For general filtering (not just exports),
// filter type even if name is not filtered // filter type even if name is not filtered
// out. // out.
// If the type contains filtered elements, // If the type contains filtered elements,
// keep the declaration. // keep the declaration.
return filterType(s.Type, f) return filterType(s.Type, f, export)
} }
} }
return false return false
} }
func filterSpecList(list []Spec, f Filter) []Spec { func filterSpecList(list []Spec, f Filter, export bool) []Spec {
j := 0 j := 0
for _, s := range list { for _, s := range list {
if filterSpec(s, f) { if filterSpec(s, f, export) {
list[j] = s list[j] = s
j++ j++
} }
@ -200,9 +200,13 @@ func filterSpecList(list []Spec, f Filter) []Spec {
// filtering; it returns false otherwise. // filtering; it returns false otherwise.
// //
func FilterDecl(decl Decl, f Filter) bool { func FilterDecl(decl Decl, f Filter) bool {
return filterDecl(decl, f, false)
}
func filterDecl(decl Decl, f Filter, export bool) bool {
switch d := decl.(type) { switch d := decl.(type) {
case *GenDecl: case *GenDecl:
d.Specs = filterSpecList(d.Specs, f) d.Specs = filterSpecList(d.Specs, f, export)
return len(d.Specs) > 0 return len(d.Specs) > 0
case *FuncDecl: case *FuncDecl:
return f(d.Name.Name) return f(d.Name.Name)
@ -221,9 +225,13 @@ func FilterDecl(decl Decl, f Filter) bool {
// left after filtering; it returns false otherwise. // left after filtering; it returns false otherwise.
// //
func FilterFile(src *File, f Filter) bool { func FilterFile(src *File, f Filter) bool {
return filterFile(src, f, false)
}
func filterFile(src *File, f Filter, export bool) bool {
j := 0 j := 0
for _, d := range src.Decls { for _, d := range src.Decls {
if FilterDecl(d, f) { if filterDecl(d, f, export) {
src.Decls[j] = d src.Decls[j] = d
j++ j++
} }
@ -244,9 +252,13 @@ func FilterFile(src *File, f Filter) bool {
// left after filtering; it returns false otherwise. // left after filtering; it returns false otherwise.
// //
func FilterPackage(pkg *Package, f Filter) bool { func FilterPackage(pkg *Package, f Filter) bool {
return filterPackage(pkg, f, false)
}
func filterPackage(pkg *Package, f Filter, export bool) bool {
hasDecls := false hasDecls := false
for _, src := range pkg.Files { for _, src := range pkg.Files {
if FilterFile(src, f) { if filterFile(src, f, export) {
hasDecls = true hasDecls = true
} }
} }

View File

@ -40,6 +40,7 @@ var buildPkgs = []struct {
GoFiles: []string{"main.go"}, GoFiles: []string{"main.go"},
Package: "main", Package: "main",
Imports: []string{"go/build/pkgtest"}, Imports: []string{"go/build/pkgtest"},
TestImports: []string{},
}, },
}, },
{ {
@ -48,6 +49,7 @@ var buildPkgs = []struct {
CgoFiles: []string{"cgotest.go"}, CgoFiles: []string{"cgotest.go"},
CFiles: []string{"cgotest.c"}, CFiles: []string{"cgotest.c"},
Imports: []string{"C", "unsafe"}, Imports: []string{"C", "unsafe"},
TestImports: []string{},
Package: "cgotest", Package: "cgotest",
}, },
}, },

View File

@ -13,6 +13,8 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings"
"text/tabwriter" "text/tabwriter"
) )
@ -244,6 +246,8 @@ func (p *printer) writeItem(pos token.Position, data string) {
p.last = p.pos p.last = p.pos
} }
const linePrefix = "//line "
// writeCommentPrefix writes the whitespace before a comment. // writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of // If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely. // it as is likely to help position the comment nicely.
@ -252,7 +256,7 @@ func (p *printer) writeItem(pos token.Position, data string) {
// a group of comments (or nil), and isKeyword indicates if the // a group of comments (or nil), and isKeyword indicates if the
// next item is a keyword. // next item is a keyword.
// //
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, isKeyword bool) { func (p *printer) writeCommentPrefix(pos, next token.Position, prev, comment *ast.Comment, isKeyword bool) {
if p.written == 0 { if p.written == 0 {
// the comment is the first item to be printed - don't write any whitespace // the comment is the first item to be printed - don't write any whitespace
return return
@ -337,6 +341,13 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
} }
p.writeWhitespace(j) p.writeWhitespace(j)
} }
// turn off indent if we're about to print a line directive.
indent := p.indent
if strings.HasPrefix(comment.Text, linePrefix) {
p.indent = 0
}
// use formfeeds to break columns before a comment; // use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate // this is analogous to using formfeeds to separate
// individual lines of /*-style comments - but make // individual lines of /*-style comments - but make
@ -347,6 +358,7 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
n = 1 n = 1
} }
p.writeNewlines(n, true) p.writeNewlines(n, true)
p.indent = indent
} }
} }
@ -526,6 +538,26 @@ func stripCommonPrefix(lines [][]byte) {
func (p *printer) writeComment(comment *ast.Comment) { func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text text := comment.Text
if strings.HasPrefix(text, linePrefix) {
pos := strings.TrimSpace(text[len(linePrefix):])
i := strings.LastIndex(pos, ":")
if i >= 0 {
// The line directive we are about to print changed
// the Filename and Line number used by go/token
// as it was reading the input originally.
// In order to match the original input, we have to
// update our own idea of the file and line number
// accordingly, after printing the directive.
file := pos[:i]
line, _ := strconv.Atoi(string(pos[i+1:]))
defer func() {
p.pos.Filename = string(file)
p.pos.Line = line
p.pos.Column = 1
}()
}
}
// shortcut common case of //-style comments // shortcut common case of //-style comments
if text[1] == '/' { if text[1] == '/' {
p.writeItem(p.fset.Position(comment.Pos()), p.escape(text)) p.writeItem(p.fset.Position(comment.Pos()), p.escape(text))
@ -599,7 +631,7 @@ func (p *printer) intersperseComments(next token.Position, tok token.Token) (dro
var last *ast.Comment var last *ast.Comment
for ; p.commentBefore(next); p.cindex++ { for ; p.commentBefore(next); p.cindex++ {
for _, c := range p.comments[p.cindex].List { for _, c := range p.comments[p.cindex].List {
p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, tok.IsKeyword()) p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, c, tok.IsKeyword())
p.writeComment(c) p.writeComment(c)
last = c last = c
} }

View File

@ -37,7 +37,7 @@ lower-cased, and attributes are collected into a []Attribute. For example:
for { for {
if z.Next() == html.ErrorToken { if z.Next() == html.ErrorToken {
// Returning io.EOF indicates success. // Returning io.EOF indicates success.
return z.Error() return z.Err()
} }
emitToken(z.Token()) emitToken(z.Token())
} }
@ -51,7 +51,7 @@ call to Next. For example, to extract an HTML page's anchor text:
tt := z.Next() tt := z.Next()
switch tt { switch tt {
case ErrorToken: case ErrorToken:
return z.Error() return z.Err()
case TextToken: case TextToken:
if depth > 0 { if depth > 0 {
// emitBytes should copy the []byte it receives, // emitBytes should copy the []byte it receives,

File diff suppressed because it is too large Load Diff

View File

@ -133,8 +133,8 @@ func TestParser(t *testing.T) {
n int n int
}{ }{
// TODO(nigeltao): Process all the test cases from all the .dat files. // TODO(nigeltao): Process all the test cases from all the .dat files.
{"tests1.dat", 92}, {"tests1.dat", -1},
{"tests2.dat", 0}, {"tests2.dat", 43},
{"tests3.dat", 0}, {"tests3.dat", 0},
} }
for _, tf := range testFiles { for _, tf := range testFiles {
@ -213,4 +213,8 @@ var renderTestBlacklist = map[string]bool{
// More cases of <a> being reparented: // More cases of <a> being reparented:
`<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true, `<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true,
`<a><table><a></table><p><a><div><a>`: true, `<a><table><a></table><p><a><div><a>`: true,
`<a><table><td><a><table></table><a></tr><a></table><a>`: true,
// A <plaintext> element is reparented, putting it before a table.
// A <plaintext> element can't have anything after it in HTML.
`<table><plaintext><td>`: true,
} }

View File

@ -52,7 +52,19 @@ func Render(w io.Writer, n *Node) error {
return buf.Flush() return buf.Flush()
} }
// plaintextAbort is returned from render1 when a <plaintext> element
// has been rendered. No more end tags should be rendered after that.
var plaintextAbort = errors.New("html: internal error (plaintext abort)")
func render(w writer, n *Node) error { func render(w writer, n *Node) error {
err := render1(w, n)
if err == plaintextAbort {
err = nil
}
return err
}
func render1(w writer, n *Node) error {
// Render non-element nodes; these are the easy cases. // Render non-element nodes; these are the easy cases.
switch n.Type { switch n.Type {
case ErrorNode: case ErrorNode:
@ -61,7 +73,7 @@ func render(w writer, n *Node) error {
return escape(w, n.Data) return escape(w, n.Data)
case DocumentNode: case DocumentNode:
for _, c := range n.Child { for _, c := range n.Child {
if err := render(w, c); err != nil { if err := render1(w, c); err != nil {
return err return err
} }
} }
@ -128,7 +140,7 @@ func render(w writer, n *Node) error {
// Render any child nodes. // Render any child nodes.
switch n.Data { switch n.Data {
case "noembed", "noframes", "noscript", "script", "style": case "noembed", "noframes", "noscript", "plaintext", "script", "style":
for _, c := range n.Child { for _, c := range n.Child {
if c.Type != TextNode { if c.Type != TextNode {
return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data) return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data)
@ -137,18 +149,23 @@ func render(w writer, n *Node) error {
return err return err
} }
} }
if n.Data == "plaintext" {
// Don't render anything else. <plaintext> must be the
// last element in the file, with no closing tag.
return plaintextAbort
}
case "textarea", "title": case "textarea", "title":
for _, c := range n.Child { for _, c := range n.Child {
if c.Type != TextNode { if c.Type != TextNode {
return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data) return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data)
} }
if err := render(w, c); err != nil { if err := render1(w, c); err != nil {
return err return err
} }
} }
default: default:
for _, c := range n.Child { for _, c := range n.Child {
if err := render(w, c); err != nil { if err := render1(w, c); err != nil {
return err return err
} }
} }

View File

@ -6,6 +6,7 @@ package template
import ( import (
"fmt" "fmt"
"reflect"
) )
// Strings of content from a trusted source. // Strings of content from a trusted source.
@ -70,10 +71,25 @@ const (
contentTypeUnsafe contentTypeUnsafe
) )
// indirect returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil).
func indirect(a interface{}) interface{} {
if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr {
// Avoid creating a reflect.Value if it's not a pointer.
return a
}
v := reflect.ValueOf(a)
for v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// stringify converts its arguments to a string and the type of the content. // stringify converts its arguments to a string and the type of the content.
// All pointers are dereferenced, as in the text/template package.
func stringify(args ...interface{}) (string, contentType) { func stringify(args ...interface{}) (string, contentType) {
if len(args) == 1 { if len(args) == 1 {
switch s := args[0].(type) { switch s := indirect(args[0]).(type) {
case string: case string:
return s, contentTypePlain return s, contentTypePlain
case CSS: case CSS:
@ -90,5 +106,8 @@ func stringify(args ...interface{}) (string, contentType) {
return string(s), contentTypeURL return string(s), contentTypeURL
} }
} }
for i, arg := range args {
args[i] = indirect(arg)
}
return fmt.Sprint(args...), contentTypePlain return fmt.Sprint(args...), contentTypePlain
} }

View File

@ -28,7 +28,7 @@ func (x *goodMarshaler) MarshalJSON() ([]byte, error) {
} }
func TestEscape(t *testing.T) { func TestEscape(t *testing.T) {
var data = struct { data := struct {
F, T bool F, T bool
C, G, H string C, G, H string
A, E []string A, E []string
@ -50,6 +50,7 @@ func TestEscape(t *testing.T) {
Z: nil, Z: nil,
W: HTML(`&iexcl;<b class="foo">Hello</b>, <textarea>O'World</textarea>!`), W: HTML(`&iexcl;<b class="foo">Hello</b>, <textarea>O'World</textarea>!`),
} }
pdata := &data
tests := []struct { tests := []struct {
name string name string
@ -668,6 +669,15 @@ func TestEscape(t *testing.T) {
t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g) t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue continue
} }
b.Reset()
if err := tmpl.Execute(b, pdata); err != nil {
t.Errorf("%s: template execution failed for pointer: %s", test.name, err)
continue
}
if w, g := test.output, b.String(); w != g {
t.Errorf("%s: escaped output for pointer: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue
}
} }
} }
@ -1605,6 +1615,29 @@ func TestRedundantFuncs(t *testing.T) {
} }
} }
func TestIndirectPrint(t *testing.T) {
a := 3
ap := &a
b := "hello"
bp := &b
bpp := &bp
tmpl := Must(New("t").Parse(`{{.}}`))
var buf bytes.Buffer
err := tmpl.Execute(&buf, ap)
if err != nil {
t.Errorf("Unexpected error: %s", err)
} else if buf.String() != "3" {
t.Errorf(`Expected "3"; got %q`, buf.String())
}
buf.Reset()
err = tmpl.Execute(&buf, bpp)
if err != nil {
t.Errorf("Unexpected error: %s", err)
} else if buf.String() != "hello" {
t.Errorf(`Expected "hello"; got %q`, buf.String())
}
}
func BenchmarkEscapedExecute(b *testing.B) { func BenchmarkEscapedExecute(b *testing.B) {
tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`)) tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`))
var buf bytes.Buffer var buf bytes.Buffer

View File

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
) )
@ -117,12 +118,24 @@ var regexpPrecederKeywords = map[string]bool{
"void": true, "void": true,
} }
var jsonMarshalType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
// indirectToJSONMarshaler returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
func indirectToJSONMarshaler(a interface{}) interface{} {
v := reflect.ValueOf(a)
for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has // jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
// nether side-effects nor free variables outside (NaN, Infinity). // neither side-effects nor free variables outside (NaN, Infinity).
func jsValEscaper(args ...interface{}) string { func jsValEscaper(args ...interface{}) string {
var a interface{} var a interface{}
if len(args) == 1 { if len(args) == 1 {
a = args[0] a = indirectToJSONMarshaler(args[0])
switch t := a.(type) { switch t := a.(type) {
case JS: case JS:
return string(t) return string(t)
@ -135,6 +148,9 @@ func jsValEscaper(args ...interface{}) string {
a = t.String() a = t.String()
} }
} else { } else {
for i, arg := range args {
args[i] = indirectToJSONMarshaler(arg)
}
a = fmt.Sprint(args...) a = fmt.Sprint(args...)
} }
// TODO: detect cycles before calling Marshal which loops infinitely on // TODO: detect cycles before calling Marshal which loops infinitely on

View File

@ -401,14 +401,14 @@ func (z *Tokenizer) readStartTag() TokenType {
break break
} }
} }
// Any "<noembed>", "<noframes>", "<noscript>", "<script>", "<style>", // Any "<noembed>", "<noframes>", "<noscript>", "<plaintext", "<script>", "<style>",
// "<textarea>" or "<title>" tag flags the tokenizer's next token as raw. // "<textarea>" or "<title>" tag flags the tokenizer's next token as raw.
// The tag name lengths of these special cases ranges in [5, 8]. // The tag name lengths of these special cases ranges in [5, 9].
if x := z.data.end - z.data.start; 5 <= x && x <= 8 { if x := z.data.end - z.data.start; 5 <= x && x <= 9 {
switch z.buf[z.data.start] { switch z.buf[z.data.start] {
case 'n', 's', 't', 'N', 'S', 'T': case 'n', 'p', 's', 't', 'N', 'P', 'S', 'T':
switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s { switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s {
case "noembed", "noframes", "noscript", "script", "style", "textarea", "title": case "noembed", "noframes", "noscript", "plaintext", "script", "style", "textarea", "title":
z.rawTag = s z.rawTag = s
} }
} }
@ -551,10 +551,20 @@ func (z *Tokenizer) Next() TokenType {
z.data.start = z.raw.end z.data.start = z.raw.end
z.data.end = z.raw.end z.data.end = z.raw.end
if z.rawTag != "" { if z.rawTag != "" {
if z.rawTag == "plaintext" {
// Read everything up to EOF.
for z.err == nil {
z.readByte()
}
z.textIsRaw = true
} else {
z.readRawOrRCDATA() z.readRawOrRCDATA()
}
if z.data.end > z.data.start {
z.tt = TextToken z.tt = TextToken
return z.tt return z.tt
} }
}
z.textIsRaw = false z.textIsRaw = false
loop: loop:

View File

@ -4,10 +4,7 @@
package tiff package tiff
import ( import "io"
"io"
"os"
)
// buffer buffers an io.Reader to satisfy io.ReaderAt. // buffer buffers an io.Reader to satisfy io.ReaderAt.
type buffer struct { type buffer struct {
@ -19,7 +16,7 @@ func (b *buffer) ReadAt(p []byte, off int64) (int, error) {
o := int(off) o := int(off)
end := o + len(p) end := o + len(p)
if int64(end) != off+int64(len(p)) { if int64(end) != off+int64(len(p)) {
return 0, os.EINVAL return 0, io.ErrUnexpectedEOF
} }
m := len(b.buf) m := len(b.buf)

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"time"
) )
// Random number state, accessed without lock; racy but harmless. // Random number state, accessed without lock; racy but harmless.
@ -17,8 +18,7 @@ import (
var rand uint32 var rand uint32
func reseed() uint32 { func reseed() uint32 {
sec, nsec, _ := os.Time() return uint32(time.Nanoseconds() + int64(os.Getpid()))
return uint32(sec*1e9 + nsec + int64(os.Getpid()))
} }
func nextSuffix() string { func nextSuffix() string {

View File

@ -8,6 +8,7 @@
package syslog package syslog
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
@ -75,7 +76,7 @@ func Dial(network, raddr string, priority Priority, prefix string) (w *Writer, e
// Write sends a log message to the syslog daemon. // Write sends a log message to the syslog daemon.
func (w *Writer) Write(b []byte) (int, error) { func (w *Writer) Write(b []byte) (int, error) {
if w.priority > LOG_DEBUG || w.priority < LOG_EMERG { if w.priority > LOG_DEBUG || w.priority < LOG_EMERG {
return 0, os.EINVAL return 0, errors.New("log/syslog: invalid priority")
} }
return w.conn.writeBytes(w.priority, w.prefix, b) return w.conn.writeBytes(w.priority, w.prefix, b)
} }

View File

@ -176,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int {
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see QuoRem for more details. // Rem implements truncated modulus (like Go); see QuoRem for more details.
func (z *Int) Rem(x, y *Int) *Int { func (z *Int) Rem(x, y *Int) *Int {
_, z.abs = nat{}.div(z.abs, x.abs, y.abs) _, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z return z
} }
@ -678,14 +678,14 @@ func (z *Int) Bit(i int) uint {
panic("negative bit index") panic("negative bit index")
} }
if z.neg { if z.neg {
t := nat{}.sub(z.abs, natOne) t := nat(nil).sub(z.abs, natOne)
return t.bit(uint(i)) ^ 1 return t.bit(uint(i)) ^ 1
} }
return z.abs.bit(uint(i)) return z.abs.bit(uint(i))
} }
// SetBit sets the i'th bit of z to bit and returns z. // SetBit sets z to x, with x's i'th bit set to b (0 or 1).
// That is, if bit is 1 SetBit sets z = x | (1 << i); // That is, if bit is 1 SetBit sets z = x | (1 << i);
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1, // if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
// SetBit will panic. // SetBit will panic.
@ -710,8 +710,8 @@ func (z *Int) And(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1) // (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y1), natOne) z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative z.neg = true // z cannot be zero if x and y are negative
return z return z
@ -729,7 +729,7 @@ func (z *Int) And(x, y *Int) *Int {
} }
// x & (-y) == x & ^(y-1) == x &^ (y-1) // x & (-y) == x & ^(y-1) == x &^ (y-1)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(x.abs, y1) z.abs = z.abs.andNot(x.abs, y1)
z.neg = false z.neg = false
return z return z
@ -740,8 +740,8 @@ func (z *Int) AndNot(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1) // (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(y1, x1) z.abs = z.abs.andNot(y1, x1)
z.neg = false z.neg = false
return z return z
@ -755,14 +755,14 @@ func (z *Int) AndNot(x, y *Int) *Int {
if x.neg { if x.neg {
// (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1) // (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne) z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
z.neg = true // z cannot be zero if x is negative and y is positive z.neg = true // z cannot be zero if x is negative and y is positive
return z return z
} }
// x &^ (-y) == x &^ ^(y-1) == x & (y-1) // x &^ (-y) == x &^ ^(y-1) == x & (y-1)
y1 := nat{}.add(y.abs, natOne) y1 := nat(nil).add(y.abs, natOne)
z.abs = z.abs.and(x.abs, y1) z.abs = z.abs.and(x.abs, y1)
z.neg = false z.neg = false
return z return z
@ -773,8 +773,8 @@ func (z *Int) Or(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1) // (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.and(x1, y1), natOne) z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative z.neg = true // z cannot be zero if x and y are negative
return z return z
@ -792,7 +792,7 @@ func (z *Int) Or(x, y *Int) *Int {
} }
// x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1) // x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne) z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
z.neg = true // z cannot be zero if one of x or y is negative z.neg = true // z cannot be zero if one of x or y is negative
return z return z
@ -803,8 +803,8 @@ func (z *Int) Xor(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1) // (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.xor(x1, y1) z.abs = z.abs.xor(x1, y1)
z.neg = false z.neg = false
return z return z
@ -822,7 +822,7 @@ func (z *Int) Xor(x, y *Int) *Int {
} }
// x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1) // x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne) z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
z.neg = true // z cannot be zero if only one of x or y is negative z.neg = true // z cannot be zero if only one of x or y is negative
return z return z

View File

@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat {
case a == b: case a == b:
return z.setUint64(a) return z.setUint64(a)
case a+1 == b: case a+1 == b:
return z.mul(nat{}.setUint64(a), nat{}.setUint64(b)) return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
} }
m := (a + b) / 2 m := (a + b) / 2
return z.mul(nat{}.mulRange(a, m), nat{}.mulRange(m+1, b)) return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
} }
// q = (x-r)/y, with 0 <= r < y // q = (x-r)/y, with 0 <= r < y
@ -785,7 +785,7 @@ func (x nat) string(charset string) string {
} }
// preserve x, create local copy for use in repeated divisions // preserve x, create local copy for use in repeated divisions
q := nat{}.set(x) q := nat(nil).set(x)
var r Word var r Word
// convert // convert
@ -1191,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool {
return false return false
} }
nm1 := nat{}.sub(n, natOne) nm1 := nat(nil).sub(n, natOne)
// 1<<k * q = nm1; // 1<<k * q = nm1;
q, k := nm1.powersOfTwoDecompose() q, k := nm1.powersOfTwoDecompose()
nm3 := nat{}.sub(nm1, natTwo) nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0]))) rand := rand.New(rand.NewSource(int64(n[0])))
var x, y, quotient nat var x, y, quotient nat

View File

@ -16,9 +16,9 @@ var cmpTests = []struct {
r int r int
}{ }{
{nil, nil, 0}, {nil, nil, 0},
{nil, nat{}, 0}, {nil, nat(nil), 0},
{nat{}, nil, 0}, {nat(nil), nil, 0},
{nat{}, nat{}, 0}, {nat(nil), nat(nil), 0},
{nat{0}, nat{0}, 0}, {nat{0}, nat{0}, 0},
{nat{0}, nat{1}, -1}, {nat{0}, nat{1}, -1},
{nat{1}, nat{0}, 1}, {nat{1}, nat{0}, 1},
@ -67,7 +67,7 @@ var prodNN = []argNN{
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
for _, a := range sumNN { for _, a := range sumNN {
z := nat{}.set(a.z) z := nat(nil).set(a.z)
if z.cmp(a.z) != 0 { if z.cmp(a.z) != 0 {
t.Errorf("got z = %v; want %v", z, a.z) t.Errorf("got z = %v; want %v", z, a.z)
} }
@ -129,7 +129,7 @@ var mulRangesN = []struct {
func TestMulRangeN(t *testing.T) { func TestMulRangeN(t *testing.T) {
for i, r := range mulRangesN { for i, r := range mulRangesN {
prod := nat{}.mulRange(r.a, r.b).decimalString() prod := nat(nil).mulRange(r.a, r.b).decimalString()
if prod != r.prod { if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod) t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
} }
@ -175,7 +175,7 @@ func toString(x nat, charset string) string {
s := make([]byte, i) s := make([]byte, i)
// don't destroy x // don't destroy x
q := nat{}.set(x) q := nat(nil).set(x)
// convert // convert
for len(q) > 0 { for len(q) > 0 {
@ -212,7 +212,7 @@ func TestString(t *testing.T) {
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
} }
x, b, err := nat{}.scan(strings.NewReader(a.s), len(a.c)) x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
if x.cmp(a.x) != 0 { if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
} }
@ -271,7 +271,7 @@ var natScanTests = []struct {
func TestScanBase(t *testing.T) { func TestScanBase(t *testing.T) {
for _, a := range natScanTests { for _, a := range natScanTests {
r := strings.NewReader(a.s) r := strings.NewReader(a.s)
x, b, err := nat{}.scan(r, a.base) x, b, err := nat(nil).scan(r, a.base)
if err == nil && !a.ok { if err == nil && !a.ok {
t.Errorf("scan%+v\n\texpected error", a) t.Errorf("scan%+v\n\texpected error", a)
} }
@ -651,17 +651,17 @@ var expNNTests = []struct {
func TestExpNN(t *testing.T) { func TestExpNN(t *testing.T) {
for i, test := range expNNTests { for i, test := range expNNTests {
x, _, _ := nat{}.scan(strings.NewReader(test.x), 0) x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
y, _, _ := nat{}.scan(strings.NewReader(test.y), 0) y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
out, _, _ := nat{}.scan(strings.NewReader(test.out), 0) out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
var m nat var m nat
if len(test.m) > 0 { if len(test.m) > 0 {
m, _, _ = nat{}.scan(strings.NewReader(test.m), 0) m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
} }
z := nat{}.expNN(x, y, m) z := nat(nil).expNN(x, y, m)
if z.cmp(out) != 0 { if z.cmp(out) != 0 {
t.Errorf("#%d got %v want %v", i, z, out) t.Errorf("#%d got %v want %v", i, z, out)
} }

View File

@ -33,7 +33,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
panic("division by zero") panic("division by zero")
} }
if &z.a == b || alias(z.a.abs, babs) { if &z.a == b || alias(z.a.abs, babs) {
babs = nat{}.set(babs) // make a copy babs = nat(nil).set(babs) // make a copy
} }
z.a.abs = z.a.abs.set(a.abs) z.a.abs = z.a.abs.set(a.abs)
z.b = z.b.set(babs) z.b = z.b.set(babs)
@ -315,7 +315,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
if _, ok := z.a.SetString(s, 10); !ok { if _, ok := z.a.SetString(s, 10); !ok {
return nil, false return nil, false
} }
powTen := nat{}.expNN(natTen, exp.abs, nil) powTen := nat(nil).expNN(natTen, exp.abs, nil)
if exp.neg { if exp.neg {
z.b = powTen z.b = powTen
z.norm() z.norm()
@ -357,23 +357,23 @@ func (z *Rat) FloatString(prec int) string {
} }
// z.b != 0 // z.b != 0
q, r := nat{}.div(nat{}, z.a.abs, z.b) q, r := nat(nil).div(nat(nil), z.a.abs, z.b)
p := natOne p := natOne
if prec > 0 { if prec > 0 {
p = nat{}.expNN(natTen, nat{}.setUint64(uint64(prec)), nil) p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil)
} }
r = r.mul(r, p) r = r.mul(r, p)
r, r2 := r.div(nat{}, r, z.b) r, r2 := r.div(nat(nil), r, z.b)
// see if we need to round up // see if we need to round up
r2 = r2.add(r2, r2) r2 = r2.add(r2, r2)
if z.b.cmp(r2) <= 0 { if z.b.cmp(r2) <= 0 {
r = r.add(r, natOne) r = r.add(r, natOne)
if r.cmp(p) >= 0 { if r.cmp(p) >= 0 {
q = nat{}.add(q, natOne) q = nat(nil).add(q, natOne)
r = nat{}.sub(r, p) r = nat(nil).sub(r, p)
} }
} }

View File

@ -63,7 +63,7 @@ package math
// Stephen L. Moshier // Stephen L. Moshier
// moshier@na-net.ornl.gov // moshier@na-net.ornl.gov
var _P = [...]float64{ var _gamP = [...]float64{
1.60119522476751861407e-04, 1.60119522476751861407e-04,
1.19135147006586384913e-03, 1.19135147006586384913e-03,
1.04213797561761569935e-02, 1.04213797561761569935e-02,
@ -72,7 +72,7 @@ var _P = [...]float64{
4.94214826801497100753e-01, 4.94214826801497100753e-01,
9.99999999999999996796e-01, 9.99999999999999996796e-01,
} }
var _Q = [...]float64{ var _gamQ = [...]float64{
-2.31581873324120129819e-05, -2.31581873324120129819e-05,
5.39605580493303397842e-04, 5.39605580493303397842e-04,
-4.45641913851797240494e-03, -4.45641913851797240494e-03,
@ -82,7 +82,7 @@ var _Q = [...]float64{
7.14304917030273074085e-02, 7.14304917030273074085e-02,
1.00000000000000000320e+00, 1.00000000000000000320e+00,
} }
var _S = [...]float64{ var _gamS = [...]float64{
7.87311395793093628397e-04, 7.87311395793093628397e-04,
-2.29549961613378126380e-04, -2.29549961613378126380e-04,
-2.68132617805781232825e-03, -2.68132617805781232825e-03,
@ -98,7 +98,7 @@ func stirling(x float64) float64 {
MaxStirling = 143.01608 MaxStirling = 143.01608
) )
w := 1 / x w := 1 / x
w = 1 + w*((((_S[0]*w+_S[1])*w+_S[2])*w+_S[3])*w+_S[4]) w = 1 + w*((((_gamS[0]*w+_gamS[1])*w+_gamS[2])*w+_gamS[3])*w+_gamS[4])
y := Exp(x) y := Exp(x)
if x > MaxStirling { // avoid Pow() overflow if x > MaxStirling { // avoid Pow() overflow
v := Pow(x, 0.5*x-0.25) v := Pow(x, 0.5*x-0.25)
@ -176,8 +176,8 @@ func Gamma(x float64) float64 {
} }
x = x - 2 x = x - 2
p = (((((x*_P[0]+_P[1])*x+_P[2])*x+_P[3])*x+_P[4])*x+_P[5])*x + _P[6] p = (((((x*_gamP[0]+_gamP[1])*x+_gamP[2])*x+_gamP[3])*x+_gamP[4])*x+_gamP[5])*x + _gamP[6]
q = ((((((x*_Q[0]+_Q[1])*x+_Q[2])*x+_Q[3])*x+_Q[4])*x+_Q[5])*x+_Q[6])*x + _Q[7] q = ((((((x*_gamQ[0]+_gamQ[1])*x+_gamQ[2])*x+_gamQ[3])*x+_gamQ[4])*x+_gamQ[5])*x+_gamQ[6])*x + _gamQ[7]
return z * p / q return z * p / q
small: small:

View File

@ -88,6 +88,81 @@ package math
// //
// //
var _lgamA = [...]float64{
7.72156649015328655494e-02, // 0x3FB3C467E37DB0C8
3.22467033424113591611e-01, // 0x3FD4A34CC4A60FAD
6.73523010531292681824e-02, // 0x3FB13E001A5562A7
2.05808084325167332806e-02, // 0x3F951322AC92547B
7.38555086081402883957e-03, // 0x3F7E404FB68FEFE8
2.89051383673415629091e-03, // 0x3F67ADD8CCB7926B
1.19270763183362067845e-03, // 0x3F538A94116F3F5D
5.10069792153511336608e-04, // 0x3F40B6C689B99C00
2.20862790713908385557e-04, // 0x3F2CF2ECED10E54D
1.08011567247583939954e-04, // 0x3F1C5088987DFB07
2.52144565451257326939e-05, // 0x3EFA7074428CFA52
4.48640949618915160150e-05, // 0x3F07858E90A45837
}
var _lgamR = [...]float64{
1.0, // placeholder
1.39200533467621045958e+00, // 0x3FF645A762C4AB74
7.21935547567138069525e-01, // 0x3FE71A1893D3DCDC
1.71933865632803078993e-01, // 0x3FC601EDCCFBDF27
1.86459191715652901344e-02, // 0x3F9317EA742ED475
7.77942496381893596434e-04, // 0x3F497DDACA41A95B
7.32668430744625636189e-06, // 0x3EDEBAF7A5B38140
}
var _lgamS = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
2.14982415960608852501e-01, // 0x3FCB848B36E20878
3.25778796408930981787e-01, // 0x3FD4D98F4F139F59
1.46350472652464452805e-01, // 0x3FC2BB9CBEE5F2F7
2.66422703033638609560e-02, // 0x3F9B481C7E939961
1.84028451407337715652e-03, // 0x3F5E26B67368F239
3.19475326584100867617e-05, // 0x3F00BFECDD17E945
}
var _lgamT = [...]float64{
4.83836122723810047042e-01, // 0x3FDEF72BC8EE38A2
-1.47587722994593911752e-01, // 0xBFC2E4278DC6C509
6.46249402391333854778e-02, // 0x3FB08B4294D5419B
-3.27885410759859649565e-02, // 0xBFA0C9A8DF35B713
1.79706750811820387126e-02, // 0x3F9266E7970AF9EC
-1.03142241298341437450e-02, // 0xBF851F9FBA91EC6A
6.10053870246291332635e-03, // 0x3F78FCE0E370E344
-3.68452016781138256760e-03, // 0xBF6E2EFFB3E914D7
2.25964780900612472250e-03, // 0x3F6282D32E15C915
-1.40346469989232843813e-03, // 0xBF56FE8EBF2D1AF1
8.81081882437654011382e-04, // 0x3F4CDF0CEF61A8E9
-5.38595305356740546715e-04, // 0xBF41A6109C73E0EC
3.15632070903625950361e-04, // 0x3F34AF6D6C0EBBF7
-3.12754168375120860518e-04, // 0xBF347F24ECC38C38
3.35529192635519073543e-04, // 0x3F35FD3EE8C2D3F4
}
var _lgamU = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
6.32827064025093366517e-01, // 0x3FE4401E8B005DFF
1.45492250137234768737e+00, // 0x3FF7475CD119BD6F
9.77717527963372745603e-01, // 0x3FEF497644EA8450
2.28963728064692451092e-01, // 0x3FCD4EAEF6010924
1.33810918536787660377e-02, // 0x3F8B678BBF2BAB09
}
var _lgamV = [...]float64{
1.0,
2.45597793713041134822e+00, // 0x4003A5D7C2BD619C
2.12848976379893395361e+00, // 0x40010725A42B18F5
7.69285150456672783825e-01, // 0x3FE89DFBE45050AF
1.04222645593369134254e-01, // 0x3FBAAE55D6537C88
3.21709242282423911810e-03, // 0x3F6A5ABB57D0CF61
}
var _lgamW = [...]float64{
4.18938533204672725052e-01, // 0x3FDACFE390C97D69
8.33333333333329678849e-02, // 0x3FB555555555553B
-2.77777777728775536470e-03, // 0xBF66C16C16B02E5C
7.93650558643019558500e-04, // 0x3F4A019F98CF38B6
-5.95187557450339963135e-04, // 0xBF4380CB8C0FE741
8.36339918996282139126e-04, // 0x3F4B67BA4CDAD5D1
-1.63092934096575273989e-03, // 0xBF5AB89D0B9E43E4
}
// Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x). // Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x).
// //
// Special cases are: // Special cases are:
@ -103,68 +178,10 @@ func Lgamma(x float64) (lgamma float64, sign int) {
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15 Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17 Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17
Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22 Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22
A0 = 7.72156649015328655494e-02 // 0x3FB3C467E37DB0C8
A1 = 3.22467033424113591611e-01 // 0x3FD4A34CC4A60FAD
A2 = 6.73523010531292681824e-02 // 0x3FB13E001A5562A7
A3 = 2.05808084325167332806e-02 // 0x3F951322AC92547B
A4 = 7.38555086081402883957e-03 // 0x3F7E404FB68FEFE8
A5 = 2.89051383673415629091e-03 // 0x3F67ADD8CCB7926B
A6 = 1.19270763183362067845e-03 // 0x3F538A94116F3F5D
A7 = 5.10069792153511336608e-04 // 0x3F40B6C689B99C00
A8 = 2.20862790713908385557e-04 // 0x3F2CF2ECED10E54D
A9 = 1.08011567247583939954e-04 // 0x3F1C5088987DFB07
A10 = 2.52144565451257326939e-05 // 0x3EFA7074428CFA52
A11 = 4.48640949618915160150e-05 // 0x3F07858E90A45837
Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F
Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42 Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42
// Tt = -(tail of Tf) // Tt = -(tail of Tf)
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
T0 = 4.83836122723810047042e-01 // 0x3FDEF72BC8EE38A2
T1 = -1.47587722994593911752e-01 // 0xBFC2E4278DC6C509
T2 = 6.46249402391333854778e-02 // 0x3FB08B4294D5419B
T3 = -3.27885410759859649565e-02 // 0xBFA0C9A8DF35B713
T4 = 1.79706750811820387126e-02 // 0x3F9266E7970AF9EC
T5 = -1.03142241298341437450e-02 // 0xBF851F9FBA91EC6A
T6 = 6.10053870246291332635e-03 // 0x3F78FCE0E370E344
T7 = -3.68452016781138256760e-03 // 0xBF6E2EFFB3E914D7
T8 = 2.25964780900612472250e-03 // 0x3F6282D32E15C915
T9 = -1.40346469989232843813e-03 // 0xBF56FE8EBF2D1AF1
T10 = 8.81081882437654011382e-04 // 0x3F4CDF0CEF61A8E9
T11 = -5.38595305356740546715e-04 // 0xBF41A6109C73E0EC
T12 = 3.15632070903625950361e-04 // 0x3F34AF6D6C0EBBF7
T13 = -3.12754168375120860518e-04 // 0xBF347F24ECC38C38
T14 = 3.35529192635519073543e-04 // 0x3F35FD3EE8C2D3F4
U0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
U1 = 6.32827064025093366517e-01 // 0x3FE4401E8B005DFF
U2 = 1.45492250137234768737e+00 // 0x3FF7475CD119BD6F
U3 = 9.77717527963372745603e-01 // 0x3FEF497644EA8450
U4 = 2.28963728064692451092e-01 // 0x3FCD4EAEF6010924
U5 = 1.33810918536787660377e-02 // 0x3F8B678BBF2BAB09
V1 = 2.45597793713041134822e+00 // 0x4003A5D7C2BD619C
V2 = 2.12848976379893395361e+00 // 0x40010725A42B18F5
V3 = 7.69285150456672783825e-01 // 0x3FE89DFBE45050AF
V4 = 1.04222645593369134254e-01 // 0x3FBAAE55D6537C88
V5 = 3.21709242282423911810e-03 // 0x3F6A5ABB57D0CF61
S0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
S1 = 2.14982415960608852501e-01 // 0x3FCB848B36E20878
S2 = 3.25778796408930981787e-01 // 0x3FD4D98F4F139F59
S3 = 1.46350472652464452805e-01 // 0x3FC2BB9CBEE5F2F7
S4 = 2.66422703033638609560e-02 // 0x3F9B481C7E939961
S5 = 1.84028451407337715652e-03 // 0x3F5E26B67368F239
S6 = 3.19475326584100867617e-05 // 0x3F00BFECDD17E945
R1 = 1.39200533467621045958e+00 // 0x3FF645A762C4AB74
R2 = 7.21935547567138069525e-01 // 0x3FE71A1893D3DCDC
R3 = 1.71933865632803078993e-01 // 0x3FC601EDCCFBDF27
R4 = 1.86459191715652901344e-02 // 0x3F9317EA742ED475
R5 = 7.77942496381893596434e-04 // 0x3F497DDACA41A95B
R6 = 7.32668430744625636189e-06 // 0x3EDEBAF7A5B38140
W0 = 4.18938533204672725052e-01 // 0x3FDACFE390C97D69
W1 = 8.33333333333329678849e-02 // 0x3FB555555555553B
W2 = -2.77777777728775536470e-03 // 0xBF66C16C16B02E5C
W3 = 7.93650558643019558500e-04 // 0x3F4A019F98CF38B6
W4 = -5.95187557450339963135e-04 // 0xBF4380CB8C0FE741
W5 = 8.36339918996282139126e-04 // 0x3F4B67BA4CDAD5D1
W6 = -1.63092934096575273989e-03 // 0xBF5AB89D0B9E43E4
) )
// TODO(rsc): Remove manual inlining of IsNaN, IsInf // TODO(rsc): Remove manual inlining of IsNaN, IsInf
// when compiler does it for us // when compiler does it for us
@ -249,28 +266,28 @@ func Lgamma(x float64) (lgamma float64, sign int) {
switch i { switch i {
case 0: case 0:
z := y * y z := y * y
p1 := A0 + z*(A2+z*(A4+z*(A6+z*(A8+z*A10)))) p1 := _lgamA[0] + z*(_lgamA[2]+z*(_lgamA[4]+z*(_lgamA[6]+z*(_lgamA[8]+z*_lgamA[10]))))
p2 := z * (A1 + z*(A3+z*(A5+z*(A7+z*(A9+z*A11))))) p2 := z * (_lgamA[1] + z*(+_lgamA[3]+z*(_lgamA[5]+z*(_lgamA[7]+z*(_lgamA[9]+z*_lgamA[11])))))
p := y*p1 + p2 p := y*p1 + p2
lgamma += (p - 0.5*y) lgamma += (p - 0.5*y)
case 1: case 1:
z := y * y z := y * y
w := z * y w := z * y
p1 := T0 + w*(T3+w*(T6+w*(T9+w*T12))) // parallel comp p1 := _lgamT[0] + w*(_lgamT[3]+w*(_lgamT[6]+w*(_lgamT[9]+w*_lgamT[12]))) // parallel comp
p2 := T1 + w*(T4+w*(T7+w*(T10+w*T13))) p2 := _lgamT[1] + w*(_lgamT[4]+w*(_lgamT[7]+w*(_lgamT[10]+w*_lgamT[13])))
p3 := T2 + w*(T5+w*(T8+w*(T11+w*T14))) p3 := _lgamT[2] + w*(_lgamT[5]+w*(_lgamT[8]+w*(_lgamT[11]+w*_lgamT[14])))
p := z*p1 - (Tt - w*(p2+y*p3)) p := z*p1 - (Tt - w*(p2+y*p3))
lgamma += (Tf + p) lgamma += (Tf + p)
case 2: case 2:
p1 := y * (U0 + y*(U1+y*(U2+y*(U3+y*(U4+y*U5))))) p1 := y * (_lgamU[0] + y*(_lgamU[1]+y*(_lgamU[2]+y*(_lgamU[3]+y*(_lgamU[4]+y*_lgamU[5])))))
p2 := 1 + y*(V1+y*(V2+y*(V3+y*(V4+y*V5)))) p2 := 1 + y*(_lgamV[1]+y*(_lgamV[2]+y*(_lgamV[3]+y*(_lgamV[4]+y*_lgamV[5]))))
lgamma += (-0.5*y + p1/p2) lgamma += (-0.5*y + p1/p2)
} }
case x < 8: // 2 <= x < 8 case x < 8: // 2 <= x < 8
i := int(x) i := int(x)
y := x - float64(i) y := x - float64(i)
p := y * (S0 + y*(S1+y*(S2+y*(S3+y*(S4+y*(S5+y*S6)))))) p := y * (_lgamS[0] + y*(_lgamS[1]+y*(_lgamS[2]+y*(_lgamS[3]+y*(_lgamS[4]+y*(_lgamS[5]+y*_lgamS[6]))))))
q := 1 + y*(R1+y*(R2+y*(R3+y*(R4+y*(R5+y*R6))))) q := 1 + y*(_lgamR[1]+y*(_lgamR[2]+y*(_lgamR[3]+y*(_lgamR[4]+y*(_lgamR[5]+y*_lgamR[6])))))
lgamma = 0.5*y + p/q lgamma = 0.5*y + p/q
z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s) z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s)
switch i { switch i {
@ -294,7 +311,7 @@ func Lgamma(x float64) (lgamma float64, sign int) {
t := Log(x) t := Log(x)
z := 1 / x z := 1 / x
y := z * z y := z * z
w := W0 + z*(W1+y*(W2+y*(W3+y*(W4+y*(W5+y*W6))))) w := _lgamW[0] + z*(_lgamW[1]+y*(_lgamW[2]+y*(_lgamW[3]+y*(_lgamW[4]+y*(_lgamW[5]+y*_lgamW[6])))))
lgamma = (x-0.5)*(t-1) + w lgamma = (x-0.5)*(t-1) + w
default: // 2**58 <= x <= Inf default: // 2**58 <= x <= Inf
lgamma = x * (Log(x) - 1) lgamma = x * (Log(x) - 1)

View File

@ -160,7 +160,7 @@ type sliceReaderAt []byte
func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, error) { func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, error) {
if int(off) >= len(r) || off < 0 { if int(off) >= len(r) || off < 0 {
return 0, os.EINVAL return 0, io.ErrUnexpectedEOF
} }
n := copy(b, r[int(off):]) n := copy(b, r[int(off):])
return n, nil return n, nil

View File

@ -6,19 +6,11 @@
package mime package mime
import ( import (
"bufio"
"fmt" "fmt"
"os"
"strings" "strings"
"sync" "sync"
) )
var typeFiles = []string{
"/etc/mime.types",
"/etc/apache2/mime.types",
"/etc/apache/mime.types",
}
var mimeTypes = map[string]string{ var mimeTypes = map[string]string{
".css": "text/css; charset=utf-8", ".css": "text/css; charset=utf-8",
".gif": "image/gif", ".gif": "image/gif",
@ -33,46 +25,13 @@ var mimeTypes = map[string]string{
var mimeLock sync.RWMutex var mimeLock sync.RWMutex
func loadMimeFile(filename string) {
f, err := os.Open(filename)
if err != nil {
return
}
reader := bufio.NewReader(f)
for {
line, err := reader.ReadString('\n')
if err != nil {
f.Close()
return
}
fields := strings.Fields(line)
if len(fields) <= 1 || fields[0][0] == '#' {
continue
}
mimeType := fields[0]
for _, ext := range fields[1:] {
if ext[0] == '#' {
break
}
setExtensionType("."+ext, mimeType)
}
}
}
func initMime() {
for _, filename := range typeFiles {
loadMimeFile(filename)
}
}
var once sync.Once var once sync.Once
// TypeByExtension returns the MIME type associated with the file extension ext. // TypeByExtension returns the MIME type associated with the file extension ext.
// The extension ext should begin with a leading dot, as in ".html". // The extension ext should begin with a leading dot, as in ".html".
// When ext has no associated type, TypeByExtension returns "". // When ext has no associated type, TypeByExtension returns "".
// //
// The built-in table is small but is is augmented by the local // The built-in table is small but on unix it is augmented by the local
// system's mime.types file(s) if available under one or more of these // system's mime.types file(s) if available under one or more of these
// names: // names:
// //
@ -80,6 +39,8 @@ var once sync.Once
// /etc/apache2/mime.types // /etc/apache2/mime.types
// /etc/apache/mime.types // /etc/apache/mime.types
// //
// Windows system mime types are extracted from registry.
//
// Text types have the charset parameter set to "utf-8" by default. // Text types have the charset parameter set to "utf-8" by default.
func TypeByExtension(ext string) string { func TypeByExtension(ext string) string {
once.Do(initMime) once.Do(initMime)

View File

@ -6,15 +6,9 @@ package mime
import "testing" import "testing"
var typeTests = map[string]string{ var typeTests = initMimeForTests()
".t1": "application/test",
".t2": "text/test; charset=utf-8",
".png": "image/png",
}
func TestTypeByExtension(t *testing.T) { func TestTypeByExtension(t *testing.T) {
typeFiles = []string{"test.types"}
for ext, want := range typeTests { for ext, want := range typeTests {
val := TypeByExtension(ext) val := TypeByExtension(ext)
if val != want { if val != want {

View File

@ -0,0 +1,59 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"bufio"
"os"
"strings"
)
var typeFiles = []string{
"/etc/mime.types",
"/etc/apache2/mime.types",
"/etc/apache/mime.types",
}
func loadMimeFile(filename string) {
f, err := os.Open(filename)
if err != nil {
return
}
reader := bufio.NewReader(f)
for {
line, err := reader.ReadString('\n')
if err != nil {
f.Close()
return
}
fields := strings.Fields(line)
if len(fields) <= 1 || fields[0][0] == '#' {
continue
}
mimeType := fields[0]
for _, ext := range fields[1:] {
if ext[0] == '#' {
break
}
setExtensionType("."+ext, mimeType)
}
}
}
func initMime() {
for _, filename := range typeFiles {
loadMimeFile(filename)
}
}
func initMimeForTests() map[string]string {
typeFiles = []string{"test.types"}
return map[string]string{
".t1": "application/test",
".t2": "text/test; charset=utf-8",
".png": "image/png",
}
}

View File

@ -0,0 +1,61 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"syscall"
"unsafe"
)
func initMime() {
var root syscall.Handle
if syscall.RegOpenKeyEx(syscall.HKEY_CLASSES_ROOT, syscall.StringToUTF16Ptr(`\`),
0, syscall.KEY_READ, &root) != 0 {
return
}
defer syscall.RegCloseKey(root)
var count uint32
if syscall.RegQueryInfoKey(root, nil, nil, nil, &count, nil, nil, nil, nil, nil, nil, nil) != 0 {
return
}
var buf [1 << 10]uint16
for i := uint32(0); i < count; i++ {
n := uint32(len(buf))
if syscall.RegEnumKeyEx(root, i, &buf[0], &n, nil, nil, nil, nil) != 0 {
continue
}
ext := syscall.UTF16ToString(buf[:])
if len(ext) < 2 || ext[0] != '.' { // looking for extensions only
continue
}
var h syscall.Handle
if syscall.RegOpenKeyEx(
syscall.HKEY_CLASSES_ROOT, syscall.StringToUTF16Ptr(`\`+ext),
0, syscall.KEY_READ, &h) != 0 {
continue
}
var typ uint32
n = uint32(len(buf) * 2) // api expects array of bytes, not uint16
if syscall.RegQueryValueEx(
h, syscall.StringToUTF16Ptr("Content Type"),
nil, &typ, (*byte)(unsafe.Pointer(&buf[0])), &n) != 0 {
syscall.RegCloseKey(h)
continue
}
syscall.RegCloseKey(h)
if typ != syscall.REG_SZ { // null terminated strings only
continue
}
mimeType := syscall.UTF16ToString(buf[:])
setExtensionType(ext, mimeType)
}
}
func initMimeForTests() map[string]string {
return map[string]string{
".bmp": "image/bmp",
".png": "image/png",
}
}

View File

@ -109,7 +109,7 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, complet
if gerrno == syscall.EAI_NONAME { if gerrno == syscall.EAI_NONAME {
str = noSuchHost str = noSuchHost
} else if gerrno == syscall.EAI_SYSTEM { } else if gerrno == syscall.EAI_SYSTEM {
str = syscall.Errstr(syscall.GetErrno()) str = syscall.GetErrno().Error()
} else { } else {
str = bytePtrToString(libc_gai_strerror(gerrno)) str = bytePtrToString(libc_gai_strerror(gerrno))
} }

View File

@ -278,8 +278,8 @@ func startServer() {
func newFD(fd, family, proto int, net string) (f *netFD, err error) { func newFD(fd, family, proto int, net string) (f *netFD, err error) {
onceStartServer.Do(startServer) onceStartServer.Do(startServer)
if e := syscall.SetNonblock(fd, true); e != 0 { if e := syscall.SetNonblock(fd, true); e != nil {
return nil, os.Errno(e) return nil, e
} }
f = &netFD{ f = &netFD{
sysfd: fd, sysfd: fd,
@ -306,19 +306,19 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
} }
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) { func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
e := syscall.Connect(fd.sysfd, ra) err = syscall.Connect(fd.sysfd, ra)
if e == syscall.EINPROGRESS { if err == syscall.EINPROGRESS {
var errno int
pollserver.WaitWrite(fd) pollserver.WaitWrite(fd)
e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) var e int
if errno != 0 { e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
return os.NewSyscallError("getsockopt", errno) if err != nil {
} return os.NewSyscallError("getsockopt", err)
} }
if e != 0 { if e != 0 {
return os.Errno(e) err = syscall.Errno(e)
} }
return nil }
return err
} }
// Add a reference to this fd. // Add a reference to this fd.
@ -362,9 +362,9 @@ func (fd *netFD) shutdown(how int) error {
if fd == nil || fd.sysfile == nil { if fd == nil || fd.sysfile == nil {
return os.EINVAL return os.EINVAL
} }
errno := syscall.Shutdown(fd.sysfd, how) err := syscall.Shutdown(fd.sysfd, how)
if errno != 0 { if err != nil {
return &OpError{"shutdown", fd.net, fd.laddr, os.Errno(errno)} return &OpError{"shutdown", fd.net, fd.laddr, err}
} }
return nil return nil
} }
@ -377,6 +377,14 @@ func (fd *netFD) CloseWrite() error {
return fd.shutdown(syscall.SHUT_WR) return fd.shutdown(syscall.SHUT_WR)
} }
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
var errTimeout error = &timeoutError{}
func (fd *netFD) Read(p []byte) (n int, err error) { func (fd *netFD) Read(p []byte) (n int, err error) {
if fd == nil { if fd == nil {
return 0, os.EINVAL return 0, os.EINVAL
@ -393,24 +401,24 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
} else { } else {
fd.rdeadline = 0 fd.rdeadline = 0
} }
var oserr error
for { for {
var errno int n, err = syscall.Read(fd.sysfile.Fd(), p)
n, errno = syscall.Read(fd.sysfile.Fd(), p) if err == syscall.EAGAIN {
if errno == syscall.EAGAIN && fd.rdeadline >= 0 { if fd.rdeadline >= 0 {
pollserver.WaitRead(fd) pollserver.WaitRead(fd)
continue continue
} }
if errno != 0 { err = errTimeout
}
if err != nil {
n = 0 n = 0
oserr = os.Errno(errno) } else if n == 0 && err == nil && fd.proto != syscall.SOCK_DGRAM {
} else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM {
err = io.EOF err = io.EOF
} }
break break
} }
if oserr != nil { if err != nil && err != io.EOF {
err = &OpError{"read", fd.net, fd.raddr, oserr} err = &OpError{"read", fd.net, fd.raddr, err}
} }
return return
} }
@ -428,22 +436,22 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
} else { } else {
fd.rdeadline = 0 fd.rdeadline = 0
} }
var oserr error
for { for {
var errno int n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0) if err == syscall.EAGAIN {
if errno == syscall.EAGAIN && fd.rdeadline >= 0 { if fd.rdeadline >= 0 {
pollserver.WaitRead(fd) pollserver.WaitRead(fd)
continue continue
} }
if errno != 0 { err = errTimeout
}
if err != nil {
n = 0 n = 0
oserr = os.Errno(errno)
} }
break break
} }
if oserr != nil { if err != nil {
err = &OpError{"read", fd.net, fd.laddr, oserr} err = &OpError{"read", fd.net, fd.laddr, err}
} }
return return
} }
@ -461,24 +469,22 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
} else { } else {
fd.rdeadline = 0 fd.rdeadline = 0
} }
var oserr error
for { for {
var errno int n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0) if err == syscall.EAGAIN {
if errno == syscall.EAGAIN && fd.rdeadline >= 0 { if fd.rdeadline >= 0 {
pollserver.WaitRead(fd) pollserver.WaitRead(fd)
continue continue
} }
if errno != 0 { err = errTimeout
oserr = os.Errno(errno)
} }
if n == 0 { if err == nil && n == 0 {
oserr = io.EOF err = io.EOF
} }
break break
} }
if oserr != nil { if err != nil && err != io.EOF {
err = &OpError{"read", fd.net, fd.laddr, oserr} err = &OpError{"read", fd.net, fd.laddr, err}
return return
} }
return return
@ -501,32 +507,34 @@ func (fd *netFD) Write(p []byte) (n int, err error) {
fd.wdeadline = 0 fd.wdeadline = 0
} }
nn := 0 nn := 0
var oserr error
for { for {
n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:]) var n int
n, err = syscall.Write(fd.sysfile.Fd(), p[nn:])
if n > 0 { if n > 0 {
nn += n nn += n
} }
if nn == len(p) { if nn == len(p) {
break break
} }
if errno == syscall.EAGAIN && fd.wdeadline >= 0 { if err == syscall.EAGAIN {
if fd.wdeadline >= 0 {
pollserver.WaitWrite(fd) pollserver.WaitWrite(fd)
continue continue
} }
if errno != 0 { err = errTimeout
}
if err != nil {
n = 0 n = 0
oserr = os.Errno(errno)
break break
} }
if n == 0 { if n == 0 {
oserr = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
break break
} }
} }
if oserr != nil { if err != nil {
err = &OpError{"write", fd.net, fd.raddr, oserr} err = &OpError{"write", fd.net, fd.raddr, err}
} }
return nn, err return nn, err
} }
@ -544,22 +552,21 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
} else { } else {
fd.wdeadline = 0 fd.wdeadline = 0
} }
var oserr error
for { for {
errno := syscall.Sendto(fd.sysfd, p, 0, sa) err = syscall.Sendto(fd.sysfd, p, 0, sa)
if errno == syscall.EAGAIN && fd.wdeadline >= 0 { if err == syscall.EAGAIN {
if fd.wdeadline >= 0 {
pollserver.WaitWrite(fd) pollserver.WaitWrite(fd)
continue continue
} }
if errno != 0 { err = errTimeout
oserr = os.Errno(errno)
} }
break break
} }
if oserr == nil { if err == nil {
n = len(p) n = len(p)
} else { } else {
err = &OpError{"write", fd.net, fd.raddr, oserr} err = &OpError{"write", fd.net, fd.raddr, err}
} }
return return
} }
@ -577,24 +584,22 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
} else { } else {
fd.wdeadline = 0 fd.wdeadline = 0
} }
var oserr error
for { for {
var errno int err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN {
if errno == syscall.EAGAIN && fd.wdeadline >= 0 { if fd.wdeadline >= 0 {
pollserver.WaitWrite(fd) pollserver.WaitWrite(fd)
continue continue
} }
if errno != 0 { err = errTimeout
oserr = os.Errno(errno)
} }
break break
} }
if oserr == nil { if err == nil {
n = len(p) n = len(p)
oobn = len(oob) oobn = len(oob)
} else { } else {
err = &OpError{"write", fd.net, fd.raddr, oserr} err = &OpError{"write", fd.net, fd.raddr, err}
} }
return return
} }
@ -615,25 +620,26 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
// See ../syscall/exec.go for description of ForkLock. // See ../syscall/exec.go for description of ForkLock.
// It is okay to hold the lock across syscall.Accept // It is okay to hold the lock across syscall.Accept
// because we have put fd.sysfd into non-blocking mode. // because we have put fd.sysfd into non-blocking mode.
syscall.ForkLock.RLock() var s int
var s, e int
var rsa syscall.Sockaddr var rsa syscall.Sockaddr
for { for {
if fd.closing { if fd.closing {
syscall.ForkLock.RUnlock()
return nil, os.EINVAL return nil, os.EINVAL
} }
s, rsa, e = syscall.Accept(fd.sysfd)
if e != syscall.EAGAIN || fd.rdeadline < 0 {
break
}
syscall.ForkLock.RUnlock()
pollserver.WaitRead(fd)
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
} s, rsa, err = syscall.Accept(fd.sysfd)
if e != 0 { if err != nil {
syscall.ForkLock.RUnlock() syscall.ForkLock.RUnlock()
return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)} if err == syscall.EAGAIN {
if fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
}
err = errTimeout
}
return nil, &OpError{"accept", fd.net, fd.laddr, err}
}
break
} }
syscall.CloseOnExec(s) syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock() syscall.ForkLock.RUnlock()
@ -648,19 +654,19 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
} }
func (fd *netFD) dup() (f *os.File, err error) { func (fd *netFD) dup() (f *os.File, err error) {
ns, e := syscall.Dup(fd.sysfd) ns, err := syscall.Dup(fd.sysfd)
if e != 0 { if err != nil {
return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)} return nil, &OpError{"dup", fd.net, fd.laddr, err}
} }
// We want blocking mode for the new fd, hence the double negative. // We want blocking mode for the new fd, hence the double negative.
if e = syscall.SetNonblock(ns, false); e != 0 { if err = syscall.SetNonblock(ns, false); err != nil {
return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)} return nil, &OpError{"setnonblock", fd.net, fd.laddr, err}
} }
return os.NewFile(ns, fd.sysfile.Name()), nil return os.NewFile(ns, fd.sysfile.Name()), nil
} }
func closesocket(s int) (errno int) { func closesocket(s int) error {
return syscall.Close(s) return syscall.Close(s)
} }

View File

@ -35,12 +35,12 @@ type pollster struct {
func newpollster() (p *pollster, err error) { func newpollster() (p *pollster, err error) {
p = new(pollster) p = new(pollster)
var e int var e error
// The arg to epoll_create is a hint to the kernel // The arg to epoll_create is a hint to the kernel
// about the number of FDs we will care about. // about the number of FDs we will care about.
// We don't know, and since 2.6.8 the kernel ignores it anyhow. // We don't know, and since 2.6.8 the kernel ignores it anyhow.
if p.epfd, e = syscall.EpollCreate(16); e != 0 { if p.epfd, e = syscall.EpollCreate(16); e != nil {
return nil, os.NewSyscallError("epoll_create", e) return nil, os.NewSyscallError("epoll_create", e)
} }
p.events = make(map[int]uint32) p.events = make(map[int]uint32)
@ -68,7 +68,7 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
} else { } else {
op = syscall.EPOLL_CTL_ADD op = syscall.EPOLL_CTL_ADD
} }
if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 { if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != nil {
return false, os.NewSyscallError("epoll_ctl", e) return false, os.NewSyscallError("epoll_ctl", e)
} }
p.events[fd] = p.ctlEvent.Events p.events[fd] = p.ctlEvent.Events
@ -97,13 +97,13 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
if int32(events)&^syscall.EPOLLONESHOT != 0 { if int32(events)&^syscall.EPOLLONESHOT != 0 {
p.ctlEvent.Fd = int32(fd) p.ctlEvent.Fd = int32(fd)
p.ctlEvent.Events = events p.ctlEvent.Events = events
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 { if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != nil {
print("Epoll modify fd=", fd, ": ", os.Errno(e).Error(), "\n") print("Epoll modify fd=", fd, ": ", e.Error(), "\n")
} }
p.events[fd] = events p.events[fd] = events
} else { } else {
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 { if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != nil {
print("Epoll delete fd=", fd, ": ", os.Errno(e).Error(), "\n") print("Epoll delete fd=", fd, ": ", e.Error(), "\n")
} }
delete(p.events, fd) delete(p.events, fd)
} }
@ -141,7 +141,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err erro
n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec) n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec)
s.Lock() s.Lock()
if e != 0 { if e != nil {
if e == syscall.EAGAIN || e == syscall.EINTR { if e == syscall.EAGAIN || e == syscall.EINTR {
continue continue
} }

View File

@ -23,9 +23,8 @@ type pollster struct {
func newpollster() (p *pollster, err error) { func newpollster() (p *pollster, err error) {
p = new(pollster) p = new(pollster)
var e int if p.kq, err = syscall.Kqueue(); err != nil {
if p.kq, e = syscall.Kqueue(); e != 0 { return nil, os.NewSyscallError("kqueue", err)
return nil, os.NewSyscallError("kqueue", e)
} }
p.events = p.eventbuf[0:0] p.events = p.eventbuf[0:0]
return p, nil return p, nil
@ -50,14 +49,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
syscall.SetKevent(ev, fd, kmode, flags) syscall.SetKevent(ev, fd, kmode, flags)
n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
if e != 0 { if e != nil {
return false, os.NewSyscallError("kevent", e) return false, os.NewSyscallError("kevent", e)
} }
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
return false, os.NewSyscallError("kqueue phase error", e) return false, os.NewSyscallError("kqueue phase error", e)
} }
if ev.Data != 0 { if ev.Data != 0 {
return false, os.Errno(int(ev.Data)) return false, syscall.Errno(int(ev.Data))
} }
return false, nil return false, nil
} }
@ -91,7 +90,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err erro
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
s.Lock() s.Lock()
if e != 0 { if e != nil {
if e == syscall.EINTR { if e == syscall.EINTR {
continue continue
} }

View File

@ -26,11 +26,11 @@ func init() {
var d syscall.WSAData var d syscall.WSAData
e := syscall.WSAStartup(uint32(0x202), &d) e := syscall.WSAStartup(uint32(0x202), &d)
if e != 0 { if e != 0 {
initErr = os.NewSyscallError("WSAStartup", e) initErr = os.NewSyscallError("WSAStartup", syscall.Errno(e))
} }
} }
func closesocket(s syscall.Handle) (errno int) { func closesocket(s syscall.Handle) (err error) {
return syscall.Closesocket(s) return syscall.Closesocket(s)
} }
@ -38,13 +38,13 @@ func closesocket(s syscall.Handle) (errno int) {
type anOpIface interface { type anOpIface interface {
Op() *anOp Op() *anOp
Name() string Name() string
Submit() (errno int) Submit() (err error)
} }
// IO completion result parameters. // IO completion result parameters.
type ioResult struct { type ioResult struct {
qty uint32 qty uint32
err int err error
} }
// anOp implements functionality common to all io operations. // anOp implements functionality common to all io operations.
@ -54,7 +54,7 @@ type anOp struct {
o syscall.Overlapped o syscall.Overlapped
resultc chan ioResult resultc chan ioResult
errnoc chan int errnoc chan error
fd *netFD fd *netFD
} }
@ -71,7 +71,7 @@ func (o *anOp) Init(fd *netFD, mode int) {
} }
o.resultc = fd.resultc[i] o.resultc = fd.resultc[i]
if fd.errnoc[i] == nil { if fd.errnoc[i] == nil {
fd.errnoc[i] = make(chan int) fd.errnoc[i] = make(chan error)
} }
o.errnoc = fd.errnoc[i] o.errnoc = fd.errnoc[i]
} }
@ -111,14 +111,14 @@ func (s *resultSrv) Run() {
for { for {
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
switch { switch {
case r.err == 0: case r.err == nil:
// Dequeued successfully completed io packet. // Dequeued successfully completed io packet.
case r.err == syscall.WAIT_TIMEOUT && o == nil: case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
// Wait has timed out (should not happen now, but might be used in the future). // Wait has timed out (should not happen now, but might be used in the future).
panic("GetQueuedCompletionStatus timed out") panic("GetQueuedCompletionStatus timed out")
case o == nil: case o == nil:
// Failed to dequeue anything -> report the error. // Failed to dequeue anything -> report the error.
panic("GetQueuedCompletionStatus failed " + syscall.Errstr(r.err)) panic("GetQueuedCompletionStatus failed " + r.err.Error())
default: default:
// Dequeued failed io packet. // Dequeued failed io packet.
} }
@ -153,7 +153,7 @@ func (s *ioSrv) ProcessRemoteIO() {
// inline, or, if timeouts are employed, passes the request onto // inline, or, if timeouts are employed, passes the request onto
// a special goroutine and waits for completion or cancels request. // a special goroutine and waits for completion or cancels request.
func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) { func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
var e int var e error
o := oi.Op() o := oi.Op()
if deadline_delta > 0 { if deadline_delta > 0 {
// Send request to a special dedicated thread, // Send request to a special dedicated thread,
@ -164,12 +164,12 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
e = oi.Submit() e = oi.Submit()
} }
switch e { switch e {
case 0: case nil:
// IO completed immediately, but we need to get our completion message anyway. // IO completed immediately, but we need to get our completion message anyway.
case syscall.ERROR_IO_PENDING: case syscall.ERROR_IO_PENDING:
// IO started, and we have to wait for its completion. // IO started, and we have to wait for its completion.
default: default:
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(e)} return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, e}
} }
// Wait for our request to complete. // Wait for our request to complete.
var r ioResult var r ioResult
@ -187,8 +187,8 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
} else { } else {
r = <-o.resultc r = <-o.resultc
} }
if r.err != 0 { if r.err != nil {
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(r.err)} err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
} }
return int(r.qty), err return int(r.qty), err
} }
@ -200,10 +200,10 @@ var onceStartServer sync.Once
func startServer() { func startServer() {
resultsrv = new(resultSrv) resultsrv = new(resultSrv)
var errno int var err error
resultsrv.iocp, errno = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1) resultsrv.iocp, err = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
if errno != 0 { if err != nil {
panic("CreateIoCompletionPort failed " + syscall.Errstr(errno)) panic("CreateIoCompletionPort: " + err.Error())
} }
go resultsrv.Run() go resultsrv.Run()
@ -228,7 +228,7 @@ type netFD struct {
laddr Addr laddr Addr
raddr Addr raddr Addr
resultc [2]chan ioResult // read/write completion results resultc [2]chan ioResult // read/write completion results
errnoc [2]chan int // read/write submit or cancel operation errors errnoc [2]chan error // read/write submit or cancel operation errors
// owned by client // owned by client
rdeadline_delta int64 rdeadline_delta int64
@ -256,8 +256,8 @@ func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err erro
} }
onceStartServer.Do(startServer) onceStartServer.Do(startServer)
// Associate our socket with resultsrv.iocp. // Associate our socket with resultsrv.iocp.
if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != 0 { if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != nil {
return nil, os.Errno(e) return nil, e
} }
return allocFD(fd, family, proto, net), nil return allocFD(fd, family, proto, net), nil
} }
@ -268,11 +268,7 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
} }
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) { func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
e := syscall.Connect(fd.sysfd, ra) return syscall.Connect(fd.sysfd, ra)
if e != 0 {
return os.Errno(e)
}
return nil
} }
// Add a reference to this fd. // Add a reference to this fd.
@ -317,9 +313,9 @@ func (fd *netFD) shutdown(how int) error {
if fd == nil || fd.sysfd == syscall.InvalidHandle { if fd == nil || fd.sysfd == syscall.InvalidHandle {
return os.EINVAL return os.EINVAL
} }
errno := syscall.Shutdown(fd.sysfd, how) err := syscall.Shutdown(fd.sysfd, how)
if errno != 0 { if err != nil {
return &OpError{"shutdown", fd.net, fd.laddr, os.Errno(errno)} return &OpError{"shutdown", fd.net, fd.laddr, err}
} }
return nil return nil
} }
@ -338,7 +334,7 @@ type readOp struct {
bufOp bufOp
} }
func (o *readOp) Submit() (errno int) { func (o *readOp) Submit() (err error) {
var d, f uint32 var d, f uint32
return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil) return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil)
} }
@ -375,7 +371,7 @@ type readFromOp struct {
rsan int32 rsan int32
} }
func (o *readFromOp) Submit() (errno int) { func (o *readFromOp) Submit() (err error) {
var d, f uint32 var d, f uint32
return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil) return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil)
} }
@ -415,7 +411,7 @@ type writeOp struct {
bufOp bufOp
} }
func (o *writeOp) Submit() (errno int) { func (o *writeOp) Submit() (err error) {
var d uint32 var d uint32
return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil) return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil)
} }
@ -447,7 +443,7 @@ type writeToOp struct {
sa syscall.Sockaddr sa syscall.Sockaddr
} }
func (o *writeToOp) Submit() (errno int) { func (o *writeToOp) Submit() (err error) {
var d uint32 var d uint32
return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil) return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil)
} }
@ -484,7 +480,7 @@ type acceptOp struct {
attrs [2]syscall.RawSockaddrAny // space for local and remote address only attrs [2]syscall.RawSockaddrAny // space for local and remote address only
} }
func (o *acceptOp) Submit() (errno int) { func (o *acceptOp) Submit() (err error) {
var d uint32 var d uint32
l := uint32(unsafe.Sizeof(o.attrs[0])) l := uint32(unsafe.Sizeof(o.attrs[0]))
return syscall.AcceptEx(o.fd.sysfd, o.newsock, return syscall.AcceptEx(o.fd.sysfd, o.newsock,
@ -506,17 +502,17 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
// See ../syscall/exec.go for description of ForkLock. // See ../syscall/exec.go for description of ForkLock.
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
s, e := syscall.Socket(fd.family, fd.proto, 0) s, e := syscall.Socket(fd.family, fd.proto, 0)
if e != 0 { if e != nil {
syscall.ForkLock.RUnlock() syscall.ForkLock.RUnlock()
return nil, os.Errno(e) return nil, e
} }
syscall.CloseOnExec(s) syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock() syscall.ForkLock.RUnlock()
// Associate our new socket with IOCP. // Associate our new socket with IOCP.
onceStartServer.Do(startServer) onceStartServer.Do(startServer)
if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != 0 { if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != nil {
return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)} return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, e}
} }
// Submit accept request. // Submit accept request.
@ -531,9 +527,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
// Inherit properties of the listening socket. // Inherit properties of the listening socket.
e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
if e != 0 { if e != nil {
closesocket(s) closesocket(s)
return nil, err return nil, e
} }
// Get local and peer addr out of AcceptEx buffer. // Get local and peer addr out of AcceptEx buffer.

View File

@ -13,12 +13,12 @@ import (
func newFileFD(f *os.File) (nfd *netFD, err error) { func newFileFD(f *os.File) (nfd *netFD, err error) {
fd, errno := syscall.Dup(f.Fd()) fd, errno := syscall.Dup(f.Fd())
if errno != 0 { if errno != nil {
return nil, os.NewSyscallError("dup", errno) return nil, os.NewSyscallError("dup", errno)
} }
proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
if errno != 0 { if errno != nil {
return nil, os.NewSyscallError("getsockopt", errno) return nil, os.NewSyscallError("getsockopt", errno)
} }

View File

@ -7,8 +7,8 @@
package net package net
import ( import (
"os"
"sync" "sync"
"time"
) )
const cacheMaxAge = int64(300) // 5 minutes. const cacheMaxAge = int64(300) // 5 minutes.
@ -26,7 +26,7 @@ var hosts struct {
} }
func readHosts() { func readHosts() {
now, _, _ := os.Time() now := time.Seconds()
hp := hostsPath hp := hostsPath
if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp { if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp {
hs := make(map[string][]string) hs := make(map[string][]string)
@ -51,7 +51,7 @@ func readHosts() {
} }
} }
// Update the data cache. // Update the data cache.
hosts.time, _, _ = os.Time() hosts.time = time.Seconds()
hosts.path = hp hosts.path = hp
hosts.byName = hs hosts.byName = hs
hosts.byAddr = is hosts.byAddr = is

View File

@ -363,7 +363,7 @@ func TestCopyError(t *testing.T) {
} }
conn.Close() conn.Close()
if tries := 0; childRunning() { tries := 0
for tries < 15 && childRunning() { for tries < 15 && childRunning() {
time.Sleep(50e6 * int64(tries)) time.Sleep(50e6 * int64(tries))
tries++ tries++
@ -372,7 +372,6 @@ func TestCopyError(t *testing.T) {
t.Fatalf("post-conn.Close, expected child to be gone") t.Fatalf("post-conn.Close, expected child to be gone")
} }
} }
}
func TestDirUnix(t *testing.T) { func TestDirUnix(t *testing.T) {
if skipTest(t) || runtime.GOOS == "windows" { if skipTest(t) || runtime.GOOS == "windows" {

View File

@ -2,20 +2,137 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
// This code is duplicated in httputil/chunked.go.
// Please make any changes in both files.
package http package http
import ( import (
"bufio" "bufio"
"bytes"
"errors"
"io" "io"
"strconv" "strconv"
) )
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// newChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// newChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func newChunkedReader(r io.Reader) io.Reader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
}
return &chunkedReader{r: br}
}
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line string
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = strconv.Btoui64(line, 16)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
if cr.err != nil {
return 0, cr.err
}
if cr.n == 0 {
cr.beginChunk()
if cr.err != nil {
return 0, cr.err
}
}
if uint64(len(b)) > cr.n {
b = b[0:cr.n]
}
n, cr.err = cr.r.Read(b)
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
b := make([]byte, 2)
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
if b[0] != '\r' || b[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
}
return n, cr.err
}
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
// Chop off trailing white space.
p = bytes.TrimRight(p, " \r\t\n")
return p, nil
}
// readLineBytes, but convert the bytes into a string.
func readLine(b *bufio.Reader) (s string, err error) {
p, e := readLineBytes(b)
if e != nil {
return "", e
}
return string(p), nil
}
// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream.
//
// newChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
// Content-Length header. Using newChunkedWriter inside a handler
// would result in double chunking or chunking with a Content-Length
// length, both of which are wrong.
func newChunkedWriter(w io.Writer) io.WriteCloser { func newChunkedWriter(w io.Writer) io.WriteCloser {
return &chunkedWriter{w} return &chunkedWriter{w}
} }
// Writing to ChunkedWriter translates to writing in HTTP chunked Transfer // Writing to chunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire writer. // Encoding wire format to the underlying Wire chunkedWriter.
type chunkedWriter struct { type chunkedWriter struct {
Wire io.Writer Wire io.Writer
} }
@ -51,7 +168,3 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n") _, err := io.WriteString(cw.Wire, "0\r\n")
return err return err
} }
func newChunkedReader(r *bufio.Reader) io.Reader {
return &chunkedReader{r: r}
}

View File

@ -0,0 +1,39 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code is duplicated in httputil/chunked_test.go.
// Please make any changes in both files.
package http
import (
"bytes"
"io/ioutil"
"testing"
)
func TestChunk(t *testing.T) {
var b bytes.Buffer
w := newChunkedWriter(&b)
const chunk1 = "hello, "
const chunk2 = "world! 0123456789abcdef"
w.Write([]byte(chunk1))
w.Write([]byte(chunk2))
w.Close()
if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
t.Fatalf("chunk writer wrote %q; want %q", g, e)
}
r := newChunkedReader(&b)
data, err := ioutil.ReadAll(r)
if err != nil {
t.Logf(`data: "%s"`, data)
t.Fatalf("ReadAll from reader: %v", err)
}
if g, e := string(data), chunk1+chunk2; g != e {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}

View File

@ -26,6 +26,31 @@ var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
}) })
// pedanticReadAll works like ioutil.ReadAll but additionally
// verifies that r obeys the documented io.Reader contract.
func pedanticReadAll(r io.Reader) (b []byte, err error) {
var bufa [64]byte
buf := bufa[:]
for {
n, err := r.Read(buf)
if n == 0 && err == nil {
return nil, fmt.Errorf("Read: n=0 with err=nil")
}
b = append(b, buf[:n]...)
if err == io.EOF {
n, err := r.Read(buf)
if n != 0 || err != io.EOF {
return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
}
return b, nil
}
if err != nil {
return b, err
}
}
panic("unreachable")
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ts := httptest.NewServer(robotsTxtHandler) ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close() defer ts.Close()
@ -33,7 +58,7 @@ func TestClient(t *testing.T) {
r, err := Get(ts.URL) r, err := Get(ts.URL)
var b []byte var b []byte
if err == nil { if err == nil {
b, err = ioutil.ReadAll(r.Body) b, err = pedanticReadAll(r.Body)
r.Body.Close() r.Body.Close()
} }
if err != nil { if err != nil {

View File

@ -7,6 +7,7 @@ package fcgi
// This file implements FastCGI from the perspective of a child process. // This file implements FastCGI from the perspective of a child process.
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -125,49 +126,61 @@ func (r *response) Close() error {
type child struct { type child struct {
conn *conn conn *conn
handler http.Handler handler http.Handler
requests map[uint16]*request // keyed by request ID
} }
func newChild(rwc net.Conn, handler http.Handler) *child { func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
return &child{newConn(rwc), handler} return &child{
conn: newConn(rwc),
handler: handler,
requests: make(map[uint16]*request),
}
} }
func (c *child) serve() { func (c *child) serve() {
requests := map[uint16]*request{}
defer c.conn.Close() defer c.conn.Close()
var rec record var rec record
var br beginRequest
for { for {
if err := rec.read(c.conn.rwc); err != nil { if err := rec.read(c.conn.rwc); err != nil {
return return
} }
if err := c.handleRecord(&rec); err != nil {
return
}
}
}
req, ok := requests[rec.h.Id] var errCloseConn = errors.New("fcgi: connection should be closed")
func (c *child) handleRecord(rec *record) error {
req, ok := c.requests[rec.h.Id]
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
// The spec says to ignore unknown request IDs. // The spec says to ignore unknown request IDs.
continue return nil
} }
if ok && rec.h.Type == typeBeginRequest { if ok && rec.h.Type == typeBeginRequest {
// The server is trying to begin a request with the same ID // The server is trying to begin a request with the same ID
// as an in-progress request. This is an error. // as an in-progress request. This is an error.
return return errors.New("fcgi: received ID that is already in-flight")
} }
switch rec.h.Type { switch rec.h.Type {
case typeBeginRequest: case typeBeginRequest:
var br beginRequest
if err := br.read(rec.content()); err != nil { if err := br.read(rec.content()); err != nil {
return return err
} }
if br.role != roleResponder { if br.role != roleResponder {
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
break return nil
} }
requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
case typeParams: case typeParams:
// NOTE(eds): Technically a key-value pair can straddle the boundary // NOTE(eds): Technically a key-value pair can straddle the boundary
// between two packets. We buffer until we've received all parameters. // between two packets. We buffer until we've received all parameters.
if len(rec.content()) > 0 { if len(rec.content()) > 0 {
req.rawParams = append(req.rawParams, rec.content()...) req.rawParams = append(req.rawParams, rec.content()...)
break return nil
} }
req.parseParams() req.parseParams()
case typeStdin: case typeStdin:
@ -190,22 +203,22 @@ func (c *child) serve() {
} }
case typeGetValues: case typeGetValues:
values := map[string]string{"FCGI_MPXS_CONNS": "1"} values := map[string]string{"FCGI_MPXS_CONNS": "1"}
c.conn.writePairs(0, typeGetValuesResult, values) c.conn.writePairs(typeGetValuesResult, 0, values)
case typeData: case typeData:
// If the filter role is implemented, read the data stream here. // If the filter role is implemented, read the data stream here.
case typeAbortRequest: case typeAbortRequest:
delete(requests, rec.h.Id) delete(c.requests, rec.h.Id)
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
if !req.keepConn { if !req.keepConn {
// connection will close upon return // connection will close upon return
return return errCloseConn
} }
default: default:
b := make([]byte, 8) b := make([]byte, 8)
b[0] = rec.h.Type b[0] = byte(rec.h.Type)
c.conn.writeRecord(typeUnknownType, 0, b) c.conn.writeRecord(typeUnknownType, 0, b)
} }
} return nil
} }
func (c *child) serveRequest(req *request, body io.ReadCloser) { func (c *child) serveRequest(req *request, body io.ReadCloser) {

View File

@ -19,19 +19,22 @@ import (
"sync" "sync"
) )
// recType is a record type, as defined by
// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8
type recType uint8
const ( const (
// Packet Types typeBeginRequest recType = 1
typeBeginRequest = iota + 1 typeAbortRequest recType = 2
typeAbortRequest typeEndRequest recType = 3
typeEndRequest typeParams recType = 4
typeParams typeStdin recType = 5
typeStdin typeStdout recType = 6
typeStdout typeStderr recType = 7
typeStderr typeData recType = 8
typeData typeGetValues recType = 9
typeGetValues typeGetValuesResult recType = 10
typeGetValuesResult typeUnknownType recType = 11
typeUnknownType
) )
// keep the connection between web-server and responder open after request // keep the connection between web-server and responder open after request
@ -59,7 +62,7 @@ const headerLen = 8
type header struct { type header struct {
Version uint8 Version uint8
Type uint8 Type recType
Id uint16 Id uint16
ContentLength uint16 ContentLength uint16
PaddingLength uint8 PaddingLength uint8
@ -85,7 +88,7 @@ func (br *beginRequest) read(content []byte) error {
// not synchronized because we don't care what the contents are // not synchronized because we don't care what the contents are
var pad [maxPad]byte var pad [maxPad]byte
func (h *header) init(recType uint8, reqId uint16, contentLength int) { func (h *header) init(recType recType, reqId uint16, contentLength int) {
h.Version = 1 h.Version = 1
h.Type = recType h.Type = recType
h.Id = reqId h.Id = reqId
@ -137,7 +140,7 @@ func (r *record) content() []byte {
} }
// writeRecord writes and sends a single record. // writeRecord writes and sends a single record.
func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) error { func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
c.buf.Reset() c.buf.Reset()
@ -167,12 +170,12 @@ func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8
return c.writeRecord(typeEndRequest, reqId, b) return c.writeRecord(typeEndRequest, reqId, b)
} }
func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) error { func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
w := newWriter(c, recType, reqId) w := newWriter(c, recType, reqId)
b := make([]byte, 8) b := make([]byte, 8)
for k, v := range pairs { for k, v := range pairs {
n := encodeSize(b, uint32(len(k))) n := encodeSize(b, uint32(len(k)))
n += encodeSize(b[n:], uint32(len(k))) n += encodeSize(b[n:], uint32(len(v)))
if _, err := w.Write(b[:n]); err != nil { if _, err := w.Write(b[:n]); err != nil {
return err return err
} }
@ -235,7 +238,7 @@ func (w *bufWriter) Close() error {
return w.closer.Close() return w.closer.Close()
} }
func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter { func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
s := &streamWriter{c: c, recType: recType, reqId: reqId} s := &streamWriter{c: c, recType: recType, reqId: reqId}
w, _ := bufio.NewWriterSize(s, maxWrite) w, _ := bufio.NewWriterSize(s, maxWrite)
return &bufWriter{s, w} return &bufWriter{s, w}
@ -245,7 +248,7 @@ func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter {
// It only writes maxWrite bytes at a time. // It only writes maxWrite bytes at a time.
type streamWriter struct { type streamWriter struct {
c *conn c *conn
recType uint8 recType recType
reqId uint16 reqId uint16
} }

View File

@ -6,6 +6,7 @@ package fcgi
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"testing" "testing"
) )
@ -40,25 +41,25 @@ func TestSize(t *testing.T) {
var streamTests = []struct { var streamTests = []struct {
desc string desc string
recType uint8 recType recType
reqId uint16 reqId uint16
content []byte content []byte
raw []byte raw []byte
}{ }{
{"single record", typeStdout, 1, nil, {"single record", typeStdout, 1, nil,
[]byte{1, typeStdout, 0, 1, 0, 0, 0, 0}, []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0},
}, },
// this data will have to be split into two records // this data will have to be split into two records
{"two records", typeStdin, 300, make([]byte, 66000), {"two records", typeStdin, 300, make([]byte, 66000),
bytes.Join([][]byte{ bytes.Join([][]byte{
// header for the first record // header for the first record
{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
make([]byte, 65536), make([]byte, 65536),
// header for the second // header for the second
{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0},
make([]byte, 472), make([]byte, 472),
// header for the empty record // header for the empty record
{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0},
}, },
nil), nil),
}, },
@ -111,3 +112,39 @@ outer:
} }
} }
} }
type writeOnlyConn struct {
buf []byte
}
func (c *writeOnlyConn) Write(p []byte) (int, error) {
c.buf = append(c.buf, p...)
return len(p), nil
}
func (c *writeOnlyConn) Read(p []byte) (int, error) {
return 0, errors.New("conn is write-only")
}
func (c *writeOnlyConn) Close() error {
return nil
}
func TestGetValues(t *testing.T) {
var rec record
rec.h.Type = typeGetValues
wc := new(writeOnlyConn)
c := newChild(wc, nil)
err := c.handleRecord(&rec)
if err != nil {
t.Fatalf("handleRecord: %v", err)
}
const want = "\x01\n\x00\x00\x00\x12\x06\x00" +
"\x0f\x01FCGI_MPXS_CONNS1" +
"\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00"
if got := string(wc.buf); got != want {
t.Errorf(" got: %q\nwant: %q\n", got, want)
}
}

View File

@ -22,13 +22,19 @@ import (
// A Dir implements http.FileSystem using the native file // A Dir implements http.FileSystem using the native file
// system restricted to a specific directory tree. // system restricted to a specific directory tree.
//
// An empty Dir is treated as ".".
type Dir string type Dir string
func (d Dir) Open(name string) (File, error) { func (d Dir) Open(name string) (File, error) {
if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 { if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
return nil, errors.New("http: invalid character in file path") return nil, errors.New("http: invalid character in file path")
} }
f, err := os.Open(filepath.Join(string(d), filepath.FromSlash(path.Clean("/"+name)))) dir := string(d)
if dir == "" {
dir = "."
}
f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -208,6 +208,20 @@ func TestDirJoin(t *testing.T) {
test(Dir("/etc/hosts"), "../") test(Dir("/etc/hosts"), "../")
} }
func TestEmptyDirOpenCWD(t *testing.T) {
test := func(d Dir) {
name := "fs_test.go"
f, err := d.Open(name)
if err != nil {
t.Fatalf("open of %s: %v", name, err)
}
defer f.Close()
}
test(Dir(""))
test(Dir("."))
test(Dir("./"))
}
func TestServeFileContentType(t *testing.T) { func TestServeFileContentType(t *testing.T) {
const ctype = "icecream/chocolate" const ctype = "icecream/chocolate"
override := false override := false
@ -247,6 +261,20 @@ func TestServeFileMimeType(t *testing.T) {
} }
} }
func TestServeFileFromCWD(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "fs_test.go")
}))
defer ts.Close()
r, err := Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if r.StatusCode != 200 {
t.Fatalf("expected 200 OK, got %s", r.Status)
}
}
func TestServeFileWithContentEncoding(t *testing.T) { func TestServeFileWithContentEncoding(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "foo") w.Header().Set("Content-Encoding", "foo")

View File

@ -2,18 +2,126 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
// This code is a duplicate of ../chunked.go with these edits:
// s/newChunked/NewChunked/g
// s/package http/package httputil/
// Please make any changes in both files.
package httputil package httputil
import ( import (
"bufio" "bufio"
"bytes"
"errors"
"io" "io"
"net/http"
"strconv" "strconv"
"strings"
) )
// NewChunkedWriter returns a new writer that translates writes into HTTP const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
// "chunked" format before writing them to w. Closing the returned writer
var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func NewChunkedReader(r io.Reader) io.Reader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
}
return &chunkedReader{r: br}
}
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line string
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = strconv.Btoui64(line, 16)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
if cr.err != nil {
return 0, cr.err
}
if cr.n == 0 {
cr.beginChunk()
if cr.err != nil {
return 0, cr.err
}
}
if uint64(len(b)) > cr.n {
b = b[0:cr.n]
}
n, cr.err = cr.r.Read(b)
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
b := make([]byte, 2)
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
if b[0] != '\r' || b[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
}
return n, cr.err
}
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
// Chop off trailing white space.
p = bytes.TrimRight(p, " \r\t\n")
return p, nil
}
// readLineBytes, but convert the bytes into a string.
func readLine(b *bufio.Reader) (s string, err error) {
p, e := readLineBytes(b)
if e != nil {
return "", e
}
return string(p), nil
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream. // sends the final 0-length chunk that marks the end of the stream.
// //
// NewChunkedWriter is not needed by normal applications. The http // NewChunkedWriter is not needed by normal applications. The http
@ -25,8 +133,8 @@ func NewChunkedWriter(w io.Writer) io.WriteCloser {
return &chunkedWriter{w} return &chunkedWriter{w}
} }
// Writing to ChunkedWriter translates to writing in HTTP chunked Transfer // Writing to chunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire writer. // Encoding wire format to the underlying Wire chunkedWriter.
type chunkedWriter struct { type chunkedWriter struct {
Wire io.Writer Wire io.Writer
} }
@ -62,23 +170,3 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n") _, err := io.WriteString(cw.Wire, "0\r\n")
return err return err
} }
// NewChunkedReader returns a new reader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The reader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func NewChunkedReader(r io.Reader) io.Reader {
// This is a bit of a hack so we don't have to copy chunkedReader into
// httputil. It's a bit more complex than chunkedWriter, which is copied
// above.
req, err := http.ReadRequest(bufio.NewReader(io.MultiReader(
strings.NewReader("POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"),
r,
strings.NewReader("\r\n"))))
if err != nil {
panic("bad fake request: " + err.Error())
}
return req.Body
}

View File

@ -2,6 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code is a duplicate of ../chunked_test.go with these edits:
// s/newChunked/NewChunked/g
// s/package http/package httputil/
// Please make any changes in both files.
package httputil package httputil
import ( import (
@ -27,7 +32,8 @@ func TestChunk(t *testing.T) {
r := NewChunkedReader(&b) r := NewChunkedReader(&b)
data, err := ioutil.ReadAll(r) data, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Fatalf("ReadAll from NewChunkedReader: %v", err) t.Logf(`data: "%s"`, data)
t.Fatalf("ReadAll from reader: %v", err)
} }
if g, e := string(data), chunk1+chunk2; g != e { if g, e := string(data), chunk1+chunk2; g != e {
t.Errorf("chunk reader read %q; want %q", g, e) t.Errorf("chunk reader read %q; want %q", g, e)

View File

@ -22,6 +22,10 @@ var (
ErrPipeline = &http.ProtocolError{"pipeline error"} ErrPipeline = &http.ProtocolError{"pipeline error"}
) )
// This is an API usage error - the local side is closed.
// ErrPersistEOF (above) reports that the remote side is closed.
var errClosed = errors.New("i/o operation on closed connection")
// A ServerConn reads requests and sends responses over an underlying // A ServerConn reads requests and sends responses over an underlying
// connection, until the HTTP keepalive logic commands an end. ServerConn // connection, until the HTTP keepalive logic commands an end. ServerConn
// also allows hijacking the underlying connection by calling Hijack // also allows hijacking the underlying connection by calling Hijack
@ -108,7 +112,7 @@ func (sc *ServerConn) Read() (req *http.Request, err error) {
} }
if sc.r == nil { // connection closed by user in the meantime if sc.r == nil { // connection closed by user in the meantime
defer sc.lk.Unlock() defer sc.lk.Unlock()
return nil, os.EBADF return nil, errClosed
} }
r := sc.r r := sc.r
lastbody := sc.lastbody lastbody := sc.lastbody
@ -313,7 +317,7 @@ func (cc *ClientConn) Write(req *http.Request) (err error) {
} }
if cc.c == nil { // connection closed by user in the meantime if cc.c == nil { // connection closed by user in the meantime
defer cc.lk.Unlock() defer cc.lk.Unlock()
return os.EBADF return errClosed
} }
c := cc.c c := cc.c
if req.Close { if req.Close {
@ -369,7 +373,7 @@ func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
} }
if cc.r == nil { // connection closed by user in the meantime if cc.r == nil { // connection closed by user in the meantime
defer cc.lk.Unlock() defer cc.lk.Unlock()
return nil, os.EBADF return nil, errClosed
} }
r := cc.r r := cc.r
lastbody := cc.lastbody lastbody := cc.lastbody

View File

@ -70,7 +70,6 @@ var reqTests = []reqTest{
Close: false, Close: false,
ContentLength: 7, ContentLength: 7,
Host: "www.techcrunch.com", Host: "www.techcrunch.com",
Form: url.Values{},
}, },
"abcdef\n", "abcdef\n",
@ -94,10 +93,10 @@ var reqTests = []reqTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: Header{},
Close: false, Close: false,
ContentLength: 0, ContentLength: 0,
Host: "foo.com", Host: "foo.com",
Form: url.Values{},
}, },
noBody, noBody,
@ -131,7 +130,6 @@ var reqTests = []reqTest{
Close: false, Close: false,
ContentLength: 0, ContentLength: 0,
Host: "test", Host: "test",
Form: url.Values{},
}, },
noBody, noBody,
@ -180,9 +178,9 @@ var reqTests = []reqTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: Header{},
ContentLength: -1, ContentLength: -1,
Host: "foo.com", Host: "foo.com",
Form: url.Values{},
}, },
"foobar", "foobar",

View File

@ -19,12 +19,10 @@ import (
"mime/multipart" "mime/multipart"
"net/textproto" "net/textproto"
"net/url" "net/url"
"strconv"
"strings" "strings"
) )
const ( const (
maxLineLength = 4096 // assumed <= bufio.defaultBufSize
maxValueLength = 4096 maxValueLength = 4096
maxHeaderLines = 1024 maxHeaderLines = 1024
chunkSize = 4 << 10 // 4 KB chunks chunkSize = 4 << 10 // 4 KB chunks
@ -43,7 +41,6 @@ type ProtocolError struct {
func (err *ProtocolError) Error() string { return err.ErrorString } func (err *ProtocolError) Error() string { return err.ErrorString }
var ( var (
ErrLineTooLong = &ProtocolError{"header line too long"}
ErrHeaderTooLong = &ProtocolError{"header too long"} ErrHeaderTooLong = &ProtocolError{"header too long"}
ErrShortBody = &ProtocolError{"entity body too short"} ErrShortBody = &ProtocolError{"entity body too short"}
ErrNotSupported = &ProtocolError{"feature not supported"} ErrNotSupported = &ProtocolError{"feature not supported"}
@ -375,44 +372,6 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
return nil return nil
} }
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
// Chop off trailing white space.
var i int
for i = len(p); i > 0; i-- {
if c := p[i-1]; c != ' ' && c != '\r' && c != '\t' && c != '\n' {
break
}
}
return p[0:i], nil
}
// readLineBytes, but convert the bytes into a string.
func readLine(b *bufio.Reader) (s string, err error) {
p, e := readLineBytes(b)
if e != nil {
return "", e
}
return string(p), nil
}
// Convert decimal at s[i:len(s)] to integer, // Convert decimal at s[i:len(s)] to integer,
// returning value, string position where the digits stopped, // returning value, string position where the digits stopped,
// and whether there was a valid number (digits, not too big). // and whether there was a valid number (digits, not too big).
@ -448,55 +407,6 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
return major, minor, true return major, minor, true
} }
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line string
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = strconv.Btoui64(line, 16)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
if cr.err != nil {
return 0, cr.err
}
if cr.n == 0 {
cr.beginChunk()
if cr.err != nil {
return 0, cr.err
}
}
if uint64(len(b)) > cr.n {
b = b[0:cr.n]
}
n, cr.err = cr.r.Read(b)
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
b := make([]byte, 2)
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
if b[0] != '\r' || b[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
}
return n, cr.err
}
// NewRequest returns a new Request given a method, URL, and optional body. // NewRequest returns a new Request given a method, URL, and optional body.
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
u, err := url.Parse(urlStr) u, err := url.Parse(urlStr)

View File

@ -65,6 +65,7 @@ var respTests = []respTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: Header{},
Request: dummyReq("GET"), Request: dummyReq("GET"),
Close: true, Close: true,
ContentLength: -1, ContentLength: -1,
@ -85,6 +86,7 @@ var respTests = []respTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: Header{},
Request: dummyReq("GET"), Request: dummyReq("GET"),
Close: false, Close: false,
ContentLength: 0, ContentLength: 0,
@ -315,7 +317,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
} }
var wr io.Writer = &buf var wr io.Writer = &buf
if test.chunked { if test.chunked {
wr = &chunkedWriter{wr} wr = newChunkedWriter(wr)
} }
if test.compressed { if test.compressed {
buf.WriteString("Content-Encoding: gzip\r\n") buf.WriteString("Content-Encoding: gzip\r\n")

View File

@ -1077,6 +1077,31 @@ func TestClientWriteShutdown(t *testing.T) {
} }
} }
// Tests that chunked server responses that write 1 byte at a time are
// buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) {
if true {
t.Logf("Skipping known broken test; see Issue 2357")
return
}
conn := new(testConn)
conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
done := make(chan bool)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
defer close(done)
rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers
rw.Write([]byte{'x'})
rw.Write([]byte{'y'})
rw.Write([]byte{'z'})
}))
<-done
if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
conn.writeBuf.Bytes())
}
}
// goTimeout runs f, failing t if f takes more than ns to complete. // goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, ns int64, f func()) { func goTimeout(t *testing.T, ns int64, f func()) {
ch := make(chan bool, 2) ch := make(chan bool, 2)
@ -1120,7 +1145,7 @@ func TestAcceptMaxFds(t *testing.T) {
ln := &errorListener{[]error{ ln := &errorListener{[]error{
&net.OpError{ &net.OpError{
Op: "accept", Op: "accept",
Err: os.Errno(syscall.EMFILE), Err: syscall.EMFILE,
}}} }}}
err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {}))) err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
if err != io.EOF { if err != io.EOF {

View File

@ -149,11 +149,13 @@ type writerOnly struct {
} }
func (w *response) ReadFrom(src io.Reader) (n int64, err error) { func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
// Flush before checking w.chunking, as Flush will call // Call WriteHeader before checking w.chunking if it hasn't
// WriteHeader if it hasn't been called yet, and WriteHeader // been called yet, since WriteHeader is what sets w.chunking.
// is what sets w.chunking. if !w.wroteHeader {
w.Flush() w.WriteHeader(StatusOK)
}
if !w.chunking && w.bodyAllowed() && !w.needSniff { if !w.chunking && w.bodyAllowed() && !w.needSniff {
w.Flush()
if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(src) n, err = rf.ReadFrom(src)
w.written += n w.written += n

View File

@ -6,6 +6,7 @@ package http_test
import ( import (
"bytes" "bytes"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
. "net/http" . "net/http"
@ -79,3 +80,35 @@ func TestServerContentType(t *testing.T) {
resp.Body.Close() resp.Body.Close()
} }
} }
func TestContentTypeWithCopy(t *testing.T) {
const (
input = "\n<html>\n\t<head>\n"
expected = "text/html; charset=utf-8"
)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Use io.Copy from a bytes.Buffer to trigger ReadFrom.
buf := bytes.NewBuffer([]byte(input))
n, err := io.Copy(w, buf)
if int(n) != len(input) || err != nil {
t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
}
}))
defer ts.Close()
resp, err := Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
if ct := resp.Header.Get("Content-Type"); ct != expected {
t.Errorf("Content-Type = %q, want %q", ct, expected)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("reading body: %v", err)
} else if !bytes.Equal(data, []byte(input)) {
t.Errorf("data is %q, want %q", data, input)
}
resp.Body.Close()
}

View File

@ -537,7 +537,9 @@ func (b *body) Read(p []byte) (n int, err error) {
// Read the final trailer once we hit EOF. // Read the final trailer once we hit EOF.
if err == io.EOF && b.hdr != nil { if err == io.EOF && b.hdr != nil {
err = b.readTrailer() if e := b.readTrailer(); e != nil {
err = e
}
b.hdr = nil b.hdr = nil
} }
return n, err return n, err

Some files were not shown because too many files have changed in this diff Show More