mirror of git://gcc.gnu.org/git/gcc.git
parent
6e456f4cf4
commit
ab61e9c4da
|
@ -1,4 +1,4 @@
|
||||||
2f4482b89a6b
|
b4a91b693374
|
||||||
|
|
||||||
The first line of this file holds the Mercurial revision number of the
|
The first line of this file holds the Mercurial revision number of the
|
||||||
last merge done from the master library sources.
|
last merge done from the master library sources.
|
||||||
|
|
|
@ -648,7 +648,8 @@ go_math_files = \
|
||||||
go_mime_files = \
|
go_mime_files = \
|
||||||
go/mime/grammar.go \
|
go/mime/grammar.go \
|
||||||
go/mime/mediatype.go \
|
go/mime/mediatype.go \
|
||||||
go/mime/type.go
|
go/mime/type.go \
|
||||||
|
go/mime/type_unix.go
|
||||||
|
|
||||||
if LIBGO_IS_RTEMS
|
if LIBGO_IS_RTEMS
|
||||||
go_net_fd_os_file = go/net/fd_select.go
|
go_net_fd_os_file = go/net/fd_select.go
|
||||||
|
@ -770,7 +771,6 @@ go_os_files = \
|
||||||
$(go_os_dir_file) \
|
$(go_os_dir_file) \
|
||||||
go/os/dir.go \
|
go/os/dir.go \
|
||||||
go/os/env.go \
|
go/os/env.go \
|
||||||
go/os/env_unix.go \
|
|
||||||
go/os/error.go \
|
go/os/error.go \
|
||||||
go/os/error_posix.go \
|
go/os/error_posix.go \
|
||||||
go/os/exec.go \
|
go/os/exec.go \
|
||||||
|
@ -1156,6 +1156,7 @@ go_exp_sql_files = \
|
||||||
go/exp/sql/sql.go
|
go/exp/sql/sql.go
|
||||||
go_exp_ssh_files = \
|
go_exp_ssh_files = \
|
||||||
go/exp/ssh/channel.go \
|
go/exp/ssh/channel.go \
|
||||||
|
go/exp/ssh/cipher.go \
|
||||||
go/exp/ssh/client.go \
|
go/exp/ssh/client.go \
|
||||||
go/exp/ssh/client_auth.go \
|
go/exp/ssh/client_auth.go \
|
||||||
go/exp/ssh/common.go \
|
go/exp/ssh/common.go \
|
||||||
|
@ -1164,10 +1165,11 @@ go_exp_ssh_files = \
|
||||||
go/exp/ssh/server.go \
|
go/exp/ssh/server.go \
|
||||||
go/exp/ssh/server_shell.go \
|
go/exp/ssh/server_shell.go \
|
||||||
go/exp/ssh/session.go \
|
go/exp/ssh/session.go \
|
||||||
|
go/exp/ssh/tcpip.go \
|
||||||
go/exp/ssh/transport.go
|
go/exp/ssh/transport.go
|
||||||
go_exp_terminal_files = \
|
go_exp_terminal_files = \
|
||||||
go/exp/terminal/shell.go \
|
go/exp/terminal/terminal.go \
|
||||||
go/exp/terminal/terminal.go
|
go/exp/terminal/util.go
|
||||||
go_exp_types_files = \
|
go_exp_types_files = \
|
||||||
go/exp/types/check.go \
|
go/exp/types/check.go \
|
||||||
go/exp/types/const.go \
|
go/exp/types/const.go \
|
||||||
|
@ -1546,6 +1548,7 @@ syscall_netlink_file =
|
||||||
endif
|
endif
|
||||||
|
|
||||||
go_base_syscall_files = \
|
go_base_syscall_files = \
|
||||||
|
go/syscall/env_unix.go \
|
||||||
go/syscall/libcall_support.go \
|
go/syscall/libcall_support.go \
|
||||||
go/syscall/libcall_posix.go \
|
go/syscall/libcall_posix.go \
|
||||||
go/syscall/socket.go \
|
go/syscall/socket.go \
|
||||||
|
|
|
@ -1032,7 +1032,8 @@ go_math_files = \
|
||||||
go_mime_files = \
|
go_mime_files = \
|
||||||
go/mime/grammar.go \
|
go/mime/grammar.go \
|
||||||
go/mime/mediatype.go \
|
go/mime/mediatype.go \
|
||||||
go/mime/type.go
|
go/mime/type.go \
|
||||||
|
go/mime/type_unix.go
|
||||||
|
|
||||||
# By default use select with pipes. Most systems should have
|
# By default use select with pipes. Most systems should have
|
||||||
# something better.
|
# something better.
|
||||||
|
@ -1103,7 +1104,6 @@ go_os_files = \
|
||||||
$(go_os_dir_file) \
|
$(go_os_dir_file) \
|
||||||
go/os/dir.go \
|
go/os/dir.go \
|
||||||
go/os/env.go \
|
go/os/env.go \
|
||||||
go/os/env_unix.go \
|
|
||||||
go/os/error.go \
|
go/os/error.go \
|
||||||
go/os/error_posix.go \
|
go/os/error_posix.go \
|
||||||
go/os/exec.go \
|
go/os/exec.go \
|
||||||
|
@ -1521,6 +1521,7 @@ go_exp_sql_files = \
|
||||||
|
|
||||||
go_exp_ssh_files = \
|
go_exp_ssh_files = \
|
||||||
go/exp/ssh/channel.go \
|
go/exp/ssh/channel.go \
|
||||||
|
go/exp/ssh/cipher.go \
|
||||||
go/exp/ssh/client.go \
|
go/exp/ssh/client.go \
|
||||||
go/exp/ssh/client_auth.go \
|
go/exp/ssh/client_auth.go \
|
||||||
go/exp/ssh/common.go \
|
go/exp/ssh/common.go \
|
||||||
|
@ -1529,11 +1530,12 @@ go_exp_ssh_files = \
|
||||||
go/exp/ssh/server.go \
|
go/exp/ssh/server.go \
|
||||||
go/exp/ssh/server_shell.go \
|
go/exp/ssh/server_shell.go \
|
||||||
go/exp/ssh/session.go \
|
go/exp/ssh/session.go \
|
||||||
|
go/exp/ssh/tcpip.go \
|
||||||
go/exp/ssh/transport.go
|
go/exp/ssh/transport.go
|
||||||
|
|
||||||
go_exp_terminal_files = \
|
go_exp_terminal_files = \
|
||||||
go/exp/terminal/shell.go \
|
go/exp/terminal/terminal.go \
|
||||||
go/exp/terminal/terminal.go
|
go/exp/terminal/util.go
|
||||||
|
|
||||||
go_exp_types_files = \
|
go_exp_types_files = \
|
||||||
go/exp/types/check.go \
|
go/exp/types/check.go \
|
||||||
|
@ -1890,6 +1892,7 @@ go_unicode_utf8_files = \
|
||||||
# Support for netlink sockets and messages.
|
# Support for netlink sockets and messages.
|
||||||
@LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go
|
@LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go
|
||||||
go_base_syscall_files = \
|
go_base_syscall_files = \
|
||||||
|
go/syscall/env_unix.go \
|
||||||
go/syscall/libcall_support.go \
|
go/syscall/libcall_support.go \
|
||||||
go/syscall/libcall_posix.go \
|
go/syscall/libcall_posix.go \
|
||||||
go/syscall/socket.go \
|
go/syscall/socket.go \
|
||||||
|
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/iotest"
|
"testing/iotest"
|
||||||
|
@ -425,9 +424,9 @@ var errorWriterTests = []errorWriterTest{
|
||||||
{0, 1, nil, io.ErrShortWrite},
|
{0, 1, nil, io.ErrShortWrite},
|
||||||
{1, 2, nil, io.ErrShortWrite},
|
{1, 2, nil, io.ErrShortWrite},
|
||||||
{1, 1, nil, nil},
|
{1, 1, nil, nil},
|
||||||
{0, 1, os.EPIPE, os.EPIPE},
|
{0, 1, io.ErrClosedPipe, io.ErrClosedPipe},
|
||||||
{1, 2, os.EPIPE, os.EPIPE},
|
{1, 2, io.ErrClosedPipe, io.ErrClosedPipe},
|
||||||
{1, 1, os.EPIPE, os.EPIPE},
|
{1, 1, io.ErrClosedPipe, io.ErrClosedPipe},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteErrors(t *testing.T) {
|
func TestWriteErrors(t *testing.T) {
|
||||||
|
|
|
@ -91,6 +91,11 @@ type rune rune
|
||||||
// invocation.
|
// invocation.
|
||||||
type Type int
|
type Type int
|
||||||
|
|
||||||
|
// Type1 is here for the purposes of documentation only. It is a stand-in
|
||||||
|
// for any Go type, but represents the same type for any given function
|
||||||
|
// invocation.
|
||||||
|
type Type1 int
|
||||||
|
|
||||||
// IntegerType is here for the purposes of documentation only. It is a stand-in
|
// IntegerType is here for the purposes of documentation only. It is a stand-in
|
||||||
// for any integer type: int, uint, int8 etc.
|
// for any integer type: int, uint, int8 etc.
|
||||||
type IntegerType int
|
type IntegerType int
|
||||||
|
@ -119,6 +124,11 @@ func append(slice []Type, elems ...Type) []Type
|
||||||
// len(src) and len(dst).
|
// len(src) and len(dst).
|
||||||
func copy(dst, src []Type) int
|
func copy(dst, src []Type) int
|
||||||
|
|
||||||
|
// The delete built-in function deletes the element with the specified key
|
||||||
|
// (m[key]) from the map. If there is no such element, delete is a no-op.
|
||||||
|
// If m is nil, delete panics.
|
||||||
|
func delete(m map[Type]Type1, key Type)
|
||||||
|
|
||||||
// The len built-in function returns the length of v, according to its type:
|
// The len built-in function returns the length of v, according to its type:
|
||||||
// Array: the number of elements in v.
|
// Array: the number of elements in v.
|
||||||
// Pointer to array: the number of elements in *v (even if v is nil).
|
// Pointer to array: the number of elements in *v (even if v is nil).
|
||||||
|
@ -171,7 +181,7 @@ func complex(r, i FloatType) ComplexType
|
||||||
// The return value will be floating point type corresponding to the type of c.
|
// The return value will be floating point type corresponding to the type of c.
|
||||||
func real(c ComplexType) FloatType
|
func real(c ComplexType) FloatType
|
||||||
|
|
||||||
// The imaginary built-in function returns the imaginary part of the complex
|
// The imag built-in function returns the imaginary part of the complex
|
||||||
// number c. The return value will be floating point type corresponding to
|
// number c. The return value will be floating point type corresponding to
|
||||||
// the type of c.
|
// the type of c.
|
||||||
func imag(c ComplexType) FloatType
|
func imag(c ComplexType) FloatType
|
||||||
|
|
|
@ -662,48 +662,49 @@ func TestRunes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TrimTest struct {
|
type TrimTest struct {
|
||||||
f func([]byte, string) []byte
|
f string
|
||||||
in, cutset, out string
|
in, cutset, out string
|
||||||
}
|
}
|
||||||
|
|
||||||
var trimTests = []TrimTest{
|
var trimTests = []TrimTest{
|
||||||
{Trim, "abba", "a", "bb"},
|
{"Trim", "abba", "a", "bb"},
|
||||||
{Trim, "abba", "ab", ""},
|
{"Trim", "abba", "ab", ""},
|
||||||
{TrimLeft, "abba", "ab", ""},
|
{"TrimLeft", "abba", "ab", ""},
|
||||||
{TrimRight, "abba", "ab", ""},
|
{"TrimRight", "abba", "ab", ""},
|
||||||
{TrimLeft, "abba", "a", "bba"},
|
{"TrimLeft", "abba", "a", "bba"},
|
||||||
{TrimRight, "abba", "a", "abb"},
|
{"TrimRight", "abba", "a", "abb"},
|
||||||
{Trim, "<tag>", "<>", "tag"},
|
{"Trim", "<tag>", "<>", "tag"},
|
||||||
{Trim, "* listitem", " *", "listitem"},
|
{"Trim", "* listitem", " *", "listitem"},
|
||||||
{Trim, `"quote"`, `"`, "quote"},
|
{"Trim", `"quote"`, `"`, "quote"},
|
||||||
{Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
|
{"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
|
||||||
//empty string tests
|
//empty string tests
|
||||||
{Trim, "abba", "", "abba"},
|
{"Trim", "abba", "", "abba"},
|
||||||
{Trim, "", "123", ""},
|
{"Trim", "", "123", ""},
|
||||||
{Trim, "", "", ""},
|
{"Trim", "", "", ""},
|
||||||
{TrimLeft, "abba", "", "abba"},
|
{"TrimLeft", "abba", "", "abba"},
|
||||||
{TrimLeft, "", "123", ""},
|
{"TrimLeft", "", "123", ""},
|
||||||
{TrimLeft, "", "", ""},
|
{"TrimLeft", "", "", ""},
|
||||||
{TrimRight, "abba", "", "abba"},
|
{"TrimRight", "abba", "", "abba"},
|
||||||
{TrimRight, "", "123", ""},
|
{"TrimRight", "", "123", ""},
|
||||||
{TrimRight, "", "", ""},
|
{"TrimRight", "", "", ""},
|
||||||
{TrimRight, "☺\xc0", "☺", "☺\xc0"},
|
{"TrimRight", "☺\xc0", "☺", "☺\xc0"},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrim(t *testing.T) {
|
func TestTrim(t *testing.T) {
|
||||||
for _, tc := range trimTests {
|
for _, tc := range trimTests {
|
||||||
actual := string(tc.f([]byte(tc.in), tc.cutset))
|
name := tc.f
|
||||||
var name string
|
var f func([]byte, string) []byte
|
||||||
switch tc.f {
|
switch name {
|
||||||
case Trim:
|
case "Trim":
|
||||||
name = "Trim"
|
f = Trim
|
||||||
case TrimLeft:
|
case "TrimLeft":
|
||||||
name = "TrimLeft"
|
f = TrimLeft
|
||||||
case TrimRight:
|
case "TrimRight":
|
||||||
name = "TrimRight"
|
f = TrimRight
|
||||||
default:
|
default:
|
||||||
t.Error("Undefined trim function")
|
t.Error("Undefined trim function %s", name)
|
||||||
}
|
}
|
||||||
|
actual := string(f([]byte(tc.in), tc.cutset))
|
||||||
if actual != tc.out {
|
if actual != tc.out {
|
||||||
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
|
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Order specifies the bit ordering in an LZW data stream.
|
// Order specifies the bit ordering in an LZW data stream.
|
||||||
|
@ -212,8 +211,10 @@ func (d *decoder) flush() {
|
||||||
d.o = 0
|
d.o = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errClosed = errors.New("compress/lzw: reader/writer is closed")
|
||||||
|
|
||||||
func (d *decoder) Close() error {
|
func (d *decoder) Close() error {
|
||||||
d.err = os.EINVAL // in case any Reads come along
|
d.err = errClosed // in case any Reads come along
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A writer is a buffered, flushable writer.
|
// A writer is a buffered, flushable writer.
|
||||||
|
@ -49,8 +48,9 @@ const (
|
||||||
type encoder struct {
|
type encoder struct {
|
||||||
// w is the writer that compressed bytes are written to.
|
// w is the writer that compressed bytes are written to.
|
||||||
w writer
|
w writer
|
||||||
// write, bits, nBits and width are the state for converting a code stream
|
// order, write, bits, nBits and width are the state for
|
||||||
// into a byte stream.
|
// converting a code stream into a byte stream.
|
||||||
|
order Order
|
||||||
write func(*encoder, uint32) error
|
write func(*encoder, uint32) error
|
||||||
bits uint32
|
bits uint32
|
||||||
nBits uint
|
nBits uint
|
||||||
|
@ -64,7 +64,7 @@ type encoder struct {
|
||||||
// call. It is equal to invalidCode if there was no such call.
|
// call. It is equal to invalidCode if there was no such call.
|
||||||
savedCode uint32
|
savedCode uint32
|
||||||
// err is the first error encountered during writing. Closing the encoder
|
// err is the first error encountered during writing. Closing the encoder
|
||||||
// will make any future Write calls return os.EINVAL.
|
// will make any future Write calls return errClosed
|
||||||
err error
|
err error
|
||||||
// table is the hash table from 20-bit keys to 12-bit values. Each table
|
// table is the hash table from 20-bit keys to 12-bit values. Each table
|
||||||
// entry contains key<<12|val and collisions resolve by linear probing.
|
// entry contains key<<12|val and collisions resolve by linear probing.
|
||||||
|
@ -191,13 +191,13 @@ loop:
|
||||||
// flush e's underlying writer.
|
// flush e's underlying writer.
|
||||||
func (e *encoder) Close() error {
|
func (e *encoder) Close() error {
|
||||||
if e.err != nil {
|
if e.err != nil {
|
||||||
if e.err == os.EINVAL {
|
if e.err == errClosed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return e.err
|
return e.err
|
||||||
}
|
}
|
||||||
// Make any future calls to Write return os.EINVAL.
|
// Make any future calls to Write return errClosed.
|
||||||
e.err = os.EINVAL
|
e.err = errClosed
|
||||||
// Write the savedCode if valid.
|
// Write the savedCode if valid.
|
||||||
if e.savedCode != invalidCode {
|
if e.savedCode != invalidCode {
|
||||||
if err := e.write(e, e.savedCode); err != nil {
|
if err := e.write(e, e.savedCode); err != nil {
|
||||||
|
@ -214,7 +214,7 @@ func (e *encoder) Close() error {
|
||||||
}
|
}
|
||||||
// Write the final bits.
|
// Write the final bits.
|
||||||
if e.nBits > 0 {
|
if e.nBits > 0 {
|
||||||
if e.write == (*encoder).writeMSB {
|
if e.order == MSB {
|
||||||
e.bits >>= 24
|
e.bits >>= 24
|
||||||
}
|
}
|
||||||
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
|
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
|
||||||
|
@ -250,6 +250,7 @@ func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
|
||||||
lw := uint(litWidth)
|
lw := uint(litWidth)
|
||||||
return &encoder{
|
return &encoder{
|
||||||
w: bw,
|
w: bw,
|
||||||
|
order: order,
|
||||||
write: write,
|
write: write,
|
||||||
width: 1 + lw,
|
width: 1 + lw,
|
||||||
litWidth: lw,
|
litWidth: lw,
|
||||||
|
|
|
@ -50,10 +50,6 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err1 := lzww.Write(b[:n])
|
_, err1 := lzww.Write(b[:n])
|
||||||
if err1 == os.EPIPE {
|
|
||||||
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
|
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
|
||||||
return
|
return
|
||||||
|
|
|
@ -59,10 +59,6 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
|
||||||
}
|
}
|
||||||
defer zlibw.Close()
|
defer zlibw.Close()
|
||||||
_, err = zlibw.Write(b0)
|
_, err = zlibw.Write(b0)
|
||||||
if err == os.EPIPE {
|
|
||||||
// Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
|
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -41,7 +41,7 @@ func NewCipher(key []byte) (*Cipher, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlockSize returns the AES block size, 16 bytes.
|
// BlockSize returns the AES block size, 16 bytes.
|
||||||
// It is necessary to satisfy the Cipher interface in the
|
// It is necessary to satisfy the Block interface in the
|
||||||
// package "crypto/cipher".
|
// package "crypto/cipher".
|
||||||
func (c *Cipher) BlockSize() int { return BlockSize }
|
func (c *Cipher) BlockSize() int { return BlockSize }
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ func NewSaltedCipher(key, salt []byte) (*Cipher, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlockSize returns the Blowfish block size, 8 bytes.
|
// BlockSize returns the Blowfish block size, 8 bytes.
|
||||||
// It is necessary to satisfy the Cipher interface in the
|
// It is necessary to satisfy the Block interface in the
|
||||||
// package "crypto/cipher".
|
// package "crypto/cipher".
|
||||||
func (c *Cipher) BlockSize() int { return BlockSize }
|
func (c *Cipher) BlockSize() int { return BlockSize }
|
||||||
|
|
||||||
|
|
|
@ -28,16 +28,16 @@ func (r *rngReader) Read(b []byte) (n int, err error) {
|
||||||
if r.prov == 0 {
|
if r.prov == 0 {
|
||||||
const provType = syscall.PROV_RSA_FULL
|
const provType = syscall.PROV_RSA_FULL
|
||||||
const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT
|
const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT
|
||||||
errno := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
|
err := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
return 0, os.NewSyscallError("CryptAcquireContext", errno)
|
return 0, os.NewSyscallError("CryptAcquireContext", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
errno := syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
|
err = syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
return 0, os.NewSyscallError("CryptGenRandom", errno)
|
return 0, os.NewSyscallError("CryptGenRandom", err)
|
||||||
}
|
}
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,16 +5,16 @@
|
||||||
package rand
|
package rand
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Prime returns a number, p, of the given size, such that p is prime
|
// Prime returns a number, p, of the given size, such that p is prime
|
||||||
// with high probability.
|
// with high probability.
|
||||||
func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
|
func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
|
||||||
if bits < 1 {
|
if bits < 1 {
|
||||||
err = os.EINVAL
|
err = errors.New("crypto/rand: prime size must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
b := uint(bits % 8)
|
b := uint(bits % 8)
|
||||||
|
|
|
@ -93,7 +93,8 @@ func (c *Conn) SetTimeout(nsec int64) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetReadTimeout sets the time (in nanoseconds) that
|
// SetReadTimeout sets the time (in nanoseconds) that
|
||||||
// Read will wait for data before returning os.EAGAIN.
|
// Read will wait for data before returning a net.Error
|
||||||
|
// with Timeout() == true.
|
||||||
// Setting nsec == 0 (the default) disables the deadline.
|
// Setting nsec == 0 (the default) disables the deadline.
|
||||||
func (c *Conn) SetReadTimeout(nsec int64) error {
|
func (c *Conn) SetReadTimeout(nsec int64) error {
|
||||||
return c.conn.SetReadTimeout(nsec)
|
return c.conn.SetReadTimeout(nsec)
|
||||||
|
@ -737,7 +738,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||||
return c.writeRecord(recordTypeApplicationData, b)
|
return c.writeRecord(recordTypeApplicationData, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read can be made to time out and return err == os.EAGAIN
|
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetTimeout and SetReadTimeout.
|
// after a fixed time limit; see SetTimeout and SetReadTimeout.
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
if err = c.Handshake(); err != nil {
|
if err = c.Handshake(); err != nil {
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
|
|
||||||
package tls
|
package tls
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
type clientHelloMsg struct {
|
type clientHelloMsg struct {
|
||||||
raw []byte
|
raw []byte
|
||||||
vers uint16
|
vers uint16
|
||||||
|
@ -18,6 +20,25 @@ type clientHelloMsg struct {
|
||||||
supportedPoints []uint8
|
supportedPoints []uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *clientHelloMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*clientHelloMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.vers == m1.vers &&
|
||||||
|
bytes.Equal(m.random, m1.random) &&
|
||||||
|
bytes.Equal(m.sessionId, m1.sessionId) &&
|
||||||
|
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
|
||||||
|
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
|
||||||
|
m.nextProtoNeg == m1.nextProtoNeg &&
|
||||||
|
m.serverName == m1.serverName &&
|
||||||
|
m.ocspStapling == m1.ocspStapling &&
|
||||||
|
eqUint16s(m.supportedCurves, m1.supportedCurves) &&
|
||||||
|
bytes.Equal(m.supportedPoints, m1.supportedPoints)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *clientHelloMsg) marshal() []byte {
|
func (m *clientHelloMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -309,6 +330,23 @@ type serverHelloMsg struct {
|
||||||
ocspStapling bool
|
ocspStapling bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *serverHelloMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*serverHelloMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.vers == m1.vers &&
|
||||||
|
bytes.Equal(m.random, m1.random) &&
|
||||||
|
bytes.Equal(m.sessionId, m1.sessionId) &&
|
||||||
|
m.cipherSuite == m1.cipherSuite &&
|
||||||
|
m.compressionMethod == m1.compressionMethod &&
|
||||||
|
m.nextProtoNeg == m1.nextProtoNeg &&
|
||||||
|
eqStrings(m.nextProtos, m1.nextProtos) &&
|
||||||
|
m.ocspStapling == m1.ocspStapling
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverHelloMsg) marshal() []byte {
|
func (m *serverHelloMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -463,6 +501,16 @@ type certificateMsg struct {
|
||||||
certificates [][]byte
|
certificates [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
eqByteSlices(m.certificates, m1.certificates)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateMsg) marshal() (x []byte) {
|
func (m *certificateMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
|
||||||
key []byte
|
key []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*serverKeyExchangeMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.key, m1.key)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverKeyExchangeMsg) marshal() []byte {
|
func (m *serverKeyExchangeMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -571,6 +629,17 @@ type certificateStatusMsg struct {
|
||||||
response []byte
|
response []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateStatusMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateStatusMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.statusType == m1.statusType &&
|
||||||
|
bytes.Equal(m.response, m1.response)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateStatusMsg) marshal() []byte {
|
func (m *certificateStatusMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
|
||||||
|
|
||||||
type serverHelloDoneMsg struct{}
|
type serverHelloDoneMsg struct{}
|
||||||
|
|
||||||
|
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
|
||||||
|
_, ok := i.(*serverHelloDoneMsg)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverHelloDoneMsg) marshal() []byte {
|
func (m *serverHelloDoneMsg) marshal() []byte {
|
||||||
x := make([]byte, 4)
|
x := make([]byte, 4)
|
||||||
x[0] = typeServerHelloDone
|
x[0] = typeServerHelloDone
|
||||||
|
@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
|
||||||
ciphertext []byte
|
ciphertext []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*clientKeyExchangeMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.ciphertext, m1.ciphertext)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *clientKeyExchangeMsg) marshal() []byte {
|
func (m *clientKeyExchangeMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -671,6 +755,16 @@ type finishedMsg struct {
|
||||||
verifyData []byte
|
verifyData []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *finishedMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*finishedMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.verifyData, m1.verifyData)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *finishedMsg) marshal() (x []byte) {
|
func (m *finishedMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -698,6 +792,16 @@ type nextProtoMsg struct {
|
||||||
proto string
|
proto string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *nextProtoMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*nextProtoMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.proto == m1.proto
|
||||||
|
}
|
||||||
|
|
||||||
func (m *nextProtoMsg) marshal() []byte {
|
func (m *nextProtoMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -759,6 +863,17 @@ type certificateRequestMsg struct {
|
||||||
certificateAuthorities [][]byte
|
certificateAuthorities [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateRequestMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateRequestMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
|
||||||
|
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateRequestMsg) marshal() (x []byte) {
|
func (m *certificateRequestMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
|
||||||
signature []byte
|
signature []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateVerifyMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateVerifyMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.signature, m1.signature)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateVerifyMsg) marshal() (x []byte) {
|
func (m *certificateVerifyMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
|
@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func eqUint16s(x, y []uint16) bool {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range x {
|
||||||
|
if y[i] != v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func eqStrings(x, y []string) bool {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range x {
|
||||||
|
if y[i] != v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func eqByteSlices(x, y [][]byte) bool {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range x {
|
||||||
|
if !bytes.Equal(v, y[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
@ -27,10 +27,12 @@ var tests = []interface{}{
|
||||||
type testMessage interface {
|
type testMessage interface {
|
||||||
marshal() []byte
|
marshal() []byte
|
||||||
unmarshal([]byte) bool
|
unmarshal([]byte) bool
|
||||||
|
equal(interface{}) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalUnmarshal(t *testing.T) {
|
func TestMarshalUnmarshal(t *testing.T) {
|
||||||
rand := rand.New(rand.NewSource(0))
|
rand := rand.New(rand.NewSource(0))
|
||||||
|
|
||||||
for i, iface := range tests {
|
for i, iface := range tests {
|
||||||
ty := reflect.ValueOf(iface).Type()
|
ty := reflect.ValueOf(iface).Type()
|
||||||
|
|
||||||
|
@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
|
||||||
}
|
}
|
||||||
m2.marshal() // to fill any marshal cache in the message
|
m2.marshal() // to fill any marshal cache in the message
|
||||||
|
|
||||||
if !reflect.DeepEqual(m1, m2) {
|
if !m1.equal(m2) {
|
||||||
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
|
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadStore(roots *x509.CertPool, name string) {
|
func loadStore(roots *x509.CertPool, name string) {
|
||||||
store, errno := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
|
store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ func NewCipher(key []byte) (*Cipher, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlockSize returns the XTEA block size, 8 bytes.
|
// BlockSize returns the XTEA block size, 8 bytes.
|
||||||
// It is necessary to satisfy the Cipher interface in the
|
// It is necessary to satisfy the Block interface in the
|
||||||
// package "crypto/cipher".
|
// package "crypto/cipher".
|
||||||
func (c *Cipher) BlockSize() int { return BlockSize }
|
func (c *Cipher) BlockSize() int { return BlockSize }
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,157 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Large data benchmark.
|
||||||
|
// The JSON data is a summary of agl's changes in the
|
||||||
|
// go, webkit, and chromium open source projects.
|
||||||
|
// We benchmark converting between the JSON form
|
||||||
|
// and in-memory data structures.
|
||||||
|
|
||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type codeResponse struct {
|
||||||
|
Tree *codeNode `json:"tree"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codeNode struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Kids []*codeNode `json:"kids"`
|
||||||
|
CLWeight float64 `json:"cl_weight"`
|
||||||
|
Touches int `json:"touches"`
|
||||||
|
MinT int64 `json:"min_t"`
|
||||||
|
MaxT int64 `json:"max_t"`
|
||||||
|
MeanT int64 `json:"mean_t"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var codeJSON []byte
|
||||||
|
var codeStruct codeResponse
|
||||||
|
|
||||||
|
func codeInit() {
|
||||||
|
f, err := os.Open("testdata/code.json.gz")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
gz, err := gzip.NewReader(f)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
data, err := ioutil.ReadAll(gz)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
codeJSON = data
|
||||||
|
|
||||||
|
if err := Unmarshal(codeJSON, &codeStruct); err != nil {
|
||||||
|
panic("unmarshal code.json: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, err = Marshal(&codeStruct); err != nil {
|
||||||
|
panic("marshal code.json: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(data, codeJSON) {
|
||||||
|
println("different lengths", len(data), len(codeJSON))
|
||||||
|
for i := 0; i < len(data) && i < len(codeJSON); i++ {
|
||||||
|
if data[i] != codeJSON[i] {
|
||||||
|
println("re-marshal: changed at byte", i)
|
||||||
|
println("orig: ", string(codeJSON[i-10:i+10]))
|
||||||
|
println("new: ", string(data[i-10:i+10]))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("re-marshal code.json: different result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeEncoder(b *testing.B) {
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
enc := NewEncoder(ioutil.Discard)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := enc.Encode(&codeStruct); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeMarshal(b *testing.B) {
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := Marshal(&codeStruct); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeDecoder(b *testing.B) {
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
dec := NewDecoder(&buf)
|
||||||
|
var r codeResponse
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
buf.Write(codeJSON)
|
||||||
|
// hide EOF
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
if err := dec.Decode(&r); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeUnmarshal(b *testing.B) {
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var r codeResponse
|
||||||
|
if err := Unmarshal(codeJSON, &r); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeUnmarshalReuse(b *testing.B) {
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
var r codeResponse
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := Unmarshal(codeJSON, &r); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
}
|
|
@ -227,7 +227,7 @@ func (d *decodeState) value(v reflect.Value) {
|
||||||
// d.scan thinks we're still at the beginning of the item.
|
// d.scan thinks we're still at the beginning of the item.
|
||||||
// Feed in an empty string - the shortest, simplest value -
|
// Feed in an empty string - the shortest, simplest value -
|
||||||
// so that it knows we got to the end of the value.
|
// so that it knows we got to the end of the value.
|
||||||
if d.scan.step == stateRedo {
|
if d.scan.redo {
|
||||||
panic("redo")
|
panic("redo")
|
||||||
}
|
}
|
||||||
d.scan.step(&d.scan, '"')
|
d.scan.step(&d.scan, '"')
|
||||||
|
@ -381,6 +381,7 @@ func (d *decodeState) array(v reflect.Value) {
|
||||||
d.error(errPhase)
|
d.error(errPhase)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if i < av.Len() {
|
if i < av.Len() {
|
||||||
if !sv.IsValid() {
|
if !sv.IsValid() {
|
||||||
// Array. Zero the rest.
|
// Array. Zero the rest.
|
||||||
|
@ -392,6 +393,9 @@ func (d *decodeState) array(v reflect.Value) {
|
||||||
sv.SetLen(i)
|
sv.SetLen(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if i == 0 && av.Kind() == reflect.Slice && sv.IsNil() {
|
||||||
|
sv.Set(reflect.MakeSlice(sv.Type(), 0, 0))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// object consumes an object from d.data[d.off-1:], decoding into the value v.
|
// object consumes an object from d.data[d.off-1:], decoding into the value v.
|
||||||
|
|
|
@ -80,6 +80,9 @@ type scanner struct {
|
||||||
// on a 64-bit Mac Mini, and it's nicer to read.
|
// on a 64-bit Mac Mini, and it's nicer to read.
|
||||||
step func(*scanner, int) int
|
step func(*scanner, int) int
|
||||||
|
|
||||||
|
// Reached end of top-level value.
|
||||||
|
endTop bool
|
||||||
|
|
||||||
// Stack of what we're in the middle of - array values, object keys, object values.
|
// Stack of what we're in the middle of - array values, object keys, object values.
|
||||||
parseState []int
|
parseState []int
|
||||||
|
|
||||||
|
@ -87,6 +90,7 @@ type scanner struct {
|
||||||
err error
|
err error
|
||||||
|
|
||||||
// 1-byte redo (see undo method)
|
// 1-byte redo (see undo method)
|
||||||
|
redo bool
|
||||||
redoCode int
|
redoCode int
|
||||||
redoState func(*scanner, int) int
|
redoState func(*scanner, int) int
|
||||||
|
|
||||||
|
@ -135,6 +139,8 @@ func (s *scanner) reset() {
|
||||||
s.step = stateBeginValue
|
s.step = stateBeginValue
|
||||||
s.parseState = s.parseState[0:0]
|
s.parseState = s.parseState[0:0]
|
||||||
s.err = nil
|
s.err = nil
|
||||||
|
s.redo = false
|
||||||
|
s.endTop = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// eof tells the scanner that the end of input has been reached.
|
// eof tells the scanner that the end of input has been reached.
|
||||||
|
@ -143,11 +149,11 @@ func (s *scanner) eof() int {
|
||||||
if s.err != nil {
|
if s.err != nil {
|
||||||
return scanError
|
return scanError
|
||||||
}
|
}
|
||||||
if s.step == stateEndTop {
|
if s.endTop {
|
||||||
return scanEnd
|
return scanEnd
|
||||||
}
|
}
|
||||||
s.step(s, ' ')
|
s.step(s, ' ')
|
||||||
if s.step == stateEndTop {
|
if s.endTop {
|
||||||
return scanEnd
|
return scanEnd
|
||||||
}
|
}
|
||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
|
@ -166,8 +172,10 @@ func (s *scanner) pushParseState(p int) {
|
||||||
func (s *scanner) popParseState() {
|
func (s *scanner) popParseState() {
|
||||||
n := len(s.parseState) - 1
|
n := len(s.parseState) - 1
|
||||||
s.parseState = s.parseState[0:n]
|
s.parseState = s.parseState[0:n]
|
||||||
|
s.redo = false
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
s.step = stateEndTop
|
s.step = stateEndTop
|
||||||
|
s.endTop = true
|
||||||
} else {
|
} else {
|
||||||
s.step = stateEndValue
|
s.step = stateEndValue
|
||||||
}
|
}
|
||||||
|
@ -269,6 +277,7 @@ func stateEndValue(s *scanner, c int) int {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
// Completed top-level before the current byte.
|
// Completed top-level before the current byte.
|
||||||
s.step = stateEndTop
|
s.step = stateEndTop
|
||||||
|
s.endTop = true
|
||||||
return stateEndTop(s, c)
|
return stateEndTop(s, c)
|
||||||
}
|
}
|
||||||
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
|
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
|
||||||
|
@ -606,16 +615,18 @@ func quoteChar(c int) string {
|
||||||
// undo causes the scanner to return scanCode from the next state transition.
|
// undo causes the scanner to return scanCode from the next state transition.
|
||||||
// This gives callers a simple 1-byte undo mechanism.
|
// This gives callers a simple 1-byte undo mechanism.
|
||||||
func (s *scanner) undo(scanCode int) {
|
func (s *scanner) undo(scanCode int) {
|
||||||
if s.step == stateRedo {
|
if s.redo {
|
||||||
panic("invalid use of scanner")
|
panic("json: invalid use of scanner")
|
||||||
}
|
}
|
||||||
s.redoCode = scanCode
|
s.redoCode = scanCode
|
||||||
s.redoState = s.step
|
s.redoState = s.step
|
||||||
s.step = stateRedo
|
s.step = stateRedo
|
||||||
|
s.redo = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// stateRedo helps implement the scanner's 1-byte undo.
|
// stateRedo helps implement the scanner's 1-byte undo.
|
||||||
func stateRedo(s *scanner, c int) int {
|
func stateRedo(s *scanner, c int) int {
|
||||||
|
s.redo = false
|
||||||
s.step = s.redoState
|
s.step = s.redoState
|
||||||
return s.redoCode
|
return s.redoCode
|
||||||
}
|
}
|
||||||
|
|
|
@ -186,11 +186,12 @@ func TestNextValueBig(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var benchScan scanner
|
||||||
|
|
||||||
func BenchmarkSkipValue(b *testing.B) {
|
func BenchmarkSkipValue(b *testing.B) {
|
||||||
initBig()
|
initBig()
|
||||||
var scan scanner
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
nextValue(jsonBig, &scan)
|
nextValue(jsonBig, &benchScan)
|
||||||
}
|
}
|
||||||
b.SetBytes(int64(len(jsonBig)))
|
b.SetBytes(int64(len(jsonBig)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ package xml
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -43,17 +42,17 @@ var rawTokens = []Token{
|
||||||
CharData([]byte("World <>'\" 白鵬翔")),
|
CharData([]byte("World <>'\" 白鵬翔")),
|
||||||
EndElement{Name{"", "hello"}},
|
EndElement{Name{"", "hello"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"", "goodbye"}, nil},
|
StartElement{Name{"", "goodbye"}, []Attr{}},
|
||||||
EndElement{Name{"", "goodbye"}},
|
EndElement{Name{"", "goodbye"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
|
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"", "inner"}, nil},
|
StartElement{Name{"", "inner"}, []Attr{}},
|
||||||
EndElement{Name{"", "inner"}},
|
EndElement{Name{"", "inner"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
EndElement{Name{"", "outer"}},
|
EndElement{Name{"", "outer"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"tag", "name"}, nil},
|
StartElement{Name{"tag", "name"}, []Attr{}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
CharData([]byte("Some text here.")),
|
CharData([]byte("Some text here.")),
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
|
@ -77,17 +76,17 @@ var cookedTokens = []Token{
|
||||||
CharData([]byte("World <>'\" 白鵬翔")),
|
CharData([]byte("World <>'\" 白鵬翔")),
|
||||||
EndElement{Name{"ns2", "hello"}},
|
EndElement{Name{"ns2", "hello"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"ns2", "goodbye"}, nil},
|
StartElement{Name{"ns2", "goodbye"}, []Attr{}},
|
||||||
EndElement{Name{"ns2", "goodbye"}},
|
EndElement{Name{"ns2", "goodbye"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
|
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"ns2", "inner"}, nil},
|
StartElement{Name{"ns2", "inner"}, []Attr{}},
|
||||||
EndElement{Name{"ns2", "inner"}},
|
EndElement{Name{"ns2", "inner"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
EndElement{Name{"ns2", "outer"}},
|
EndElement{Name{"ns2", "outer"}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
StartElement{Name{"ns3", "name"}, nil},
|
StartElement{Name{"ns3", "name"}, []Attr{}},
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
CharData([]byte("Some text here.")),
|
CharData([]byte("Some text here.")),
|
||||||
CharData([]byte("\n ")),
|
CharData([]byte("\n ")),
|
||||||
|
@ -105,7 +104,7 @@ var rawTokensAltEncoding = []Token{
|
||||||
CharData([]byte("\n")),
|
CharData([]byte("\n")),
|
||||||
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
|
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
|
||||||
CharData([]byte("\n")),
|
CharData([]byte("\n")),
|
||||||
StartElement{Name{"", "tag"}, nil},
|
StartElement{Name{"", "tag"}, []Attr{}},
|
||||||
CharData([]byte("value")),
|
CharData([]byte("value")),
|
||||||
EndElement{Name{"", "tag"}},
|
EndElement{Name{"", "tag"}},
|
||||||
}
|
}
|
||||||
|
@ -205,7 +204,7 @@ func (d *downCaser) ReadByte() (c byte, err error) {
|
||||||
|
|
||||||
func (d *downCaser) Read(p []byte) (int, error) {
|
func (d *downCaser) Read(p []byte) (int, error) {
|
||||||
d.t.Fatalf("unexpected Read call on downCaser reader")
|
d.t.Fatalf("unexpected Read call on downCaser reader")
|
||||||
return 0, os.EINVAL
|
panic("unreachable")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRawTokenAltEncoding(t *testing.T) {
|
func TestRawTokenAltEncoding(t *testing.T) {
|
||||||
|
|
|
@ -105,9 +105,9 @@ func (w *Watcher) AddWatch(path string, flags uint32) error {
|
||||||
watchEntry.flags |= flags
|
watchEntry.flags |= flags
|
||||||
flags |= syscall.IN_MASK_ADD
|
flags |= syscall.IN_MASK_ADD
|
||||||
}
|
}
|
||||||
wd, errno := syscall.InotifyAddWatch(w.fd, path, flags)
|
wd, err := syscall.InotifyAddWatch(w.fd, path, flags)
|
||||||
if wd == -1 {
|
if err != nil {
|
||||||
return &os.PathError{"inotify_add_watch", path, os.Errno(errno)}
|
return &os.PathError{"inotify_add_watch", path, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
@ -139,14 +139,10 @@ func (w *Watcher) RemoveWatch(path string) error {
|
||||||
// readEvents reads from the inotify file descriptor, converts the
|
// readEvents reads from the inotify file descriptor, converts the
|
||||||
// received events into Event objects and sends them via the Event channel
|
// received events into Event objects and sends them via the Event channel
|
||||||
func (w *Watcher) readEvents() {
|
func (w *Watcher) readEvents() {
|
||||||
var (
|
var buf [syscall.SizeofInotifyEvent * 4096]byte
|
||||||
buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
|
|
||||||
n int // Number of bytes read with read()
|
|
||||||
errno int // Syscall errno
|
|
||||||
)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, errno = syscall.Read(w.fd, buf[0:])
|
n, err := syscall.Read(w.fd, buf[0:])
|
||||||
// See if there is a message on the "done" channel
|
// See if there is a message on the "done" channel
|
||||||
var done bool
|
var done bool
|
||||||
select {
|
select {
|
||||||
|
@ -156,16 +152,16 @@ func (w *Watcher) readEvents() {
|
||||||
|
|
||||||
// If EOF or a "done" message is received
|
// If EOF or a "done" message is received
|
||||||
if n == 0 || done {
|
if n == 0 || done {
|
||||||
errno := syscall.Close(w.fd)
|
err := syscall.Close(w.fd)
|
||||||
if errno == -1 {
|
if err != nil {
|
||||||
w.Error <- os.NewSyscallError("close", errno)
|
w.Error <- os.NewSyscallError("close", err)
|
||||||
}
|
}
|
||||||
close(w.Event)
|
close(w.Event)
|
||||||
close(w.Error)
|
close(w.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
w.Error <- os.NewSyscallError("read", errno)
|
w.Error <- os.NewSyscallError("read", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if n < syscall.SizeofInotifyEvent {
|
if n < syscall.SizeofInotifyEvent {
|
||||||
|
|
|
@ -14,6 +14,21 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// subsetTypeArgs takes a slice of arguments from callers of the sql
|
||||||
|
// package and converts them into a slice of the driver package's
|
||||||
|
// "subset types".
|
||||||
|
func subsetTypeArgs(args []interface{}) ([]interface{}, error) {
|
||||||
|
out := make([]interface{}, len(args))
|
||||||
|
for n, arg := range args {
|
||||||
|
var err error
|
||||||
|
out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// convertAssign copies to dest the value in src, converting it if possible.
|
// convertAssign copies to dest the value in src, converting it if possible.
|
||||||
// An error is returned if the copy would result in loss of information.
|
// An error is returned if the copy would result in loss of information.
|
||||||
// dest should be a pointer type.
|
// dest should be a pointer type.
|
||||||
|
|
|
@ -36,19 +36,22 @@ type Driver interface {
|
||||||
Open(name string) (Conn, error)
|
Open(name string) (Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execer is an optional interface that may be implemented by a Driver
|
// ErrSkip may be returned by some optional interfaces' methods to
|
||||||
// or a Conn.
|
// indicate at runtime that the fast path is unavailable and the sql
|
||||||
//
|
// package should continue as if the optional interface was not
|
||||||
// If a Driver does not implement Execer, the sql package's DB.Exec
|
// implemented. ErrSkip is only supported where explicitly
|
||||||
// method first obtains a free connection from its free pool or from
|
// documented.
|
||||||
// the driver's Open method. Execer should only be implemented by
|
var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
|
||||||
// drivers that can provide a more efficient implementation.
|
|
||||||
|
// Execer is an optional interface that may be implemented by a Conn.
|
||||||
//
|
//
|
||||||
// If a Conn does not implement Execer, the db package's DB.Exec will
|
// If a Conn does not implement Execer, the db package's DB.Exec will
|
||||||
// first prepare a query, execute the statement, and then close the
|
// first prepare a query, execute the statement, and then close the
|
||||||
// statement.
|
// statement.
|
||||||
//
|
//
|
||||||
// All arguments are of a subset type as defined in the package docs.
|
// All arguments are of a subset type as defined in the package docs.
|
||||||
|
//
|
||||||
|
// Exec may return ErrSkip.
|
||||||
type Execer interface {
|
type Execer interface {
|
||||||
Exec(query string, args []interface{}) (Result, error)
|
Exec(query string, args []interface{}) (Result, error)
|
||||||
}
|
}
|
||||||
|
@ -94,6 +97,9 @@ type Stmt interface {
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
// NumInput returns the number of placeholder parameters.
|
// NumInput returns the number of placeholder parameters.
|
||||||
|
// -1 means the driver doesn't know how to count the number of
|
||||||
|
// placeholders, so we won't sanity check input here and instead let the
|
||||||
|
// driver deal with errors.
|
||||||
NumInput() int
|
NumInput() int
|
||||||
|
|
||||||
// Exec executes a query that doesn't return rows, such
|
// Exec executes a query that doesn't return rows, such
|
||||||
|
@ -135,6 +141,8 @@ type Rows interface {
|
||||||
// The dest slice may be populated with only with values
|
// The dest slice may be populated with only with values
|
||||||
// of subset types defined above, but excluding string.
|
// of subset types defined above, but excluding string.
|
||||||
// All string values must be converted to []byte.
|
// All string values must be converted to []byte.
|
||||||
|
//
|
||||||
|
// Next should return io.EOF when there are no more rows.
|
||||||
Next(dest []interface{}) error
|
Next(dest []interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -195,6 +195,29 @@ func (c *fakeConn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkSubsetTypes(args []interface{}) error {
|
||||||
|
for n, arg := range args {
|
||||||
|
switch arg.(type) {
|
||||||
|
case int64, float64, bool, nil, []byte, string:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
|
||||||
|
// This is an optional interface, but it's implemented here
|
||||||
|
// just to check that all the args of of the proper types.
|
||||||
|
// ErrSkip is returned so the caller acts as if we didn't
|
||||||
|
// implement this at all.
|
||||||
|
err := checkSubsetTypes(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, driver.ErrSkip
|
||||||
|
}
|
||||||
|
|
||||||
func errf(msg string, args ...interface{}) error {
|
func errf(msg string, args ...interface{}) error {
|
||||||
return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
|
return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
|
||||||
}
|
}
|
||||||
|
@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
|
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
|
||||||
|
err := checkSubsetTypes(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
db := s.c.db
|
db := s.c.db
|
||||||
switch s.cmd {
|
switch s.cmd {
|
||||||
case "WIPE":
|
case "WIPE":
|
||||||
|
@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
|
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
|
||||||
|
err := checkSubsetTypes(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
db := s.c.db
|
db := s.c.db
|
||||||
if len(args) != s.placeholders {
|
if len(args) != s.placeholders {
|
||||||
panic("error in pkg db; should only get here if size is correct")
|
panic("error in pkg db; should only get here if size is correct")
|
||||||
|
|
|
@ -88,8 +88,9 @@ type DB struct {
|
||||||
driver driver.Driver
|
driver driver.Driver
|
||||||
dsn string
|
dsn string
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex // protects freeConn and closed
|
||||||
freeConn []driver.Conn
|
freeConn []driver.Conn
|
||||||
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open opens a database specified by its database driver name and a
|
// Open opens a database specified by its database driver name and a
|
||||||
|
@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) {
|
||||||
return &DB{driver: driver, dsn: dataSourceName}, nil
|
return &DB{driver: driver, dsn: dataSourceName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the database, releasing any open resources.
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
var err error
|
||||||
|
for _, c := range db.freeConn {
|
||||||
|
err1 := c.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.freeConn = nil
|
||||||
|
db.closed = true
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (db *DB) maxIdleConns() int {
|
func (db *DB) maxIdleConns() int {
|
||||||
const defaultMaxIdleConns = 2
|
const defaultMaxIdleConns = 2
|
||||||
// TODO(bradfitz): ask driver, if supported, for its default preference
|
// TODO(bradfitz): ask driver, if supported, for its default preference
|
||||||
|
@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int {
|
||||||
// conn returns a newly-opened or cached driver.Conn
|
// conn returns a newly-opened or cached driver.Conn
|
||||||
func (db *DB) conn() (driver.Conn, error) {
|
func (db *DB) conn() (driver.Conn, error) {
|
||||||
db.mu.Lock()
|
db.mu.Lock()
|
||||||
|
if db.closed {
|
||||||
|
return nil, errors.New("sql: database is closed")
|
||||||
|
}
|
||||||
if n := len(db.freeConn); n > 0 {
|
if n := len(db.freeConn); n > 0 {
|
||||||
conn := db.freeConn[n-1]
|
conn := db.freeConn[n-1]
|
||||||
db.freeConn = db.freeConn[:n-1]
|
db.freeConn = db.freeConn[:n-1]
|
||||||
|
@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) putConn(c driver.Conn) {
|
func (db *DB) putConn(c driver.Conn) {
|
||||||
if n := len(db.freeConn); n < db.maxIdleConns() {
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
|
||||||
db.freeConn = append(db.freeConn, c)
|
db.freeConn = append(db.freeConn, c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db.closeConn(c)
|
db.closeConn(c) // TODO(bradfitz): release lock before calling this?
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) closeConn(c driver.Conn) {
|
func (db *DB) closeConn(c driver.Conn) {
|
||||||
|
@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
|
||||||
|
|
||||||
// Exec executes a query without returning any rows.
|
// Exec executes a query without returning any rows.
|
||||||
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
|
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
|
||||||
// Optional fast path, if the driver implements driver.Execer.
|
sargs, err := subsetTypeArgs(args)
|
||||||
if execer, ok := db.driver.(driver.Execer); ok {
|
|
||||||
resi, err := execer.Exec(query, args)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return result{resi}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the driver does not implement driver.Execer, we need
|
|
||||||
// a connection.
|
|
||||||
ci, err := db.conn()
|
ci, err := db.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -198,19 +214,22 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
|
||||||
defer db.putConn(ci)
|
defer db.putConn(ci)
|
||||||
|
|
||||||
if execer, ok := ci.(driver.Execer); ok {
|
if execer, ok := ci.(driver.Execer); ok {
|
||||||
resi, err := execer.Exec(query, args)
|
resi, err := execer.Exec(query, sargs)
|
||||||
|
if err != driver.ErrSkip {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return result{resi}, nil
|
return result{resi}, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sti, err := ci.Prepare(query)
|
sti, err := ci.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer sti.Close()
|
defer sti.Close()
|
||||||
resi, err := sti.Exec(args)
|
|
||||||
|
resi, err := sti.Exec(sargs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer sti.Close()
|
defer sti.Close()
|
||||||
resi, err := sti.Exec(args)
|
|
||||||
|
sargs, err := subsetTypeArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resi, err := sti.Exec(sargs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -449,7 +474,10 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
||||||
}
|
}
|
||||||
defer releaseConn()
|
defer releaseConn()
|
||||||
|
|
||||||
if want := si.NumInput(); len(args) != want {
|
// -1 means the driver doesn't know how to count the number of
|
||||||
|
// placeholders, so we won't sanity check input here and instead let the
|
||||||
|
// driver deal with errors.
|
||||||
|
if want := si.NumInput(); want != -1 && len(args) != want {
|
||||||
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
|
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -545,10 +573,18 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(args) != si.NumInput() {
|
|
||||||
|
// -1 means the driver doesn't know how to count the number of
|
||||||
|
// placeholders, so we won't sanity check input here and instead let the
|
||||||
|
// driver deal with errors.
|
||||||
|
if want := si.NumInput(); want != -1 && len(args) != want {
|
||||||
return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
|
return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
|
||||||
}
|
}
|
||||||
rowsi, err := si.Query(args)
|
sargs, err := subsetTypeArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rowsi, err := si.Query(sargs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.db.putConn(ci)
|
s.db.putConn(ci)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeDB(t *testing.T, db *DB) {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error closing DB: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQuery(t *testing.T) {
|
func TestQuery(t *testing.T) {
|
||||||
db := newTestDB(t, "people")
|
db := newTestDB(t, "people")
|
||||||
|
defer closeDB(t, db)
|
||||||
var name string
|
var name string
|
||||||
var age int
|
var age int
|
||||||
|
|
||||||
|
@ -69,6 +77,7 @@ func TestQuery(t *testing.T) {
|
||||||
|
|
||||||
func TestStatementQueryRow(t *testing.T) {
|
func TestStatementQueryRow(t *testing.T) {
|
||||||
db := newTestDB(t, "people")
|
db := newTestDB(t, "people")
|
||||||
|
defer closeDB(t, db)
|
||||||
stmt, err := db.Prepare("SELECT|people|age|name=?")
|
stmt, err := db.Prepare("SELECT|people|age|name=?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Prepare: %v", err)
|
t.Fatalf("Prepare: %v", err)
|
||||||
|
@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) {
|
||||||
// just a test of fakedb itself
|
// just a test of fakedb itself
|
||||||
func TestBogusPreboundParameters(t *testing.T) {
|
func TestBogusPreboundParameters(t *testing.T) {
|
||||||
db := newTestDB(t, "foo")
|
db := newTestDB(t, "foo")
|
||||||
|
defer closeDB(t, db)
|
||||||
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
|
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
|
||||||
_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
|
_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) {
|
||||||
|
|
||||||
func TestDb(t *testing.T) {
|
func TestDb(t *testing.T) {
|
||||||
db := newTestDB(t, "foo")
|
db := newTestDB(t, "foo")
|
||||||
|
defer closeDB(t, db)
|
||||||
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
|
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
|
||||||
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
|
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rc4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// streamDump is used to dump the initial keystream for stream ciphers. It is a
|
||||||
|
// a write-only buffer, and not intended for reading so do not require a mutex.
|
||||||
|
var streamDump [512]byte
|
||||||
|
|
||||||
|
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
||||||
|
// by the transport before the first key-exchange.
|
||||||
|
type noneCipher struct{}
|
||||||
|
|
||||||
|
func (c noneCipher) XORKeyStream(dst, src []byte) {
|
||||||
|
copy(dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
|
||||||
|
c, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cipher.NewCTR(c, iv), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRC4(key, iv []byte) (cipher.Stream, error) {
|
||||||
|
return rc4.NewCipher(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cipherMode struct {
|
||||||
|
keySize int
|
||||||
|
ivSize int
|
||||||
|
skip int
|
||||||
|
createFn func(key, iv []byte) (cipher.Stream, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
|
||||||
|
if len(key) < c.keySize {
|
||||||
|
panic("ssh: key length too small for cipher")
|
||||||
|
}
|
||||||
|
if len(iv) < c.ivSize {
|
||||||
|
panic("ssh: iv too small for cipher")
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for remainingToDump := c.skip; remainingToDump > 0; {
|
||||||
|
dumpThisTime := remainingToDump
|
||||||
|
if dumpThisTime > len(streamDump) {
|
||||||
|
dumpThisTime = len(streamDump)
|
||||||
|
}
|
||||||
|
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
|
||||||
|
remainingToDump -= dumpThisTime
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specifies a default set of ciphers and a preference order. This is based on
|
||||||
|
// OpenSSH's default client preference order, minus algorithms that are not
|
||||||
|
// implemented.
|
||||||
|
var DefaultCipherOrder = []string{
|
||||||
|
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
||||||
|
"arcfour256", "arcfour128",
|
||||||
|
}
|
||||||
|
|
||||||
|
var cipherModes = map[string]*cipherMode{
|
||||||
|
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
|
||||||
|
// are defined in the order specified in the RFC.
|
||||||
|
"aes128-ctr": &cipherMode{16, aes.BlockSize, 0, newAESCTR},
|
||||||
|
"aes192-ctr": &cipherMode{24, aes.BlockSize, 0, newAESCTR},
|
||||||
|
"aes256-ctr": &cipherMode{32, aes.BlockSize, 0, newAESCTR},
|
||||||
|
|
||||||
|
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
|
||||||
|
// They are defined in the order specified in the RFC.
|
||||||
|
"arcfour128": &cipherMode{16, 0, 1536, newRC4},
|
||||||
|
"arcfour256": &cipherMode{32, 0, 1536, newRC4},
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestCipherReversal tests that each cipher factory produces ciphers that can
|
||||||
|
// encrypt and decrypt some data successfully.
|
||||||
|
func TestCipherReversal(t *testing.T) {
|
||||||
|
testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
|
||||||
|
testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
|
||||||
|
testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
|
||||||
|
|
||||||
|
cryptBuffer := make([]byte, 32)
|
||||||
|
|
||||||
|
for name, cipherMode := range cipherModes {
|
||||||
|
encrypter, err := cipherMode.createCipher(testKey, testIv)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create encrypter for %q: %s", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
decrypter, err := cipherMode.createCipher(testKey, testIv)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create decrypter for %q: %s", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(cryptBuffer, testData)
|
||||||
|
|
||||||
|
encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
|
||||||
|
if name == "none" {
|
||||||
|
if !bytes.Equal(cryptBuffer, testData) {
|
||||||
|
t.Errorf("encryption made change with 'none' cipher")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if bytes.Equal(cryptBuffer, testData) {
|
||||||
|
t.Errorf("encryption made no change with %q", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
|
||||||
|
if !bytes.Equal(cryptBuffer, testData) {
|
||||||
|
t.Errorf("decrypted bytes not equal to input with %q", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultCiphersExist(t *testing.T) {
|
||||||
|
for _, cipherAlgo := range DefaultCipherOrder {
|
||||||
|
if _, ok := cipherModes[cipherAlgo]; !ok {
|
||||||
|
t.Errorf("default cipher %q is unknown", cipherAlgo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,10 +35,6 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := conn.authenticate(); err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go conn.mainLoop()
|
go conn.mainLoop()
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
@ -64,8 +60,8 @@ func (c *ClientConn) handshake() error {
|
||||||
clientKexInit := kexInitMsg{
|
clientKexInit := kexInitMsg{
|
||||||
KexAlgos: supportedKexAlgos,
|
KexAlgos: supportedKexAlgos,
|
||||||
ServerHostKeyAlgos: supportedHostKeyAlgos,
|
ServerHostKeyAlgos: supportedHostKeyAlgos,
|
||||||
CiphersClientServer: supportedCiphers,
|
CiphersClientServer: c.config.Crypto.ciphers(),
|
||||||
CiphersServerClient: supportedCiphers,
|
CiphersServerClient: c.config.Crypto.ciphers(),
|
||||||
MACsClientServer: supportedMACs,
|
MACsClientServer: supportedMACs,
|
||||||
MACsServerClient: supportedMACs,
|
MACsServerClient: supportedMACs,
|
||||||
CompressionClientServer: supportedCompressions,
|
CompressionClientServer: supportedCompressions,
|
||||||
|
@ -128,7 +124,10 @@ func (c *ClientConn) handshake() error {
|
||||||
if packet[0] != msgNewKeys {
|
if packet[0] != msgNewKeys {
|
||||||
return UnexpectedMessageError{msgNewKeys, packet[0]}
|
return UnexpectedMessageError{msgNewKeys, packet[0]}
|
||||||
}
|
}
|
||||||
return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
|
if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.authenticate(H)
|
||||||
}
|
}
|
||||||
|
|
||||||
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
|
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
|
||||||
|
@ -195,6 +194,7 @@ func (c *ClientConn) openChan(typ string) (*clientChan, error) {
|
||||||
switch msg := (<-ch.msg).(type) {
|
switch msg := (<-ch.msg).(type) {
|
||||||
case *channelOpenConfirmMsg:
|
case *channelOpenConfirmMsg:
|
||||||
ch.peersId = msg.MyId
|
ch.peersId = msg.MyId
|
||||||
|
ch.win <- int(msg.MyWindow)
|
||||||
case *channelOpenFailureMsg:
|
case *channelOpenFailureMsg:
|
||||||
c.chanlist.remove(ch.id)
|
c.chanlist.remove(ch.id)
|
||||||
return nil, errors.New(msg.Message)
|
return nil, errors.New(msg.Message)
|
||||||
|
@ -301,6 +301,9 @@ type ClientConfig struct {
|
||||||
// A slice of ClientAuth methods. Only the first instance
|
// A slice of ClientAuth methods. Only the first instance
|
||||||
// of a particular RFC 4252 method will be used during authentication.
|
// of a particular RFC 4252 method will be used during authentication.
|
||||||
Auth []ClientAuth
|
Auth []ClientAuth
|
||||||
|
|
||||||
|
// Cryptographic-related configuration.
|
||||||
|
Crypto CryptoConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientConfig) rand() io.Reader {
|
func (c *ClientConfig) rand() io.Reader {
|
||||||
|
|
|
@ -6,10 +6,11 @@ package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
// authenticate authenticates with the remote server. See RFC 4252.
|
// authenticate authenticates with the remote server. See RFC 4252.
|
||||||
func (c *ClientConn) authenticate() error {
|
func (c *ClientConn) authenticate(session []byte) error {
|
||||||
// initiate user auth session
|
// initiate user auth session
|
||||||
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
|
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -26,7 +27,7 @@ func (c *ClientConn) authenticate() error {
|
||||||
// then any untried methods suggested by the server.
|
// then any untried methods suggested by the server.
|
||||||
tried, remain := make(map[string]bool), make(map[string]bool)
|
tried, remain := make(map[string]bool), make(map[string]bool)
|
||||||
for auth := ClientAuth(new(noneAuth)); auth != nil; {
|
for auth := ClientAuth(new(noneAuth)); auth != nil; {
|
||||||
ok, methods, err := auth.auth(c.config.User, c.transport)
|
ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -60,7 +61,7 @@ type ClientAuth interface {
|
||||||
// Returns true if authentication is successful.
|
// Returns true if authentication is successful.
|
||||||
// If authentication is not successful, a []string of alternative
|
// If authentication is not successful, a []string of alternative
|
||||||
// method names is returned.
|
// method names is returned.
|
||||||
auth(user string, t *transport) (bool, []string, error)
|
auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
|
||||||
|
|
||||||
// method returns the RFC 4252 method name.
|
// method returns the RFC 4252 method name.
|
||||||
method() string
|
method() string
|
||||||
|
@ -69,7 +70,7 @@ type ClientAuth interface {
|
||||||
// "none" authentication, RFC 4252 section 5.2.
|
// "none" authentication, RFC 4252 section 5.2.
|
||||||
type noneAuth int
|
type noneAuth int
|
||||||
|
|
||||||
func (n *noneAuth) auth(user string, t *transport) (bool, []string, error) {
|
func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
|
||||||
if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
|
if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
|
||||||
User: user,
|
User: user,
|
||||||
Service: serviceSSH,
|
Service: serviceSSH,
|
||||||
|
@ -102,7 +103,7 @@ type passwordAuth struct {
|
||||||
ClientPassword
|
ClientPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *passwordAuth) auth(user string, t *transport) (bool, []string, error) {
|
func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
|
||||||
type passwordAuthMsg struct {
|
type passwordAuthMsg struct {
|
||||||
User string
|
User string
|
||||||
Service string
|
Service string
|
||||||
|
@ -155,3 +156,140 @@ type ClientPassword interface {
|
||||||
func ClientAuthPassword(impl ClientPassword) ClientAuth {
|
func ClientAuthPassword(impl ClientPassword) ClientAuth {
|
||||||
return &passwordAuth{impl}
|
return &passwordAuth{impl}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientKeyring implements access to a client key ring.
|
||||||
|
type ClientKeyring interface {
|
||||||
|
// Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if
|
||||||
|
// no key exists at i.
|
||||||
|
Key(i int) (key interface{}, err error)
|
||||||
|
|
||||||
|
// Sign returns a signature of the given data using the i'th key
|
||||||
|
// and the supplied random source.
|
||||||
|
Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// "publickey" authentication, RFC 4252 Section 7.
|
||||||
|
type publickeyAuth struct {
|
||||||
|
ClientKeyring
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
|
||||||
|
type publickeyAuthMsg struct {
|
||||||
|
User string
|
||||||
|
Service string
|
||||||
|
Method string
|
||||||
|
// HasSig indicates to the reciver packet that the auth request is signed and
|
||||||
|
// should be used for authentication of the request.
|
||||||
|
HasSig bool
|
||||||
|
Algoname string
|
||||||
|
Pubkey string
|
||||||
|
// Sig is defined as []byte so marshal will exclude it during the query phase
|
||||||
|
Sig []byte `ssh:"rest"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authentication is performed in two stages. The first stage sends an
|
||||||
|
// enquiry to test if each key is acceptable to the remote. The second
|
||||||
|
// stage attempts to authenticate with the valid keys obtained in the
|
||||||
|
// first stage.
|
||||||
|
|
||||||
|
var index int
|
||||||
|
// a map of public keys to their index in the keyring
|
||||||
|
validKeys := make(map[int]interface{})
|
||||||
|
for {
|
||||||
|
key, err := p.Key(index)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
if key == nil {
|
||||||
|
// no more keys in the keyring
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pubkey := serializePublickey(key)
|
||||||
|
algoname := algoName(key)
|
||||||
|
msg := publickeyAuthMsg{
|
||||||
|
User: user,
|
||||||
|
Service: serviceSSH,
|
||||||
|
Method: p.method(),
|
||||||
|
HasSig: false,
|
||||||
|
Algoname: algoname,
|
||||||
|
Pubkey: string(pubkey),
|
||||||
|
}
|
||||||
|
if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
packet, err := t.readPacket()
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
switch packet[0] {
|
||||||
|
case msgUserAuthPubKeyOk:
|
||||||
|
msg := decode(packet).(*userAuthPubKeyOkMsg)
|
||||||
|
if msg.Algo != algoname || msg.PubKey != string(pubkey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
validKeys[index] = key
|
||||||
|
case msgUserAuthFailure:
|
||||||
|
default:
|
||||||
|
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
|
||||||
|
// methods that may continue if this auth is not successful.
|
||||||
|
var methods []string
|
||||||
|
for i, key := range validKeys {
|
||||||
|
pubkey := serializePublickey(key)
|
||||||
|
algoname := algoName(key)
|
||||||
|
sign, err := p.Sign(i, rand, buildDataSignedForAuth(session, userAuthRequestMsg{
|
||||||
|
User: user,
|
||||||
|
Service: serviceSSH,
|
||||||
|
Method: p.method(),
|
||||||
|
}, []byte(algoname), pubkey))
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
// manually wrap the serialized signature in a string
|
||||||
|
s := serializeSignature(algoname, sign)
|
||||||
|
sig := make([]byte, stringLength(s))
|
||||||
|
marshalString(sig, s)
|
||||||
|
msg := publickeyAuthMsg{
|
||||||
|
User: user,
|
||||||
|
Service: serviceSSH,
|
||||||
|
Method: p.method(),
|
||||||
|
HasSig: true,
|
||||||
|
Algoname: algoname,
|
||||||
|
Pubkey: string(pubkey),
|
||||||
|
Sig: sig,
|
||||||
|
}
|
||||||
|
p := marshal(msgUserAuthRequest, msg)
|
||||||
|
if err := t.writePacket(p); err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
packet, err := t.readPacket()
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
switch packet[0] {
|
||||||
|
case msgUserAuthSuccess:
|
||||||
|
return true, nil, nil
|
||||||
|
case msgUserAuthFailure:
|
||||||
|
msg := decode(packet).(*userAuthFailureMsg)
|
||||||
|
methods = msg.Methods
|
||||||
|
continue
|
||||||
|
case msgDisconnect:
|
||||||
|
return false, nil, io.EOF
|
||||||
|
default:
|
||||||
|
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, methods, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *publickeyAuth) method() string {
|
||||||
|
return "publickey"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientAuthPublickey returns a ClientAuth using public key authentication.
|
||||||
|
func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
|
||||||
|
return &publickeyAuth{impl}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,248 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const _pem = `-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
|
||||||
|
70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
|
||||||
|
9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
|
||||||
|
tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
|
||||||
|
s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
|
||||||
|
qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
|
||||||
|
+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
|
||||||
|
riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
|
||||||
|
D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
|
||||||
|
atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
|
||||||
|
b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
|
||||||
|
ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
|
||||||
|
MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
|
||||||
|
KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
|
||||||
|
e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
|
||||||
|
D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
|
||||||
|
3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
|
||||||
|
orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
|
||||||
|
64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
|
||||||
|
XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
|
||||||
|
QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
|
||||||
|
/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
|
||||||
|
I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
|
||||||
|
gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
|
||||||
|
NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
|
||||||
|
-----END RSA PRIVATE KEY-----`
|
||||||
|
|
||||||
|
// reused internally by tests
|
||||||
|
var serverConfig = new(ServerConfig)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if err := serverConfig.SetRSAPrivateKey([]byte(_pem)); err != nil {
|
||||||
|
panic("unable to set private key: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keychain implements the ClientPublickey interface
|
||||||
|
type keychain struct {
|
||||||
|
keys []*rsa.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *keychain) Key(i int) (interface{}, error) {
|
||||||
|
if i < 0 || i >= len(k.keys) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return k.keys[i].PublicKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
|
||||||
|
hashFunc := crypto.SHA1
|
||||||
|
h := hashFunc.New()
|
||||||
|
h.Write(data)
|
||||||
|
digest := h.Sum()
|
||||||
|
return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *keychain) loadPEM(file string) error {
|
||||||
|
buf, err := ioutil.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
block, _ := pem.Decode(buf)
|
||||||
|
if block == nil {
|
||||||
|
return errors.New("ssh: no key found")
|
||||||
|
}
|
||||||
|
r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
k.keys = append(k.keys, r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var pkey *rsa.PrivateKey
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
pkey, err = rsa.GenerateKey(rand.Reader, 512)
|
||||||
|
if err != nil {
|
||||||
|
panic("unable to generate public key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientAuthPublickey(t *testing.T) {
|
||||||
|
k := new(keychain)
|
||||||
|
k.keys = append(k.keys, pkey)
|
||||||
|
|
||||||
|
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
|
||||||
|
expected := []byte(serializePublickey(k.keys[0].PublicKey))
|
||||||
|
algoname := algoName(k.keys[0].PublicKey)
|
||||||
|
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
|
||||||
|
}
|
||||||
|
serverConfig.PasswordCallback = nil
|
||||||
|
|
||||||
|
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to listen: %s", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
done := make(chan bool, 1)
|
||||||
|
go func() {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
if err := c.Handshake(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := &ClientConfig{
|
||||||
|
User: "testuser",
|
||||||
|
Auth: []ClientAuth{
|
||||||
|
ClientAuthPublickey(k),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := Dial("tcp", l.Addr().String(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to dial remote side: %s", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// password implements the ClientPassword interface
|
||||||
|
type password string
|
||||||
|
|
||||||
|
func (p password) Password(user string) (string, error) {
|
||||||
|
return string(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientAuthPassword(t *testing.T) {
|
||||||
|
pw := password("tiger")
|
||||||
|
|
||||||
|
serverConfig.PasswordCallback = func(user, pass string) bool {
|
||||||
|
return user == "testuser" && pass == string(pw)
|
||||||
|
}
|
||||||
|
serverConfig.PubKeyCallback = nil
|
||||||
|
|
||||||
|
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to listen: %s", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Handshake(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := &ClientConfig{
|
||||||
|
User: "testuser",
|
||||||
|
Auth: []ClientAuth{
|
||||||
|
ClientAuthPassword(pw),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := Dial("tcp", l.Addr().String(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to dial remote side: %s", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientAuthPasswordAndPublickey(t *testing.T) {
|
||||||
|
pw := password("tiger")
|
||||||
|
|
||||||
|
serverConfig.PasswordCallback = func(user, pass string) bool {
|
||||||
|
return user == "testuser" && pass == string(pw)
|
||||||
|
}
|
||||||
|
|
||||||
|
k := new(keychain)
|
||||||
|
k.keys = append(k.keys, pkey)
|
||||||
|
|
||||||
|
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
|
||||||
|
expected := []byte(serializePublickey(k.keys[0].PublicKey))
|
||||||
|
algoname := algoName(k.keys[0].PublicKey)
|
||||||
|
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to listen: %s", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Handshake(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
wrongPw := password("wrong")
|
||||||
|
config := &ClientConfig{
|
||||||
|
User: "testuser",
|
||||||
|
Auth: []ClientAuth{
|
||||||
|
ClientAuthPassword(wrongPw),
|
||||||
|
ClientAuthPublickey(k),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := Dial("tcp", l.Addr().String(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to dial remote side: %s", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
<-done
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
// ClientConn functional tests.
|
||||||
|
// These tests require a running ssh server listening on port 22
|
||||||
|
// on the local host. Functional tests will be skipped unless
|
||||||
|
// -ssh.user and -ssh.pass must be passed to gotest.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sshuser = flag.String("ssh.user", "", "ssh username")
|
||||||
|
sshpass = flag.String("ssh.pass", "", "ssh password")
|
||||||
|
sshprivkey = flag.String("ssh.privkey", "", "ssh privkey file")
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFuncPasswordAuth(t *testing.T) {
|
||||||
|
if *sshuser == "" {
|
||||||
|
t.Log("ssh.user not defined, skipping test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config := &ClientConfig{
|
||||||
|
User: *sshuser,
|
||||||
|
Auth: []ClientAuth{
|
||||||
|
ClientAuthPassword(password(*sshpass)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := Dial("tcp", "localhost:22", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to connect: %s", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFuncPublickeyAuth(t *testing.T) {
|
||||||
|
if *sshuser == "" {
|
||||||
|
t.Log("ssh.user not defined, skipping test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
kc := new(keychain)
|
||||||
|
if err := kc.loadPEM(*sshprivkey); err != nil {
|
||||||
|
t.Fatalf("unable to load private key: %s", err)
|
||||||
|
}
|
||||||
|
config := &ClientConfig{
|
||||||
|
User: *sshuser,
|
||||||
|
Auth: []ClientAuth{
|
||||||
|
ClientAuthPublickey(kc),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := Dial("tcp", "localhost:22", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to connect: %s", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
}
|
|
@ -5,6 +5,8 @@
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/dsa"
|
||||||
|
"crypto/rsa"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -14,7 +16,6 @@ import (
|
||||||
const (
|
const (
|
||||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
|
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
|
||||||
hostAlgoRSA = "ssh-rsa"
|
hostAlgoRSA = "ssh-rsa"
|
||||||
cipherAES128CTR = "aes128-ctr"
|
|
||||||
macSHA196 = "hmac-sha1-96"
|
macSHA196 = "hmac-sha1-96"
|
||||||
compressionNone = "none"
|
compressionNone = "none"
|
||||||
serviceUserAuth = "ssh-userauth"
|
serviceUserAuth = "ssh-userauth"
|
||||||
|
@ -23,7 +24,6 @@ const (
|
||||||
|
|
||||||
var supportedKexAlgos = []string{kexAlgoDH14SHA1}
|
var supportedKexAlgos = []string{kexAlgoDH14SHA1}
|
||||||
var supportedHostKeyAlgos = []string{hostAlgoRSA}
|
var supportedHostKeyAlgos = []string{hostAlgoRSA}
|
||||||
var supportedCiphers = []string{cipherAES128CTR}
|
|
||||||
var supportedMACs = []string{macSHA196}
|
var supportedMACs = []string{macSHA196}
|
||||||
var supportedCompressions = []string{compressionNone}
|
var supportedCompressions = []string{compressionNone}
|
||||||
|
|
||||||
|
@ -127,3 +127,100 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke
|
||||||
ok = true
|
ok = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cryptographic configuration common to both ServerConfig and ClientConfig.
|
||||||
|
type CryptoConfig struct {
|
||||||
|
// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
|
||||||
|
// used.
|
||||||
|
Ciphers []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CryptoConfig) ciphers() []string {
|
||||||
|
if c.Ciphers == nil {
|
||||||
|
return DefaultCipherOrder
|
||||||
|
}
|
||||||
|
return c.Ciphers
|
||||||
|
}
|
||||||
|
|
||||||
|
// serialize a signed slice according to RFC 4254 6.6.
|
||||||
|
func serializeSignature(algoname string, sig []byte) []byte {
|
||||||
|
length := stringLength([]byte(algoname))
|
||||||
|
length += stringLength(sig)
|
||||||
|
|
||||||
|
ret := make([]byte, length)
|
||||||
|
r := marshalString(ret, []byte(algoname))
|
||||||
|
r = marshalString(r, sig)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// serialize an rsa.PublicKey or dsa.PublicKey according to RFC 4253 6.6.
|
||||||
|
func serializePublickey(key interface{}) []byte {
|
||||||
|
algoname := algoName(key)
|
||||||
|
switch key := key.(type) {
|
||||||
|
case rsa.PublicKey:
|
||||||
|
e := new(big.Int).SetInt64(int64(key.E))
|
||||||
|
length := stringLength([]byte(algoname))
|
||||||
|
length += intLength(e)
|
||||||
|
length += intLength(key.N)
|
||||||
|
ret := make([]byte, length)
|
||||||
|
r := marshalString(ret, []byte(algoname))
|
||||||
|
r = marshalInt(r, e)
|
||||||
|
marshalInt(r, key.N)
|
||||||
|
return ret
|
||||||
|
case dsa.PublicKey:
|
||||||
|
length := stringLength([]byte(algoname))
|
||||||
|
length += intLength(key.P)
|
||||||
|
length += intLength(key.Q)
|
||||||
|
length += intLength(key.G)
|
||||||
|
length += intLength(key.Y)
|
||||||
|
ret := make([]byte, length)
|
||||||
|
r := marshalString(ret, []byte(algoname))
|
||||||
|
r = marshalInt(r, key.P)
|
||||||
|
r = marshalInt(r, key.Q)
|
||||||
|
r = marshalInt(r, key.G)
|
||||||
|
marshalInt(r, key.Y)
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
panic("unexpected key type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func algoName(key interface{}) string {
|
||||||
|
switch key.(type) {
|
||||||
|
case rsa.PublicKey:
|
||||||
|
return "ssh-rsa"
|
||||||
|
case dsa.PublicKey:
|
||||||
|
return "ssh-dss"
|
||||||
|
}
|
||||||
|
panic("unexpected key type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||||
|
// posession of a private key. See RFC 4252, section 7.
|
||||||
|
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
|
||||||
|
user := []byte(req.User)
|
||||||
|
service := []byte(req.Service)
|
||||||
|
method := []byte(req.Method)
|
||||||
|
|
||||||
|
length := stringLength(sessionId)
|
||||||
|
length += 1
|
||||||
|
length += stringLength(user)
|
||||||
|
length += stringLength(service)
|
||||||
|
length += stringLength(method)
|
||||||
|
length += 1
|
||||||
|
length += stringLength(algo)
|
||||||
|
length += stringLength(pubKey)
|
||||||
|
|
||||||
|
ret := make([]byte, length)
|
||||||
|
r := marshalString(ret, sessionId)
|
||||||
|
r[0] = msgUserAuthRequest
|
||||||
|
r = r[1:]
|
||||||
|
r = marshalString(r, user)
|
||||||
|
r = marshalString(r, service)
|
||||||
|
r = marshalString(r, method)
|
||||||
|
r[0] = 1
|
||||||
|
r = r[1:]
|
||||||
|
r = marshalString(r, algo)
|
||||||
|
r = marshalString(r, pubKey)
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
|
@ -392,7 +392,10 @@ func parseString(in []byte) (out, rest []byte, ok bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var comma = []byte{','}
|
var (
|
||||||
|
comma = []byte{','}
|
||||||
|
emptyNameList = []string{}
|
||||||
|
)
|
||||||
|
|
||||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
||||||
contents, rest, ok := parseString(in)
|
contents, rest, ok := parseString(in)
|
||||||
|
@ -400,6 +403,7 @@ func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(contents) == 0 {
|
if len(contents) == 0 {
|
||||||
|
out = emptyNameList
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
parts := bytes.Split(contents, comma)
|
parts := bytes.Split(contents, comma)
|
||||||
|
@ -444,8 +448,6 @@ func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxPacketSize = 36000
|
|
||||||
|
|
||||||
func nameListLength(namelist []string) int {
|
func nameListLength(namelist []string) int {
|
||||||
length := 4 /* uint32 length prefix */
|
length := 4 /* uint32 length prefix */
|
||||||
for i, name := range namelist {
|
for i, name := range namelist {
|
||||||
|
|
|
@ -40,6 +40,9 @@ type ServerConfig struct {
|
||||||
// key authentication. It must return true iff the given public key is
|
// key authentication. It must return true iff the given public key is
|
||||||
// valid for the given user.
|
// valid for the given user.
|
||||||
PubKeyCallback func(user, algo string, pubkey []byte) bool
|
PubKeyCallback func(user, algo string, pubkey []byte) bool
|
||||||
|
|
||||||
|
// Cryptographic-related configuration.
|
||||||
|
Crypto CryptoConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ServerConfig) rand() io.Reader {
|
func (c *ServerConfig) rand() io.Reader {
|
||||||
|
@ -221,7 +224,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
||||||
return nil, nil, errors.New("internal error")
|
return nil, nil, errors.New("internal error")
|
||||||
}
|
}
|
||||||
|
|
||||||
serializedSig := serializeRSASignature(sig)
|
serializedSig := serializeSignature(hostAlgoRSA, sig)
|
||||||
|
|
||||||
kexDHReply := kexDHReplyMsg{
|
kexDHReply := kexDHReplyMsg{
|
||||||
HostKey: serializedHostKey,
|
HostKey: serializedHostKey,
|
||||||
|
@ -234,50 +237,9 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func serializeRSASignature(sig []byte) []byte {
|
|
||||||
length := stringLength([]byte(hostAlgoRSA))
|
|
||||||
length += stringLength(sig)
|
|
||||||
|
|
||||||
ret := make([]byte, length)
|
|
||||||
r := marshalString(ret, []byte(hostAlgoRSA))
|
|
||||||
r = marshalString(r, sig)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverVersion is the fixed identification string that Server will use.
|
// serverVersion is the fixed identification string that Server will use.
|
||||||
var serverVersion = []byte("SSH-2.0-Go\r\n")
|
var serverVersion = []byte("SSH-2.0-Go\r\n")
|
||||||
|
|
||||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
|
||||||
// posession of a private key. See RFC 4252, section 7.
|
|
||||||
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
|
|
||||||
user := []byte(req.User)
|
|
||||||
service := []byte(req.Service)
|
|
||||||
method := []byte(req.Method)
|
|
||||||
|
|
||||||
length := stringLength(sessionId)
|
|
||||||
length += 1
|
|
||||||
length += stringLength(user)
|
|
||||||
length += stringLength(service)
|
|
||||||
length += stringLength(method)
|
|
||||||
length += 1
|
|
||||||
length += stringLength(algo)
|
|
||||||
length += stringLength(pubKey)
|
|
||||||
|
|
||||||
ret := make([]byte, length)
|
|
||||||
r := marshalString(ret, sessionId)
|
|
||||||
r[0] = msgUserAuthRequest
|
|
||||||
r = r[1:]
|
|
||||||
r = marshalString(r, user)
|
|
||||||
r = marshalString(r, service)
|
|
||||||
r = marshalString(r, method)
|
|
||||||
r[0] = 1
|
|
||||||
r = r[1:]
|
|
||||||
r = marshalString(r, algo)
|
|
||||||
r = marshalString(r, pubKey)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handshake performs an SSH transport and client authentication on the given ServerConn.
|
// Handshake performs an SSH transport and client authentication on the given ServerConn.
|
||||||
func (s *ServerConn) Handshake() error {
|
func (s *ServerConn) Handshake() error {
|
||||||
var magics handshakeMagics
|
var magics handshakeMagics
|
||||||
|
@ -298,8 +260,8 @@ func (s *ServerConn) Handshake() error {
|
||||||
serverKexInit := kexInitMsg{
|
serverKexInit := kexInitMsg{
|
||||||
KexAlgos: supportedKexAlgos,
|
KexAlgos: supportedKexAlgos,
|
||||||
ServerHostKeyAlgos: supportedHostKeyAlgos,
|
ServerHostKeyAlgos: supportedHostKeyAlgos,
|
||||||
CiphersClientServer: supportedCiphers,
|
CiphersClientServer: s.config.Crypto.ciphers(),
|
||||||
CiphersServerClient: supportedCiphers,
|
CiphersServerClient: s.config.Crypto.ciphers(),
|
||||||
MACsClientServer: supportedMACs,
|
MACsClientServer: supportedMACs,
|
||||||
MACsServerClient: supportedMACs,
|
MACsServerClient: supportedMACs,
|
||||||
CompressionClientServer: supportedCompressions,
|
CompressionClientServer: supportedCompressions,
|
||||||
|
@ -364,7 +326,9 @@ func (s *ServerConn) Handshake() error {
|
||||||
if packet[0] != msgNewKeys {
|
if packet[0] != msgNewKeys {
|
||||||
return UnexpectedMessageError{msgNewKeys, packet[0]}
|
return UnexpectedMessageError{msgNewKeys, packet[0]}
|
||||||
}
|
}
|
||||||
s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
|
if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if packet, err = s.readPacket(); err != nil {
|
if packet, err = s.readPacket(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
// Dial initiates a connection to the addr from the remote host.
|
||||||
|
// addr is resolved using net.ResolveTCPAddr before connection.
|
||||||
|
// This could allow an observer to observe the DNS name of the
|
||||||
|
// remote host. Consider using ssh.DialTCP to avoid this.
|
||||||
|
func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
|
||||||
|
raddr, err := net.ResolveTCPAddr(n, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.DialTCP(n, nil, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialTCP connects to the remote address raddr on the network net,
|
||||||
|
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
|
||||||
|
// as the local address for the connection.
|
||||||
|
func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
|
||||||
|
if laddr == nil {
|
||||||
|
laddr = &net.TCPAddr{
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tcpchanconn{
|
||||||
|
tcpchan: ch,
|
||||||
|
laddr: laddr,
|
||||||
|
raddr: raddr,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
|
||||||
|
// strings and are expected to be resolveable at the remote end.
|
||||||
|
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
|
||||||
|
// RFC 4254 7.2
|
||||||
|
type channelOpenDirectMsg struct {
|
||||||
|
ChanType string
|
||||||
|
PeersId uint32
|
||||||
|
PeersWindow uint32
|
||||||
|
MaxPacketSize uint32
|
||||||
|
raddr string
|
||||||
|
rport uint32
|
||||||
|
laddr string
|
||||||
|
lport uint32
|
||||||
|
}
|
||||||
|
ch := c.newChan(c.transport)
|
||||||
|
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
|
||||||
|
ChanType: "direct-tcpip",
|
||||||
|
PeersId: ch.id,
|
||||||
|
PeersWindow: 1 << 14,
|
||||||
|
MaxPacketSize: 1 << 15, // RFC 4253 6.1
|
||||||
|
raddr: raddr,
|
||||||
|
rport: uint32(rport),
|
||||||
|
laddr: laddr,
|
||||||
|
lport: uint32(lport),
|
||||||
|
})); err != nil {
|
||||||
|
c.chanlist.remove(ch.id)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// wait for response
|
||||||
|
switch msg := (<-ch.msg).(type) {
|
||||||
|
case *channelOpenConfirmMsg:
|
||||||
|
ch.peersId = msg.MyId
|
||||||
|
ch.win <- int(msg.MyWindow)
|
||||||
|
case *channelOpenFailureMsg:
|
||||||
|
c.chanlist.remove(ch.id)
|
||||||
|
return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
|
||||||
|
default:
|
||||||
|
c.chanlist.remove(ch.id)
|
||||||
|
return nil, errors.New("ssh: unexpected packet")
|
||||||
|
}
|
||||||
|
return &tcpchan{
|
||||||
|
clientChan: ch,
|
||||||
|
Reader: &chanReader{
|
||||||
|
packetWriter: ch,
|
||||||
|
id: ch.id,
|
||||||
|
data: ch.data,
|
||||||
|
},
|
||||||
|
Writer: &chanWriter{
|
||||||
|
packetWriter: ch,
|
||||||
|
id: ch.id,
|
||||||
|
win: ch.win,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpchan struct {
|
||||||
|
*clientChan // the backing channel
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpchanconn fulfills the net.Conn interface without
|
||||||
|
// the tcpchan having to hold laddr or raddr directly.
|
||||||
|
type tcpchanconn struct {
|
||||||
|
*tcpchan
|
||||||
|
laddr, raddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local network address.
|
||||||
|
func (t *tcpchanconn) LocalAddr() net.Addr {
|
||||||
|
return t.laddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote network address.
|
||||||
|
func (t *tcpchanconn) RemoteAddr() net.Addr {
|
||||||
|
return t.raddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout sets the read and write deadlines associated
|
||||||
|
// with the connection.
|
||||||
|
func (t *tcpchanconn) SetTimeout(nsec int64) error {
|
||||||
|
if err := t.SetReadTimeout(nsec); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return t.SetWriteTimeout(nsec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadTimeout sets the time (in nanoseconds) that
|
||||||
|
// Read will wait for data before returning an error with Timeout() == true.
|
||||||
|
// Setting nsec == 0 (the default) disables the deadline.
|
||||||
|
func (t *tcpchanconn) SetReadTimeout(nsec int64) error {
|
||||||
|
return errors.New("ssh: tcpchan: timeout not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteTimeout sets the time (in nanoseconds) that
|
||||||
|
// Write will wait to send its data before returning an error with Timeout() == true.
|
||||||
|
// Setting nsec == 0 (the default) disables the deadline.
|
||||||
|
// Even if write times out, it may return n > 0, indicating that
|
||||||
|
// some of the data was successfully written.
|
||||||
|
func (t *tcpchanconn) SetWriteTimeout(nsec int64) error {
|
||||||
|
return errors.New("ssh: tcpchan: timeout not supported")
|
||||||
|
}
|
|
@ -7,7 +7,6 @@ package ssh
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
@ -19,7 +18,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
paddingMultiple = 16 // TODO(dfc) does this need to be configurable?
|
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
||||||
|
minPacketSize = 16
|
||||||
|
maxPacketSize = 36000
|
||||||
|
minPaddingSize = 4 // TODO(huin) should this be configurable?
|
||||||
)
|
)
|
||||||
|
|
||||||
// filteredConn reduces the set of methods exposed when embeddeding
|
// filteredConn reduces the set of methods exposed when embeddeding
|
||||||
|
@ -61,7 +63,6 @@ type reader struct {
|
||||||
type writer struct {
|
type writer struct {
|
||||||
*sync.Mutex // protects writer.Writer from concurrent writes
|
*sync.Mutex // protects writer.Writer from concurrent writes
|
||||||
*bufio.Writer
|
*bufio.Writer
|
||||||
paddingMultiple int
|
|
||||||
rand io.Reader
|
rand io.Reader
|
||||||
common
|
common
|
||||||
}
|
}
|
||||||
|
@ -82,14 +83,11 @@ type common struct {
|
||||||
func (r *reader) readOnePacket() ([]byte, error) {
|
func (r *reader) readOnePacket() ([]byte, error) {
|
||||||
var lengthBytes = make([]byte, 5)
|
var lengthBytes = make([]byte, 5)
|
||||||
var macSize uint32
|
var macSize uint32
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, lengthBytes); err != nil {
|
if _, err := io.ReadFull(r, lengthBytes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.cipher != nil {
|
|
||||||
r.cipher.XORKeyStream(lengthBytes, lengthBytes)
|
r.cipher.XORKeyStream(lengthBytes, lengthBytes)
|
||||||
}
|
|
||||||
|
|
||||||
if r.mac != nil {
|
if r.mac != nil {
|
||||||
r.mac.Reset()
|
r.mac.Reset()
|
||||||
|
@ -153,9 +151,9 @@ func (w *writer) writePacket(packet []byte) error {
|
||||||
w.Mutex.Lock()
|
w.Mutex.Lock()
|
||||||
defer w.Mutex.Unlock()
|
defer w.Mutex.Unlock()
|
||||||
|
|
||||||
paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
|
paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
|
||||||
if paddingLength < 4 {
|
if paddingLength < 4 {
|
||||||
paddingLength += paddingMultiple
|
paddingLength += packetSizeMultiple
|
||||||
}
|
}
|
||||||
|
|
||||||
length := len(packet) + 1 + paddingLength
|
length := len(packet) + 1 + paddingLength
|
||||||
|
@ -188,11 +186,9 @@ func (w *writer) writePacket(packet []byte) error {
|
||||||
|
|
||||||
// TODO(dfc) lengthBytes, packet and padding should be
|
// TODO(dfc) lengthBytes, packet and padding should be
|
||||||
// subslices of a single buffer
|
// subslices of a single buffer
|
||||||
if w.cipher != nil {
|
|
||||||
w.cipher.XORKeyStream(lengthBytes, lengthBytes)
|
w.cipher.XORKeyStream(lengthBytes, lengthBytes)
|
||||||
w.cipher.XORKeyStream(packet, packet)
|
w.cipher.XORKeyStream(packet, packet)
|
||||||
w.cipher.XORKeyStream(padding, padding)
|
w.cipher.XORKeyStream(padding, padding)
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.Write(lengthBytes); err != nil {
|
if _, err := w.Write(lengthBytes); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -227,11 +223,17 @@ func newTransport(conn net.Conn, rand io.Reader) *transport {
|
||||||
return &transport{
|
return &transport{
|
||||||
reader: reader{
|
reader: reader{
|
||||||
Reader: bufio.NewReader(conn),
|
Reader: bufio.NewReader(conn),
|
||||||
|
common: common{
|
||||||
|
cipher: noneCipher{},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
writer: writer{
|
writer: writer{
|
||||||
Writer: bufio.NewWriter(conn),
|
Writer: bufio.NewWriter(conn),
|
||||||
rand: rand,
|
rand: rand,
|
||||||
Mutex: new(sync.Mutex),
|
Mutex: new(sync.Mutex),
|
||||||
|
common: common{
|
||||||
|
cipher: noneCipher{},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
filteredConn: conn,
|
filteredConn: conn,
|
||||||
}
|
}
|
||||||
|
@ -249,29 +251,32 @@ var (
|
||||||
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
|
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupKeys sets the cipher and MAC keys from K, H and sessionId, as
|
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
|
||||||
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
||||||
// (to setup server->client keys) or clientKeys (for client->server keys).
|
// (to setup server->client keys) or clientKeys (for client->server keys).
|
||||||
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
|
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
|
||||||
h := hashFunc.New()
|
cipherMode := cipherModes[c.cipherAlgo]
|
||||||
|
|
||||||
blockSize := 16
|
|
||||||
keySize := 16
|
|
||||||
macKeySize := 20
|
macKeySize := 20
|
||||||
|
|
||||||
iv := make([]byte, blockSize)
|
iv := make([]byte, cipherMode.ivSize)
|
||||||
key := make([]byte, keySize)
|
key := make([]byte, cipherMode.keySize)
|
||||||
macKey := make([]byte, macKeySize)
|
macKey := make([]byte, macKeySize)
|
||||||
|
|
||||||
|
h := hashFunc.New()
|
||||||
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
|
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
|
||||||
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
|
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
|
||||||
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
|
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
|
||||||
|
|
||||||
c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
|
c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
|
||||||
aes, err := aes.NewCipher(key)
|
|
||||||
|
cipher, err := cipherMode.createCipher(key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.cipher = cipher.NewCTR(aes, iv)
|
|
||||||
|
c.cipher = cipher
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,356 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import "io"
|
|
||||||
|
|
||||||
// Shell contains the state for running a VT100 terminal that is capable of
|
|
||||||
// reading lines of input.
|
|
||||||
type Shell struct {
|
|
||||||
c io.ReadWriter
|
|
||||||
prompt string
|
|
||||||
|
|
||||||
// line is the current line being entered.
|
|
||||||
line []byte
|
|
||||||
// pos is the logical position of the cursor in line
|
|
||||||
pos int
|
|
||||||
|
|
||||||
// cursorX contains the current X value of the cursor where the left
|
|
||||||
// edge is 0. cursorY contains the row number where the first row of
|
|
||||||
// the current line is 0.
|
|
||||||
cursorX, cursorY int
|
|
||||||
// maxLine is the greatest value of cursorY so far.
|
|
||||||
maxLine int
|
|
||||||
|
|
||||||
termWidth, termHeight int
|
|
||||||
|
|
||||||
// outBuf contains the terminal data to be sent.
|
|
||||||
outBuf []byte
|
|
||||||
// remainder contains the remainder of any partial key sequences after
|
|
||||||
// a read. It aliases into inBuf.
|
|
||||||
remainder []byte
|
|
||||||
inBuf [256]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewShell runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
|
|
||||||
// a local terminal, that terminal must first have been put into raw mode.
|
|
||||||
// prompt is a string that is written at the start of each input line (i.e.
|
|
||||||
// "> ").
|
|
||||||
func NewShell(c io.ReadWriter, prompt string) *Shell {
|
|
||||||
return &Shell{
|
|
||||||
c: c,
|
|
||||||
prompt: prompt,
|
|
||||||
termWidth: 80,
|
|
||||||
termHeight: 24,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
keyCtrlD = 4
|
|
||||||
keyEnter = '\r'
|
|
||||||
keyEscape = 27
|
|
||||||
keyBackspace = 127
|
|
||||||
keyUnknown = 256 + iota
|
|
||||||
keyUp
|
|
||||||
keyDown
|
|
||||||
keyLeft
|
|
||||||
keyRight
|
|
||||||
keyAltLeft
|
|
||||||
keyAltRight
|
|
||||||
)
|
|
||||||
|
|
||||||
// bytesToKey tries to parse a key sequence from b. If successful, it returns
|
|
||||||
// the key and the remainder of the input. Otherwise it returns -1.
|
|
||||||
func bytesToKey(b []byte) (int, []byte) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return -1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if b[0] != keyEscape {
|
|
||||||
return int(b[0]), b[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
|
|
||||||
switch b[2] {
|
|
||||||
case 'A':
|
|
||||||
return keyUp, b[3:]
|
|
||||||
case 'B':
|
|
||||||
return keyDown, b[3:]
|
|
||||||
case 'C':
|
|
||||||
return keyRight, b[3:]
|
|
||||||
case 'D':
|
|
||||||
return keyLeft, b[3:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
|
|
||||||
switch b[5] {
|
|
||||||
case 'C':
|
|
||||||
return keyAltRight, b[6:]
|
|
||||||
case 'D':
|
|
||||||
return keyAltLeft, b[6:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we get here then we have a key that we don't recognise, or a
|
|
||||||
// partial sequence. It's not clear how one should find the end of a
|
|
||||||
// sequence without knowing them all, but it seems that [a-zA-Z] only
|
|
||||||
// appears at the end of a sequence.
|
|
||||||
for i, c := range b[0:] {
|
|
||||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
|
|
||||||
return keyUnknown, b[i+1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1, b
|
|
||||||
}
|
|
||||||
|
|
||||||
// queue appends data to the end of ss.outBuf
|
|
||||||
func (ss *Shell) queue(data []byte) {
|
|
||||||
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
|
|
||||||
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
|
|
||||||
copy(newOutBuf, ss.outBuf)
|
|
||||||
ss.outBuf = newOutBuf
|
|
||||||
}
|
|
||||||
|
|
||||||
oldLen := len(ss.outBuf)
|
|
||||||
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
|
|
||||||
copy(ss.outBuf[oldLen:], data)
|
|
||||||
}
|
|
||||||
|
|
||||||
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
|
|
||||||
|
|
||||||
func isPrintable(key int) bool {
|
|
||||||
return key >= 32 && key < 127
|
|
||||||
}
|
|
||||||
|
|
||||||
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
|
|
||||||
// given, logical position in the text.
|
|
||||||
func (ss *Shell) moveCursorToPos(pos int) {
|
|
||||||
x := len(ss.prompt) + pos
|
|
||||||
y := x / ss.termWidth
|
|
||||||
x = x % ss.termWidth
|
|
||||||
|
|
||||||
up := 0
|
|
||||||
if y < ss.cursorY {
|
|
||||||
up = ss.cursorY - y
|
|
||||||
}
|
|
||||||
|
|
||||||
down := 0
|
|
||||||
if y > ss.cursorY {
|
|
||||||
down = y - ss.cursorY
|
|
||||||
}
|
|
||||||
|
|
||||||
left := 0
|
|
||||||
if x < ss.cursorX {
|
|
||||||
left = ss.cursorX - x
|
|
||||||
}
|
|
||||||
|
|
||||||
right := 0
|
|
||||||
if x > ss.cursorX {
|
|
||||||
right = x - ss.cursorX
|
|
||||||
}
|
|
||||||
|
|
||||||
movement := make([]byte, 3*(up+down+left+right))
|
|
||||||
m := movement
|
|
||||||
for i := 0; i < up; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'A'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
for i := 0; i < down; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'B'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
for i := 0; i < left; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'D'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
for i := 0; i < right; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'C'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
|
|
||||||
ss.cursorX = x
|
|
||||||
ss.cursorY = y
|
|
||||||
ss.queue(movement)
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxLineLength = 4096
|
|
||||||
|
|
||||||
// handleKey processes the given key and, optionally, returns a line of text
|
|
||||||
// that the user has entered.
|
|
||||||
func (ss *Shell) handleKey(key int) (line string, ok bool) {
|
|
||||||
switch key {
|
|
||||||
case keyBackspace:
|
|
||||||
if ss.pos == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ss.pos--
|
|
||||||
|
|
||||||
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
|
|
||||||
ss.line = ss.line[:len(ss.line)-1]
|
|
||||||
ss.writeLine(ss.line[ss.pos:])
|
|
||||||
ss.moveCursorToPos(ss.pos)
|
|
||||||
ss.queue(eraseUnderCursor)
|
|
||||||
case keyAltLeft:
|
|
||||||
// move left by a word.
|
|
||||||
if ss.pos == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ss.pos--
|
|
||||||
for ss.pos > 0 {
|
|
||||||
if ss.line[ss.pos] != ' ' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
ss.pos--
|
|
||||||
}
|
|
||||||
for ss.pos > 0 {
|
|
||||||
if ss.line[ss.pos] == ' ' {
|
|
||||||
ss.pos++
|
|
||||||
break
|
|
||||||
}
|
|
||||||
ss.pos--
|
|
||||||
}
|
|
||||||
ss.moveCursorToPos(ss.pos)
|
|
||||||
case keyAltRight:
|
|
||||||
// move right by a word.
|
|
||||||
for ss.pos < len(ss.line) {
|
|
||||||
if ss.line[ss.pos] == ' ' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
ss.pos++
|
|
||||||
}
|
|
||||||
for ss.pos < len(ss.line) {
|
|
||||||
if ss.line[ss.pos] != ' ' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
ss.pos++
|
|
||||||
}
|
|
||||||
ss.moveCursorToPos(ss.pos)
|
|
||||||
case keyLeft:
|
|
||||||
if ss.pos == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ss.pos--
|
|
||||||
ss.moveCursorToPos(ss.pos)
|
|
||||||
case keyRight:
|
|
||||||
if ss.pos == len(ss.line) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ss.pos++
|
|
||||||
ss.moveCursorToPos(ss.pos)
|
|
||||||
case keyEnter:
|
|
||||||
ss.moveCursorToPos(len(ss.line))
|
|
||||||
ss.queue([]byte("\r\n"))
|
|
||||||
line = string(ss.line)
|
|
||||||
ok = true
|
|
||||||
ss.line = ss.line[:0]
|
|
||||||
ss.pos = 0
|
|
||||||
ss.cursorX = 0
|
|
||||||
ss.cursorY = 0
|
|
||||||
ss.maxLine = 0
|
|
||||||
default:
|
|
||||||
if !isPrintable(key) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(ss.line) == maxLineLength {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(ss.line) == cap(ss.line) {
|
|
||||||
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
|
|
||||||
copy(newLine, ss.line)
|
|
||||||
ss.line = newLine
|
|
||||||
}
|
|
||||||
ss.line = ss.line[:len(ss.line)+1]
|
|
||||||
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
|
|
||||||
ss.line[ss.pos] = byte(key)
|
|
||||||
ss.writeLine(ss.line[ss.pos:])
|
|
||||||
ss.pos++
|
|
||||||
ss.moveCursorToPos(ss.pos)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Shell) writeLine(line []byte) {
|
|
||||||
for len(line) != 0 {
|
|
||||||
if ss.cursorX == ss.termWidth {
|
|
||||||
ss.queue([]byte("\r\n"))
|
|
||||||
ss.cursorX = 0
|
|
||||||
ss.cursorY++
|
|
||||||
if ss.cursorY > ss.maxLine {
|
|
||||||
ss.maxLine = ss.cursorY
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
remainingOnLine := ss.termWidth - ss.cursorX
|
|
||||||
todo := len(line)
|
|
||||||
if todo > remainingOnLine {
|
|
||||||
todo = remainingOnLine
|
|
||||||
}
|
|
||||||
ss.queue(line[:todo])
|
|
||||||
ss.cursorX += todo
|
|
||||||
line = line[todo:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Shell) Write(buf []byte) (n int, err error) {
|
|
||||||
return ss.c.Write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadLine returns a line of input from the terminal.
|
|
||||||
func (ss *Shell) ReadLine() (line string, err error) {
|
|
||||||
ss.writeLine([]byte(ss.prompt))
|
|
||||||
ss.c.Write(ss.outBuf)
|
|
||||||
ss.outBuf = ss.outBuf[:0]
|
|
||||||
|
|
||||||
for {
|
|
||||||
// ss.remainder is a slice at the beginning of ss.inBuf
|
|
||||||
// containing a partial key sequence
|
|
||||||
readBuf := ss.inBuf[len(ss.remainder):]
|
|
||||||
var n int
|
|
||||||
n, err = ss.c.Read(readBuf)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
|
|
||||||
rest := ss.remainder
|
|
||||||
lineOk := false
|
|
||||||
for !lineOk {
|
|
||||||
var key int
|
|
||||||
key, rest = bytesToKey(rest)
|
|
||||||
if key < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if key == keyCtrlD {
|
|
||||||
return "", io.EOF
|
|
||||||
}
|
|
||||||
line, lineOk = ss.handleKey(key)
|
|
||||||
}
|
|
||||||
if len(rest) > 0 {
|
|
||||||
n := copy(ss.inBuf[:], rest)
|
|
||||||
ss.remainder = ss.inBuf[:n]
|
|
||||||
} else {
|
|
||||||
ss.remainder = nil
|
|
||||||
}
|
|
||||||
ss.c.Write(ss.outBuf)
|
|
||||||
ss.outBuf = ss.outBuf[:0]
|
|
||||||
if lineOk {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
|
@ -2,102 +2,361 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
// Package terminal provides support functions for dealing with terminals, as
|
|
||||||
// commonly found on UNIX systems.
|
|
||||||
//
|
|
||||||
// Putting a terminal into raw mode is the most common requirement:
|
|
||||||
//
|
|
||||||
// oldState, err := terminal.MakeRaw(0)
|
|
||||||
// if err != nil {
|
|
||||||
// panic(err.String())
|
|
||||||
// }
|
|
||||||
// defer terminal.Restore(0, oldState)
|
|
||||||
package terminal
|
package terminal
|
||||||
|
|
||||||
import (
|
import "io"
|
||||||
"io"
|
|
||||||
"os"
|
// Terminal contains the state for running a VT100 terminal that is capable of
|
||||||
"syscall"
|
// reading lines of input.
|
||||||
|
type Terminal struct {
|
||||||
|
c io.ReadWriter
|
||||||
|
prompt string
|
||||||
|
|
||||||
|
// line is the current line being entered.
|
||||||
|
line []byte
|
||||||
|
// pos is the logical position of the cursor in line
|
||||||
|
pos int
|
||||||
|
|
||||||
|
// cursorX contains the current X value of the cursor where the left
|
||||||
|
// edge is 0. cursorY contains the row number where the first row of
|
||||||
|
// the current line is 0.
|
||||||
|
cursorX, cursorY int
|
||||||
|
// maxLine is the greatest value of cursorY so far.
|
||||||
|
maxLine int
|
||||||
|
|
||||||
|
termWidth, termHeight int
|
||||||
|
|
||||||
|
// outBuf contains the terminal data to be sent.
|
||||||
|
outBuf []byte
|
||||||
|
// remainder contains the remainder of any partial key sequences after
|
||||||
|
// a read. It aliases into inBuf.
|
||||||
|
remainder []byte
|
||||||
|
inBuf [256]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
|
||||||
|
// a local terminal, that terminal must first have been put into raw mode.
|
||||||
|
// prompt is a string that is written at the start of each input line (i.e.
|
||||||
|
// "> ").
|
||||||
|
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
|
||||||
|
return &Terminal{
|
||||||
|
c: c,
|
||||||
|
prompt: prompt,
|
||||||
|
termWidth: 80,
|
||||||
|
termHeight: 24,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
keyCtrlD = 4
|
||||||
|
keyEnter = '\r'
|
||||||
|
keyEscape = 27
|
||||||
|
keyBackspace = 127
|
||||||
|
keyUnknown = 256 + iota
|
||||||
|
keyUp
|
||||||
|
keyDown
|
||||||
|
keyLeft
|
||||||
|
keyRight
|
||||||
|
keyAltLeft
|
||||||
|
keyAltRight
|
||||||
)
|
)
|
||||||
|
|
||||||
// State contains the state of a terminal.
|
// bytesToKey tries to parse a key sequence from b. If successful, it returns
|
||||||
type State struct {
|
// the key and the remainder of the input. Otherwise it returns -1.
|
||||||
termios syscall.Termios
|
func bytesToKey(b []byte) (int, []byte) {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return -1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
if b[0] != keyEscape {
|
||||||
func IsTerminal(fd int) bool {
|
return int(b[0]), b[1:]
|
||||||
var termios syscall.Termios
|
|
||||||
e := syscall.Tcgetattr(fd, &termios)
|
|
||||||
return e == 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeRaw put the terminal connected to the given file descriptor into raw
|
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
|
||||||
// mode and returns the previous state of the terminal so that it can be
|
switch b[2] {
|
||||||
// restored.
|
case 'A':
|
||||||
func MakeRaw(fd int) (*State, error) {
|
return keyUp, b[3:]
|
||||||
var oldState State
|
case 'B':
|
||||||
if e := syscall.Tcgetattr(fd, &oldState.termios); e != 0 {
|
return keyDown, b[3:]
|
||||||
return nil, os.Errno(e)
|
case 'C':
|
||||||
|
return keyRight, b[3:]
|
||||||
|
case 'D':
|
||||||
|
return keyLeft, b[3:]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newState := oldState.termios
|
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
|
||||||
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
|
switch b[5] {
|
||||||
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
|
case 'C':
|
||||||
if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 {
|
return keyAltRight, b[6:]
|
||||||
return nil, os.Errno(e)
|
case 'D':
|
||||||
|
return keyAltLeft, b[6:]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &oldState, nil
|
// If we get here then we have a key that we don't recognise, or a
|
||||||
|
// partial sequence. It's not clear how one should find the end of a
|
||||||
|
// sequence without knowing them all, but it seems that [a-zA-Z] only
|
||||||
|
// appears at the end of a sequence.
|
||||||
|
for i, c := range b[0:] {
|
||||||
|
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
|
||||||
|
return keyUnknown, b[i+1:]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore restores the terminal connected to the given file descriptor to a
|
return -1, b
|
||||||
// previous state.
|
|
||||||
func Restore(fd int, state *State) error {
|
|
||||||
e := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
|
|
||||||
return os.Errno(e)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
// queue appends data to the end of t.outBuf
|
||||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
func (t *Terminal) queue(data []byte) {
|
||||||
// returned does not include the \n.
|
if len(t.outBuf)+len(data) > cap(t.outBuf) {
|
||||||
func ReadPassword(fd int) ([]byte, error) {
|
newOutBuf := make([]byte, len(t.outBuf), 2*(len(t.outBuf)+len(data)))
|
||||||
var oldState syscall.Termios
|
copy(newOutBuf, t.outBuf)
|
||||||
if e := syscall.Tcgetattr(fd, &oldState); e != 0 {
|
t.outBuf = newOutBuf
|
||||||
return nil, os.Errno(e)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newState := oldState
|
oldLen := len(t.outBuf)
|
||||||
newState.Lflag &^= syscall.ECHO
|
t.outBuf = t.outBuf[:len(t.outBuf)+len(data)]
|
||||||
if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 {
|
copy(t.outBuf[oldLen:], data)
|
||||||
return nil, os.Errno(e)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
|
||||||
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
|
|
||||||
}()
|
func isPrintable(key int) bool {
|
||||||
|
return key >= 32 && key < 127
|
||||||
|
}
|
||||||
|
|
||||||
|
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
|
||||||
|
// given, logical position in the text.
|
||||||
|
func (t *Terminal) moveCursorToPos(pos int) {
|
||||||
|
x := len(t.prompt) + pos
|
||||||
|
y := x / t.termWidth
|
||||||
|
x = x % t.termWidth
|
||||||
|
|
||||||
|
up := 0
|
||||||
|
if y < t.cursorY {
|
||||||
|
up = t.cursorY - y
|
||||||
|
}
|
||||||
|
|
||||||
|
down := 0
|
||||||
|
if y > t.cursorY {
|
||||||
|
down = y - t.cursorY
|
||||||
|
}
|
||||||
|
|
||||||
|
left := 0
|
||||||
|
if x < t.cursorX {
|
||||||
|
left = t.cursorX - x
|
||||||
|
}
|
||||||
|
|
||||||
|
right := 0
|
||||||
|
if x > t.cursorX {
|
||||||
|
right = x - t.cursorX
|
||||||
|
}
|
||||||
|
|
||||||
|
movement := make([]byte, 3*(up+down+left+right))
|
||||||
|
m := movement
|
||||||
|
for i := 0; i < up; i++ {
|
||||||
|
m[0] = keyEscape
|
||||||
|
m[1] = '['
|
||||||
|
m[2] = 'A'
|
||||||
|
m = m[3:]
|
||||||
|
}
|
||||||
|
for i := 0; i < down; i++ {
|
||||||
|
m[0] = keyEscape
|
||||||
|
m[1] = '['
|
||||||
|
m[2] = 'B'
|
||||||
|
m = m[3:]
|
||||||
|
}
|
||||||
|
for i := 0; i < left; i++ {
|
||||||
|
m[0] = keyEscape
|
||||||
|
m[1] = '['
|
||||||
|
m[2] = 'D'
|
||||||
|
m = m[3:]
|
||||||
|
}
|
||||||
|
for i := 0; i < right; i++ {
|
||||||
|
m[0] = keyEscape
|
||||||
|
m[1] = '['
|
||||||
|
m[2] = 'C'
|
||||||
|
m = m[3:]
|
||||||
|
}
|
||||||
|
|
||||||
|
t.cursorX = x
|
||||||
|
t.cursorY = y
|
||||||
|
t.queue(movement)
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxLineLength = 4096
|
||||||
|
|
||||||
|
// handleKey processes the given key and, optionally, returns a line of text
|
||||||
|
// that the user has entered.
|
||||||
|
func (t *Terminal) handleKey(key int) (line string, ok bool) {
|
||||||
|
switch key {
|
||||||
|
case keyBackspace:
|
||||||
|
if t.pos == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.pos--
|
||||||
|
|
||||||
|
copy(t.line[t.pos:], t.line[1+t.pos:])
|
||||||
|
t.line = t.line[:len(t.line)-1]
|
||||||
|
t.writeLine(t.line[t.pos:])
|
||||||
|
t.moveCursorToPos(t.pos)
|
||||||
|
t.queue(eraseUnderCursor)
|
||||||
|
case keyAltLeft:
|
||||||
|
// move left by a word.
|
||||||
|
if t.pos == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.pos--
|
||||||
|
for t.pos > 0 {
|
||||||
|
if t.line[t.pos] != ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.pos--
|
||||||
|
}
|
||||||
|
for t.pos > 0 {
|
||||||
|
if t.line[t.pos] == ' ' {
|
||||||
|
t.pos++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.pos--
|
||||||
|
}
|
||||||
|
t.moveCursorToPos(t.pos)
|
||||||
|
case keyAltRight:
|
||||||
|
// move right by a word.
|
||||||
|
for t.pos < len(t.line) {
|
||||||
|
if t.line[t.pos] == ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.pos++
|
||||||
|
}
|
||||||
|
for t.pos < len(t.line) {
|
||||||
|
if t.line[t.pos] != ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.pos++
|
||||||
|
}
|
||||||
|
t.moveCursorToPos(t.pos)
|
||||||
|
case keyLeft:
|
||||||
|
if t.pos == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.pos--
|
||||||
|
t.moveCursorToPos(t.pos)
|
||||||
|
case keyRight:
|
||||||
|
if t.pos == len(t.line) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.pos++
|
||||||
|
t.moveCursorToPos(t.pos)
|
||||||
|
case keyEnter:
|
||||||
|
t.moveCursorToPos(len(t.line))
|
||||||
|
t.queue([]byte("\r\n"))
|
||||||
|
line = string(t.line)
|
||||||
|
ok = true
|
||||||
|
t.line = t.line[:0]
|
||||||
|
t.pos = 0
|
||||||
|
t.cursorX = 0
|
||||||
|
t.cursorY = 0
|
||||||
|
t.maxLine = 0
|
||||||
|
default:
|
||||||
|
if !isPrintable(key) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(t.line) == maxLineLength {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(t.line) == cap(t.line) {
|
||||||
|
newLine := make([]byte, len(t.line), 2*(1+len(t.line)))
|
||||||
|
copy(newLine, t.line)
|
||||||
|
t.line = newLine
|
||||||
|
}
|
||||||
|
t.line = t.line[:len(t.line)+1]
|
||||||
|
copy(t.line[t.pos+1:], t.line[t.pos:])
|
||||||
|
t.line[t.pos] = byte(key)
|
||||||
|
t.writeLine(t.line[t.pos:])
|
||||||
|
t.pos++
|
||||||
|
t.moveCursorToPos(t.pos)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Terminal) writeLine(line []byte) {
|
||||||
|
for len(line) != 0 {
|
||||||
|
if t.cursorX == t.termWidth {
|
||||||
|
t.queue([]byte("\r\n"))
|
||||||
|
t.cursorX = 0
|
||||||
|
t.cursorY++
|
||||||
|
if t.cursorY > t.maxLine {
|
||||||
|
t.maxLine = t.cursorY
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remainingOnLine := t.termWidth - t.cursorX
|
||||||
|
todo := len(line)
|
||||||
|
if todo > remainingOnLine {
|
||||||
|
todo = remainingOnLine
|
||||||
|
}
|
||||||
|
t.queue(line[:todo])
|
||||||
|
t.cursorX += todo
|
||||||
|
line = line[todo:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Terminal) Write(buf []byte) (n int, err error) {
|
||||||
|
return t.c.Write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLine returns a line of input from the terminal.
|
||||||
|
func (t *Terminal) ReadLine() (line string, err error) {
|
||||||
|
if t.cursorX == 0 {
|
||||||
|
t.writeLine([]byte(t.prompt))
|
||||||
|
t.c.Write(t.outBuf)
|
||||||
|
t.outBuf = t.outBuf[:0]
|
||||||
|
}
|
||||||
|
|
||||||
var buf [16]byte
|
|
||||||
var ret []byte
|
|
||||||
for {
|
for {
|
||||||
n, errno := syscall.Read(fd, buf[:])
|
// t.remainder is a slice at the beginning of t.inBuf
|
||||||
if errno != 0 {
|
// containing a partial key sequence
|
||||||
return nil, os.Errno(errno)
|
readBuf := t.inBuf[len(t.remainder):]
|
||||||
}
|
var n int
|
||||||
if n == 0 {
|
n, err = t.c.Read(readBuf)
|
||||||
if len(ret) == 0 {
|
if err != nil {
|
||||||
return nil, io.EOF
|
return
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if buf[n-1] == '\n' {
|
|
||||||
n--
|
|
||||||
}
|
|
||||||
ret = append(ret, buf[:n]...)
|
|
||||||
if n < len(buf) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret, nil
|
if err == nil {
|
||||||
|
t.remainder = t.inBuf[:n+len(t.remainder)]
|
||||||
|
rest := t.remainder
|
||||||
|
lineOk := false
|
||||||
|
for !lineOk {
|
||||||
|
var key int
|
||||||
|
key, rest = bytesToKey(rest)
|
||||||
|
if key < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if key == keyCtrlD {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
line, lineOk = t.handleKey(key)
|
||||||
|
}
|
||||||
|
if len(rest) > 0 {
|
||||||
|
n := copy(t.inBuf[:], rest)
|
||||||
|
t.remainder = t.inBuf[:n]
|
||||||
|
} else {
|
||||||
|
t.remainder = nil
|
||||||
|
}
|
||||||
|
t.c.Write(t.outBuf)
|
||||||
|
t.outBuf = t.outBuf[:0]
|
||||||
|
if lineOk {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Terminal) SetSize(width, height int) {
|
||||||
|
t.termWidth, t.termHeight = width, height
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (c *MockTerminal) Write(data []byte) (n int, err error) {
|
||||||
|
|
||||||
func TestClose(t *testing.T) {
|
func TestClose(t *testing.T) {
|
||||||
c := &MockTerminal{}
|
c := &MockTerminal{}
|
||||||
ss := NewShell(c, "> ")
|
ss := NewTerminal(c, "> ")
|
||||||
line, err := ss.ReadLine()
|
line, err := ss.ReadLine()
|
||||||
if line != "" {
|
if line != "" {
|
||||||
t.Errorf("Expected empty line but got: %s", line)
|
t.Errorf("Expected empty line but got: %s", line)
|
||||||
|
@ -95,7 +95,7 @@ func TestKeyPresses(t *testing.T) {
|
||||||
toSend: []byte(test.in),
|
toSend: []byte(test.in),
|
||||||
bytesPerRead: j,
|
bytesPerRead: j,
|
||||||
}
|
}
|
||||||
ss := NewShell(c, "> ")
|
ss := NewTerminal(c, "> ")
|
||||||
line, err := ss.ReadLine()
|
line, err := ss.ReadLine()
|
||||||
if line != test.line {
|
if line != test.line {
|
||||||
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
|
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
|
|
@ -0,0 +1,102 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package terminal provides support functions for dealing with terminals, as
|
||||||
|
// commonly found on UNIX systems.
|
||||||
|
//
|
||||||
|
// Putting a terminal into raw mode is the most common requirement:
|
||||||
|
//
|
||||||
|
// oldState, err := terminal.MakeRaw(0)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err.String())
|
||||||
|
// }
|
||||||
|
// defer terminal.Restore(0, oldState)
|
||||||
|
package terminal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State contains the state of a terminal.
|
||||||
|
type State struct {
|
||||||
|
termios syscall.Termios
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||||
|
func IsTerminal(fd int) bool {
|
||||||
|
var termios syscall.Termios
|
||||||
|
err := syscall.Tcgetattr(fd, &termios)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeRaw put the terminal connected to the given file descriptor into raw
|
||||||
|
// mode and returns the previous state of the terminal so that it can be
|
||||||
|
// restored.
|
||||||
|
func MakeRaw(fd int) (*State, error) {
|
||||||
|
var oldState State
|
||||||
|
if err := syscall.Tcgetattr(fd, &oldState.termios); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newState := oldState.termios
|
||||||
|
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
|
||||||
|
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
|
||||||
|
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &oldState, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore restores the terminal connected to the given file descriptor to a
|
||||||
|
// previous state.
|
||||||
|
func Restore(fd int, state *State) error {
|
||||||
|
err := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||||
|
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||||
|
// returned does not include the \n.
|
||||||
|
func ReadPassword(fd int) ([]byte, error) {
|
||||||
|
var oldState syscall.Termios
|
||||||
|
if err := syscall.Tcgetattr(fd, &oldState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newState := oldState
|
||||||
|
newState.Lflag &^= syscall.ECHO
|
||||||
|
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var buf [16]byte
|
||||||
|
var ret []byte
|
||||||
|
for {
|
||||||
|
n, err := syscall.Read(fd, buf[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
if len(ret) == 0 {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if buf[n-1] == '\n' {
|
||||||
|
n--
|
||||||
|
}
|
||||||
|
ret = append(ret, buf[:n]...)
|
||||||
|
if n < len(buf) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
|
@ -357,6 +357,10 @@ var fmttests = []struct {
|
||||||
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
|
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
|
||||||
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
|
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
|
||||||
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
|
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
|
||||||
|
{"%#v", []int(nil), `[]int(nil)`},
|
||||||
|
{"%#v", []int{}, `[]int{}`},
|
||||||
|
{"%#v", map[int]byte(nil), `map[int] uint8(nil)`},
|
||||||
|
{"%#v", map[int]byte{}, `map[int] uint8{}`},
|
||||||
|
|
||||||
// slices with other formats
|
// slices with other formats
|
||||||
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},
|
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},
|
||||||
|
|
|
@ -795,6 +795,10 @@ BigSwitch:
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
if goSyntax {
|
if goSyntax {
|
||||||
p.buf.WriteString(f.Type().String())
|
p.buf.WriteString(f.Type().String())
|
||||||
|
if f.IsNil() {
|
||||||
|
p.buf.WriteString("(nil)")
|
||||||
|
break
|
||||||
|
}
|
||||||
p.buf.WriteByte('{')
|
p.buf.WriteByte('{')
|
||||||
} else {
|
} else {
|
||||||
p.buf.Write(mapBytes)
|
p.buf.Write(mapBytes)
|
||||||
|
@ -873,6 +877,10 @@ BigSwitch:
|
||||||
}
|
}
|
||||||
if goSyntax {
|
if goSyntax {
|
||||||
p.buf.WriteString(value.Type().String())
|
p.buf.WriteString(value.Type().String())
|
||||||
|
if f.IsNil() {
|
||||||
|
p.buf.WriteString("(nil)")
|
||||||
|
break
|
||||||
|
}
|
||||||
p.buf.WriteByte('{')
|
p.buf.WriteByte('{')
|
||||||
} else {
|
} else {
|
||||||
p.buf.WriteByte('[')
|
p.buf.WriteByte('[')
|
||||||
|
|
|
@ -324,7 +324,7 @@ var x, y Xs
|
||||||
var z IntString
|
var z IntString
|
||||||
|
|
||||||
var multiTests = []ScanfMultiTest{
|
var multiTests = []ScanfMultiTest{
|
||||||
{"", "", nil, nil, ""},
|
{"", "", []interface{}{}, []interface{}{}, ""},
|
||||||
{"%d", "23", args(&i), args(23), ""},
|
{"%d", "23", args(&i), args(23), ""},
|
||||||
{"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""},
|
{"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""},
|
||||||
{"%2d%3d", "44555", args(&i, &j), args(44, 555), ""},
|
{"%2d%3d", "44555", args(&i, &j), args(44, 555), ""},
|
||||||
|
@ -378,7 +378,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{}
|
||||||
}
|
}
|
||||||
val := v.Interface()
|
val := v.Interface()
|
||||||
if !reflect.DeepEqual(val, test.out) {
|
if !reflect.DeepEqual(val, test.out) {
|
||||||
t.Errorf("%s scanning %q: expected %v got %v, type %T", name, test.text, test.out, val, val)
|
t.Errorf("%s scanning %q: expected %#v got %#v, type %T", name, test.text, test.out, val, val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -417,7 +417,7 @@ func TestScanf(t *testing.T) {
|
||||||
}
|
}
|
||||||
val := v.Interface()
|
val := v.Interface()
|
||||||
if !reflect.DeepEqual(val, test.out) {
|
if !reflect.DeepEqual(val, test.out) {
|
||||||
t.Errorf("scanning (%q, %q): expected %v got %v, type %T", test.format, test.text, test.out, val, val)
|
t.Errorf("scanning (%q, %q): expected %#v got %#v, type %T", test.format, test.text, test.out, val, val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -520,7 +520,7 @@ func testScanfMulti(name string, t *testing.T) {
|
||||||
}
|
}
|
||||||
result := resultVal.Interface()
|
result := resultVal.Interface()
|
||||||
if !reflect.DeepEqual(result, test.out) {
|
if !reflect.DeepEqual(result, test.out) {
|
||||||
t.Errorf("scanning (%q, %q): expected %v got %v", test.format, test.text, test.out, result)
|
t.Errorf("scanning (%q, %q): expected %#v got %#v", test.format, test.text, test.out, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -412,29 +412,29 @@ func (x *ChanType) End() token.Pos { return x.Value.End() }
|
||||||
// exprNode() ensures that only expression/type nodes can be
|
// exprNode() ensures that only expression/type nodes can be
|
||||||
// assigned to an ExprNode.
|
// assigned to an ExprNode.
|
||||||
//
|
//
|
||||||
func (x *BadExpr) exprNode() {}
|
func (*BadExpr) exprNode() {}
|
||||||
func (x *Ident) exprNode() {}
|
func (*Ident) exprNode() {}
|
||||||
func (x *Ellipsis) exprNode() {}
|
func (*Ellipsis) exprNode() {}
|
||||||
func (x *BasicLit) exprNode() {}
|
func (*BasicLit) exprNode() {}
|
||||||
func (x *FuncLit) exprNode() {}
|
func (*FuncLit) exprNode() {}
|
||||||
func (x *CompositeLit) exprNode() {}
|
func (*CompositeLit) exprNode() {}
|
||||||
func (x *ParenExpr) exprNode() {}
|
func (*ParenExpr) exprNode() {}
|
||||||
func (x *SelectorExpr) exprNode() {}
|
func (*SelectorExpr) exprNode() {}
|
||||||
func (x *IndexExpr) exprNode() {}
|
func (*IndexExpr) exprNode() {}
|
||||||
func (x *SliceExpr) exprNode() {}
|
func (*SliceExpr) exprNode() {}
|
||||||
func (x *TypeAssertExpr) exprNode() {}
|
func (*TypeAssertExpr) exprNode() {}
|
||||||
func (x *CallExpr) exprNode() {}
|
func (*CallExpr) exprNode() {}
|
||||||
func (x *StarExpr) exprNode() {}
|
func (*StarExpr) exprNode() {}
|
||||||
func (x *UnaryExpr) exprNode() {}
|
func (*UnaryExpr) exprNode() {}
|
||||||
func (x *BinaryExpr) exprNode() {}
|
func (*BinaryExpr) exprNode() {}
|
||||||
func (x *KeyValueExpr) exprNode() {}
|
func (*KeyValueExpr) exprNode() {}
|
||||||
|
|
||||||
func (x *ArrayType) exprNode() {}
|
func (*ArrayType) exprNode() {}
|
||||||
func (x *StructType) exprNode() {}
|
func (*StructType) exprNode() {}
|
||||||
func (x *FuncType) exprNode() {}
|
func (*FuncType) exprNode() {}
|
||||||
func (x *InterfaceType) exprNode() {}
|
func (*InterfaceType) exprNode() {}
|
||||||
func (x *MapType) exprNode() {}
|
func (*MapType) exprNode() {}
|
||||||
func (x *ChanType) exprNode() {}
|
func (*ChanType) exprNode() {}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Convenience functions for Idents
|
// Convenience functions for Idents
|
||||||
|
@ -711,27 +711,27 @@ func (s *RangeStmt) End() token.Pos { return s.Body.End() }
|
||||||
// stmtNode() ensures that only statement nodes can be
|
// stmtNode() ensures that only statement nodes can be
|
||||||
// assigned to a StmtNode.
|
// assigned to a StmtNode.
|
||||||
//
|
//
|
||||||
func (s *BadStmt) stmtNode() {}
|
func (*BadStmt) stmtNode() {}
|
||||||
func (s *DeclStmt) stmtNode() {}
|
func (*DeclStmt) stmtNode() {}
|
||||||
func (s *EmptyStmt) stmtNode() {}
|
func (*EmptyStmt) stmtNode() {}
|
||||||
func (s *LabeledStmt) stmtNode() {}
|
func (*LabeledStmt) stmtNode() {}
|
||||||
func (s *ExprStmt) stmtNode() {}
|
func (*ExprStmt) stmtNode() {}
|
||||||
func (s *SendStmt) stmtNode() {}
|
func (*SendStmt) stmtNode() {}
|
||||||
func (s *IncDecStmt) stmtNode() {}
|
func (*IncDecStmt) stmtNode() {}
|
||||||
func (s *AssignStmt) stmtNode() {}
|
func (*AssignStmt) stmtNode() {}
|
||||||
func (s *GoStmt) stmtNode() {}
|
func (*GoStmt) stmtNode() {}
|
||||||
func (s *DeferStmt) stmtNode() {}
|
func (*DeferStmt) stmtNode() {}
|
||||||
func (s *ReturnStmt) stmtNode() {}
|
func (*ReturnStmt) stmtNode() {}
|
||||||
func (s *BranchStmt) stmtNode() {}
|
func (*BranchStmt) stmtNode() {}
|
||||||
func (s *BlockStmt) stmtNode() {}
|
func (*BlockStmt) stmtNode() {}
|
||||||
func (s *IfStmt) stmtNode() {}
|
func (*IfStmt) stmtNode() {}
|
||||||
func (s *CaseClause) stmtNode() {}
|
func (*CaseClause) stmtNode() {}
|
||||||
func (s *SwitchStmt) stmtNode() {}
|
func (*SwitchStmt) stmtNode() {}
|
||||||
func (s *TypeSwitchStmt) stmtNode() {}
|
func (*TypeSwitchStmt) stmtNode() {}
|
||||||
func (s *CommClause) stmtNode() {}
|
func (*CommClause) stmtNode() {}
|
||||||
func (s *SelectStmt) stmtNode() {}
|
func (*SelectStmt) stmtNode() {}
|
||||||
func (s *ForStmt) stmtNode() {}
|
func (*ForStmt) stmtNode() {}
|
||||||
func (s *RangeStmt) stmtNode() {}
|
func (*RangeStmt) stmtNode() {}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Declarations
|
// Declarations
|
||||||
|
@ -807,9 +807,9 @@ func (s *TypeSpec) End() token.Pos { return s.Type.End() }
|
||||||
// specNode() ensures that only spec nodes can be
|
// specNode() ensures that only spec nodes can be
|
||||||
// assigned to a Spec.
|
// assigned to a Spec.
|
||||||
//
|
//
|
||||||
func (s *ImportSpec) specNode() {}
|
func (*ImportSpec) specNode() {}
|
||||||
func (s *ValueSpec) specNode() {}
|
func (*ValueSpec) specNode() {}
|
||||||
func (s *TypeSpec) specNode() {}
|
func (*TypeSpec) specNode() {}
|
||||||
|
|
||||||
// A declaration is represented by one of the following declaration nodes.
|
// A declaration is represented by one of the following declaration nodes.
|
||||||
//
|
//
|
||||||
|
@ -875,9 +875,9 @@ func (d *FuncDecl) End() token.Pos {
|
||||||
// declNode() ensures that only declaration nodes can be
|
// declNode() ensures that only declaration nodes can be
|
||||||
// assigned to a DeclNode.
|
// assigned to a DeclNode.
|
||||||
//
|
//
|
||||||
func (d *BadDecl) declNode() {}
|
func (*BadDecl) declNode() {}
|
||||||
func (d *GenDecl) declNode() {}
|
func (*GenDecl) declNode() {}
|
||||||
func (d *FuncDecl) declNode() {}
|
func (*FuncDecl) declNode() {}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Files and packages
|
// Files and packages
|
||||||
|
|
|
@ -24,7 +24,7 @@ func exportFilter(name string) bool {
|
||||||
// it returns false otherwise.
|
// it returns false otherwise.
|
||||||
//
|
//
|
||||||
func FileExports(src *File) bool {
|
func FileExports(src *File) bool {
|
||||||
return FilterFile(src, exportFilter)
|
return filterFile(src, exportFilter, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PackageExports trims the AST for a Go package in place such that
|
// PackageExports trims the AST for a Go package in place such that
|
||||||
|
@ -35,7 +35,7 @@ func FileExports(src *File) bool {
|
||||||
// it returns false otherwise.
|
// it returns false otherwise.
|
||||||
//
|
//
|
||||||
func PackageExports(pkg *Package) bool {
|
func PackageExports(pkg *Package) bool {
|
||||||
return FilterPackage(pkg, exportFilter)
|
return filterPackage(pkg, exportFilter, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
@ -72,7 +72,7 @@ func fieldName(x Expr) *Ident {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
|
func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
|
||||||
if fields == nil {
|
if fields == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -93,8 +93,8 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
|
||||||
keepField = len(f.Names) > 0
|
keepField = len(f.Names) > 0
|
||||||
}
|
}
|
||||||
if keepField {
|
if keepField {
|
||||||
if filter == exportFilter {
|
if export {
|
||||||
filterType(f.Type, filter)
|
filterType(f.Type, filter, export)
|
||||||
}
|
}
|
||||||
list[j] = f
|
list[j] = f
|
||||||
j++
|
j++
|
||||||
|
@ -107,84 +107,84 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterParamList(fields *FieldList, filter Filter) bool {
|
func filterParamList(fields *FieldList, filter Filter, export bool) bool {
|
||||||
if fields == nil {
|
if fields == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
var b bool
|
var b bool
|
||||||
for _, f := range fields.List {
|
for _, f := range fields.List {
|
||||||
if filterType(f.Type, filter) {
|
if filterType(f.Type, filter, export) {
|
||||||
b = true
|
b = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterType(typ Expr, f Filter) bool {
|
func filterType(typ Expr, f Filter, export bool) bool {
|
||||||
switch t := typ.(type) {
|
switch t := typ.(type) {
|
||||||
case *Ident:
|
case *Ident:
|
||||||
return f(t.Name)
|
return f(t.Name)
|
||||||
case *ParenExpr:
|
case *ParenExpr:
|
||||||
return filterType(t.X, f)
|
return filterType(t.X, f, export)
|
||||||
case *ArrayType:
|
case *ArrayType:
|
||||||
return filterType(t.Elt, f)
|
return filterType(t.Elt, f, export)
|
||||||
case *StructType:
|
case *StructType:
|
||||||
if filterFieldList(t.Fields, f) {
|
if filterFieldList(t.Fields, f, export) {
|
||||||
t.Incomplete = true
|
t.Incomplete = true
|
||||||
}
|
}
|
||||||
return len(t.Fields.List) > 0
|
return len(t.Fields.List) > 0
|
||||||
case *FuncType:
|
case *FuncType:
|
||||||
b1 := filterParamList(t.Params, f)
|
b1 := filterParamList(t.Params, f, export)
|
||||||
b2 := filterParamList(t.Results, f)
|
b2 := filterParamList(t.Results, f, export)
|
||||||
return b1 || b2
|
return b1 || b2
|
||||||
case *InterfaceType:
|
case *InterfaceType:
|
||||||
if filterFieldList(t.Methods, f) {
|
if filterFieldList(t.Methods, f, export) {
|
||||||
t.Incomplete = true
|
t.Incomplete = true
|
||||||
}
|
}
|
||||||
return len(t.Methods.List) > 0
|
return len(t.Methods.List) > 0
|
||||||
case *MapType:
|
case *MapType:
|
||||||
b1 := filterType(t.Key, f)
|
b1 := filterType(t.Key, f, export)
|
||||||
b2 := filterType(t.Value, f)
|
b2 := filterType(t.Value, f, export)
|
||||||
return b1 || b2
|
return b1 || b2
|
||||||
case *ChanType:
|
case *ChanType:
|
||||||
return filterType(t.Value, f)
|
return filterType(t.Value, f, export)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterSpec(spec Spec, f Filter) bool {
|
func filterSpec(spec Spec, f Filter, export bool) bool {
|
||||||
switch s := spec.(type) {
|
switch s := spec.(type) {
|
||||||
case *ValueSpec:
|
case *ValueSpec:
|
||||||
s.Names = filterIdentList(s.Names, f)
|
s.Names = filterIdentList(s.Names, f)
|
||||||
if len(s.Names) > 0 {
|
if len(s.Names) > 0 {
|
||||||
if f == exportFilter {
|
if export {
|
||||||
filterType(s.Type, f)
|
filterType(s.Type, f, export)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case *TypeSpec:
|
case *TypeSpec:
|
||||||
if f(s.Name.Name) {
|
if f(s.Name.Name) {
|
||||||
if f == exportFilter {
|
if export {
|
||||||
filterType(s.Type, f)
|
filterType(s.Type, f, export)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if f != exportFilter {
|
if !export {
|
||||||
// For general filtering (not just exports),
|
// For general filtering (not just exports),
|
||||||
// filter type even if name is not filtered
|
// filter type even if name is not filtered
|
||||||
// out.
|
// out.
|
||||||
// If the type contains filtered elements,
|
// If the type contains filtered elements,
|
||||||
// keep the declaration.
|
// keep the declaration.
|
||||||
return filterType(s.Type, f)
|
return filterType(s.Type, f, export)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterSpecList(list []Spec, f Filter) []Spec {
|
func filterSpecList(list []Spec, f Filter, export bool) []Spec {
|
||||||
j := 0
|
j := 0
|
||||||
for _, s := range list {
|
for _, s := range list {
|
||||||
if filterSpec(s, f) {
|
if filterSpec(s, f, export) {
|
||||||
list[j] = s
|
list[j] = s
|
||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
|
@ -200,9 +200,13 @@ func filterSpecList(list []Spec, f Filter) []Spec {
|
||||||
// filtering; it returns false otherwise.
|
// filtering; it returns false otherwise.
|
||||||
//
|
//
|
||||||
func FilterDecl(decl Decl, f Filter) bool {
|
func FilterDecl(decl Decl, f Filter) bool {
|
||||||
|
return filterDecl(decl, f, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterDecl(decl Decl, f Filter, export bool) bool {
|
||||||
switch d := decl.(type) {
|
switch d := decl.(type) {
|
||||||
case *GenDecl:
|
case *GenDecl:
|
||||||
d.Specs = filterSpecList(d.Specs, f)
|
d.Specs = filterSpecList(d.Specs, f, export)
|
||||||
return len(d.Specs) > 0
|
return len(d.Specs) > 0
|
||||||
case *FuncDecl:
|
case *FuncDecl:
|
||||||
return f(d.Name.Name)
|
return f(d.Name.Name)
|
||||||
|
@ -221,9 +225,13 @@ func FilterDecl(decl Decl, f Filter) bool {
|
||||||
// left after filtering; it returns false otherwise.
|
// left after filtering; it returns false otherwise.
|
||||||
//
|
//
|
||||||
func FilterFile(src *File, f Filter) bool {
|
func FilterFile(src *File, f Filter) bool {
|
||||||
|
return filterFile(src, f, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterFile(src *File, f Filter, export bool) bool {
|
||||||
j := 0
|
j := 0
|
||||||
for _, d := range src.Decls {
|
for _, d := range src.Decls {
|
||||||
if FilterDecl(d, f) {
|
if filterDecl(d, f, export) {
|
||||||
src.Decls[j] = d
|
src.Decls[j] = d
|
||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
|
@ -244,9 +252,13 @@ func FilterFile(src *File, f Filter) bool {
|
||||||
// left after filtering; it returns false otherwise.
|
// left after filtering; it returns false otherwise.
|
||||||
//
|
//
|
||||||
func FilterPackage(pkg *Package, f Filter) bool {
|
func FilterPackage(pkg *Package, f Filter) bool {
|
||||||
|
return filterPackage(pkg, f, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterPackage(pkg *Package, f Filter, export bool) bool {
|
||||||
hasDecls := false
|
hasDecls := false
|
||||||
for _, src := range pkg.Files {
|
for _, src := range pkg.Files {
|
||||||
if FilterFile(src, f) {
|
if filterFile(src, f, export) {
|
||||||
hasDecls = true
|
hasDecls = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,6 +40,7 @@ var buildPkgs = []struct {
|
||||||
GoFiles: []string{"main.go"},
|
GoFiles: []string{"main.go"},
|
||||||
Package: "main",
|
Package: "main",
|
||||||
Imports: []string{"go/build/pkgtest"},
|
Imports: []string{"go/build/pkgtest"},
|
||||||
|
TestImports: []string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -48,6 +49,7 @@ var buildPkgs = []struct {
|
||||||
CgoFiles: []string{"cgotest.go"},
|
CgoFiles: []string{"cgotest.go"},
|
||||||
CFiles: []string{"cgotest.c"},
|
CFiles: []string{"cgotest.c"},
|
||||||
Imports: []string{"C", "unsafe"},
|
Imports: []string{"C", "unsafe"},
|
||||||
|
TestImports: []string{},
|
||||||
Package: "cgotest",
|
Package: "cgotest",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -13,6 +13,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -244,6 +246,8 @@ func (p *printer) writeItem(pos token.Position, data string) {
|
||||||
p.last = p.pos
|
p.last = p.pos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const linePrefix = "//line "
|
||||||
|
|
||||||
// writeCommentPrefix writes the whitespace before a comment.
|
// writeCommentPrefix writes the whitespace before a comment.
|
||||||
// If there is any pending whitespace, it consumes as much of
|
// If there is any pending whitespace, it consumes as much of
|
||||||
// it as is likely to help position the comment nicely.
|
// it as is likely to help position the comment nicely.
|
||||||
|
@ -252,7 +256,7 @@ func (p *printer) writeItem(pos token.Position, data string) {
|
||||||
// a group of comments (or nil), and isKeyword indicates if the
|
// a group of comments (or nil), and isKeyword indicates if the
|
||||||
// next item is a keyword.
|
// next item is a keyword.
|
||||||
//
|
//
|
||||||
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, isKeyword bool) {
|
func (p *printer) writeCommentPrefix(pos, next token.Position, prev, comment *ast.Comment, isKeyword bool) {
|
||||||
if p.written == 0 {
|
if p.written == 0 {
|
||||||
// the comment is the first item to be printed - don't write any whitespace
|
// the comment is the first item to be printed - don't write any whitespace
|
||||||
return
|
return
|
||||||
|
@ -337,6 +341,13 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
|
||||||
}
|
}
|
||||||
p.writeWhitespace(j)
|
p.writeWhitespace(j)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// turn off indent if we're about to print a line directive.
|
||||||
|
indent := p.indent
|
||||||
|
if strings.HasPrefix(comment.Text, linePrefix) {
|
||||||
|
p.indent = 0
|
||||||
|
}
|
||||||
|
|
||||||
// use formfeeds to break columns before a comment;
|
// use formfeeds to break columns before a comment;
|
||||||
// this is analogous to using formfeeds to separate
|
// this is analogous to using formfeeds to separate
|
||||||
// individual lines of /*-style comments - but make
|
// individual lines of /*-style comments - but make
|
||||||
|
@ -347,6 +358,7 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
|
||||||
n = 1
|
n = 1
|
||||||
}
|
}
|
||||||
p.writeNewlines(n, true)
|
p.writeNewlines(n, true)
|
||||||
|
p.indent = indent
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -526,6 +538,26 @@ func stripCommonPrefix(lines [][]byte) {
|
||||||
func (p *printer) writeComment(comment *ast.Comment) {
|
func (p *printer) writeComment(comment *ast.Comment) {
|
||||||
text := comment.Text
|
text := comment.Text
|
||||||
|
|
||||||
|
if strings.HasPrefix(text, linePrefix) {
|
||||||
|
pos := strings.TrimSpace(text[len(linePrefix):])
|
||||||
|
i := strings.LastIndex(pos, ":")
|
||||||
|
if i >= 0 {
|
||||||
|
// The line directive we are about to print changed
|
||||||
|
// the Filename and Line number used by go/token
|
||||||
|
// as it was reading the input originally.
|
||||||
|
// In order to match the original input, we have to
|
||||||
|
// update our own idea of the file and line number
|
||||||
|
// accordingly, after printing the directive.
|
||||||
|
file := pos[:i]
|
||||||
|
line, _ := strconv.Atoi(string(pos[i+1:]))
|
||||||
|
defer func() {
|
||||||
|
p.pos.Filename = string(file)
|
||||||
|
p.pos.Line = line
|
||||||
|
p.pos.Column = 1
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shortcut common case of //-style comments
|
// shortcut common case of //-style comments
|
||||||
if text[1] == '/' {
|
if text[1] == '/' {
|
||||||
p.writeItem(p.fset.Position(comment.Pos()), p.escape(text))
|
p.writeItem(p.fset.Position(comment.Pos()), p.escape(text))
|
||||||
|
@ -599,7 +631,7 @@ func (p *printer) intersperseComments(next token.Position, tok token.Token) (dro
|
||||||
var last *ast.Comment
|
var last *ast.Comment
|
||||||
for ; p.commentBefore(next); p.cindex++ {
|
for ; p.commentBefore(next); p.cindex++ {
|
||||||
for _, c := range p.comments[p.cindex].List {
|
for _, c := range p.comments[p.cindex].List {
|
||||||
p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, tok.IsKeyword())
|
p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, c, tok.IsKeyword())
|
||||||
p.writeComment(c)
|
p.writeComment(c)
|
||||||
last = c
|
last = c
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ lower-cased, and attributes are collected into a []Attribute. For example:
|
||||||
for {
|
for {
|
||||||
if z.Next() == html.ErrorToken {
|
if z.Next() == html.ErrorToken {
|
||||||
// Returning io.EOF indicates success.
|
// Returning io.EOF indicates success.
|
||||||
return z.Error()
|
return z.Err()
|
||||||
}
|
}
|
||||||
emitToken(z.Token())
|
emitToken(z.Token())
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ call to Next. For example, to extract an HTML page's anchor text:
|
||||||
tt := z.Next()
|
tt := z.Next()
|
||||||
switch tt {
|
switch tt {
|
||||||
case ErrorToken:
|
case ErrorToken:
|
||||||
return z.Error()
|
return z.Err()
|
||||||
case TextToken:
|
case TextToken:
|
||||||
if depth > 0 {
|
if depth > 0 {
|
||||||
// emitBytes should copy the []byte it receives,
|
// emitBytes should copy the []byte it receives,
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -133,8 +133,8 @@ func TestParser(t *testing.T) {
|
||||||
n int
|
n int
|
||||||
}{
|
}{
|
||||||
// TODO(nigeltao): Process all the test cases from all the .dat files.
|
// TODO(nigeltao): Process all the test cases from all the .dat files.
|
||||||
{"tests1.dat", 92},
|
{"tests1.dat", -1},
|
||||||
{"tests2.dat", 0},
|
{"tests2.dat", 43},
|
||||||
{"tests3.dat", 0},
|
{"tests3.dat", 0},
|
||||||
}
|
}
|
||||||
for _, tf := range testFiles {
|
for _, tf := range testFiles {
|
||||||
|
@ -213,4 +213,8 @@ var renderTestBlacklist = map[string]bool{
|
||||||
// More cases of <a> being reparented:
|
// More cases of <a> being reparented:
|
||||||
`<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true,
|
`<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true,
|
||||||
`<a><table><a></table><p><a><div><a>`: true,
|
`<a><table><a></table><p><a><div><a>`: true,
|
||||||
|
`<a><table><td><a><table></table><a></tr><a></table><a>`: true,
|
||||||
|
// A <plaintext> element is reparented, putting it before a table.
|
||||||
|
// A <plaintext> element can't have anything after it in HTML.
|
||||||
|
`<table><plaintext><td>`: true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,19 @@ func Render(w io.Writer, n *Node) error {
|
||||||
return buf.Flush()
|
return buf.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// plaintextAbort is returned from render1 when a <plaintext> element
|
||||||
|
// has been rendered. No more end tags should be rendered after that.
|
||||||
|
var plaintextAbort = errors.New("html: internal error (plaintext abort)")
|
||||||
|
|
||||||
func render(w writer, n *Node) error {
|
func render(w writer, n *Node) error {
|
||||||
|
err := render1(w, n)
|
||||||
|
if err == plaintextAbort {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func render1(w writer, n *Node) error {
|
||||||
// Render non-element nodes; these are the easy cases.
|
// Render non-element nodes; these are the easy cases.
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case ErrorNode:
|
case ErrorNode:
|
||||||
|
@ -61,7 +73,7 @@ func render(w writer, n *Node) error {
|
||||||
return escape(w, n.Data)
|
return escape(w, n.Data)
|
||||||
case DocumentNode:
|
case DocumentNode:
|
||||||
for _, c := range n.Child {
|
for _, c := range n.Child {
|
||||||
if err := render(w, c); err != nil {
|
if err := render1(w, c); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -128,7 +140,7 @@ func render(w writer, n *Node) error {
|
||||||
|
|
||||||
// Render any child nodes.
|
// Render any child nodes.
|
||||||
switch n.Data {
|
switch n.Data {
|
||||||
case "noembed", "noframes", "noscript", "script", "style":
|
case "noembed", "noframes", "noscript", "plaintext", "script", "style":
|
||||||
for _, c := range n.Child {
|
for _, c := range n.Child {
|
||||||
if c.Type != TextNode {
|
if c.Type != TextNode {
|
||||||
return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data)
|
return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data)
|
||||||
|
@ -137,18 +149,23 @@ func render(w writer, n *Node) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if n.Data == "plaintext" {
|
||||||
|
// Don't render anything else. <plaintext> must be the
|
||||||
|
// last element in the file, with no closing tag.
|
||||||
|
return plaintextAbort
|
||||||
|
}
|
||||||
case "textarea", "title":
|
case "textarea", "title":
|
||||||
for _, c := range n.Child {
|
for _, c := range n.Child {
|
||||||
if c.Type != TextNode {
|
if c.Type != TextNode {
|
||||||
return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data)
|
return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data)
|
||||||
}
|
}
|
||||||
if err := render(w, c); err != nil {
|
if err := render1(w, c); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
for _, c := range n.Child {
|
for _, c := range n.Child {
|
||||||
if err := render(w, c); err != nil {
|
if err := render1(w, c); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ package template
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Strings of content from a trusted source.
|
// Strings of content from a trusted source.
|
||||||
|
@ -70,10 +71,25 @@ const (
|
||||||
contentTypeUnsafe
|
contentTypeUnsafe
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// indirect returns the value, after dereferencing as many times
|
||||||
|
// as necessary to reach the base type (or nil).
|
||||||
|
func indirect(a interface{}) interface{} {
|
||||||
|
if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr {
|
||||||
|
// Avoid creating a reflect.Value if it's not a pointer.
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(a)
|
||||||
|
for v.Kind() == reflect.Ptr && !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
return v.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
// stringify converts its arguments to a string and the type of the content.
|
// stringify converts its arguments to a string and the type of the content.
|
||||||
|
// All pointers are dereferenced, as in the text/template package.
|
||||||
func stringify(args ...interface{}) (string, contentType) {
|
func stringify(args ...interface{}) (string, contentType) {
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
switch s := args[0].(type) {
|
switch s := indirect(args[0]).(type) {
|
||||||
case string:
|
case string:
|
||||||
return s, contentTypePlain
|
return s, contentTypePlain
|
||||||
case CSS:
|
case CSS:
|
||||||
|
@ -90,5 +106,8 @@ func stringify(args ...interface{}) (string, contentType) {
|
||||||
return string(s), contentTypeURL
|
return string(s), contentTypeURL
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for i, arg := range args {
|
||||||
|
args[i] = indirect(arg)
|
||||||
|
}
|
||||||
return fmt.Sprint(args...), contentTypePlain
|
return fmt.Sprint(args...), contentTypePlain
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ func (x *goodMarshaler) MarshalJSON() ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEscape(t *testing.T) {
|
func TestEscape(t *testing.T) {
|
||||||
var data = struct {
|
data := struct {
|
||||||
F, T bool
|
F, T bool
|
||||||
C, G, H string
|
C, G, H string
|
||||||
A, E []string
|
A, E []string
|
||||||
|
@ -50,6 +50,7 @@ func TestEscape(t *testing.T) {
|
||||||
Z: nil,
|
Z: nil,
|
||||||
W: HTML(`¡<b class="foo">Hello</b>, <textarea>O'World</textarea>!`),
|
W: HTML(`¡<b class="foo">Hello</b>, <textarea>O'World</textarea>!`),
|
||||||
}
|
}
|
||||||
|
pdata := &data
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -668,6 +669,15 @@ func TestEscape(t *testing.T) {
|
||||||
t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g)
|
t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
b.Reset()
|
||||||
|
if err := tmpl.Execute(b, pdata); err != nil {
|
||||||
|
t.Errorf("%s: template execution failed for pointer: %s", test.name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if w, g := test.output, b.String(); w != g {
|
||||||
|
t.Errorf("%s: escaped output for pointer: want\n\t%q\ngot\n\t%q", test.name, w, g)
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1605,6 +1615,29 @@ func TestRedundantFuncs(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIndirectPrint(t *testing.T) {
|
||||||
|
a := 3
|
||||||
|
ap := &a
|
||||||
|
b := "hello"
|
||||||
|
bp := &b
|
||||||
|
bpp := &bp
|
||||||
|
tmpl := Must(New("t").Parse(`{{.}}`))
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := tmpl.Execute(&buf, ap)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %s", err)
|
||||||
|
} else if buf.String() != "3" {
|
||||||
|
t.Errorf(`Expected "3"; got %q`, buf.String())
|
||||||
|
}
|
||||||
|
buf.Reset()
|
||||||
|
err = tmpl.Execute(&buf, bpp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %s", err)
|
||||||
|
} else if buf.String() != "hello" {
|
||||||
|
t.Errorf(`Expected "hello"; got %q`, buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkEscapedExecute(b *testing.B) {
|
func BenchmarkEscapedExecute(b *testing.B) {
|
||||||
tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`))
|
tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`))
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
@ -117,12 +118,24 @@ var regexpPrecederKeywords = map[string]bool{
|
||||||
"void": true,
|
"void": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var jsonMarshalType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
|
||||||
|
|
||||||
|
// indirectToJSONMarshaler returns the value, after dereferencing as many times
|
||||||
|
// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
|
||||||
|
func indirectToJSONMarshaler(a interface{}) interface{} {
|
||||||
|
v := reflect.ValueOf(a)
|
||||||
|
for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Ptr && !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
return v.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
|
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
|
||||||
// nether side-effects nor free variables outside (NaN, Infinity).
|
// neither side-effects nor free variables outside (NaN, Infinity).
|
||||||
func jsValEscaper(args ...interface{}) string {
|
func jsValEscaper(args ...interface{}) string {
|
||||||
var a interface{}
|
var a interface{}
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
a = args[0]
|
a = indirectToJSONMarshaler(args[0])
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case JS:
|
case JS:
|
||||||
return string(t)
|
return string(t)
|
||||||
|
@ -135,6 +148,9 @@ func jsValEscaper(args ...interface{}) string {
|
||||||
a = t.String()
|
a = t.String()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
for i, arg := range args {
|
||||||
|
args[i] = indirectToJSONMarshaler(arg)
|
||||||
|
}
|
||||||
a = fmt.Sprint(args...)
|
a = fmt.Sprint(args...)
|
||||||
}
|
}
|
||||||
// TODO: detect cycles before calling Marshal which loops infinitely on
|
// TODO: detect cycles before calling Marshal which loops infinitely on
|
||||||
|
|
|
@ -401,14 +401,14 @@ func (z *Tokenizer) readStartTag() TokenType {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Any "<noembed>", "<noframes>", "<noscript>", "<script>", "<style>",
|
// Any "<noembed>", "<noframes>", "<noscript>", "<plaintext", "<script>", "<style>",
|
||||||
// "<textarea>" or "<title>" tag flags the tokenizer's next token as raw.
|
// "<textarea>" or "<title>" tag flags the tokenizer's next token as raw.
|
||||||
// The tag name lengths of these special cases ranges in [5, 8].
|
// The tag name lengths of these special cases ranges in [5, 9].
|
||||||
if x := z.data.end - z.data.start; 5 <= x && x <= 8 {
|
if x := z.data.end - z.data.start; 5 <= x && x <= 9 {
|
||||||
switch z.buf[z.data.start] {
|
switch z.buf[z.data.start] {
|
||||||
case 'n', 's', 't', 'N', 'S', 'T':
|
case 'n', 'p', 's', 't', 'N', 'P', 'S', 'T':
|
||||||
switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s {
|
switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s {
|
||||||
case "noembed", "noframes", "noscript", "script", "style", "textarea", "title":
|
case "noembed", "noframes", "noscript", "plaintext", "script", "style", "textarea", "title":
|
||||||
z.rawTag = s
|
z.rawTag = s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -551,10 +551,20 @@ func (z *Tokenizer) Next() TokenType {
|
||||||
z.data.start = z.raw.end
|
z.data.start = z.raw.end
|
||||||
z.data.end = z.raw.end
|
z.data.end = z.raw.end
|
||||||
if z.rawTag != "" {
|
if z.rawTag != "" {
|
||||||
|
if z.rawTag == "plaintext" {
|
||||||
|
// Read everything up to EOF.
|
||||||
|
for z.err == nil {
|
||||||
|
z.readByte()
|
||||||
|
}
|
||||||
|
z.textIsRaw = true
|
||||||
|
} else {
|
||||||
z.readRawOrRCDATA()
|
z.readRawOrRCDATA()
|
||||||
|
}
|
||||||
|
if z.data.end > z.data.start {
|
||||||
z.tt = TextToken
|
z.tt = TextToken
|
||||||
return z.tt
|
return z.tt
|
||||||
}
|
}
|
||||||
|
}
|
||||||
z.textIsRaw = false
|
z.textIsRaw = false
|
||||||
|
|
||||||
loop:
|
loop:
|
||||||
|
|
|
@ -4,10 +4,7 @@
|
||||||
|
|
||||||
package tiff
|
package tiff
|
||||||
|
|
||||||
import (
|
import "io"
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// buffer buffers an io.Reader to satisfy io.ReaderAt.
|
// buffer buffers an io.Reader to satisfy io.ReaderAt.
|
||||||
type buffer struct {
|
type buffer struct {
|
||||||
|
@ -19,7 +16,7 @@ func (b *buffer) ReadAt(p []byte, off int64) (int, error) {
|
||||||
o := int(off)
|
o := int(off)
|
||||||
end := o + len(p)
|
end := o + len(p)
|
||||||
if int64(end) != off+int64(len(p)) {
|
if int64(end) != off+int64(len(p)) {
|
||||||
return 0, os.EINVAL
|
return 0, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
||||||
m := len(b.buf)
|
m := len(b.buf)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Random number state, accessed without lock; racy but harmless.
|
// Random number state, accessed without lock; racy but harmless.
|
||||||
|
@ -17,8 +18,7 @@ import (
|
||||||
var rand uint32
|
var rand uint32
|
||||||
|
|
||||||
func reseed() uint32 {
|
func reseed() uint32 {
|
||||||
sec, nsec, _ := os.Time()
|
return uint32(time.Nanoseconds() + int64(os.Getpid()))
|
||||||
return uint32(sec*1e9 + nsec + int64(os.Getpid()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func nextSuffix() string {
|
func nextSuffix() string {
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
package syslog
|
package syslog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
@ -75,7 +76,7 @@ func Dial(network, raddr string, priority Priority, prefix string) (w *Writer, e
|
||||||
// Write sends a log message to the syslog daemon.
|
// Write sends a log message to the syslog daemon.
|
||||||
func (w *Writer) Write(b []byte) (int, error) {
|
func (w *Writer) Write(b []byte) (int, error) {
|
||||||
if w.priority > LOG_DEBUG || w.priority < LOG_EMERG {
|
if w.priority > LOG_DEBUG || w.priority < LOG_EMERG {
|
||||||
return 0, os.EINVAL
|
return 0, errors.New("log/syslog: invalid priority")
|
||||||
}
|
}
|
||||||
return w.conn.writeBytes(w.priority, w.prefix, b)
|
return w.conn.writeBytes(w.priority, w.prefix, b)
|
||||||
}
|
}
|
||||||
|
|
|
@ -176,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int {
|
||||||
// If y == 0, a division-by-zero run-time panic occurs.
|
// If y == 0, a division-by-zero run-time panic occurs.
|
||||||
// Rem implements truncated modulus (like Go); see QuoRem for more details.
|
// Rem implements truncated modulus (like Go); see QuoRem for more details.
|
||||||
func (z *Int) Rem(x, y *Int) *Int {
|
func (z *Int) Rem(x, y *Int) *Int {
|
||||||
_, z.abs = nat{}.div(z.abs, x.abs, y.abs)
|
_, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
|
||||||
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
|
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
|
||||||
return z
|
return z
|
||||||
}
|
}
|
||||||
|
@ -678,14 +678,14 @@ func (z *Int) Bit(i int) uint {
|
||||||
panic("negative bit index")
|
panic("negative bit index")
|
||||||
}
|
}
|
||||||
if z.neg {
|
if z.neg {
|
||||||
t := nat{}.sub(z.abs, natOne)
|
t := nat(nil).sub(z.abs, natOne)
|
||||||
return t.bit(uint(i)) ^ 1
|
return t.bit(uint(i)) ^ 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return z.abs.bit(uint(i))
|
return z.abs.bit(uint(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBit sets the i'th bit of z to bit and returns z.
|
// SetBit sets z to x, with x's i'th bit set to b (0 or 1).
|
||||||
// That is, if bit is 1 SetBit sets z = x | (1 << i);
|
// That is, if bit is 1 SetBit sets z = x | (1 << i);
|
||||||
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
|
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
|
||||||
// SetBit will panic.
|
// SetBit will panic.
|
||||||
|
@ -710,8 +710,8 @@ func (z *Int) And(x, y *Int) *Int {
|
||||||
if x.neg == y.neg {
|
if x.neg == y.neg {
|
||||||
if x.neg {
|
if x.neg {
|
||||||
// (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1)
|
// (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1)
|
||||||
x1 := nat{}.sub(x.abs, natOne)
|
x1 := nat(nil).sub(x.abs, natOne)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
|
z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
|
||||||
z.neg = true // z cannot be zero if x and y are negative
|
z.neg = true // z cannot be zero if x and y are negative
|
||||||
return z
|
return z
|
||||||
|
@ -729,7 +729,7 @@ func (z *Int) And(x, y *Int) *Int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// x & (-y) == x & ^(y-1) == x &^ (y-1)
|
// x & (-y) == x & ^(y-1) == x &^ (y-1)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.andNot(x.abs, y1)
|
z.abs = z.abs.andNot(x.abs, y1)
|
||||||
z.neg = false
|
z.neg = false
|
||||||
return z
|
return z
|
||||||
|
@ -740,8 +740,8 @@ func (z *Int) AndNot(x, y *Int) *Int {
|
||||||
if x.neg == y.neg {
|
if x.neg == y.neg {
|
||||||
if x.neg {
|
if x.neg {
|
||||||
// (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1)
|
// (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1)
|
||||||
x1 := nat{}.sub(x.abs, natOne)
|
x1 := nat(nil).sub(x.abs, natOne)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.andNot(y1, x1)
|
z.abs = z.abs.andNot(y1, x1)
|
||||||
z.neg = false
|
z.neg = false
|
||||||
return z
|
return z
|
||||||
|
@ -755,14 +755,14 @@ func (z *Int) AndNot(x, y *Int) *Int {
|
||||||
|
|
||||||
if x.neg {
|
if x.neg {
|
||||||
// (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1)
|
// (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1)
|
||||||
x1 := nat{}.sub(x.abs, natOne)
|
x1 := nat(nil).sub(x.abs, natOne)
|
||||||
z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
|
z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
|
||||||
z.neg = true // z cannot be zero if x is negative and y is positive
|
z.neg = true // z cannot be zero if x is negative and y is positive
|
||||||
return z
|
return z
|
||||||
}
|
}
|
||||||
|
|
||||||
// x &^ (-y) == x &^ ^(y-1) == x & (y-1)
|
// x &^ (-y) == x &^ ^(y-1) == x & (y-1)
|
||||||
y1 := nat{}.add(y.abs, natOne)
|
y1 := nat(nil).add(y.abs, natOne)
|
||||||
z.abs = z.abs.and(x.abs, y1)
|
z.abs = z.abs.and(x.abs, y1)
|
||||||
z.neg = false
|
z.neg = false
|
||||||
return z
|
return z
|
||||||
|
@ -773,8 +773,8 @@ func (z *Int) Or(x, y *Int) *Int {
|
||||||
if x.neg == y.neg {
|
if x.neg == y.neg {
|
||||||
if x.neg {
|
if x.neg {
|
||||||
// (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1)
|
// (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1)
|
||||||
x1 := nat{}.sub(x.abs, natOne)
|
x1 := nat(nil).sub(x.abs, natOne)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
|
z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
|
||||||
z.neg = true // z cannot be zero if x and y are negative
|
z.neg = true // z cannot be zero if x and y are negative
|
||||||
return z
|
return z
|
||||||
|
@ -792,7 +792,7 @@ func (z *Int) Or(x, y *Int) *Int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1)
|
// x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
|
z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
|
||||||
z.neg = true // z cannot be zero if one of x or y is negative
|
z.neg = true // z cannot be zero if one of x or y is negative
|
||||||
return z
|
return z
|
||||||
|
@ -803,8 +803,8 @@ func (z *Int) Xor(x, y *Int) *Int {
|
||||||
if x.neg == y.neg {
|
if x.neg == y.neg {
|
||||||
if x.neg {
|
if x.neg {
|
||||||
// (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1)
|
// (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1)
|
||||||
x1 := nat{}.sub(x.abs, natOne)
|
x1 := nat(nil).sub(x.abs, natOne)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.xor(x1, y1)
|
z.abs = z.abs.xor(x1, y1)
|
||||||
z.neg = false
|
z.neg = false
|
||||||
return z
|
return z
|
||||||
|
@ -822,7 +822,7 @@ func (z *Int) Xor(x, y *Int) *Int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1)
|
// x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1)
|
||||||
y1 := nat{}.sub(y.abs, natOne)
|
y1 := nat(nil).sub(y.abs, natOne)
|
||||||
z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
|
z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
|
||||||
z.neg = true // z cannot be zero if only one of x or y is negative
|
z.neg = true // z cannot be zero if only one of x or y is negative
|
||||||
return z
|
return z
|
||||||
|
|
|
@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat {
|
||||||
case a == b:
|
case a == b:
|
||||||
return z.setUint64(a)
|
return z.setUint64(a)
|
||||||
case a+1 == b:
|
case a+1 == b:
|
||||||
return z.mul(nat{}.setUint64(a), nat{}.setUint64(b))
|
return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
|
||||||
}
|
}
|
||||||
m := (a + b) / 2
|
m := (a + b) / 2
|
||||||
return z.mul(nat{}.mulRange(a, m), nat{}.mulRange(m+1, b))
|
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
|
||||||
}
|
}
|
||||||
|
|
||||||
// q = (x-r)/y, with 0 <= r < y
|
// q = (x-r)/y, with 0 <= r < y
|
||||||
|
@ -785,7 +785,7 @@ func (x nat) string(charset string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// preserve x, create local copy for use in repeated divisions
|
// preserve x, create local copy for use in repeated divisions
|
||||||
q := nat{}.set(x)
|
q := nat(nil).set(x)
|
||||||
var r Word
|
var r Word
|
||||||
|
|
||||||
// convert
|
// convert
|
||||||
|
@ -1191,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
nm1 := nat{}.sub(n, natOne)
|
nm1 := nat(nil).sub(n, natOne)
|
||||||
// 1<<k * q = nm1;
|
// 1<<k * q = nm1;
|
||||||
q, k := nm1.powersOfTwoDecompose()
|
q, k := nm1.powersOfTwoDecompose()
|
||||||
|
|
||||||
nm3 := nat{}.sub(nm1, natTwo)
|
nm3 := nat(nil).sub(nm1, natTwo)
|
||||||
rand := rand.New(rand.NewSource(int64(n[0])))
|
rand := rand.New(rand.NewSource(int64(n[0])))
|
||||||
|
|
||||||
var x, y, quotient nat
|
var x, y, quotient nat
|
||||||
|
|
|
@ -16,9 +16,9 @@ var cmpTests = []struct {
|
||||||
r int
|
r int
|
||||||
}{
|
}{
|
||||||
{nil, nil, 0},
|
{nil, nil, 0},
|
||||||
{nil, nat{}, 0},
|
{nil, nat(nil), 0},
|
||||||
{nat{}, nil, 0},
|
{nat(nil), nil, 0},
|
||||||
{nat{}, nat{}, 0},
|
{nat(nil), nat(nil), 0},
|
||||||
{nat{0}, nat{0}, 0},
|
{nat{0}, nat{0}, 0},
|
||||||
{nat{0}, nat{1}, -1},
|
{nat{0}, nat{1}, -1},
|
||||||
{nat{1}, nat{0}, 1},
|
{nat{1}, nat{0}, 1},
|
||||||
|
@ -67,7 +67,7 @@ var prodNN = []argNN{
|
||||||
|
|
||||||
func TestSet(t *testing.T) {
|
func TestSet(t *testing.T) {
|
||||||
for _, a := range sumNN {
|
for _, a := range sumNN {
|
||||||
z := nat{}.set(a.z)
|
z := nat(nil).set(a.z)
|
||||||
if z.cmp(a.z) != 0 {
|
if z.cmp(a.z) != 0 {
|
||||||
t.Errorf("got z = %v; want %v", z, a.z)
|
t.Errorf("got z = %v; want %v", z, a.z)
|
||||||
}
|
}
|
||||||
|
@ -129,7 +129,7 @@ var mulRangesN = []struct {
|
||||||
|
|
||||||
func TestMulRangeN(t *testing.T) {
|
func TestMulRangeN(t *testing.T) {
|
||||||
for i, r := range mulRangesN {
|
for i, r := range mulRangesN {
|
||||||
prod := nat{}.mulRange(r.a, r.b).decimalString()
|
prod := nat(nil).mulRange(r.a, r.b).decimalString()
|
||||||
if prod != r.prod {
|
if prod != r.prod {
|
||||||
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
|
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
|
||||||
}
|
}
|
||||||
|
@ -175,7 +175,7 @@ func toString(x nat, charset string) string {
|
||||||
s := make([]byte, i)
|
s := make([]byte, i)
|
||||||
|
|
||||||
// don't destroy x
|
// don't destroy x
|
||||||
q := nat{}.set(x)
|
q := nat(nil).set(x)
|
||||||
|
|
||||||
// convert
|
// convert
|
||||||
for len(q) > 0 {
|
for len(q) > 0 {
|
||||||
|
@ -212,7 +212,7 @@ func TestString(t *testing.T) {
|
||||||
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
|
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
|
||||||
}
|
}
|
||||||
|
|
||||||
x, b, err := nat{}.scan(strings.NewReader(a.s), len(a.c))
|
x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
|
||||||
if x.cmp(a.x) != 0 {
|
if x.cmp(a.x) != 0 {
|
||||||
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
|
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,7 @@ var natScanTests = []struct {
|
||||||
func TestScanBase(t *testing.T) {
|
func TestScanBase(t *testing.T) {
|
||||||
for _, a := range natScanTests {
|
for _, a := range natScanTests {
|
||||||
r := strings.NewReader(a.s)
|
r := strings.NewReader(a.s)
|
||||||
x, b, err := nat{}.scan(r, a.base)
|
x, b, err := nat(nil).scan(r, a.base)
|
||||||
if err == nil && !a.ok {
|
if err == nil && !a.ok {
|
||||||
t.Errorf("scan%+v\n\texpected error", a)
|
t.Errorf("scan%+v\n\texpected error", a)
|
||||||
}
|
}
|
||||||
|
@ -651,17 +651,17 @@ var expNNTests = []struct {
|
||||||
|
|
||||||
func TestExpNN(t *testing.T) {
|
func TestExpNN(t *testing.T) {
|
||||||
for i, test := range expNNTests {
|
for i, test := range expNNTests {
|
||||||
x, _, _ := nat{}.scan(strings.NewReader(test.x), 0)
|
x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
|
||||||
y, _, _ := nat{}.scan(strings.NewReader(test.y), 0)
|
y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
|
||||||
out, _, _ := nat{}.scan(strings.NewReader(test.out), 0)
|
out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
|
||||||
|
|
||||||
var m nat
|
var m nat
|
||||||
|
|
||||||
if len(test.m) > 0 {
|
if len(test.m) > 0 {
|
||||||
m, _, _ = nat{}.scan(strings.NewReader(test.m), 0)
|
m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
z := nat{}.expNN(x, y, m)
|
z := nat(nil).expNN(x, y, m)
|
||||||
if z.cmp(out) != 0 {
|
if z.cmp(out) != 0 {
|
||||||
t.Errorf("#%d got %v want %v", i, z, out)
|
t.Errorf("#%d got %v want %v", i, z, out)
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
|
||||||
panic("division by zero")
|
panic("division by zero")
|
||||||
}
|
}
|
||||||
if &z.a == b || alias(z.a.abs, babs) {
|
if &z.a == b || alias(z.a.abs, babs) {
|
||||||
babs = nat{}.set(babs) // make a copy
|
babs = nat(nil).set(babs) // make a copy
|
||||||
}
|
}
|
||||||
z.a.abs = z.a.abs.set(a.abs)
|
z.a.abs = z.a.abs.set(a.abs)
|
||||||
z.b = z.b.set(babs)
|
z.b = z.b.set(babs)
|
||||||
|
@ -315,7 +315,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
|
||||||
if _, ok := z.a.SetString(s, 10); !ok {
|
if _, ok := z.a.SetString(s, 10); !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
powTen := nat{}.expNN(natTen, exp.abs, nil)
|
powTen := nat(nil).expNN(natTen, exp.abs, nil)
|
||||||
if exp.neg {
|
if exp.neg {
|
||||||
z.b = powTen
|
z.b = powTen
|
||||||
z.norm()
|
z.norm()
|
||||||
|
@ -357,23 +357,23 @@ func (z *Rat) FloatString(prec int) string {
|
||||||
}
|
}
|
||||||
// z.b != 0
|
// z.b != 0
|
||||||
|
|
||||||
q, r := nat{}.div(nat{}, z.a.abs, z.b)
|
q, r := nat(nil).div(nat(nil), z.a.abs, z.b)
|
||||||
|
|
||||||
p := natOne
|
p := natOne
|
||||||
if prec > 0 {
|
if prec > 0 {
|
||||||
p = nat{}.expNN(natTen, nat{}.setUint64(uint64(prec)), nil)
|
p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
r = r.mul(r, p)
|
r = r.mul(r, p)
|
||||||
r, r2 := r.div(nat{}, r, z.b)
|
r, r2 := r.div(nat(nil), r, z.b)
|
||||||
|
|
||||||
// see if we need to round up
|
// see if we need to round up
|
||||||
r2 = r2.add(r2, r2)
|
r2 = r2.add(r2, r2)
|
||||||
if z.b.cmp(r2) <= 0 {
|
if z.b.cmp(r2) <= 0 {
|
||||||
r = r.add(r, natOne)
|
r = r.add(r, natOne)
|
||||||
if r.cmp(p) >= 0 {
|
if r.cmp(p) >= 0 {
|
||||||
q = nat{}.add(q, natOne)
|
q = nat(nil).add(q, natOne)
|
||||||
r = nat{}.sub(r, p)
|
r = nat(nil).sub(r, p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ package math
|
||||||
// Stephen L. Moshier
|
// Stephen L. Moshier
|
||||||
// moshier@na-net.ornl.gov
|
// moshier@na-net.ornl.gov
|
||||||
|
|
||||||
var _P = [...]float64{
|
var _gamP = [...]float64{
|
||||||
1.60119522476751861407e-04,
|
1.60119522476751861407e-04,
|
||||||
1.19135147006586384913e-03,
|
1.19135147006586384913e-03,
|
||||||
1.04213797561761569935e-02,
|
1.04213797561761569935e-02,
|
||||||
|
@ -72,7 +72,7 @@ var _P = [...]float64{
|
||||||
4.94214826801497100753e-01,
|
4.94214826801497100753e-01,
|
||||||
9.99999999999999996796e-01,
|
9.99999999999999996796e-01,
|
||||||
}
|
}
|
||||||
var _Q = [...]float64{
|
var _gamQ = [...]float64{
|
||||||
-2.31581873324120129819e-05,
|
-2.31581873324120129819e-05,
|
||||||
5.39605580493303397842e-04,
|
5.39605580493303397842e-04,
|
||||||
-4.45641913851797240494e-03,
|
-4.45641913851797240494e-03,
|
||||||
|
@ -82,7 +82,7 @@ var _Q = [...]float64{
|
||||||
7.14304917030273074085e-02,
|
7.14304917030273074085e-02,
|
||||||
1.00000000000000000320e+00,
|
1.00000000000000000320e+00,
|
||||||
}
|
}
|
||||||
var _S = [...]float64{
|
var _gamS = [...]float64{
|
||||||
7.87311395793093628397e-04,
|
7.87311395793093628397e-04,
|
||||||
-2.29549961613378126380e-04,
|
-2.29549961613378126380e-04,
|
||||||
-2.68132617805781232825e-03,
|
-2.68132617805781232825e-03,
|
||||||
|
@ -98,7 +98,7 @@ func stirling(x float64) float64 {
|
||||||
MaxStirling = 143.01608
|
MaxStirling = 143.01608
|
||||||
)
|
)
|
||||||
w := 1 / x
|
w := 1 / x
|
||||||
w = 1 + w*((((_S[0]*w+_S[1])*w+_S[2])*w+_S[3])*w+_S[4])
|
w = 1 + w*((((_gamS[0]*w+_gamS[1])*w+_gamS[2])*w+_gamS[3])*w+_gamS[4])
|
||||||
y := Exp(x)
|
y := Exp(x)
|
||||||
if x > MaxStirling { // avoid Pow() overflow
|
if x > MaxStirling { // avoid Pow() overflow
|
||||||
v := Pow(x, 0.5*x-0.25)
|
v := Pow(x, 0.5*x-0.25)
|
||||||
|
@ -176,8 +176,8 @@ func Gamma(x float64) float64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
x = x - 2
|
x = x - 2
|
||||||
p = (((((x*_P[0]+_P[1])*x+_P[2])*x+_P[3])*x+_P[4])*x+_P[5])*x + _P[6]
|
p = (((((x*_gamP[0]+_gamP[1])*x+_gamP[2])*x+_gamP[3])*x+_gamP[4])*x+_gamP[5])*x + _gamP[6]
|
||||||
q = ((((((x*_Q[0]+_Q[1])*x+_Q[2])*x+_Q[3])*x+_Q[4])*x+_Q[5])*x+_Q[6])*x + _Q[7]
|
q = ((((((x*_gamQ[0]+_gamQ[1])*x+_gamQ[2])*x+_gamQ[3])*x+_gamQ[4])*x+_gamQ[5])*x+_gamQ[6])*x + _gamQ[7]
|
||||||
return z * p / q
|
return z * p / q
|
||||||
|
|
||||||
small:
|
small:
|
||||||
|
|
|
@ -88,6 +88,81 @@ package math
|
||||||
//
|
//
|
||||||
//
|
//
|
||||||
|
|
||||||
|
var _lgamA = [...]float64{
|
||||||
|
7.72156649015328655494e-02, // 0x3FB3C467E37DB0C8
|
||||||
|
3.22467033424113591611e-01, // 0x3FD4A34CC4A60FAD
|
||||||
|
6.73523010531292681824e-02, // 0x3FB13E001A5562A7
|
||||||
|
2.05808084325167332806e-02, // 0x3F951322AC92547B
|
||||||
|
7.38555086081402883957e-03, // 0x3F7E404FB68FEFE8
|
||||||
|
2.89051383673415629091e-03, // 0x3F67ADD8CCB7926B
|
||||||
|
1.19270763183362067845e-03, // 0x3F538A94116F3F5D
|
||||||
|
5.10069792153511336608e-04, // 0x3F40B6C689B99C00
|
||||||
|
2.20862790713908385557e-04, // 0x3F2CF2ECED10E54D
|
||||||
|
1.08011567247583939954e-04, // 0x3F1C5088987DFB07
|
||||||
|
2.52144565451257326939e-05, // 0x3EFA7074428CFA52
|
||||||
|
4.48640949618915160150e-05, // 0x3F07858E90A45837
|
||||||
|
}
|
||||||
|
var _lgamR = [...]float64{
|
||||||
|
1.0, // placeholder
|
||||||
|
1.39200533467621045958e+00, // 0x3FF645A762C4AB74
|
||||||
|
7.21935547567138069525e-01, // 0x3FE71A1893D3DCDC
|
||||||
|
1.71933865632803078993e-01, // 0x3FC601EDCCFBDF27
|
||||||
|
1.86459191715652901344e-02, // 0x3F9317EA742ED475
|
||||||
|
7.77942496381893596434e-04, // 0x3F497DDACA41A95B
|
||||||
|
7.32668430744625636189e-06, // 0x3EDEBAF7A5B38140
|
||||||
|
}
|
||||||
|
var _lgamS = [...]float64{
|
||||||
|
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
|
||||||
|
2.14982415960608852501e-01, // 0x3FCB848B36E20878
|
||||||
|
3.25778796408930981787e-01, // 0x3FD4D98F4F139F59
|
||||||
|
1.46350472652464452805e-01, // 0x3FC2BB9CBEE5F2F7
|
||||||
|
2.66422703033638609560e-02, // 0x3F9B481C7E939961
|
||||||
|
1.84028451407337715652e-03, // 0x3F5E26B67368F239
|
||||||
|
3.19475326584100867617e-05, // 0x3F00BFECDD17E945
|
||||||
|
}
|
||||||
|
var _lgamT = [...]float64{
|
||||||
|
4.83836122723810047042e-01, // 0x3FDEF72BC8EE38A2
|
||||||
|
-1.47587722994593911752e-01, // 0xBFC2E4278DC6C509
|
||||||
|
6.46249402391333854778e-02, // 0x3FB08B4294D5419B
|
||||||
|
-3.27885410759859649565e-02, // 0xBFA0C9A8DF35B713
|
||||||
|
1.79706750811820387126e-02, // 0x3F9266E7970AF9EC
|
||||||
|
-1.03142241298341437450e-02, // 0xBF851F9FBA91EC6A
|
||||||
|
6.10053870246291332635e-03, // 0x3F78FCE0E370E344
|
||||||
|
-3.68452016781138256760e-03, // 0xBF6E2EFFB3E914D7
|
||||||
|
2.25964780900612472250e-03, // 0x3F6282D32E15C915
|
||||||
|
-1.40346469989232843813e-03, // 0xBF56FE8EBF2D1AF1
|
||||||
|
8.81081882437654011382e-04, // 0x3F4CDF0CEF61A8E9
|
||||||
|
-5.38595305356740546715e-04, // 0xBF41A6109C73E0EC
|
||||||
|
3.15632070903625950361e-04, // 0x3F34AF6D6C0EBBF7
|
||||||
|
-3.12754168375120860518e-04, // 0xBF347F24ECC38C38
|
||||||
|
3.35529192635519073543e-04, // 0x3F35FD3EE8C2D3F4
|
||||||
|
}
|
||||||
|
var _lgamU = [...]float64{
|
||||||
|
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
|
||||||
|
6.32827064025093366517e-01, // 0x3FE4401E8B005DFF
|
||||||
|
1.45492250137234768737e+00, // 0x3FF7475CD119BD6F
|
||||||
|
9.77717527963372745603e-01, // 0x3FEF497644EA8450
|
||||||
|
2.28963728064692451092e-01, // 0x3FCD4EAEF6010924
|
||||||
|
1.33810918536787660377e-02, // 0x3F8B678BBF2BAB09
|
||||||
|
}
|
||||||
|
var _lgamV = [...]float64{
|
||||||
|
1.0,
|
||||||
|
2.45597793713041134822e+00, // 0x4003A5D7C2BD619C
|
||||||
|
2.12848976379893395361e+00, // 0x40010725A42B18F5
|
||||||
|
7.69285150456672783825e-01, // 0x3FE89DFBE45050AF
|
||||||
|
1.04222645593369134254e-01, // 0x3FBAAE55D6537C88
|
||||||
|
3.21709242282423911810e-03, // 0x3F6A5ABB57D0CF61
|
||||||
|
}
|
||||||
|
var _lgamW = [...]float64{
|
||||||
|
4.18938533204672725052e-01, // 0x3FDACFE390C97D69
|
||||||
|
8.33333333333329678849e-02, // 0x3FB555555555553B
|
||||||
|
-2.77777777728775536470e-03, // 0xBF66C16C16B02E5C
|
||||||
|
7.93650558643019558500e-04, // 0x3F4A019F98CF38B6
|
||||||
|
-5.95187557450339963135e-04, // 0xBF4380CB8C0FE741
|
||||||
|
8.36339918996282139126e-04, // 0x3F4B67BA4CDAD5D1
|
||||||
|
-1.63092934096575273989e-03, // 0xBF5AB89D0B9E43E4
|
||||||
|
}
|
||||||
|
|
||||||
// Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x).
|
// Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x).
|
||||||
//
|
//
|
||||||
// Special cases are:
|
// Special cases are:
|
||||||
|
@ -103,68 +178,10 @@ func Lgamma(x float64) (lgamma float64, sign int) {
|
||||||
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
|
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
|
||||||
Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17
|
Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17
|
||||||
Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22
|
Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22
|
||||||
A0 = 7.72156649015328655494e-02 // 0x3FB3C467E37DB0C8
|
|
||||||
A1 = 3.22467033424113591611e-01 // 0x3FD4A34CC4A60FAD
|
|
||||||
A2 = 6.73523010531292681824e-02 // 0x3FB13E001A5562A7
|
|
||||||
A3 = 2.05808084325167332806e-02 // 0x3F951322AC92547B
|
|
||||||
A4 = 7.38555086081402883957e-03 // 0x3F7E404FB68FEFE8
|
|
||||||
A5 = 2.89051383673415629091e-03 // 0x3F67ADD8CCB7926B
|
|
||||||
A6 = 1.19270763183362067845e-03 // 0x3F538A94116F3F5D
|
|
||||||
A7 = 5.10069792153511336608e-04 // 0x3F40B6C689B99C00
|
|
||||||
A8 = 2.20862790713908385557e-04 // 0x3F2CF2ECED10E54D
|
|
||||||
A9 = 1.08011567247583939954e-04 // 0x3F1C5088987DFB07
|
|
||||||
A10 = 2.52144565451257326939e-05 // 0x3EFA7074428CFA52
|
|
||||||
A11 = 4.48640949618915160150e-05 // 0x3F07858E90A45837
|
|
||||||
Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F
|
Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F
|
||||||
Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42
|
Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42
|
||||||
// Tt = -(tail of Tf)
|
// Tt = -(tail of Tf)
|
||||||
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
|
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
|
||||||
T0 = 4.83836122723810047042e-01 // 0x3FDEF72BC8EE38A2
|
|
||||||
T1 = -1.47587722994593911752e-01 // 0xBFC2E4278DC6C509
|
|
||||||
T2 = 6.46249402391333854778e-02 // 0x3FB08B4294D5419B
|
|
||||||
T3 = -3.27885410759859649565e-02 // 0xBFA0C9A8DF35B713
|
|
||||||
T4 = 1.79706750811820387126e-02 // 0x3F9266E7970AF9EC
|
|
||||||
T5 = -1.03142241298341437450e-02 // 0xBF851F9FBA91EC6A
|
|
||||||
T6 = 6.10053870246291332635e-03 // 0x3F78FCE0E370E344
|
|
||||||
T7 = -3.68452016781138256760e-03 // 0xBF6E2EFFB3E914D7
|
|
||||||
T8 = 2.25964780900612472250e-03 // 0x3F6282D32E15C915
|
|
||||||
T9 = -1.40346469989232843813e-03 // 0xBF56FE8EBF2D1AF1
|
|
||||||
T10 = 8.81081882437654011382e-04 // 0x3F4CDF0CEF61A8E9
|
|
||||||
T11 = -5.38595305356740546715e-04 // 0xBF41A6109C73E0EC
|
|
||||||
T12 = 3.15632070903625950361e-04 // 0x3F34AF6D6C0EBBF7
|
|
||||||
T13 = -3.12754168375120860518e-04 // 0xBF347F24ECC38C38
|
|
||||||
T14 = 3.35529192635519073543e-04 // 0x3F35FD3EE8C2D3F4
|
|
||||||
U0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
|
|
||||||
U1 = 6.32827064025093366517e-01 // 0x3FE4401E8B005DFF
|
|
||||||
U2 = 1.45492250137234768737e+00 // 0x3FF7475CD119BD6F
|
|
||||||
U3 = 9.77717527963372745603e-01 // 0x3FEF497644EA8450
|
|
||||||
U4 = 2.28963728064692451092e-01 // 0x3FCD4EAEF6010924
|
|
||||||
U5 = 1.33810918536787660377e-02 // 0x3F8B678BBF2BAB09
|
|
||||||
V1 = 2.45597793713041134822e+00 // 0x4003A5D7C2BD619C
|
|
||||||
V2 = 2.12848976379893395361e+00 // 0x40010725A42B18F5
|
|
||||||
V3 = 7.69285150456672783825e-01 // 0x3FE89DFBE45050AF
|
|
||||||
V4 = 1.04222645593369134254e-01 // 0x3FBAAE55D6537C88
|
|
||||||
V5 = 3.21709242282423911810e-03 // 0x3F6A5ABB57D0CF61
|
|
||||||
S0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
|
|
||||||
S1 = 2.14982415960608852501e-01 // 0x3FCB848B36E20878
|
|
||||||
S2 = 3.25778796408930981787e-01 // 0x3FD4D98F4F139F59
|
|
||||||
S3 = 1.46350472652464452805e-01 // 0x3FC2BB9CBEE5F2F7
|
|
||||||
S4 = 2.66422703033638609560e-02 // 0x3F9B481C7E939961
|
|
||||||
S5 = 1.84028451407337715652e-03 // 0x3F5E26B67368F239
|
|
||||||
S6 = 3.19475326584100867617e-05 // 0x3F00BFECDD17E945
|
|
||||||
R1 = 1.39200533467621045958e+00 // 0x3FF645A762C4AB74
|
|
||||||
R2 = 7.21935547567138069525e-01 // 0x3FE71A1893D3DCDC
|
|
||||||
R3 = 1.71933865632803078993e-01 // 0x3FC601EDCCFBDF27
|
|
||||||
R4 = 1.86459191715652901344e-02 // 0x3F9317EA742ED475
|
|
||||||
R5 = 7.77942496381893596434e-04 // 0x3F497DDACA41A95B
|
|
||||||
R6 = 7.32668430744625636189e-06 // 0x3EDEBAF7A5B38140
|
|
||||||
W0 = 4.18938533204672725052e-01 // 0x3FDACFE390C97D69
|
|
||||||
W1 = 8.33333333333329678849e-02 // 0x3FB555555555553B
|
|
||||||
W2 = -2.77777777728775536470e-03 // 0xBF66C16C16B02E5C
|
|
||||||
W3 = 7.93650558643019558500e-04 // 0x3F4A019F98CF38B6
|
|
||||||
W4 = -5.95187557450339963135e-04 // 0xBF4380CB8C0FE741
|
|
||||||
W5 = 8.36339918996282139126e-04 // 0x3F4B67BA4CDAD5D1
|
|
||||||
W6 = -1.63092934096575273989e-03 // 0xBF5AB89D0B9E43E4
|
|
||||||
)
|
)
|
||||||
// TODO(rsc): Remove manual inlining of IsNaN, IsInf
|
// TODO(rsc): Remove manual inlining of IsNaN, IsInf
|
||||||
// when compiler does it for us
|
// when compiler does it for us
|
||||||
|
@ -249,28 +266,28 @@ func Lgamma(x float64) (lgamma float64, sign int) {
|
||||||
switch i {
|
switch i {
|
||||||
case 0:
|
case 0:
|
||||||
z := y * y
|
z := y * y
|
||||||
p1 := A0 + z*(A2+z*(A4+z*(A6+z*(A8+z*A10))))
|
p1 := _lgamA[0] + z*(_lgamA[2]+z*(_lgamA[4]+z*(_lgamA[6]+z*(_lgamA[8]+z*_lgamA[10]))))
|
||||||
p2 := z * (A1 + z*(A3+z*(A5+z*(A7+z*(A9+z*A11)))))
|
p2 := z * (_lgamA[1] + z*(+_lgamA[3]+z*(_lgamA[5]+z*(_lgamA[7]+z*(_lgamA[9]+z*_lgamA[11])))))
|
||||||
p := y*p1 + p2
|
p := y*p1 + p2
|
||||||
lgamma += (p - 0.5*y)
|
lgamma += (p - 0.5*y)
|
||||||
case 1:
|
case 1:
|
||||||
z := y * y
|
z := y * y
|
||||||
w := z * y
|
w := z * y
|
||||||
p1 := T0 + w*(T3+w*(T6+w*(T9+w*T12))) // parallel comp
|
p1 := _lgamT[0] + w*(_lgamT[3]+w*(_lgamT[6]+w*(_lgamT[9]+w*_lgamT[12]))) // parallel comp
|
||||||
p2 := T1 + w*(T4+w*(T7+w*(T10+w*T13)))
|
p2 := _lgamT[1] + w*(_lgamT[4]+w*(_lgamT[7]+w*(_lgamT[10]+w*_lgamT[13])))
|
||||||
p3 := T2 + w*(T5+w*(T8+w*(T11+w*T14)))
|
p3 := _lgamT[2] + w*(_lgamT[5]+w*(_lgamT[8]+w*(_lgamT[11]+w*_lgamT[14])))
|
||||||
p := z*p1 - (Tt - w*(p2+y*p3))
|
p := z*p1 - (Tt - w*(p2+y*p3))
|
||||||
lgamma += (Tf + p)
|
lgamma += (Tf + p)
|
||||||
case 2:
|
case 2:
|
||||||
p1 := y * (U0 + y*(U1+y*(U2+y*(U3+y*(U4+y*U5)))))
|
p1 := y * (_lgamU[0] + y*(_lgamU[1]+y*(_lgamU[2]+y*(_lgamU[3]+y*(_lgamU[4]+y*_lgamU[5])))))
|
||||||
p2 := 1 + y*(V1+y*(V2+y*(V3+y*(V4+y*V5))))
|
p2 := 1 + y*(_lgamV[1]+y*(_lgamV[2]+y*(_lgamV[3]+y*(_lgamV[4]+y*_lgamV[5]))))
|
||||||
lgamma += (-0.5*y + p1/p2)
|
lgamma += (-0.5*y + p1/p2)
|
||||||
}
|
}
|
||||||
case x < 8: // 2 <= x < 8
|
case x < 8: // 2 <= x < 8
|
||||||
i := int(x)
|
i := int(x)
|
||||||
y := x - float64(i)
|
y := x - float64(i)
|
||||||
p := y * (S0 + y*(S1+y*(S2+y*(S3+y*(S4+y*(S5+y*S6))))))
|
p := y * (_lgamS[0] + y*(_lgamS[1]+y*(_lgamS[2]+y*(_lgamS[3]+y*(_lgamS[4]+y*(_lgamS[5]+y*_lgamS[6]))))))
|
||||||
q := 1 + y*(R1+y*(R2+y*(R3+y*(R4+y*(R5+y*R6)))))
|
q := 1 + y*(_lgamR[1]+y*(_lgamR[2]+y*(_lgamR[3]+y*(_lgamR[4]+y*(_lgamR[5]+y*_lgamR[6])))))
|
||||||
lgamma = 0.5*y + p/q
|
lgamma = 0.5*y + p/q
|
||||||
z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s)
|
z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s)
|
||||||
switch i {
|
switch i {
|
||||||
|
@ -294,7 +311,7 @@ func Lgamma(x float64) (lgamma float64, sign int) {
|
||||||
t := Log(x)
|
t := Log(x)
|
||||||
z := 1 / x
|
z := 1 / x
|
||||||
y := z * z
|
y := z * z
|
||||||
w := W0 + z*(W1+y*(W2+y*(W3+y*(W4+y*(W5+y*W6)))))
|
w := _lgamW[0] + z*(_lgamW[1]+y*(_lgamW[2]+y*(_lgamW[3]+y*(_lgamW[4]+y*(_lgamW[5]+y*_lgamW[6])))))
|
||||||
lgamma = (x-0.5)*(t-1) + w
|
lgamma = (x-0.5)*(t-1) + w
|
||||||
default: // 2**58 <= x <= Inf
|
default: // 2**58 <= x <= Inf
|
||||||
lgamma = x * (Log(x) - 1)
|
lgamma = x * (Log(x) - 1)
|
||||||
|
|
|
@ -160,7 +160,7 @@ type sliceReaderAt []byte
|
||||||
|
|
||||||
func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, error) {
|
func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, error) {
|
||||||
if int(off) >= len(r) || off < 0 {
|
if int(off) >= len(r) || off < 0 {
|
||||||
return 0, os.EINVAL
|
return 0, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
n := copy(b, r[int(off):])
|
n := copy(b, r[int(off):])
|
||||||
return n, nil
|
return n, nil
|
||||||
|
|
|
@ -6,19 +6,11 @@
|
||||||
package mime
|
package mime
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var typeFiles = []string{
|
|
||||||
"/etc/mime.types",
|
|
||||||
"/etc/apache2/mime.types",
|
|
||||||
"/etc/apache/mime.types",
|
|
||||||
}
|
|
||||||
|
|
||||||
var mimeTypes = map[string]string{
|
var mimeTypes = map[string]string{
|
||||||
".css": "text/css; charset=utf-8",
|
".css": "text/css; charset=utf-8",
|
||||||
".gif": "image/gif",
|
".gif": "image/gif",
|
||||||
|
@ -33,46 +25,13 @@ var mimeTypes = map[string]string{
|
||||||
|
|
||||||
var mimeLock sync.RWMutex
|
var mimeLock sync.RWMutex
|
||||||
|
|
||||||
func loadMimeFile(filename string) {
|
|
||||||
f, err := os.Open(filename)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
reader := bufio.NewReader(f)
|
|
||||||
for {
|
|
||||||
line, err := reader.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
f.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) <= 1 || fields[0][0] == '#' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mimeType := fields[0]
|
|
||||||
for _, ext := range fields[1:] {
|
|
||||||
if ext[0] == '#' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
setExtensionType("."+ext, mimeType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func initMime() {
|
|
||||||
for _, filename := range typeFiles {
|
|
||||||
loadMimeFile(filename)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|
||||||
// TypeByExtension returns the MIME type associated with the file extension ext.
|
// TypeByExtension returns the MIME type associated with the file extension ext.
|
||||||
// The extension ext should begin with a leading dot, as in ".html".
|
// The extension ext should begin with a leading dot, as in ".html".
|
||||||
// When ext has no associated type, TypeByExtension returns "".
|
// When ext has no associated type, TypeByExtension returns "".
|
||||||
//
|
//
|
||||||
// The built-in table is small but is is augmented by the local
|
// The built-in table is small but on unix it is augmented by the local
|
||||||
// system's mime.types file(s) if available under one or more of these
|
// system's mime.types file(s) if available under one or more of these
|
||||||
// names:
|
// names:
|
||||||
//
|
//
|
||||||
|
@ -80,6 +39,8 @@ var once sync.Once
|
||||||
// /etc/apache2/mime.types
|
// /etc/apache2/mime.types
|
||||||
// /etc/apache/mime.types
|
// /etc/apache/mime.types
|
||||||
//
|
//
|
||||||
|
// Windows system mime types are extracted from registry.
|
||||||
|
//
|
||||||
// Text types have the charset parameter set to "utf-8" by default.
|
// Text types have the charset parameter set to "utf-8" by default.
|
||||||
func TypeByExtension(ext string) string {
|
func TypeByExtension(ext string) string {
|
||||||
once.Do(initMime)
|
once.Do(initMime)
|
||||||
|
|
|
@ -6,15 +6,9 @@ package mime
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
var typeTests = map[string]string{
|
var typeTests = initMimeForTests()
|
||||||
".t1": "application/test",
|
|
||||||
".t2": "text/test; charset=utf-8",
|
|
||||||
".png": "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTypeByExtension(t *testing.T) {
|
func TestTypeByExtension(t *testing.T) {
|
||||||
typeFiles = []string{"test.types"}
|
|
||||||
|
|
||||||
for ext, want := range typeTests {
|
for ext, want := range typeTests {
|
||||||
val := TypeByExtension(ext)
|
val := TypeByExtension(ext)
|
||||||
if val != want {
|
if val != want {
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var typeFiles = []string{
|
||||||
|
"/etc/mime.types",
|
||||||
|
"/etc/apache2/mime.types",
|
||||||
|
"/etc/apache/mime.types",
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadMimeFile(filename string) {
|
||||||
|
f, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(f)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) <= 1 || fields[0][0] == '#' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mimeType := fields[0]
|
||||||
|
for _, ext := range fields[1:] {
|
||||||
|
if ext[0] == '#' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
setExtensionType("."+ext, mimeType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initMime() {
|
||||||
|
for _, filename := range typeFiles {
|
||||||
|
loadMimeFile(filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initMimeForTests() map[string]string {
|
||||||
|
typeFiles = []string{"test.types"}
|
||||||
|
return map[string]string{
|
||||||
|
".t1": "application/test",
|
||||||
|
".t2": "text/test; charset=utf-8",
|
||||||
|
".png": "image/png",
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initMime() {
|
||||||
|
var root syscall.Handle
|
||||||
|
if syscall.RegOpenKeyEx(syscall.HKEY_CLASSES_ROOT, syscall.StringToUTF16Ptr(`\`),
|
||||||
|
0, syscall.KEY_READ, &root) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer syscall.RegCloseKey(root)
|
||||||
|
var count uint32
|
||||||
|
if syscall.RegQueryInfoKey(root, nil, nil, nil, &count, nil, nil, nil, nil, nil, nil, nil) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var buf [1 << 10]uint16
|
||||||
|
for i := uint32(0); i < count; i++ {
|
||||||
|
n := uint32(len(buf))
|
||||||
|
if syscall.RegEnumKeyEx(root, i, &buf[0], &n, nil, nil, nil, nil) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ext := syscall.UTF16ToString(buf[:])
|
||||||
|
if len(ext) < 2 || ext[0] != '.' { // looking for extensions only
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var h syscall.Handle
|
||||||
|
if syscall.RegOpenKeyEx(
|
||||||
|
syscall.HKEY_CLASSES_ROOT, syscall.StringToUTF16Ptr(`\`+ext),
|
||||||
|
0, syscall.KEY_READ, &h) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var typ uint32
|
||||||
|
n = uint32(len(buf) * 2) // api expects array of bytes, not uint16
|
||||||
|
if syscall.RegQueryValueEx(
|
||||||
|
h, syscall.StringToUTF16Ptr("Content Type"),
|
||||||
|
nil, &typ, (*byte)(unsafe.Pointer(&buf[0])), &n) != 0 {
|
||||||
|
syscall.RegCloseKey(h)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
syscall.RegCloseKey(h)
|
||||||
|
if typ != syscall.REG_SZ { // null terminated strings only
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mimeType := syscall.UTF16ToString(buf[:])
|
||||||
|
setExtensionType(ext, mimeType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initMimeForTests() map[string]string {
|
||||||
|
return map[string]string{
|
||||||
|
".bmp": "image/bmp",
|
||||||
|
".png": "image/png",
|
||||||
|
}
|
||||||
|
}
|
|
@ -109,7 +109,7 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, complet
|
||||||
if gerrno == syscall.EAI_NONAME {
|
if gerrno == syscall.EAI_NONAME {
|
||||||
str = noSuchHost
|
str = noSuchHost
|
||||||
} else if gerrno == syscall.EAI_SYSTEM {
|
} else if gerrno == syscall.EAI_SYSTEM {
|
||||||
str = syscall.Errstr(syscall.GetErrno())
|
str = syscall.GetErrno().Error()
|
||||||
} else {
|
} else {
|
||||||
str = bytePtrToString(libc_gai_strerror(gerrno))
|
str = bytePtrToString(libc_gai_strerror(gerrno))
|
||||||
}
|
}
|
||||||
|
|
|
@ -278,8 +278,8 @@ func startServer() {
|
||||||
|
|
||||||
func newFD(fd, family, proto int, net string) (f *netFD, err error) {
|
func newFD(fd, family, proto int, net string) (f *netFD, err error) {
|
||||||
onceStartServer.Do(startServer)
|
onceStartServer.Do(startServer)
|
||||||
if e := syscall.SetNonblock(fd, true); e != 0 {
|
if e := syscall.SetNonblock(fd, true); e != nil {
|
||||||
return nil, os.Errno(e)
|
return nil, e
|
||||||
}
|
}
|
||||||
f = &netFD{
|
f = &netFD{
|
||||||
sysfd: fd,
|
sysfd: fd,
|
||||||
|
@ -306,19 +306,19 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
|
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
|
||||||
e := syscall.Connect(fd.sysfd, ra)
|
err = syscall.Connect(fd.sysfd, ra)
|
||||||
if e == syscall.EINPROGRESS {
|
if err == syscall.EINPROGRESS {
|
||||||
var errno int
|
|
||||||
pollserver.WaitWrite(fd)
|
pollserver.WaitWrite(fd)
|
||||||
e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
|
var e int
|
||||||
if errno != 0 {
|
e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
|
||||||
return os.NewSyscallError("getsockopt", errno)
|
if err != nil {
|
||||||
}
|
return os.NewSyscallError("getsockopt", err)
|
||||||
}
|
}
|
||||||
if e != 0 {
|
if e != 0 {
|
||||||
return os.Errno(e)
|
err = syscall.Errno(e)
|
||||||
}
|
}
|
||||||
return nil
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a reference to this fd.
|
// Add a reference to this fd.
|
||||||
|
@ -362,9 +362,9 @@ func (fd *netFD) shutdown(how int) error {
|
||||||
if fd == nil || fd.sysfile == nil {
|
if fd == nil || fd.sysfile == nil {
|
||||||
return os.EINVAL
|
return os.EINVAL
|
||||||
}
|
}
|
||||||
errno := syscall.Shutdown(fd.sysfd, how)
|
err := syscall.Shutdown(fd.sysfd, how)
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
return &OpError{"shutdown", fd.net, fd.laddr, os.Errno(errno)}
|
return &OpError{"shutdown", fd.net, fd.laddr, err}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -377,6 +377,14 @@ func (fd *netFD) CloseWrite() error {
|
||||||
return fd.shutdown(syscall.SHUT_WR)
|
return fd.shutdown(syscall.SHUT_WR)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type timeoutError struct{}
|
||||||
|
|
||||||
|
func (e *timeoutError) Error() string { return "i/o timeout" }
|
||||||
|
func (e *timeoutError) Timeout() bool { return true }
|
||||||
|
func (e *timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
var errTimeout error = &timeoutError{}
|
||||||
|
|
||||||
func (fd *netFD) Read(p []byte) (n int, err error) {
|
func (fd *netFD) Read(p []byte) (n int, err error) {
|
||||||
if fd == nil {
|
if fd == nil {
|
||||||
return 0, os.EINVAL
|
return 0, os.EINVAL
|
||||||
|
@ -393,24 +401,24 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
|
||||||
} else {
|
} else {
|
||||||
fd.rdeadline = 0
|
fd.rdeadline = 0
|
||||||
}
|
}
|
||||||
var oserr error
|
|
||||||
for {
|
for {
|
||||||
var errno int
|
n, err = syscall.Read(fd.sysfile.Fd(), p)
|
||||||
n, errno = syscall.Read(fd.sysfile.Fd(), p)
|
if err == syscall.EAGAIN {
|
||||||
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
|
if fd.rdeadline >= 0 {
|
||||||
pollserver.WaitRead(fd)
|
pollserver.WaitRead(fd)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errno != 0 {
|
err = errTimeout
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
n = 0
|
n = 0
|
||||||
oserr = os.Errno(errno)
|
} else if n == 0 && err == nil && fd.proto != syscall.SOCK_DGRAM {
|
||||||
} else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM {
|
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if oserr != nil {
|
if err != nil && err != io.EOF {
|
||||||
err = &OpError{"read", fd.net, fd.raddr, oserr}
|
err = &OpError{"read", fd.net, fd.raddr, err}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -428,22 +436,22 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
|
||||||
} else {
|
} else {
|
||||||
fd.rdeadline = 0
|
fd.rdeadline = 0
|
||||||
}
|
}
|
||||||
var oserr error
|
|
||||||
for {
|
for {
|
||||||
var errno int
|
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
|
||||||
n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0)
|
if err == syscall.EAGAIN {
|
||||||
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
|
if fd.rdeadline >= 0 {
|
||||||
pollserver.WaitRead(fd)
|
pollserver.WaitRead(fd)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errno != 0 {
|
err = errTimeout
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
n = 0
|
n = 0
|
||||||
oserr = os.Errno(errno)
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if oserr != nil {
|
if err != nil {
|
||||||
err = &OpError{"read", fd.net, fd.laddr, oserr}
|
err = &OpError{"read", fd.net, fd.laddr, err}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -461,24 +469,22 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
|
||||||
} else {
|
} else {
|
||||||
fd.rdeadline = 0
|
fd.rdeadline = 0
|
||||||
}
|
}
|
||||||
var oserr error
|
|
||||||
for {
|
for {
|
||||||
var errno int
|
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
|
||||||
n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0)
|
if err == syscall.EAGAIN {
|
||||||
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
|
if fd.rdeadline >= 0 {
|
||||||
pollserver.WaitRead(fd)
|
pollserver.WaitRead(fd)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errno != 0 {
|
err = errTimeout
|
||||||
oserr = os.Errno(errno)
|
|
||||||
}
|
}
|
||||||
if n == 0 {
|
if err == nil && n == 0 {
|
||||||
oserr = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if oserr != nil {
|
if err != nil && err != io.EOF {
|
||||||
err = &OpError{"read", fd.net, fd.laddr, oserr}
|
err = &OpError{"read", fd.net, fd.laddr, err}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -501,32 +507,34 @@ func (fd *netFD) Write(p []byte) (n int, err error) {
|
||||||
fd.wdeadline = 0
|
fd.wdeadline = 0
|
||||||
}
|
}
|
||||||
nn := 0
|
nn := 0
|
||||||
var oserr error
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:])
|
var n int
|
||||||
|
n, err = syscall.Write(fd.sysfile.Fd(), p[nn:])
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
nn += n
|
nn += n
|
||||||
}
|
}
|
||||||
if nn == len(p) {
|
if nn == len(p) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
|
if err == syscall.EAGAIN {
|
||||||
|
if fd.wdeadline >= 0 {
|
||||||
pollserver.WaitWrite(fd)
|
pollserver.WaitWrite(fd)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errno != 0 {
|
err = errTimeout
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
n = 0
|
n = 0
|
||||||
oserr = os.Errno(errno)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
oserr = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if oserr != nil {
|
if err != nil {
|
||||||
err = &OpError{"write", fd.net, fd.raddr, oserr}
|
err = &OpError{"write", fd.net, fd.raddr, err}
|
||||||
}
|
}
|
||||||
return nn, err
|
return nn, err
|
||||||
}
|
}
|
||||||
|
@ -544,22 +552,21 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
|
||||||
} else {
|
} else {
|
||||||
fd.wdeadline = 0
|
fd.wdeadline = 0
|
||||||
}
|
}
|
||||||
var oserr error
|
|
||||||
for {
|
for {
|
||||||
errno := syscall.Sendto(fd.sysfd, p, 0, sa)
|
err = syscall.Sendto(fd.sysfd, p, 0, sa)
|
||||||
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
|
if err == syscall.EAGAIN {
|
||||||
|
if fd.wdeadline >= 0 {
|
||||||
pollserver.WaitWrite(fd)
|
pollserver.WaitWrite(fd)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errno != 0 {
|
err = errTimeout
|
||||||
oserr = os.Errno(errno)
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if oserr == nil {
|
if err == nil {
|
||||||
n = len(p)
|
n = len(p)
|
||||||
} else {
|
} else {
|
||||||
err = &OpError{"write", fd.net, fd.raddr, oserr}
|
err = &OpError{"write", fd.net, fd.raddr, err}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -577,24 +584,22 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
|
||||||
} else {
|
} else {
|
||||||
fd.wdeadline = 0
|
fd.wdeadline = 0
|
||||||
}
|
}
|
||||||
var oserr error
|
|
||||||
for {
|
for {
|
||||||
var errno int
|
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
|
||||||
errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
|
if err == syscall.EAGAIN {
|
||||||
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
|
if fd.wdeadline >= 0 {
|
||||||
pollserver.WaitWrite(fd)
|
pollserver.WaitWrite(fd)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errno != 0 {
|
err = errTimeout
|
||||||
oserr = os.Errno(errno)
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if oserr == nil {
|
if err == nil {
|
||||||
n = len(p)
|
n = len(p)
|
||||||
oobn = len(oob)
|
oobn = len(oob)
|
||||||
} else {
|
} else {
|
||||||
err = &OpError{"write", fd.net, fd.raddr, oserr}
|
err = &OpError{"write", fd.net, fd.raddr, err}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -615,25 +620,26 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
|
||||||
// See ../syscall/exec.go for description of ForkLock.
|
// See ../syscall/exec.go for description of ForkLock.
|
||||||
// It is okay to hold the lock across syscall.Accept
|
// It is okay to hold the lock across syscall.Accept
|
||||||
// because we have put fd.sysfd into non-blocking mode.
|
// because we have put fd.sysfd into non-blocking mode.
|
||||||
syscall.ForkLock.RLock()
|
var s int
|
||||||
var s, e int
|
|
||||||
var rsa syscall.Sockaddr
|
var rsa syscall.Sockaddr
|
||||||
for {
|
for {
|
||||||
if fd.closing {
|
if fd.closing {
|
||||||
syscall.ForkLock.RUnlock()
|
|
||||||
return nil, os.EINVAL
|
return nil, os.EINVAL
|
||||||
}
|
}
|
||||||
s, rsa, e = syscall.Accept(fd.sysfd)
|
|
||||||
if e != syscall.EAGAIN || fd.rdeadline < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
syscall.ForkLock.RUnlock()
|
|
||||||
pollserver.WaitRead(fd)
|
|
||||||
syscall.ForkLock.RLock()
|
syscall.ForkLock.RLock()
|
||||||
}
|
s, rsa, err = syscall.Accept(fd.sysfd)
|
||||||
if e != 0 {
|
if err != nil {
|
||||||
syscall.ForkLock.RUnlock()
|
syscall.ForkLock.RUnlock()
|
||||||
return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)}
|
if err == syscall.EAGAIN {
|
||||||
|
if fd.rdeadline >= 0 {
|
||||||
|
pollserver.WaitRead(fd)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errTimeout
|
||||||
|
}
|
||||||
|
return nil, &OpError{"accept", fd.net, fd.laddr, err}
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
syscall.CloseOnExec(s)
|
syscall.CloseOnExec(s)
|
||||||
syscall.ForkLock.RUnlock()
|
syscall.ForkLock.RUnlock()
|
||||||
|
@ -648,19 +654,19 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fd *netFD) dup() (f *os.File, err error) {
|
func (fd *netFD) dup() (f *os.File, err error) {
|
||||||
ns, e := syscall.Dup(fd.sysfd)
|
ns, err := syscall.Dup(fd.sysfd)
|
||||||
if e != 0 {
|
if err != nil {
|
||||||
return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)}
|
return nil, &OpError{"dup", fd.net, fd.laddr, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We want blocking mode for the new fd, hence the double negative.
|
// We want blocking mode for the new fd, hence the double negative.
|
||||||
if e = syscall.SetNonblock(ns, false); e != 0 {
|
if err = syscall.SetNonblock(ns, false); err != nil {
|
||||||
return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)}
|
return nil, &OpError{"setnonblock", fd.net, fd.laddr, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.NewFile(ns, fd.sysfile.Name()), nil
|
return os.NewFile(ns, fd.sysfile.Name()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func closesocket(s int) (errno int) {
|
func closesocket(s int) error {
|
||||||
return syscall.Close(s)
|
return syscall.Close(s)
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,12 +35,12 @@ type pollster struct {
|
||||||
|
|
||||||
func newpollster() (p *pollster, err error) {
|
func newpollster() (p *pollster, err error) {
|
||||||
p = new(pollster)
|
p = new(pollster)
|
||||||
var e int
|
var e error
|
||||||
|
|
||||||
// The arg to epoll_create is a hint to the kernel
|
// The arg to epoll_create is a hint to the kernel
|
||||||
// about the number of FDs we will care about.
|
// about the number of FDs we will care about.
|
||||||
// We don't know, and since 2.6.8 the kernel ignores it anyhow.
|
// We don't know, and since 2.6.8 the kernel ignores it anyhow.
|
||||||
if p.epfd, e = syscall.EpollCreate(16); e != 0 {
|
if p.epfd, e = syscall.EpollCreate(16); e != nil {
|
||||||
return nil, os.NewSyscallError("epoll_create", e)
|
return nil, os.NewSyscallError("epoll_create", e)
|
||||||
}
|
}
|
||||||
p.events = make(map[int]uint32)
|
p.events = make(map[int]uint32)
|
||||||
|
@ -68,7 +68,7 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
|
||||||
} else {
|
} else {
|
||||||
op = syscall.EPOLL_CTL_ADD
|
op = syscall.EPOLL_CTL_ADD
|
||||||
}
|
}
|
||||||
if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 {
|
if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != nil {
|
||||||
return false, os.NewSyscallError("epoll_ctl", e)
|
return false, os.NewSyscallError("epoll_ctl", e)
|
||||||
}
|
}
|
||||||
p.events[fd] = p.ctlEvent.Events
|
p.events[fd] = p.ctlEvent.Events
|
||||||
|
@ -97,13 +97,13 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
|
||||||
if int32(events)&^syscall.EPOLLONESHOT != 0 {
|
if int32(events)&^syscall.EPOLLONESHOT != 0 {
|
||||||
p.ctlEvent.Fd = int32(fd)
|
p.ctlEvent.Fd = int32(fd)
|
||||||
p.ctlEvent.Events = events
|
p.ctlEvent.Events = events
|
||||||
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 {
|
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != nil {
|
||||||
print("Epoll modify fd=", fd, ": ", os.Errno(e).Error(), "\n")
|
print("Epoll modify fd=", fd, ": ", e.Error(), "\n")
|
||||||
}
|
}
|
||||||
p.events[fd] = events
|
p.events[fd] = events
|
||||||
} else {
|
} else {
|
||||||
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 {
|
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != nil {
|
||||||
print("Epoll delete fd=", fd, ": ", os.Errno(e).Error(), "\n")
|
print("Epoll delete fd=", fd, ": ", e.Error(), "\n")
|
||||||
}
|
}
|
||||||
delete(p.events, fd)
|
delete(p.events, fd)
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err erro
|
||||||
n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec)
|
n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec)
|
||||||
s.Lock()
|
s.Lock()
|
||||||
|
|
||||||
if e != 0 {
|
if e != nil {
|
||||||
if e == syscall.EAGAIN || e == syscall.EINTR {
|
if e == syscall.EAGAIN || e == syscall.EINTR {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,9 +23,8 @@ type pollster struct {
|
||||||
|
|
||||||
func newpollster() (p *pollster, err error) {
|
func newpollster() (p *pollster, err error) {
|
||||||
p = new(pollster)
|
p = new(pollster)
|
||||||
var e int
|
if p.kq, err = syscall.Kqueue(); err != nil {
|
||||||
if p.kq, e = syscall.Kqueue(); e != 0 {
|
return nil, os.NewSyscallError("kqueue", err)
|
||||||
return nil, os.NewSyscallError("kqueue", e)
|
|
||||||
}
|
}
|
||||||
p.events = p.eventbuf[0:0]
|
p.events = p.eventbuf[0:0]
|
||||||
return p, nil
|
return p, nil
|
||||||
|
@ -50,14 +49,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
|
||||||
syscall.SetKevent(ev, fd, kmode, flags)
|
syscall.SetKevent(ev, fd, kmode, flags)
|
||||||
|
|
||||||
n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
|
n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
|
||||||
if e != 0 {
|
if e != nil {
|
||||||
return false, os.NewSyscallError("kevent", e)
|
return false, os.NewSyscallError("kevent", e)
|
||||||
}
|
}
|
||||||
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
|
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
|
||||||
return false, os.NewSyscallError("kqueue phase error", e)
|
return false, os.NewSyscallError("kqueue phase error", e)
|
||||||
}
|
}
|
||||||
if ev.Data != 0 {
|
if ev.Data != 0 {
|
||||||
return false, os.Errno(int(ev.Data))
|
return false, syscall.Errno(int(ev.Data))
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -91,7 +90,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err erro
|
||||||
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
|
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
|
||||||
s.Lock()
|
s.Lock()
|
||||||
|
|
||||||
if e != 0 {
|
if e != nil {
|
||||||
if e == syscall.EINTR {
|
if e == syscall.EINTR {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,11 +26,11 @@ func init() {
|
||||||
var d syscall.WSAData
|
var d syscall.WSAData
|
||||||
e := syscall.WSAStartup(uint32(0x202), &d)
|
e := syscall.WSAStartup(uint32(0x202), &d)
|
||||||
if e != 0 {
|
if e != 0 {
|
||||||
initErr = os.NewSyscallError("WSAStartup", e)
|
initErr = os.NewSyscallError("WSAStartup", syscall.Errno(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func closesocket(s syscall.Handle) (errno int) {
|
func closesocket(s syscall.Handle) (err error) {
|
||||||
return syscall.Closesocket(s)
|
return syscall.Closesocket(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,13 +38,13 @@ func closesocket(s syscall.Handle) (errno int) {
|
||||||
type anOpIface interface {
|
type anOpIface interface {
|
||||||
Op() *anOp
|
Op() *anOp
|
||||||
Name() string
|
Name() string
|
||||||
Submit() (errno int)
|
Submit() (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IO completion result parameters.
|
// IO completion result parameters.
|
||||||
type ioResult struct {
|
type ioResult struct {
|
||||||
qty uint32
|
qty uint32
|
||||||
err int
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// anOp implements functionality common to all io operations.
|
// anOp implements functionality common to all io operations.
|
||||||
|
@ -54,7 +54,7 @@ type anOp struct {
|
||||||
o syscall.Overlapped
|
o syscall.Overlapped
|
||||||
|
|
||||||
resultc chan ioResult
|
resultc chan ioResult
|
||||||
errnoc chan int
|
errnoc chan error
|
||||||
fd *netFD
|
fd *netFD
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ func (o *anOp) Init(fd *netFD, mode int) {
|
||||||
}
|
}
|
||||||
o.resultc = fd.resultc[i]
|
o.resultc = fd.resultc[i]
|
||||||
if fd.errnoc[i] == nil {
|
if fd.errnoc[i] == nil {
|
||||||
fd.errnoc[i] = make(chan int)
|
fd.errnoc[i] = make(chan error)
|
||||||
}
|
}
|
||||||
o.errnoc = fd.errnoc[i]
|
o.errnoc = fd.errnoc[i]
|
||||||
}
|
}
|
||||||
|
@ -111,14 +111,14 @@ func (s *resultSrv) Run() {
|
||||||
for {
|
for {
|
||||||
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
|
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
|
||||||
switch {
|
switch {
|
||||||
case r.err == 0:
|
case r.err == nil:
|
||||||
// Dequeued successfully completed io packet.
|
// Dequeued successfully completed io packet.
|
||||||
case r.err == syscall.WAIT_TIMEOUT && o == nil:
|
case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
|
||||||
// Wait has timed out (should not happen now, but might be used in the future).
|
// Wait has timed out (should not happen now, but might be used in the future).
|
||||||
panic("GetQueuedCompletionStatus timed out")
|
panic("GetQueuedCompletionStatus timed out")
|
||||||
case o == nil:
|
case o == nil:
|
||||||
// Failed to dequeue anything -> report the error.
|
// Failed to dequeue anything -> report the error.
|
||||||
panic("GetQueuedCompletionStatus failed " + syscall.Errstr(r.err))
|
panic("GetQueuedCompletionStatus failed " + r.err.Error())
|
||||||
default:
|
default:
|
||||||
// Dequeued failed io packet.
|
// Dequeued failed io packet.
|
||||||
}
|
}
|
||||||
|
@ -153,7 +153,7 @@ func (s *ioSrv) ProcessRemoteIO() {
|
||||||
// inline, or, if timeouts are employed, passes the request onto
|
// inline, or, if timeouts are employed, passes the request onto
|
||||||
// a special goroutine and waits for completion or cancels request.
|
// a special goroutine and waits for completion or cancels request.
|
||||||
func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
|
func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
|
||||||
var e int
|
var e error
|
||||||
o := oi.Op()
|
o := oi.Op()
|
||||||
if deadline_delta > 0 {
|
if deadline_delta > 0 {
|
||||||
// Send request to a special dedicated thread,
|
// Send request to a special dedicated thread,
|
||||||
|
@ -164,12 +164,12 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
|
||||||
e = oi.Submit()
|
e = oi.Submit()
|
||||||
}
|
}
|
||||||
switch e {
|
switch e {
|
||||||
case 0:
|
case nil:
|
||||||
// IO completed immediately, but we need to get our completion message anyway.
|
// IO completed immediately, but we need to get our completion message anyway.
|
||||||
case syscall.ERROR_IO_PENDING:
|
case syscall.ERROR_IO_PENDING:
|
||||||
// IO started, and we have to wait for its completion.
|
// IO started, and we have to wait for its completion.
|
||||||
default:
|
default:
|
||||||
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(e)}
|
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, e}
|
||||||
}
|
}
|
||||||
// Wait for our request to complete.
|
// Wait for our request to complete.
|
||||||
var r ioResult
|
var r ioResult
|
||||||
|
@ -187,8 +187,8 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
|
||||||
} else {
|
} else {
|
||||||
r = <-o.resultc
|
r = <-o.resultc
|
||||||
}
|
}
|
||||||
if r.err != 0 {
|
if r.err != nil {
|
||||||
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(r.err)}
|
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
|
||||||
}
|
}
|
||||||
return int(r.qty), err
|
return int(r.qty), err
|
||||||
}
|
}
|
||||||
|
@ -200,10 +200,10 @@ var onceStartServer sync.Once
|
||||||
|
|
||||||
func startServer() {
|
func startServer() {
|
||||||
resultsrv = new(resultSrv)
|
resultsrv = new(resultSrv)
|
||||||
var errno int
|
var err error
|
||||||
resultsrv.iocp, errno = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
|
resultsrv.iocp, err = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
panic("CreateIoCompletionPort failed " + syscall.Errstr(errno))
|
panic("CreateIoCompletionPort: " + err.Error())
|
||||||
}
|
}
|
||||||
go resultsrv.Run()
|
go resultsrv.Run()
|
||||||
|
|
||||||
|
@ -228,7 +228,7 @@ type netFD struct {
|
||||||
laddr Addr
|
laddr Addr
|
||||||
raddr Addr
|
raddr Addr
|
||||||
resultc [2]chan ioResult // read/write completion results
|
resultc [2]chan ioResult // read/write completion results
|
||||||
errnoc [2]chan int // read/write submit or cancel operation errors
|
errnoc [2]chan error // read/write submit or cancel operation errors
|
||||||
|
|
||||||
// owned by client
|
// owned by client
|
||||||
rdeadline_delta int64
|
rdeadline_delta int64
|
||||||
|
@ -256,8 +256,8 @@ func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err erro
|
||||||
}
|
}
|
||||||
onceStartServer.Do(startServer)
|
onceStartServer.Do(startServer)
|
||||||
// Associate our socket with resultsrv.iocp.
|
// Associate our socket with resultsrv.iocp.
|
||||||
if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != 0 {
|
if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != nil {
|
||||||
return nil, os.Errno(e)
|
return nil, e
|
||||||
}
|
}
|
||||||
return allocFD(fd, family, proto, net), nil
|
return allocFD(fd, family, proto, net), nil
|
||||||
}
|
}
|
||||||
|
@ -268,11 +268,7 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
|
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
|
||||||
e := syscall.Connect(fd.sysfd, ra)
|
return syscall.Connect(fd.sysfd, ra)
|
||||||
if e != 0 {
|
|
||||||
return os.Errno(e)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a reference to this fd.
|
// Add a reference to this fd.
|
||||||
|
@ -317,9 +313,9 @@ func (fd *netFD) shutdown(how int) error {
|
||||||
if fd == nil || fd.sysfd == syscall.InvalidHandle {
|
if fd == nil || fd.sysfd == syscall.InvalidHandle {
|
||||||
return os.EINVAL
|
return os.EINVAL
|
||||||
}
|
}
|
||||||
errno := syscall.Shutdown(fd.sysfd, how)
|
err := syscall.Shutdown(fd.sysfd, how)
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
return &OpError{"shutdown", fd.net, fd.laddr, os.Errno(errno)}
|
return &OpError{"shutdown", fd.net, fd.laddr, err}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -338,7 +334,7 @@ type readOp struct {
|
||||||
bufOp
|
bufOp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *readOp) Submit() (errno int) {
|
func (o *readOp) Submit() (err error) {
|
||||||
var d, f uint32
|
var d, f uint32
|
||||||
return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil)
|
return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil)
|
||||||
}
|
}
|
||||||
|
@ -375,7 +371,7 @@ type readFromOp struct {
|
||||||
rsan int32
|
rsan int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *readFromOp) Submit() (errno int) {
|
func (o *readFromOp) Submit() (err error) {
|
||||||
var d, f uint32
|
var d, f uint32
|
||||||
return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil)
|
return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil)
|
||||||
}
|
}
|
||||||
|
@ -415,7 +411,7 @@ type writeOp struct {
|
||||||
bufOp
|
bufOp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *writeOp) Submit() (errno int) {
|
func (o *writeOp) Submit() (err error) {
|
||||||
var d uint32
|
var d uint32
|
||||||
return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil)
|
return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil)
|
||||||
}
|
}
|
||||||
|
@ -447,7 +443,7 @@ type writeToOp struct {
|
||||||
sa syscall.Sockaddr
|
sa syscall.Sockaddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *writeToOp) Submit() (errno int) {
|
func (o *writeToOp) Submit() (err error) {
|
||||||
var d uint32
|
var d uint32
|
||||||
return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil)
|
return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil)
|
||||||
}
|
}
|
||||||
|
@ -484,7 +480,7 @@ type acceptOp struct {
|
||||||
attrs [2]syscall.RawSockaddrAny // space for local and remote address only
|
attrs [2]syscall.RawSockaddrAny // space for local and remote address only
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *acceptOp) Submit() (errno int) {
|
func (o *acceptOp) Submit() (err error) {
|
||||||
var d uint32
|
var d uint32
|
||||||
l := uint32(unsafe.Sizeof(o.attrs[0]))
|
l := uint32(unsafe.Sizeof(o.attrs[0]))
|
||||||
return syscall.AcceptEx(o.fd.sysfd, o.newsock,
|
return syscall.AcceptEx(o.fd.sysfd, o.newsock,
|
||||||
|
@ -506,17 +502,17 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
|
||||||
// See ../syscall/exec.go for description of ForkLock.
|
// See ../syscall/exec.go for description of ForkLock.
|
||||||
syscall.ForkLock.RLock()
|
syscall.ForkLock.RLock()
|
||||||
s, e := syscall.Socket(fd.family, fd.proto, 0)
|
s, e := syscall.Socket(fd.family, fd.proto, 0)
|
||||||
if e != 0 {
|
if e != nil {
|
||||||
syscall.ForkLock.RUnlock()
|
syscall.ForkLock.RUnlock()
|
||||||
return nil, os.Errno(e)
|
return nil, e
|
||||||
}
|
}
|
||||||
syscall.CloseOnExec(s)
|
syscall.CloseOnExec(s)
|
||||||
syscall.ForkLock.RUnlock()
|
syscall.ForkLock.RUnlock()
|
||||||
|
|
||||||
// Associate our new socket with IOCP.
|
// Associate our new socket with IOCP.
|
||||||
onceStartServer.Do(startServer)
|
onceStartServer.Do(startServer)
|
||||||
if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != 0 {
|
if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != nil {
|
||||||
return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)}
|
return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, e}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Submit accept request.
|
// Submit accept request.
|
||||||
|
@ -531,9 +527,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
|
||||||
|
|
||||||
// Inherit properties of the listening socket.
|
// Inherit properties of the listening socket.
|
||||||
e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
|
e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
|
||||||
if e != 0 {
|
if e != nil {
|
||||||
closesocket(s)
|
closesocket(s)
|
||||||
return nil, err
|
return nil, e
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get local and peer addr out of AcceptEx buffer.
|
// Get local and peer addr out of AcceptEx buffer.
|
||||||
|
|
|
@ -13,12 +13,12 @@ import (
|
||||||
|
|
||||||
func newFileFD(f *os.File) (nfd *netFD, err error) {
|
func newFileFD(f *os.File) (nfd *netFD, err error) {
|
||||||
fd, errno := syscall.Dup(f.Fd())
|
fd, errno := syscall.Dup(f.Fd())
|
||||||
if errno != 0 {
|
if errno != nil {
|
||||||
return nil, os.NewSyscallError("dup", errno)
|
return nil, os.NewSyscallError("dup", errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
|
proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
|
||||||
if errno != 0 {
|
if errno != nil {
|
||||||
return nil, os.NewSyscallError("getsockopt", errno)
|
return nil, os.NewSyscallError("getsockopt", errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const cacheMaxAge = int64(300) // 5 minutes.
|
const cacheMaxAge = int64(300) // 5 minutes.
|
||||||
|
@ -26,7 +26,7 @@ var hosts struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func readHosts() {
|
func readHosts() {
|
||||||
now, _, _ := os.Time()
|
now := time.Seconds()
|
||||||
hp := hostsPath
|
hp := hostsPath
|
||||||
if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp {
|
if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp {
|
||||||
hs := make(map[string][]string)
|
hs := make(map[string][]string)
|
||||||
|
@ -51,7 +51,7 @@ func readHosts() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Update the data cache.
|
// Update the data cache.
|
||||||
hosts.time, _, _ = os.Time()
|
hosts.time = time.Seconds()
|
||||||
hosts.path = hp
|
hosts.path = hp
|
||||||
hosts.byName = hs
|
hosts.byName = hs
|
||||||
hosts.byAddr = is
|
hosts.byAddr = is
|
||||||
|
|
|
@ -363,7 +363,7 @@ func TestCopyError(t *testing.T) {
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|
||||||
if tries := 0; childRunning() {
|
tries := 0
|
||||||
for tries < 15 && childRunning() {
|
for tries < 15 && childRunning() {
|
||||||
time.Sleep(50e6 * int64(tries))
|
time.Sleep(50e6 * int64(tries))
|
||||||
tries++
|
tries++
|
||||||
|
@ -372,7 +372,6 @@ func TestCopyError(t *testing.T) {
|
||||||
t.Fatalf("post-conn.Close, expected child to be gone")
|
t.Fatalf("post-conn.Close, expected child to be gone")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestDirUnix(t *testing.T) {
|
func TestDirUnix(t *testing.T) {
|
||||||
if skipTest(t) || runtime.GOOS == "windows" {
|
if skipTest(t) || runtime.GOOS == "windows" {
|
||||||
|
|
|
@ -2,20 +2,137 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
|
||||||
|
|
||||||
|
// This code is duplicated in httputil/chunked.go.
|
||||||
|
// Please make any changes in both files.
|
||||||
|
|
||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
|
||||||
|
|
||||||
|
var ErrLineTooLong = errors.New("header line too long")
|
||||||
|
|
||||||
|
// newChunkedReader returns a new chunkedReader that translates the data read from r
|
||||||
|
// out of HTTP "chunked" format before returning it.
|
||||||
|
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
|
||||||
|
//
|
||||||
|
// newChunkedReader is not needed by normal applications. The http package
|
||||||
|
// automatically decodes chunking when reading response bodies.
|
||||||
|
func newChunkedReader(r io.Reader) io.Reader {
|
||||||
|
br, ok := r.(*bufio.Reader)
|
||||||
|
if !ok {
|
||||||
|
br = bufio.NewReader(r)
|
||||||
|
}
|
||||||
|
return &chunkedReader{r: br}
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkedReader struct {
|
||||||
|
r *bufio.Reader
|
||||||
|
n uint64 // unread bytes in chunk
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *chunkedReader) beginChunk() {
|
||||||
|
// chunk-size CRLF
|
||||||
|
var line string
|
||||||
|
line, cr.err = readLine(cr.r)
|
||||||
|
if cr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cr.n, cr.err = strconv.Btoui64(line, 16)
|
||||||
|
if cr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cr.n == 0 {
|
||||||
|
cr.err = io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
|
||||||
|
if cr.err != nil {
|
||||||
|
return 0, cr.err
|
||||||
|
}
|
||||||
|
if cr.n == 0 {
|
||||||
|
cr.beginChunk()
|
||||||
|
if cr.err != nil {
|
||||||
|
return 0, cr.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if uint64(len(b)) > cr.n {
|
||||||
|
b = b[0:cr.n]
|
||||||
|
}
|
||||||
|
n, cr.err = cr.r.Read(b)
|
||||||
|
cr.n -= uint64(n)
|
||||||
|
if cr.n == 0 && cr.err == nil {
|
||||||
|
// end of chunk (CRLF)
|
||||||
|
b := make([]byte, 2)
|
||||||
|
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
|
||||||
|
if b[0] != '\r' || b[1] != '\n' {
|
||||||
|
cr.err = errors.New("malformed chunked encoding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, cr.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read a line of bytes (up to \n) from b.
|
||||||
|
// Give up if the line exceeds maxLineLength.
|
||||||
|
// The returned bytes are a pointer into storage in
|
||||||
|
// the bufio, so they are only valid until the next bufio read.
|
||||||
|
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
|
||||||
|
if p, err = b.ReadSlice('\n'); err != nil {
|
||||||
|
// We always know when EOF is coming.
|
||||||
|
// If the caller asked for a line, there should be a line.
|
||||||
|
if err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
} else if err == bufio.ErrBufferFull {
|
||||||
|
err = ErrLineTooLong
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(p) >= maxLineLength {
|
||||||
|
return nil, ErrLineTooLong
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chop off trailing white space.
|
||||||
|
p = bytes.TrimRight(p, " \r\t\n")
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLineBytes, but convert the bytes into a string.
|
||||||
|
func readLine(b *bufio.Reader) (s string, err error) {
|
||||||
|
p, e := readLineBytes(b)
|
||||||
|
if e != nil {
|
||||||
|
return "", e
|
||||||
|
}
|
||||||
|
return string(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
|
||||||
|
// "chunked" format before writing them to w. Closing the returned chunkedWriter
|
||||||
|
// sends the final 0-length chunk that marks the end of the stream.
|
||||||
|
//
|
||||||
|
// newChunkedWriter is not needed by normal applications. The http
|
||||||
|
// package adds chunking automatically if handlers don't set a
|
||||||
|
// Content-Length header. Using newChunkedWriter inside a handler
|
||||||
|
// would result in double chunking or chunking with a Content-Length
|
||||||
|
// length, both of which are wrong.
|
||||||
func newChunkedWriter(w io.Writer) io.WriteCloser {
|
func newChunkedWriter(w io.Writer) io.WriteCloser {
|
||||||
return &chunkedWriter{w}
|
return &chunkedWriter{w}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Writing to ChunkedWriter translates to writing in HTTP chunked Transfer
|
// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
|
||||||
// Encoding wire format to the underlying Wire writer.
|
// Encoding wire format to the underlying Wire chunkedWriter.
|
||||||
type chunkedWriter struct {
|
type chunkedWriter struct {
|
||||||
Wire io.Writer
|
Wire io.Writer
|
||||||
}
|
}
|
||||||
|
@ -51,7 +168,3 @@ func (cw *chunkedWriter) Close() error {
|
||||||
_, err := io.WriteString(cw.Wire, "0\r\n")
|
_, err := io.WriteString(cw.Wire, "0\r\n")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func newChunkedReader(r *bufio.Reader) io.Reader {
|
|
||||||
return &chunkedReader{r: r}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// This code is duplicated in httputil/chunked_test.go.
|
||||||
|
// Please make any changes in both files.
|
||||||
|
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io/ioutil"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChunk(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
|
||||||
|
w := newChunkedWriter(&b)
|
||||||
|
const chunk1 = "hello, "
|
||||||
|
const chunk2 = "world! 0123456789abcdef"
|
||||||
|
w.Write([]byte(chunk1))
|
||||||
|
w.Write([]byte(chunk2))
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
|
||||||
|
t.Fatalf("chunk writer wrote %q; want %q", g, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := newChunkedReader(&b)
|
||||||
|
data, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf(`data: "%s"`, data)
|
||||||
|
t.Fatalf("ReadAll from reader: %v", err)
|
||||||
|
}
|
||||||
|
if g, e := string(data), chunk1+chunk2; g != e {
|
||||||
|
t.Errorf("chunk reader read %q; want %q", g, e)
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,6 +26,31 @@ var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
|
fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// pedanticReadAll works like ioutil.ReadAll but additionally
|
||||||
|
// verifies that r obeys the documented io.Reader contract.
|
||||||
|
func pedanticReadAll(r io.Reader) (b []byte, err error) {
|
||||||
|
var bufa [64]byte
|
||||||
|
buf := bufa[:]
|
||||||
|
for {
|
||||||
|
n, err := r.Read(buf)
|
||||||
|
if n == 0 && err == nil {
|
||||||
|
return nil, fmt.Errorf("Read: n=0 with err=nil")
|
||||||
|
}
|
||||||
|
b = append(b, buf[:n]...)
|
||||||
|
if err == io.EOF {
|
||||||
|
n, err := r.Read(buf)
|
||||||
|
if n != 0 || err != io.EOF {
|
||||||
|
return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return b, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
ts := httptest.NewServer(robotsTxtHandler)
|
ts := httptest.NewServer(robotsTxtHandler)
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
@ -33,7 +58,7 @@ func TestClient(t *testing.T) {
|
||||||
r, err := Get(ts.URL)
|
r, err := Get(ts.URL)
|
||||||
var b []byte
|
var b []byte
|
||||||
if err == nil {
|
if err == nil {
|
||||||
b, err = ioutil.ReadAll(r.Body)
|
b, err = pedanticReadAll(r.Body)
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -7,6 +7,7 @@ package fcgi
|
||||||
// This file implements FastCGI from the perspective of a child process.
|
// This file implements FastCGI from the perspective of a child process.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -125,49 +126,61 @@ func (r *response) Close() error {
|
||||||
type child struct {
|
type child struct {
|
||||||
conn *conn
|
conn *conn
|
||||||
handler http.Handler
|
handler http.Handler
|
||||||
|
requests map[uint16]*request // keyed by request ID
|
||||||
}
|
}
|
||||||
|
|
||||||
func newChild(rwc net.Conn, handler http.Handler) *child {
|
func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
|
||||||
return &child{newConn(rwc), handler}
|
return &child{
|
||||||
|
conn: newConn(rwc),
|
||||||
|
handler: handler,
|
||||||
|
requests: make(map[uint16]*request),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *child) serve() {
|
func (c *child) serve() {
|
||||||
requests := map[uint16]*request{}
|
|
||||||
defer c.conn.Close()
|
defer c.conn.Close()
|
||||||
var rec record
|
var rec record
|
||||||
var br beginRequest
|
|
||||||
for {
|
for {
|
||||||
if err := rec.read(c.conn.rwc); err != nil {
|
if err := rec.read(c.conn.rwc); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := c.handleRecord(&rec); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
req, ok := requests[rec.h.Id]
|
var errCloseConn = errors.New("fcgi: connection should be closed")
|
||||||
|
|
||||||
|
func (c *child) handleRecord(rec *record) error {
|
||||||
|
req, ok := c.requests[rec.h.Id]
|
||||||
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
|
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
|
||||||
// The spec says to ignore unknown request IDs.
|
// The spec says to ignore unknown request IDs.
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
if ok && rec.h.Type == typeBeginRequest {
|
if ok && rec.h.Type == typeBeginRequest {
|
||||||
// The server is trying to begin a request with the same ID
|
// The server is trying to begin a request with the same ID
|
||||||
// as an in-progress request. This is an error.
|
// as an in-progress request. This is an error.
|
||||||
return
|
return errors.New("fcgi: received ID that is already in-flight")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch rec.h.Type {
|
switch rec.h.Type {
|
||||||
case typeBeginRequest:
|
case typeBeginRequest:
|
||||||
|
var br beginRequest
|
||||||
if err := br.read(rec.content()); err != nil {
|
if err := br.read(rec.content()); err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
if br.role != roleResponder {
|
if br.role != roleResponder {
|
||||||
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
|
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
|
||||||
break
|
return nil
|
||||||
}
|
}
|
||||||
requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
|
c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
|
||||||
case typeParams:
|
case typeParams:
|
||||||
// NOTE(eds): Technically a key-value pair can straddle the boundary
|
// NOTE(eds): Technically a key-value pair can straddle the boundary
|
||||||
// between two packets. We buffer until we've received all parameters.
|
// between two packets. We buffer until we've received all parameters.
|
||||||
if len(rec.content()) > 0 {
|
if len(rec.content()) > 0 {
|
||||||
req.rawParams = append(req.rawParams, rec.content()...)
|
req.rawParams = append(req.rawParams, rec.content()...)
|
||||||
break
|
return nil
|
||||||
}
|
}
|
||||||
req.parseParams()
|
req.parseParams()
|
||||||
case typeStdin:
|
case typeStdin:
|
||||||
|
@ -190,22 +203,22 @@ func (c *child) serve() {
|
||||||
}
|
}
|
||||||
case typeGetValues:
|
case typeGetValues:
|
||||||
values := map[string]string{"FCGI_MPXS_CONNS": "1"}
|
values := map[string]string{"FCGI_MPXS_CONNS": "1"}
|
||||||
c.conn.writePairs(0, typeGetValuesResult, values)
|
c.conn.writePairs(typeGetValuesResult, 0, values)
|
||||||
case typeData:
|
case typeData:
|
||||||
// If the filter role is implemented, read the data stream here.
|
// If the filter role is implemented, read the data stream here.
|
||||||
case typeAbortRequest:
|
case typeAbortRequest:
|
||||||
delete(requests, rec.h.Id)
|
delete(c.requests, rec.h.Id)
|
||||||
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
|
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
|
||||||
if !req.keepConn {
|
if !req.keepConn {
|
||||||
// connection will close upon return
|
// connection will close upon return
|
||||||
return
|
return errCloseConn
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
b := make([]byte, 8)
|
b := make([]byte, 8)
|
||||||
b[0] = rec.h.Type
|
b[0] = byte(rec.h.Type)
|
||||||
c.conn.writeRecord(typeUnknownType, 0, b)
|
c.conn.writeRecord(typeUnknownType, 0, b)
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *child) serveRequest(req *request, body io.ReadCloser) {
|
func (c *child) serveRequest(req *request, body io.ReadCloser) {
|
||||||
|
|
|
@ -19,19 +19,22 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// recType is a record type, as defined by
|
||||||
|
// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8
|
||||||
|
type recType uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Packet Types
|
typeBeginRequest recType = 1
|
||||||
typeBeginRequest = iota + 1
|
typeAbortRequest recType = 2
|
||||||
typeAbortRequest
|
typeEndRequest recType = 3
|
||||||
typeEndRequest
|
typeParams recType = 4
|
||||||
typeParams
|
typeStdin recType = 5
|
||||||
typeStdin
|
typeStdout recType = 6
|
||||||
typeStdout
|
typeStderr recType = 7
|
||||||
typeStderr
|
typeData recType = 8
|
||||||
typeData
|
typeGetValues recType = 9
|
||||||
typeGetValues
|
typeGetValuesResult recType = 10
|
||||||
typeGetValuesResult
|
typeUnknownType recType = 11
|
||||||
typeUnknownType
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// keep the connection between web-server and responder open after request
|
// keep the connection between web-server and responder open after request
|
||||||
|
@ -59,7 +62,7 @@ const headerLen = 8
|
||||||
|
|
||||||
type header struct {
|
type header struct {
|
||||||
Version uint8
|
Version uint8
|
||||||
Type uint8
|
Type recType
|
||||||
Id uint16
|
Id uint16
|
||||||
ContentLength uint16
|
ContentLength uint16
|
||||||
PaddingLength uint8
|
PaddingLength uint8
|
||||||
|
@ -85,7 +88,7 @@ func (br *beginRequest) read(content []byte) error {
|
||||||
// not synchronized because we don't care what the contents are
|
// not synchronized because we don't care what the contents are
|
||||||
var pad [maxPad]byte
|
var pad [maxPad]byte
|
||||||
|
|
||||||
func (h *header) init(recType uint8, reqId uint16, contentLength int) {
|
func (h *header) init(recType recType, reqId uint16, contentLength int) {
|
||||||
h.Version = 1
|
h.Version = 1
|
||||||
h.Type = recType
|
h.Type = recType
|
||||||
h.Id = reqId
|
h.Id = reqId
|
||||||
|
@ -137,7 +140,7 @@ func (r *record) content() []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeRecord writes and sends a single record.
|
// writeRecord writes and sends a single record.
|
||||||
func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) error {
|
func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
c.buf.Reset()
|
c.buf.Reset()
|
||||||
|
@ -167,12 +170,12 @@ func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8
|
||||||
return c.writeRecord(typeEndRequest, reqId, b)
|
return c.writeRecord(typeEndRequest, reqId, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) error {
|
func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
|
||||||
w := newWriter(c, recType, reqId)
|
w := newWriter(c, recType, reqId)
|
||||||
b := make([]byte, 8)
|
b := make([]byte, 8)
|
||||||
for k, v := range pairs {
|
for k, v := range pairs {
|
||||||
n := encodeSize(b, uint32(len(k)))
|
n := encodeSize(b, uint32(len(k)))
|
||||||
n += encodeSize(b[n:], uint32(len(k)))
|
n += encodeSize(b[n:], uint32(len(v)))
|
||||||
if _, err := w.Write(b[:n]); err != nil {
|
if _, err := w.Write(b[:n]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -235,7 +238,7 @@ func (w *bufWriter) Close() error {
|
||||||
return w.closer.Close()
|
return w.closer.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter {
|
func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
|
||||||
s := &streamWriter{c: c, recType: recType, reqId: reqId}
|
s := &streamWriter{c: c, recType: recType, reqId: reqId}
|
||||||
w, _ := bufio.NewWriterSize(s, maxWrite)
|
w, _ := bufio.NewWriterSize(s, maxWrite)
|
||||||
return &bufWriter{s, w}
|
return &bufWriter{s, w}
|
||||||
|
@ -245,7 +248,7 @@ func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter {
|
||||||
// It only writes maxWrite bytes at a time.
|
// It only writes maxWrite bytes at a time.
|
||||||
type streamWriter struct {
|
type streamWriter struct {
|
||||||
c *conn
|
c *conn
|
||||||
recType uint8
|
recType recType
|
||||||
reqId uint16
|
reqId uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ package fcgi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -40,25 +41,25 @@ func TestSize(t *testing.T) {
|
||||||
|
|
||||||
var streamTests = []struct {
|
var streamTests = []struct {
|
||||||
desc string
|
desc string
|
||||||
recType uint8
|
recType recType
|
||||||
reqId uint16
|
reqId uint16
|
||||||
content []byte
|
content []byte
|
||||||
raw []byte
|
raw []byte
|
||||||
}{
|
}{
|
||||||
{"single record", typeStdout, 1, nil,
|
{"single record", typeStdout, 1, nil,
|
||||||
[]byte{1, typeStdout, 0, 1, 0, 0, 0, 0},
|
[]byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
// this data will have to be split into two records
|
// this data will have to be split into two records
|
||||||
{"two records", typeStdin, 300, make([]byte, 66000),
|
{"two records", typeStdin, 300, make([]byte, 66000),
|
||||||
bytes.Join([][]byte{
|
bytes.Join([][]byte{
|
||||||
// header for the first record
|
// header for the first record
|
||||||
{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
|
{1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
|
||||||
make([]byte, 65536),
|
make([]byte, 65536),
|
||||||
// header for the second
|
// header for the second
|
||||||
{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0},
|
{1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0},
|
||||||
make([]byte, 472),
|
make([]byte, 472),
|
||||||
// header for the empty record
|
// header for the empty record
|
||||||
{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0},
|
{1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
nil),
|
nil),
|
||||||
},
|
},
|
||||||
|
@ -111,3 +112,39 @@ outer:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type writeOnlyConn struct {
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *writeOnlyConn) Write(p []byte) (int, error) {
|
||||||
|
c.buf = append(c.buf, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *writeOnlyConn) Read(p []byte) (int, error) {
|
||||||
|
return 0, errors.New("conn is write-only")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *writeOnlyConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValues(t *testing.T) {
|
||||||
|
var rec record
|
||||||
|
rec.h.Type = typeGetValues
|
||||||
|
|
||||||
|
wc := new(writeOnlyConn)
|
||||||
|
c := newChild(wc, nil)
|
||||||
|
err := c.handleRecord(&rec)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("handleRecord: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const want = "\x01\n\x00\x00\x00\x12\x06\x00" +
|
||||||
|
"\x0f\x01FCGI_MPXS_CONNS1" +
|
||||||
|
"\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00"
|
||||||
|
if got := string(wc.buf); got != want {
|
||||||
|
t.Errorf(" got: %q\nwant: %q\n", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -22,13 +22,19 @@ import (
|
||||||
|
|
||||||
// A Dir implements http.FileSystem using the native file
|
// A Dir implements http.FileSystem using the native file
|
||||||
// system restricted to a specific directory tree.
|
// system restricted to a specific directory tree.
|
||||||
|
//
|
||||||
|
// An empty Dir is treated as ".".
|
||||||
type Dir string
|
type Dir string
|
||||||
|
|
||||||
func (d Dir) Open(name string) (File, error) {
|
func (d Dir) Open(name string) (File, error) {
|
||||||
if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
|
if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
|
||||||
return nil, errors.New("http: invalid character in file path")
|
return nil, errors.New("http: invalid character in file path")
|
||||||
}
|
}
|
||||||
f, err := os.Open(filepath.Join(string(d), filepath.FromSlash(path.Clean("/"+name))))
|
dir := string(d)
|
||||||
|
if dir == "" {
|
||||||
|
dir = "."
|
||||||
|
}
|
||||||
|
f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -208,6 +208,20 @@ func TestDirJoin(t *testing.T) {
|
||||||
test(Dir("/etc/hosts"), "../")
|
test(Dir("/etc/hosts"), "../")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmptyDirOpenCWD(t *testing.T) {
|
||||||
|
test := func(d Dir) {
|
||||||
|
name := "fs_test.go"
|
||||||
|
f, err := d.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open of %s: %v", name, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
}
|
||||||
|
test(Dir(""))
|
||||||
|
test(Dir("."))
|
||||||
|
test(Dir("./"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestServeFileContentType(t *testing.T) {
|
func TestServeFileContentType(t *testing.T) {
|
||||||
const ctype = "icecream/chocolate"
|
const ctype = "icecream/chocolate"
|
||||||
override := false
|
override := false
|
||||||
|
@ -247,6 +261,20 @@ func TestServeFileMimeType(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServeFileFromCWD(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
|
ServeFile(w, r, "fs_test.go")
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
r, err := Get(ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if r.StatusCode != 200 {
|
||||||
|
t.Fatalf("expected 200 OK, got %s", r.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServeFileWithContentEncoding(t *testing.T) {
|
func TestServeFileWithContentEncoding(t *testing.T) {
|
||||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
w.Header().Set("Content-Encoding", "foo")
|
w.Header().Set("Content-Encoding", "foo")
|
||||||
|
|
|
@ -2,18 +2,126 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
|
||||||
|
|
||||||
|
// This code is a duplicate of ../chunked.go with these edits:
|
||||||
|
// s/newChunked/NewChunked/g
|
||||||
|
// s/package http/package httputil/
|
||||||
|
// Please make any changes in both files.
|
||||||
|
|
||||||
package httputil
|
package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewChunkedWriter returns a new writer that translates writes into HTTP
|
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
|
||||||
// "chunked" format before writing them to w. Closing the returned writer
|
|
||||||
|
var ErrLineTooLong = errors.New("header line too long")
|
||||||
|
|
||||||
|
// NewChunkedReader returns a new chunkedReader that translates the data read from r
|
||||||
|
// out of HTTP "chunked" format before returning it.
|
||||||
|
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
|
||||||
|
//
|
||||||
|
// NewChunkedReader is not needed by normal applications. The http package
|
||||||
|
// automatically decodes chunking when reading response bodies.
|
||||||
|
func NewChunkedReader(r io.Reader) io.Reader {
|
||||||
|
br, ok := r.(*bufio.Reader)
|
||||||
|
if !ok {
|
||||||
|
br = bufio.NewReader(r)
|
||||||
|
}
|
||||||
|
return &chunkedReader{r: br}
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkedReader struct {
|
||||||
|
r *bufio.Reader
|
||||||
|
n uint64 // unread bytes in chunk
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *chunkedReader) beginChunk() {
|
||||||
|
// chunk-size CRLF
|
||||||
|
var line string
|
||||||
|
line, cr.err = readLine(cr.r)
|
||||||
|
if cr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cr.n, cr.err = strconv.Btoui64(line, 16)
|
||||||
|
if cr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cr.n == 0 {
|
||||||
|
cr.err = io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
|
||||||
|
if cr.err != nil {
|
||||||
|
return 0, cr.err
|
||||||
|
}
|
||||||
|
if cr.n == 0 {
|
||||||
|
cr.beginChunk()
|
||||||
|
if cr.err != nil {
|
||||||
|
return 0, cr.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if uint64(len(b)) > cr.n {
|
||||||
|
b = b[0:cr.n]
|
||||||
|
}
|
||||||
|
n, cr.err = cr.r.Read(b)
|
||||||
|
cr.n -= uint64(n)
|
||||||
|
if cr.n == 0 && cr.err == nil {
|
||||||
|
// end of chunk (CRLF)
|
||||||
|
b := make([]byte, 2)
|
||||||
|
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
|
||||||
|
if b[0] != '\r' || b[1] != '\n' {
|
||||||
|
cr.err = errors.New("malformed chunked encoding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, cr.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read a line of bytes (up to \n) from b.
|
||||||
|
// Give up if the line exceeds maxLineLength.
|
||||||
|
// The returned bytes are a pointer into storage in
|
||||||
|
// the bufio, so they are only valid until the next bufio read.
|
||||||
|
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
|
||||||
|
if p, err = b.ReadSlice('\n'); err != nil {
|
||||||
|
// We always know when EOF is coming.
|
||||||
|
// If the caller asked for a line, there should be a line.
|
||||||
|
if err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
} else if err == bufio.ErrBufferFull {
|
||||||
|
err = ErrLineTooLong
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(p) >= maxLineLength {
|
||||||
|
return nil, ErrLineTooLong
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chop off trailing white space.
|
||||||
|
p = bytes.TrimRight(p, " \r\t\n")
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLineBytes, but convert the bytes into a string.
|
||||||
|
func readLine(b *bufio.Reader) (s string, err error) {
|
||||||
|
p, e := readLineBytes(b)
|
||||||
|
if e != nil {
|
||||||
|
return "", e
|
||||||
|
}
|
||||||
|
return string(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
|
||||||
|
// "chunked" format before writing them to w. Closing the returned chunkedWriter
|
||||||
// sends the final 0-length chunk that marks the end of the stream.
|
// sends the final 0-length chunk that marks the end of the stream.
|
||||||
//
|
//
|
||||||
// NewChunkedWriter is not needed by normal applications. The http
|
// NewChunkedWriter is not needed by normal applications. The http
|
||||||
|
@ -25,8 +133,8 @@ func NewChunkedWriter(w io.Writer) io.WriteCloser {
|
||||||
return &chunkedWriter{w}
|
return &chunkedWriter{w}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Writing to ChunkedWriter translates to writing in HTTP chunked Transfer
|
// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
|
||||||
// Encoding wire format to the underlying Wire writer.
|
// Encoding wire format to the underlying Wire chunkedWriter.
|
||||||
type chunkedWriter struct {
|
type chunkedWriter struct {
|
||||||
Wire io.Writer
|
Wire io.Writer
|
||||||
}
|
}
|
||||||
|
@ -62,23 +170,3 @@ func (cw *chunkedWriter) Close() error {
|
||||||
_, err := io.WriteString(cw.Wire, "0\r\n")
|
_, err := io.WriteString(cw.Wire, "0\r\n")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChunkedReader returns a new reader that translates the data read from r
|
|
||||||
// out of HTTP "chunked" format before returning it.
|
|
||||||
// The reader returns io.EOF when the final 0-length chunk is read.
|
|
||||||
//
|
|
||||||
// NewChunkedReader is not needed by normal applications. The http package
|
|
||||||
// automatically decodes chunking when reading response bodies.
|
|
||||||
func NewChunkedReader(r io.Reader) io.Reader {
|
|
||||||
// This is a bit of a hack so we don't have to copy chunkedReader into
|
|
||||||
// httputil. It's a bit more complex than chunkedWriter, which is copied
|
|
||||||
// above.
|
|
||||||
req, err := http.ReadRequest(bufio.NewReader(io.MultiReader(
|
|
||||||
strings.NewReader("POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"),
|
|
||||||
r,
|
|
||||||
strings.NewReader("\r\n"))))
|
|
||||||
if err != nil {
|
|
||||||
panic("bad fake request: " + err.Error())
|
|
||||||
}
|
|
||||||
return req.Body
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,6 +2,11 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// This code is a duplicate of ../chunked_test.go with these edits:
|
||||||
|
// s/newChunked/NewChunked/g
|
||||||
|
// s/package http/package httputil/
|
||||||
|
// Please make any changes in both files.
|
||||||
|
|
||||||
package httputil
|
package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -27,7 +32,8 @@ func TestChunk(t *testing.T) {
|
||||||
r := NewChunkedReader(&b)
|
r := NewChunkedReader(&b)
|
||||||
data, err := ioutil.ReadAll(r)
|
data, err := ioutil.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadAll from NewChunkedReader: %v", err)
|
t.Logf(`data: "%s"`, data)
|
||||||
|
t.Fatalf("ReadAll from reader: %v", err)
|
||||||
}
|
}
|
||||||
if g, e := string(data), chunk1+chunk2; g != e {
|
if g, e := string(data), chunk1+chunk2; g != e {
|
||||||
t.Errorf("chunk reader read %q; want %q", g, e)
|
t.Errorf("chunk reader read %q; want %q", g, e)
|
||||||
|
|
|
@ -22,6 +22,10 @@ var (
|
||||||
ErrPipeline = &http.ProtocolError{"pipeline error"}
|
ErrPipeline = &http.ProtocolError{"pipeline error"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// This is an API usage error - the local side is closed.
|
||||||
|
// ErrPersistEOF (above) reports that the remote side is closed.
|
||||||
|
var errClosed = errors.New("i/o operation on closed connection")
|
||||||
|
|
||||||
// A ServerConn reads requests and sends responses over an underlying
|
// A ServerConn reads requests and sends responses over an underlying
|
||||||
// connection, until the HTTP keepalive logic commands an end. ServerConn
|
// connection, until the HTTP keepalive logic commands an end. ServerConn
|
||||||
// also allows hijacking the underlying connection by calling Hijack
|
// also allows hijacking the underlying connection by calling Hijack
|
||||||
|
@ -108,7 +112,7 @@ func (sc *ServerConn) Read() (req *http.Request, err error) {
|
||||||
}
|
}
|
||||||
if sc.r == nil { // connection closed by user in the meantime
|
if sc.r == nil { // connection closed by user in the meantime
|
||||||
defer sc.lk.Unlock()
|
defer sc.lk.Unlock()
|
||||||
return nil, os.EBADF
|
return nil, errClosed
|
||||||
}
|
}
|
||||||
r := sc.r
|
r := sc.r
|
||||||
lastbody := sc.lastbody
|
lastbody := sc.lastbody
|
||||||
|
@ -313,7 +317,7 @@ func (cc *ClientConn) Write(req *http.Request) (err error) {
|
||||||
}
|
}
|
||||||
if cc.c == nil { // connection closed by user in the meantime
|
if cc.c == nil { // connection closed by user in the meantime
|
||||||
defer cc.lk.Unlock()
|
defer cc.lk.Unlock()
|
||||||
return os.EBADF
|
return errClosed
|
||||||
}
|
}
|
||||||
c := cc.c
|
c := cc.c
|
||||||
if req.Close {
|
if req.Close {
|
||||||
|
@ -369,7 +373,7 @@ func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
|
||||||
}
|
}
|
||||||
if cc.r == nil { // connection closed by user in the meantime
|
if cc.r == nil { // connection closed by user in the meantime
|
||||||
defer cc.lk.Unlock()
|
defer cc.lk.Unlock()
|
||||||
return nil, os.EBADF
|
return nil, errClosed
|
||||||
}
|
}
|
||||||
r := cc.r
|
r := cc.r
|
||||||
lastbody := cc.lastbody
|
lastbody := cc.lastbody
|
||||||
|
|
|
@ -70,7 +70,6 @@ var reqTests = []reqTest{
|
||||||
Close: false,
|
Close: false,
|
||||||
ContentLength: 7,
|
ContentLength: 7,
|
||||||
Host: "www.techcrunch.com",
|
Host: "www.techcrunch.com",
|
||||||
Form: url.Values{},
|
|
||||||
},
|
},
|
||||||
|
|
||||||
"abcdef\n",
|
"abcdef\n",
|
||||||
|
@ -94,10 +93,10 @@ var reqTests = []reqTest{
|
||||||
Proto: "HTTP/1.1",
|
Proto: "HTTP/1.1",
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
|
Header: Header{},
|
||||||
Close: false,
|
Close: false,
|
||||||
ContentLength: 0,
|
ContentLength: 0,
|
||||||
Host: "foo.com",
|
Host: "foo.com",
|
||||||
Form: url.Values{},
|
|
||||||
},
|
},
|
||||||
|
|
||||||
noBody,
|
noBody,
|
||||||
|
@ -131,7 +130,6 @@ var reqTests = []reqTest{
|
||||||
Close: false,
|
Close: false,
|
||||||
ContentLength: 0,
|
ContentLength: 0,
|
||||||
Host: "test",
|
Host: "test",
|
||||||
Form: url.Values{},
|
|
||||||
},
|
},
|
||||||
|
|
||||||
noBody,
|
noBody,
|
||||||
|
@ -180,9 +178,9 @@ var reqTests = []reqTest{
|
||||||
Proto: "HTTP/1.1",
|
Proto: "HTTP/1.1",
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
|
Header: Header{},
|
||||||
ContentLength: -1,
|
ContentLength: -1,
|
||||||
Host: "foo.com",
|
Host: "foo.com",
|
||||||
Form: url.Values{},
|
|
||||||
},
|
},
|
||||||
|
|
||||||
"foobar",
|
"foobar",
|
||||||
|
|
|
@ -19,12 +19,10 @@ import (
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxLineLength = 4096 // assumed <= bufio.defaultBufSize
|
|
||||||
maxValueLength = 4096
|
maxValueLength = 4096
|
||||||
maxHeaderLines = 1024
|
maxHeaderLines = 1024
|
||||||
chunkSize = 4 << 10 // 4 KB chunks
|
chunkSize = 4 << 10 // 4 KB chunks
|
||||||
|
@ -43,7 +41,6 @@ type ProtocolError struct {
|
||||||
func (err *ProtocolError) Error() string { return err.ErrorString }
|
func (err *ProtocolError) Error() string { return err.ErrorString }
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrLineTooLong = &ProtocolError{"header line too long"}
|
|
||||||
ErrHeaderTooLong = &ProtocolError{"header too long"}
|
ErrHeaderTooLong = &ProtocolError{"header too long"}
|
||||||
ErrShortBody = &ProtocolError{"entity body too short"}
|
ErrShortBody = &ProtocolError{"entity body too short"}
|
||||||
ErrNotSupported = &ProtocolError{"feature not supported"}
|
ErrNotSupported = &ProtocolError{"feature not supported"}
|
||||||
|
@ -375,44 +372,6 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read a line of bytes (up to \n) from b.
|
|
||||||
// Give up if the line exceeds maxLineLength.
|
|
||||||
// The returned bytes are a pointer into storage in
|
|
||||||
// the bufio, so they are only valid until the next bufio read.
|
|
||||||
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
|
|
||||||
if p, err = b.ReadSlice('\n'); err != nil {
|
|
||||||
// We always know when EOF is coming.
|
|
||||||
// If the caller asked for a line, there should be a line.
|
|
||||||
if err == io.EOF {
|
|
||||||
err = io.ErrUnexpectedEOF
|
|
||||||
} else if err == bufio.ErrBufferFull {
|
|
||||||
err = ErrLineTooLong
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(p) >= maxLineLength {
|
|
||||||
return nil, ErrLineTooLong
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chop off trailing white space.
|
|
||||||
var i int
|
|
||||||
for i = len(p); i > 0; i-- {
|
|
||||||
if c := p[i-1]; c != ' ' && c != '\r' && c != '\t' && c != '\n' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p[0:i], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// readLineBytes, but convert the bytes into a string.
|
|
||||||
func readLine(b *bufio.Reader) (s string, err error) {
|
|
||||||
p, e := readLineBytes(b)
|
|
||||||
if e != nil {
|
|
||||||
return "", e
|
|
||||||
}
|
|
||||||
return string(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert decimal at s[i:len(s)] to integer,
|
// Convert decimal at s[i:len(s)] to integer,
|
||||||
// returning value, string position where the digits stopped,
|
// returning value, string position where the digits stopped,
|
||||||
// and whether there was a valid number (digits, not too big).
|
// and whether there was a valid number (digits, not too big).
|
||||||
|
@ -448,55 +407,6 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
|
||||||
return major, minor, true
|
return major, minor, true
|
||||||
}
|
}
|
||||||
|
|
||||||
type chunkedReader struct {
|
|
||||||
r *bufio.Reader
|
|
||||||
n uint64 // unread bytes in chunk
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cr *chunkedReader) beginChunk() {
|
|
||||||
// chunk-size CRLF
|
|
||||||
var line string
|
|
||||||
line, cr.err = readLine(cr.r)
|
|
||||||
if cr.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cr.n, cr.err = strconv.Btoui64(line, 16)
|
|
||||||
if cr.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cr.n == 0 {
|
|
||||||
cr.err = io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
|
|
||||||
if cr.err != nil {
|
|
||||||
return 0, cr.err
|
|
||||||
}
|
|
||||||
if cr.n == 0 {
|
|
||||||
cr.beginChunk()
|
|
||||||
if cr.err != nil {
|
|
||||||
return 0, cr.err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if uint64(len(b)) > cr.n {
|
|
||||||
b = b[0:cr.n]
|
|
||||||
}
|
|
||||||
n, cr.err = cr.r.Read(b)
|
|
||||||
cr.n -= uint64(n)
|
|
||||||
if cr.n == 0 && cr.err == nil {
|
|
||||||
// end of chunk (CRLF)
|
|
||||||
b := make([]byte, 2)
|
|
||||||
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
|
|
||||||
if b[0] != '\r' || b[1] != '\n' {
|
|
||||||
cr.err = errors.New("malformed chunked encoding")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return n, cr.err
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRequest returns a new Request given a method, URL, and optional body.
|
// NewRequest returns a new Request given a method, URL, and optional body.
|
||||||
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
|
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
|
||||||
u, err := url.Parse(urlStr)
|
u, err := url.Parse(urlStr)
|
||||||
|
|
|
@ -65,6 +65,7 @@ var respTests = []respTest{
|
||||||
Proto: "HTTP/1.1",
|
Proto: "HTTP/1.1",
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
|
Header: Header{},
|
||||||
Request: dummyReq("GET"),
|
Request: dummyReq("GET"),
|
||||||
Close: true,
|
Close: true,
|
||||||
ContentLength: -1,
|
ContentLength: -1,
|
||||||
|
@ -85,6 +86,7 @@ var respTests = []respTest{
|
||||||
Proto: "HTTP/1.1",
|
Proto: "HTTP/1.1",
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
|
Header: Header{},
|
||||||
Request: dummyReq("GET"),
|
Request: dummyReq("GET"),
|
||||||
Close: false,
|
Close: false,
|
||||||
ContentLength: 0,
|
ContentLength: 0,
|
||||||
|
@ -315,7 +317,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
|
||||||
}
|
}
|
||||||
var wr io.Writer = &buf
|
var wr io.Writer = &buf
|
||||||
if test.chunked {
|
if test.chunked {
|
||||||
wr = &chunkedWriter{wr}
|
wr = newChunkedWriter(wr)
|
||||||
}
|
}
|
||||||
if test.compressed {
|
if test.compressed {
|
||||||
buf.WriteString("Content-Encoding: gzip\r\n")
|
buf.WriteString("Content-Encoding: gzip\r\n")
|
||||||
|
|
|
@ -1077,6 +1077,31 @@ func TestClientWriteShutdown(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that chunked server responses that write 1 byte at a time are
|
||||||
|
// buffered before chunk headers are added, not after chunk headers.
|
||||||
|
func TestServerBufferedChunking(t *testing.T) {
|
||||||
|
if true {
|
||||||
|
t.Logf("Skipping known broken test; see Issue 2357")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn := new(testConn)
|
||||||
|
conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||||
|
done := make(chan bool)
|
||||||
|
ls := &oneConnListener{conn}
|
||||||
|
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||||
|
defer close(done)
|
||||||
|
rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers
|
||||||
|
rw.Write([]byte{'x'})
|
||||||
|
rw.Write([]byte{'y'})
|
||||||
|
rw.Write([]byte{'z'})
|
||||||
|
}))
|
||||||
|
<-done
|
||||||
|
if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
|
||||||
|
t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
|
||||||
|
conn.writeBuf.Bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// goTimeout runs f, failing t if f takes more than ns to complete.
|
// goTimeout runs f, failing t if f takes more than ns to complete.
|
||||||
func goTimeout(t *testing.T, ns int64, f func()) {
|
func goTimeout(t *testing.T, ns int64, f func()) {
|
||||||
ch := make(chan bool, 2)
|
ch := make(chan bool, 2)
|
||||||
|
@ -1120,7 +1145,7 @@ func TestAcceptMaxFds(t *testing.T) {
|
||||||
ln := &errorListener{[]error{
|
ln := &errorListener{[]error{
|
||||||
&net.OpError{
|
&net.OpError{
|
||||||
Op: "accept",
|
Op: "accept",
|
||||||
Err: os.Errno(syscall.EMFILE),
|
Err: syscall.EMFILE,
|
||||||
}}}
|
}}}
|
||||||
err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
|
err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
|
|
|
@ -149,11 +149,13 @@ type writerOnly struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
|
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
|
||||||
// Flush before checking w.chunking, as Flush will call
|
// Call WriteHeader before checking w.chunking if it hasn't
|
||||||
// WriteHeader if it hasn't been called yet, and WriteHeader
|
// been called yet, since WriteHeader is what sets w.chunking.
|
||||||
// is what sets w.chunking.
|
if !w.wroteHeader {
|
||||||
w.Flush()
|
w.WriteHeader(StatusOK)
|
||||||
|
}
|
||||||
if !w.chunking && w.bodyAllowed() && !w.needSniff {
|
if !w.chunking && w.bodyAllowed() && !w.needSniff {
|
||||||
|
w.Flush()
|
||||||
if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
|
if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
|
||||||
n, err = rf.ReadFrom(src)
|
n, err = rf.ReadFrom(src)
|
||||||
w.written += n
|
w.written += n
|
||||||
|
|
|
@ -6,6 +6,7 @@ package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
. "net/http"
|
. "net/http"
|
||||||
|
@ -79,3 +80,35 @@ func TestServerContentType(t *testing.T) {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContentTypeWithCopy(t *testing.T) {
|
||||||
|
const (
|
||||||
|
input = "\n<html>\n\t<head>\n"
|
||||||
|
expected = "text/html; charset=utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
|
// Use io.Copy from a bytes.Buffer to trigger ReadFrom.
|
||||||
|
buf := bytes.NewBuffer([]byte(input))
|
||||||
|
n, err := io.Copy(w, buf)
|
||||||
|
if int(n) != len(input) || err != nil {
|
||||||
|
t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
resp, err := Get(ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get: %v", err)
|
||||||
|
}
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != expected {
|
||||||
|
t.Errorf("Content-Type = %q, want %q", ct, expected)
|
||||||
|
}
|
||||||
|
data, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("reading body: %v", err)
|
||||||
|
} else if !bytes.Equal(data, []byte(input)) {
|
||||||
|
t.Errorf("data is %q, want %q", data, input)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
|
@ -537,7 +537,9 @@ func (b *body) Read(p []byte) (n int, err error) {
|
||||||
|
|
||||||
// Read the final trailer once we hit EOF.
|
// Read the final trailer once we hit EOF.
|
||||||
if err == io.EOF && b.hdr != nil {
|
if err == io.EOF && b.hdr != nil {
|
||||||
err = b.readTrailer()
|
if e := b.readTrailer(); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
b.hdr = nil
|
b.hdr = nil
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue