libgo: Update to weekly.2012-01-15.

From-SVN: r183539
This commit is contained in:
Ian Lance Taylor 2012-01-25 20:56:26 +00:00
parent 3be18e47c3
commit df1304ee03
192 changed files with 7870 additions and 3929 deletions

View File

@ -1,4 +1,4 @@
4a8268927758 354b17404643
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.

View File

@ -188,7 +188,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp
toolexeclibgocryptoopenpgp_DATA = \ toolexeclibgocryptoopenpgp_DATA = \
crypto/openpgp/armor.gox \ crypto/openpgp/armor.gox \
crypto/openpgp/elgamal.gox \ crypto/openpgp/elgamal.gox \
crypto/openpgp/error.gox \ crypto/openpgp/errors.gox \
crypto/openpgp/packet.gox \ crypto/openpgp/packet.gox \
crypto/openpgp/s2k.gox crypto/openpgp/s2k.gox
@ -235,6 +235,7 @@ toolexeclibgoexp_DATA = \
exp/ebnf.gox \ exp/ebnf.gox \
$(exp_inotify_gox) \ $(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/proxy.gox \
exp/spdy.gox \ exp/spdy.gox \
exp/sql.gox \ exp/sql.gox \
exp/ssh.gox \ exp/ssh.gox \
@ -669,17 +670,25 @@ endif # !LIBGO_IS_RTEMS
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
go_net_cgo_file = go/net/cgo_linux.go go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else else
if LIBGO_IS_IRIX if LIBGO_IS_IRIX
go_net_cgo_file = go/net/cgo_linux.go go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else else
if LIBGO_IS_SOLARIS if LIBGO_IS_SOLARIS
go_net_cgo_file = go/net/cgo_linux.go go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else else
go_net_cgo_file = go/net/cgo_bsd.go go_net_cgo_file = go/net/cgo_bsd.go
go_net_sock_file = go/net/sock_bsd.go go_net_sock_file = go/net/sock_bsd.go
go_net_sockopt_file = go/net/sockopt_bsd.go
go_net_sockoptip_file = go/net/sockoptip_bsd.go
endif endif
endif endif
endif endif
@ -728,6 +737,10 @@ go_net_files = \
$(go_net_sendfile_file) \ $(go_net_sendfile_file) \
go/net/sock.go \ go/net/sock.go \
$(go_net_sock_file) \ $(go_net_sock_file) \
go/net/sockopt.go \
$(go_net_sockopt_file) \
go/net/sockoptip.go \
$(go_net_sockoptip_file) \
go/net/tcpsock.go \ go/net/tcpsock.go \
go/net/tcpsock_posix.go \ go/net/tcpsock_posix.go \
go/net/udpsock.go \ go/net/udpsock.go \
@ -890,8 +903,7 @@ go_syslog_c_files = \
go_testing_files = \ go_testing_files = \
go/testing/benchmark.go \ go/testing/benchmark.go \
go/testing/example.go \ go/testing/example.go \
go/testing/testing.go \ go/testing/testing.go
go/testing/wrapper.go
go_time_files = \ go_time_files = \
go/time/format.go \ go/time/format.go \
@ -1061,8 +1073,8 @@ go_crypto_openpgp_armor_files = \
go/crypto/openpgp/armor/encode.go go/crypto/openpgp/armor/encode.go
go_crypto_openpgp_elgamal_files = \ go_crypto_openpgp_elgamal_files = \
go/crypto/openpgp/elgamal/elgamal.go go/crypto/openpgp/elgamal/elgamal.go
go_crypto_openpgp_error_files = \ go_crypto_openpgp_errors_files = \
go/crypto/openpgp/error/error.go go/crypto/openpgp/errors/errors.go
go_crypto_openpgp_packet_files = \ go_crypto_openpgp_packet_files = \
go/crypto/openpgp/packet/compressed.go \ go/crypto/openpgp/packet/compressed.go \
go/crypto/openpgp/packet/encrypted_key.go \ go/crypto/openpgp/packet/encrypted_key.go \
@ -1142,6 +1154,7 @@ go_encoding_pem_files = \
go_encoding_xml_files = \ go_encoding_xml_files = \
go/encoding/xml/marshal.go \ go/encoding/xml/marshal.go \
go/encoding/xml/read.go \ go/encoding/xml/read.go \
go/encoding/xml/typeinfo.go \
go/encoding/xml/xml.go go/encoding/xml/xml.go
go_exp_ebnf_files = \ go_exp_ebnf_files = \
@ -1157,6 +1170,11 @@ go_exp_norm_files = \
go/exp/norm/readwriter.go \ go/exp/norm/readwriter.go \
go/exp/norm/tables.go \ go/exp/norm/tables.go \
go/exp/norm/trie.go go/exp/norm/trie.go
go_exp_proxy_files = \
go/exp/proxy/direct.go \
go/exp/proxy/per_host.go \
go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go
go_exp_spdy_files = \ go_exp_spdy_files = \
go/exp/spdy/read.go \ go/exp/spdy/read.go \
go/exp/spdy/types.go \ go/exp/spdy/types.go \
@ -1173,7 +1191,7 @@ go_exp_ssh_files = \
go/exp/ssh/doc.go \ go/exp/ssh/doc.go \
go/exp/ssh/messages.go \ go/exp/ssh/messages.go \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_terminal.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \ go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
@ -1210,7 +1228,8 @@ go_go_doc_files = \
go/go/doc/doc.go \ go/go/doc/doc.go \
go/go/doc/example.go \ go/go/doc/example.go \
go/go/doc/exports.go \ go/go/doc/exports.go \
go/go/doc/filter.go go/go/doc/filter.go \
go/go/doc/reader.go
go_go_parser_files = \ go_go_parser_files = \
go/go/parser/interface.go \ go/go/parser/interface.go \
go/go/parser/parser.go go/go/parser/parser.go
@ -1461,8 +1480,15 @@ endif
# Define ForkExec and Exec. # Define ForkExec and Exec.
if LIBGO_IS_RTEMS if LIBGO_IS_RTEMS
syscall_exec_file = go/syscall/exec_stubs.go syscall_exec_file = go/syscall/exec_stubs.go
syscall_exec_os_file =
else
if LIBGO_IS_LINUX
syscall_exec_file = go/syscall/exec_unix.go
syscall_exec_os_file = go/syscall/exec_linux.go
else else
syscall_exec_file = go/syscall/exec_unix.go syscall_exec_file = go/syscall/exec_unix.go
syscall_exec_os_file = go/syscall/exec_bsd.go
endif
endif endif
# Define Wait4. # Define Wait4.
@ -1573,6 +1599,7 @@ go_base_syscall_files = \
go/syscall/syscall.go \ go/syscall/syscall.go \
$(syscall_syscall_file) \ $(syscall_syscall_file) \
$(syscall_exec_file) \ $(syscall_exec_file) \
$(syscall_exec_os_file) \
$(syscall_wait_file) \ $(syscall_wait_file) \
$(syscall_sleep_file) \ $(syscall_sleep_file) \
$(syscall_errstr_file) \ $(syscall_errstr_file) \
@ -1720,7 +1747,7 @@ libgo_go_objs = \
crypto/xtea.lo \ crypto/xtea.lo \
crypto/openpgp/armor.lo \ crypto/openpgp/armor.lo \
crypto/openpgp/elgamal.lo \ crypto/openpgp/elgamal.lo \
crypto/openpgp/error.lo \ crypto/openpgp/errors.lo \
crypto/openpgp/packet.lo \ crypto/openpgp/packet.lo \
crypto/openpgp/s2k.lo \ crypto/openpgp/s2k.lo \
crypto/x509/pkix.lo \ crypto/x509/pkix.lo \
@ -1743,6 +1770,7 @@ libgo_go_objs = \
encoding/xml.lo \ encoding/xml.lo \
exp/ebnf.lo \ exp/ebnf.lo \
exp/norm.lo \ exp/norm.lo \
exp/proxy.lo \
exp/spdy.lo \ exp/spdy.lo \
exp/sql.lo \ exp/sql.lo \
exp/ssh.lo \ exp/ssh.lo \
@ -2578,15 +2606,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: crypto/openpgp/elgamal/check .PHONY: crypto/openpgp/elgamal/check
@go_include@ crypto/openpgp/error.lo.dep @go_include@ crypto/openpgp/errors.lo.dep
crypto/openpgp/error.lo.dep: $(go_crypto_openpgp_error_files) crypto/openpgp/errors.lo.dep: $(go_crypto_openpgp_errors_files)
$(BUILDDEPS) $(BUILDDEPS)
crypto/openpgp/error.lo: $(go_crypto_openpgp_error_files) crypto/openpgp/errors.lo: $(go_crypto_openpgp_errors_files)
$(BUILDPACKAGE) $(BUILDPACKAGE)
crypto/openpgp/error/check: $(CHECK_DEPS) crypto/openpgp/errors/check: $(CHECK_DEPS)
@$(MKDIR_P) crypto/openpgp/error @$(MKDIR_P) crypto/openpgp/errors
@$(CHECK) @$(CHECK)
.PHONY: crypto/openpgp/error/check .PHONY: crypto/openpgp/errors/check
@go_include@ crypto/openpgp/packet.lo.dep @go_include@ crypto/openpgp/packet.lo.dep
crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files) crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files)
@ -2808,6 +2836,16 @@ exp/norm/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/norm/check .PHONY: exp/norm/check
@go_include@ exp/proxy.lo.dep
exp/proxy.lo.dep: $(go_exp_proxy_files)
$(BUILDDEPS)
exp/proxy.lo: $(go_exp_proxy_files)
$(BUILDPACKAGE)
exp/proxy/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/proxy
@$(CHECK)
.PHONY: exp/proxy/check
@go_include@ exp/spdy.lo.dep @go_include@ exp/spdy.lo.dep
exp/spdy.lo.dep: $(go_exp_spdy_files) exp/spdy.lo.dep: $(go_exp_spdy_files)
$(BUILDDEPS) $(BUILDDEPS)
@ -3622,7 +3660,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/error.gox: crypto/openpgp/error.lo crypto/openpgp/errors.gox: crypto/openpgp/errors.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/packet.gox: crypto/openpgp/packet.lo crypto/openpgp/packet.gox: crypto/openpgp/packet.lo
$(BUILDGOX) $(BUILDGOX)
@ -3674,6 +3712,8 @@ exp/inotify.gox: exp/inotify.lo
$(BUILDGOX) $(BUILDGOX)
exp/norm.gox: exp/norm.lo exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/proxy.gox: exp/proxy.lo
$(BUILDGOX)
exp/spdy.gox: exp/spdy.lo exp/spdy.gox: exp/spdy.lo
$(BUILDGOX) $(BUILDGOX)
exp/sql.gox: exp/sql.lo exp/sql.gox: exp/sql.lo
@ -3920,6 +3960,7 @@ TEST_PACKAGES = \
exp/ebnf/check \ exp/ebnf/check \
$(exp_inotify_check) \ $(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/proxy/check \
exp/spdy/check \ exp/spdy/check \
exp/sql/check \ exp/sql/check \
exp/ssh/check \ exp/ssh/check \

View File

@ -153,33 +153,34 @@ am__DEPENDENCIES_2 = bufio/bufio.lo bytes/bytes.lo bytes/index.lo \
crypto/sha256.lo crypto/sha512.lo crypto/subtle.lo \ crypto/sha256.lo crypto/sha512.lo crypto/subtle.lo \
crypto/tls.lo crypto/twofish.lo crypto/x509.lo crypto/xtea.lo \ crypto/tls.lo crypto/twofish.lo crypto/x509.lo crypto/xtea.lo \
crypto/openpgp/armor.lo crypto/openpgp/elgamal.lo \ crypto/openpgp/armor.lo crypto/openpgp/elgamal.lo \
crypto/openpgp/error.lo crypto/openpgp/packet.lo \ crypto/openpgp/errors.lo crypto/openpgp/packet.lo \
crypto/openpgp/s2k.lo crypto/x509/pkix.lo debug/dwarf.lo \ crypto/openpgp/s2k.lo crypto/x509/pkix.lo debug/dwarf.lo \
debug/elf.lo debug/gosym.lo debug/macho.lo debug/pe.lo \ debug/elf.lo debug/gosym.lo debug/macho.lo debug/pe.lo \
encoding/ascii85.lo encoding/asn1.lo encoding/base32.lo \ encoding/ascii85.lo encoding/asn1.lo encoding/base32.lo \
encoding/base64.lo encoding/binary.lo encoding/csv.lo \ encoding/base64.lo encoding/binary.lo encoding/csv.lo \
encoding/git85.lo encoding/gob.lo encoding/hex.lo \ encoding/git85.lo encoding/gob.lo encoding/hex.lo \
encoding/json.lo encoding/pem.lo encoding/xml.lo exp/ebnf.lo \ encoding/json.lo encoding/pem.lo encoding/xml.lo exp/ebnf.lo \
exp/norm.lo exp/spdy.lo exp/sql.lo exp/ssh.lo exp/terminal.lo \ exp/norm.lo exp/proxy.lo exp/spdy.lo exp/sql.lo exp/ssh.lo \
exp/types.lo exp/sql/driver.lo html/template.lo go/ast.lo \ exp/terminal.lo exp/types.lo exp/sql/driver.lo \
go/build.lo go/doc.lo go/parser.lo go/printer.lo go/scanner.lo \ html/template.lo go/ast.lo go/build.lo go/doc.lo go/parser.lo \
go/token.lo hash/adler32.lo hash/crc32.lo hash/crc64.lo \ go/printer.lo go/scanner.lo go/token.lo hash/adler32.lo \
hash/fnv.lo net/http/cgi.lo net/http/fcgi.lo \ hash/crc32.lo hash/crc64.lo hash/fnv.lo net/http/cgi.lo \
net/http/httptest.lo net/http/httputil.lo net/http/pprof.lo \ net/http/fcgi.lo net/http/httptest.lo net/http/httputil.lo \
image/bmp.lo image/color.lo image/draw.lo image/gif.lo \ net/http/pprof.lo image/bmp.lo image/color.lo image/draw.lo \
image/jpeg.lo image/png.lo image/tiff.lo index/suffixarray.lo \ image/gif.lo image/jpeg.lo image/png.lo image/tiff.lo \
io/ioutil.lo log/syslog.lo log/syslog/syslog_c.lo math/big.lo \ index/suffixarray.lo io/ioutil.lo log/syslog.lo \
math/cmplx.lo math/rand.lo mime/mime.lo mime/multipart.lo \ log/syslog/syslog_c.lo math/big.lo math/cmplx.lo math/rand.lo \
net/dict.lo net/http.lo net/mail.lo net/rpc.lo net/smtp.lo \ mime/mime.lo mime/multipart.lo net/dict.lo net/http.lo \
net/textproto.lo net/url.lo old/netchan.lo old/regexp.lo \ net/mail.lo net/rpc.lo net/smtp.lo net/textproto.lo net/url.lo \
old/template.lo $(am__DEPENDENCIES_1) os/user.lo os/signal.lo \ old/netchan.lo old/regexp.lo old/template.lo \
path/filepath.lo regexp/syntax.lo net/rpc/jsonrpc.lo \ $(am__DEPENDENCIES_1) os/user.lo os/signal.lo path/filepath.lo \
runtime/debug.lo runtime/pprof.lo sync/atomic.lo \ regexp/syntax.lo net/rpc/jsonrpc.lo runtime/debug.lo \
sync/atomic_c.lo syscall/syscall.lo syscall/errno.lo \ runtime/pprof.lo sync/atomic.lo sync/atomic_c.lo \
syscall/wait.lo text/scanner.lo text/tabwriter.lo \ syscall/syscall.lo syscall/errno.lo syscall/wait.lo \
text/template.lo text/template/parse.lo testing/testing.lo \ text/scanner.lo text/tabwriter.lo text/template.lo \
testing/iotest.lo testing/quick.lo testing/script.lo \ text/template/parse.lo testing/testing.lo testing/iotest.lo \
unicode/utf16.lo unicode/utf8.lo testing/quick.lo testing/script.lo unicode/utf16.lo \
unicode/utf8.lo
libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \ libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \
$(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \ $(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \
$(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1)
@ -652,7 +653,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp
toolexeclibgocryptoopenpgp_DATA = \ toolexeclibgocryptoopenpgp_DATA = \
crypto/openpgp/armor.gox \ crypto/openpgp/armor.gox \
crypto/openpgp/elgamal.gox \ crypto/openpgp/elgamal.gox \
crypto/openpgp/error.gox \ crypto/openpgp/errors.gox \
crypto/openpgp/packet.gox \ crypto/openpgp/packet.gox \
crypto/openpgp/s2k.gox crypto/openpgp/s2k.gox
@ -692,6 +693,7 @@ toolexeclibgoexp_DATA = \
exp/ebnf.gox \ exp/ebnf.gox \
$(exp_inotify_gox) \ $(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/proxy.gox \
exp/spdy.gox \ exp/spdy.gox \
exp/sql.gox \ exp/sql.gox \
exp/ssh.gox \ exp/ssh.gox \
@ -1049,6 +1051,14 @@ go_mime_files = \
@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sock_file = go/net/sock_linux.go @LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sock_file = go/net/sock_linux.go
@LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sock_file = go/net/sock_linux.go @LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sock_file = go/net/sock_linux.go
@LIBGO_IS_LINUX_TRUE@go_net_sock_file = go/net/sock_linux.go @LIBGO_IS_LINUX_TRUE@go_net_sock_file = go/net/sock_linux.go
@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_FALSE@go_net_sockopt_file = go/net/sockopt_bsd.go
@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sockopt_file = go/net/sockopt_linux.go
@LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sockopt_file = go/net/sockopt_linux.go
@LIBGO_IS_LINUX_TRUE@go_net_sockopt_file = go/net/sockopt_linux.go
@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_FALSE@go_net_sockoptip_file = go/net/sockoptip_bsd.go
@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sockoptip_file = go/net/sockoptip_linux.go
@LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sockoptip_file = go/net/sockoptip_linux.go
@LIBGO_IS_LINUX_TRUE@go_net_sockoptip_file = go/net/sockoptip_linux.go
@LIBGO_IS_LINUX_FALSE@go_net_sendfile_file = go/net/sendfile_stub.go @LIBGO_IS_LINUX_FALSE@go_net_sendfile_file = go/net/sendfile_stub.go
@LIBGO_IS_LINUX_TRUE@go_net_sendfile_file = go/net/sendfile_linux.go @LIBGO_IS_LINUX_TRUE@go_net_sendfile_file = go/net/sendfile_linux.go
@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_NETBSD_FALSE@go_net_interface_file = go/net/interface_stub.go @LIBGO_IS_LINUX_FALSE@@LIBGO_IS_NETBSD_FALSE@go_net_interface_file = go/net/interface_stub.go
@ -1082,6 +1092,10 @@ go_net_files = \
$(go_net_sendfile_file) \ $(go_net_sendfile_file) \
go/net/sock.go \ go/net/sock.go \
$(go_net_sock_file) \ $(go_net_sock_file) \
go/net/sockopt.go \
$(go_net_sockopt_file) \
go/net/sockoptip.go \
$(go_net_sockoptip_file) \
go/net/tcpsock.go \ go/net/tcpsock.go \
go/net/tcpsock_posix.go \ go/net/tcpsock_posix.go \
go/net/udpsock.go \ go/net/udpsock.go \
@ -1197,8 +1211,7 @@ go_syslog_c_files = \
go_testing_files = \ go_testing_files = \
go/testing/benchmark.go \ go/testing/benchmark.go \
go/testing/example.go \ go/testing/example.go \
go/testing/testing.go \ go/testing/testing.go
go/testing/wrapper.go
go_time_files = \ go_time_files = \
go/time/format.go \ go/time/format.go \
@ -1394,8 +1407,8 @@ go_crypto_openpgp_armor_files = \
go_crypto_openpgp_elgamal_files = \ go_crypto_openpgp_elgamal_files = \
go/crypto/openpgp/elgamal/elgamal.go go/crypto/openpgp/elgamal/elgamal.go
go_crypto_openpgp_error_files = \ go_crypto_openpgp_errors_files = \
go/crypto/openpgp/error/error.go go/crypto/openpgp/errors/errors.go
go_crypto_openpgp_packet_files = \ go_crypto_openpgp_packet_files = \
go/crypto/openpgp/packet/compressed.go \ go/crypto/openpgp/packet/compressed.go \
@ -1492,6 +1505,7 @@ go_encoding_pem_files = \
go_encoding_xml_files = \ go_encoding_xml_files = \
go/encoding/xml/marshal.go \ go/encoding/xml/marshal.go \
go/encoding/xml/read.go \ go/encoding/xml/read.go \
go/encoding/xml/typeinfo.go \
go/encoding/xml/xml.go go/encoding/xml/xml.go
go_exp_ebnf_files = \ go_exp_ebnf_files = \
@ -1510,6 +1524,12 @@ go_exp_norm_files = \
go/exp/norm/tables.go \ go/exp/norm/tables.go \
go/exp/norm/trie.go go/exp/norm/trie.go
go_exp_proxy_files = \
go/exp/proxy/direct.go \
go/exp/proxy/per_host.go \
go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go
go_exp_spdy_files = \ go_exp_spdy_files = \
go/exp/spdy/read.go \ go/exp/spdy/read.go \
go/exp/spdy/types.go \ go/exp/spdy/types.go \
@ -1528,7 +1548,7 @@ go_exp_ssh_files = \
go/exp/ssh/doc.go \ go/exp/ssh/doc.go \
go/exp/ssh/messages.go \ go/exp/ssh/messages.go \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_terminal.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \ go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
@ -1569,7 +1589,8 @@ go_go_doc_files = \
go/go/doc/doc.go \ go/go/doc/doc.go \
go/go/doc/example.go \ go/go/doc/example.go \
go/go/doc/exports.go \ go/go/doc/exports.go \
go/go/doc/filter.go go/go/doc/filter.go \
go/go/doc/reader.go
go_go_parser_files = \ go_go_parser_files = \
go/go/parser/interface.go \ go/go/parser/interface.go \
@ -1840,10 +1861,14 @@ go_unicode_utf8_files = \
# Define Syscall and Syscall6. # Define Syscall and Syscall6.
@LIBGO_IS_RTEMS_TRUE@syscall_syscall_file = go/syscall/syscall_stubs.go @LIBGO_IS_RTEMS_TRUE@syscall_syscall_file = go/syscall/syscall_stubs.go
@LIBGO_IS_RTEMS_FALSE@syscall_exec_file = go/syscall/exec_unix.go @LIBGO_IS_LINUX_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_file = go/syscall/exec_unix.go
@LIBGO_IS_LINUX_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_file = go/syscall/exec_unix.go
# Define ForkExec and Exec. # Define ForkExec and Exec.
@LIBGO_IS_RTEMS_TRUE@syscall_exec_file = go/syscall/exec_stubs.go @LIBGO_IS_RTEMS_TRUE@syscall_exec_file = go/syscall/exec_stubs.go
@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_os_file = go/syscall/exec_bsd.go
@LIBGO_IS_LINUX_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_os_file = go/syscall/exec_linux.go
@LIBGO_IS_RTEMS_TRUE@syscall_exec_os_file =
@HAVE_WAIT4_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_wait_file = go/syscall/libcall_waitpid.go @HAVE_WAIT4_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_wait_file = go/syscall/libcall_waitpid.go
@HAVE_WAIT4_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_wait_file = go/syscall/libcall_wait4.go @HAVE_WAIT4_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_wait_file = go/syscall/libcall_wait4.go
@ -1901,6 +1926,7 @@ go_base_syscall_files = \
go/syscall/syscall.go \ go/syscall/syscall.go \
$(syscall_syscall_file) \ $(syscall_syscall_file) \
$(syscall_exec_file) \ $(syscall_exec_file) \
$(syscall_exec_os_file) \
$(syscall_wait_file) \ $(syscall_wait_file) \
$(syscall_sleep_file) \ $(syscall_sleep_file) \
$(syscall_errstr_file) \ $(syscall_errstr_file) \
@ -1995,7 +2021,7 @@ libgo_go_objs = \
crypto/xtea.lo \ crypto/xtea.lo \
crypto/openpgp/armor.lo \ crypto/openpgp/armor.lo \
crypto/openpgp/elgamal.lo \ crypto/openpgp/elgamal.lo \
crypto/openpgp/error.lo \ crypto/openpgp/errors.lo \
crypto/openpgp/packet.lo \ crypto/openpgp/packet.lo \
crypto/openpgp/s2k.lo \ crypto/openpgp/s2k.lo \
crypto/x509/pkix.lo \ crypto/x509/pkix.lo \
@ -2018,6 +2044,7 @@ libgo_go_objs = \
encoding/xml.lo \ encoding/xml.lo \
exp/ebnf.lo \ exp/ebnf.lo \
exp/norm.lo \ exp/norm.lo \
exp/proxy.lo \
exp/spdy.lo \ exp/spdy.lo \
exp/sql.lo \ exp/sql.lo \
exp/ssh.lo \ exp/ssh.lo \
@ -2286,6 +2313,7 @@ TEST_PACKAGES = \
exp/ebnf/check \ exp/ebnf/check \
$(exp_inotify_check) \ $(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/proxy/check \
exp/spdy/check \ exp/spdy/check \
exp/sql/check \ exp/sql/check \
exp/ssh/check \ exp/ssh/check \
@ -5162,15 +5190,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: crypto/openpgp/elgamal/check .PHONY: crypto/openpgp/elgamal/check
@go_include@ crypto/openpgp/error.lo.dep @go_include@ crypto/openpgp/errors.lo.dep
crypto/openpgp/error.lo.dep: $(go_crypto_openpgp_error_files) crypto/openpgp/errors.lo.dep: $(go_crypto_openpgp_errors_files)
$(BUILDDEPS) $(BUILDDEPS)
crypto/openpgp/error.lo: $(go_crypto_openpgp_error_files) crypto/openpgp/errors.lo: $(go_crypto_openpgp_errors_files)
$(BUILDPACKAGE) $(BUILDPACKAGE)
crypto/openpgp/error/check: $(CHECK_DEPS) crypto/openpgp/errors/check: $(CHECK_DEPS)
@$(MKDIR_P) crypto/openpgp/error @$(MKDIR_P) crypto/openpgp/errors
@$(CHECK) @$(CHECK)
.PHONY: crypto/openpgp/error/check .PHONY: crypto/openpgp/errors/check
@go_include@ crypto/openpgp/packet.lo.dep @go_include@ crypto/openpgp/packet.lo.dep
crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files) crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files)
@ -5392,6 +5420,16 @@ exp/norm/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/norm/check .PHONY: exp/norm/check
@go_include@ exp/proxy.lo.dep
exp/proxy.lo.dep: $(go_exp_proxy_files)
$(BUILDDEPS)
exp/proxy.lo: $(go_exp_proxy_files)
$(BUILDPACKAGE)
exp/proxy/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/proxy
@$(CHECK)
.PHONY: exp/proxy/check
@go_include@ exp/spdy.lo.dep @go_include@ exp/spdy.lo.dep
exp/spdy.lo.dep: $(go_exp_spdy_files) exp/spdy.lo.dep: $(go_exp_spdy_files)
$(BUILDDEPS) $(BUILDDEPS)
@ -6201,7 +6239,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/error.gox: crypto/openpgp/error.lo crypto/openpgp/errors.gox: crypto/openpgp/errors.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/packet.gox: crypto/openpgp/packet.lo crypto/openpgp/packet.gox: crypto/openpgp/packet.lo
$(BUILDGOX) $(BUILDGOX)
@ -6253,6 +6291,8 @@ exp/inotify.gox: exp/inotify.lo
$(BUILDGOX) $(BUILDGOX)
exp/norm.gox: exp/norm.lo exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/proxy.gox: exp/proxy.lo
$(BUILDGOX)
exp/spdy.gox: exp/spdy.lo exp/spdy.gox: exp/spdy.lo
$(BUILDGOX) $(BUILDGOX)
exp/sql.gox: exp/sql.lo exp/sql.gox: exp/sql.lo

View File

@ -74,6 +74,9 @@
/* Define to 1 if you have the <sys/mman.h> header file. */ /* Define to 1 if you have the <sys/mman.h> header file. */
#undef HAVE_SYS_MMAN_H #undef HAVE_SYS_MMAN_H
/* Define to 1 if you have the <sys/prctl.h> header file. */
#undef HAVE_SYS_PRCTL_H
/* Define to 1 if you have the <sys/ptrace.h> header file. */ /* Define to 1 if you have the <sys/ptrace.h> header file. */
#undef HAVE_SYS_PTRACE_H #undef HAVE_SYS_PTRACE_H

2
libgo/configure vendored
View File

@ -14505,7 +14505,7 @@ no)
;; ;;
esac esac
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h
do : do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"

View File

@ -451,7 +451,7 @@ no)
;; ;;
esac esac
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h) AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h)
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [], AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H [#ifdef HAVE_SYS_SOCKET_H

View File

@ -97,8 +97,7 @@ func (b *Buffer) grow(n int) int {
func (b *Buffer) Write(p []byte) (n int, err error) { func (b *Buffer) Write(p []byte) (n int, err error) {
b.lastRead = opInvalid b.lastRead = opInvalid
m := b.grow(len(p)) m := b.grow(len(p))
copy(b.buf[m:], p) return copy(b.buf[m:], p), nil
return len(p), nil
} }
// WriteString appends the contents of s to the buffer. The return // WriteString appends the contents of s to the buffer. The return
@ -200,13 +199,16 @@ func (b *Buffer) WriteRune(r rune) (n int, err error) {
// Read reads the next len(p) bytes from the buffer or until the buffer // Read reads the next len(p) bytes from the buffer or until the buffer
// is drained. The return value n is the number of bytes read. If the // is drained. The return value n is the number of bytes read. If the
// buffer has no data to return, err is io.EOF even if len(p) is zero; // buffer has no data to return, err is io.EOF (unless len(p) is zero);
// otherwise it is nil. // otherwise it is nil.
func (b *Buffer) Read(p []byte) (n int, err error) { func (b *Buffer) Read(p []byte) (n int, err error) {
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off >= len(b.buf) { if b.off >= len(b.buf) {
// Buffer is empty, reset to recover space. // Buffer is empty, reset to recover space.
b.Truncate(0) b.Truncate(0)
if len(p) == 0 {
return
}
return 0, io.EOF return 0, io.EOF
} }
n = copy(p, b.buf[b.off:]) n = copy(p, b.buf[b.off:])

View File

@ -373,3 +373,16 @@ func TestReadBytes(t *testing.T) {
} }
} }
} }
// Was a bug: used to give EOF reading empty slice at EOF.
func TestReadEmptyAtEOF(t *testing.T) {
b := new(Buffer)
slice := make([]byte, 0)
n, err := b.Read(slice)
if err != nil {
t.Errorf("read error: %v", err)
}
if n != 0 {
t.Errorf("wrong count; got %d want 0", n)
}
}

View File

@ -9,7 +9,7 @@ package armor
import ( import (
"bufio" "bufio"
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"encoding/base64" "encoding/base64"
"io" "io"
) )
@ -35,7 +35,7 @@ type Block struct {
oReader openpgpReader oReader openpgpReader
} }
var ArmorCorrupt error = error_.StructuralError("armor invalid") var ArmorCorrupt error = errors.StructuralError("armor invalid")
const crc24Init = 0xb704ce const crc24Init = 0xb704ce
const crc24Poly = 0x1864cfb const crc24Poly = 0x1864cfb

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package error contains common error types for the OpenPGP packages. // Package errors contains common error types for the OpenPGP packages.
package error package errors
import ( import (
"strconv" "strconv"

View File

@ -7,8 +7,9 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/packet" "crypto/openpgp/packet"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"io" "io"
"time" "time"
@ -181,13 +182,13 @@ func (el EntityList) DecryptionKeys() (keys []Key) {
func ReadArmoredKeyRing(r io.Reader) (EntityList, error) { func ReadArmoredKeyRing(r io.Reader) (EntityList, error) {
block, err := armor.Decode(r) block, err := armor.Decode(r)
if err == io.EOF { if err == io.EOF {
return nil, error_.InvalidArgumentError("no armored data found") return nil, errors.InvalidArgumentError("no armored data found")
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
if block.Type != PublicKeyType && block.Type != PrivateKeyType { if block.Type != PublicKeyType && block.Type != PrivateKeyType {
return nil, error_.InvalidArgumentError("expected public or private key block, got: " + block.Type) return nil, errors.InvalidArgumentError("expected public or private key block, got: " + block.Type)
} }
return ReadKeyRing(block.Body) return ReadKeyRing(block.Body)
@ -203,7 +204,7 @@ func ReadKeyRing(r io.Reader) (el EntityList, err error) {
var e *Entity var e *Entity
e, err = readEntity(packets) e, err = readEntity(packets)
if err != nil { if err != nil {
if _, ok := err.(error_.UnsupportedError); ok { if _, ok := err.(errors.UnsupportedError); ok {
lastUnsupportedError = err lastUnsupportedError = err
err = readToNextPublicKey(packets) err = readToNextPublicKey(packets)
} }
@ -235,7 +236,7 @@ func readToNextPublicKey(packets *packet.Reader) (err error) {
if err == io.EOF { if err == io.EOF {
return return
} else if err != nil { } else if err != nil {
if _, ok := err.(error_.UnsupportedError); ok { if _, ok := err.(errors.UnsupportedError); ok {
err = nil err = nil
continue continue
} }
@ -266,14 +267,14 @@ func readEntity(packets *packet.Reader) (*Entity, error) {
if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok { if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok {
if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok { if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
packets.Unread(p) packets.Unread(p)
return nil, error_.StructuralError("first packet was not a public/private key") return nil, errors.StructuralError("first packet was not a public/private key")
} else { } else {
e.PrimaryKey = &e.PrivateKey.PublicKey e.PrimaryKey = &e.PrivateKey.PublicKey
} }
} }
if !e.PrimaryKey.PubKeyAlgo.CanSign() { if !e.PrimaryKey.PubKeyAlgo.CanSign() {
return nil, error_.StructuralError("primary key cannot be used for signatures") return nil, errors.StructuralError("primary key cannot be used for signatures")
} }
var current *Identity var current *Identity
@ -303,12 +304,12 @@ EachPacket:
sig, ok := p.(*packet.Signature) sig, ok := p.(*packet.Signature)
if !ok { if !ok {
return nil, error_.StructuralError("user ID packet not followed by self-signature") return nil, errors.StructuralError("user ID packet not followed by self-signature")
} }
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId { if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil { if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
return nil, error_.StructuralError("user ID self-signature invalid: " + err.Error()) return nil, errors.StructuralError("user ID self-signature invalid: " + err.Error())
} }
current.SelfSignature = sig current.SelfSignature = sig
break break
@ -317,7 +318,7 @@ EachPacket:
} }
case *packet.Signature: case *packet.Signature:
if current == nil { if current == nil {
return nil, error_.StructuralError("signature packet found before user id packet") return nil, errors.StructuralError("signature packet found before user id packet")
} }
current.Signatures = append(current.Signatures, pkt) current.Signatures = append(current.Signatures, pkt)
case *packet.PrivateKey: case *packet.PrivateKey:
@ -344,7 +345,7 @@ EachPacket:
} }
if len(e.Identities) == 0 { if len(e.Identities) == 0 {
return nil, error_.StructuralError("entity without any identities") return nil, errors.StructuralError("entity without any identities")
} }
return e, nil return e, nil
@ -359,19 +360,19 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
if err != nil { if err != nil {
return error_.StructuralError("subkey signature invalid: " + err.Error()) return errors.StructuralError("subkey signature invalid: " + err.Error())
} }
var ok bool var ok bool
subKey.Sig, ok = p.(*packet.Signature) subKey.Sig, ok = p.(*packet.Signature)
if !ok { if !ok {
return error_.StructuralError("subkey packet not followed by signature") return errors.StructuralError("subkey packet not followed by signature")
} }
if subKey.Sig.SigType != packet.SigTypeSubkeyBinding { if subKey.Sig.SigType != packet.SigTypeSubkeyBinding {
return error_.StructuralError("subkey signature with wrong type") return errors.StructuralError("subkey signature with wrong type")
} }
err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig) err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig)
if err != nil { if err != nil {
return error_.StructuralError("subkey signature invalid: " + err.Error()) return errors.StructuralError("subkey signature invalid: " + err.Error())
} }
e.Subkeys = append(e.Subkeys, subKey) e.Subkeys = append(e.Subkeys, subKey)
return nil return nil
@ -385,7 +386,7 @@ const defaultRSAKeyBits = 2048
func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) { func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) {
uid := packet.NewUserId(name, comment, email) uid := packet.NewUserId(name, comment, email)
if uid == nil { if uid == nil {
return nil, error_.InvalidArgumentError("user id field contained invalid characters") return nil, errors.InvalidArgumentError("user id field contained invalid characters")
} }
signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits) signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil { if err != nil {
@ -397,8 +398,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
} }
e := &Entity{ e := &Entity{
PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey, false /* not a subkey */ ), PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey),
PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv, false /* not a subkey */ ), PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv),
Identities: make(map[string]*Identity), Identities: make(map[string]*Identity),
} }
isPrimaryId := true isPrimaryId := true
@ -420,8 +421,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
e.Subkeys = make([]Subkey, 1) e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{ e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey, true /* is a subkey */ ), PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey),
PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv, true /* is a subkey */ ), PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv),
Sig: &packet.Signature{ Sig: &packet.Signature{
CreationTime: currentTime, CreationTime: currentTime,
SigType: packet.SigTypeSubkeyBinding, SigType: packet.SigTypeSubkeyBinding,
@ -433,6 +434,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
IssuerKeyId: &e.PrimaryKey.KeyId, IssuerKeyId: &e.PrimaryKey.KeyId,
}, },
} }
e.Subkeys[0].PublicKey.IsSubkey = true
e.Subkeys[0].PrivateKey.IsSubkey = true
return e, nil return e, nil
} }
@ -450,7 +453,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
if err != nil { if err != nil {
return return
} }
err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey) err = ident.SelfSignature.SignUserId(rand.Reader, ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
if err != nil { if err != nil {
return return
} }
@ -464,7 +467,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
if err != nil { if err != nil {
return return
} }
err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey) err = subkey.Sig.SignKey(rand.Reader, subkey.PublicKey, e.PrivateKey)
if err != nil { if err != nil {
return return
} }
@ -518,14 +521,14 @@ func (e *Entity) Serialize(w io.Writer) error {
// necessary. // necessary.
func (e *Entity) SignIdentity(identity string, signer *Entity) error { func (e *Entity) SignIdentity(identity string, signer *Entity) error {
if signer.PrivateKey == nil { if signer.PrivateKey == nil {
return error_.InvalidArgumentError("signing Entity must have a private key") return errors.InvalidArgumentError("signing Entity must have a private key")
} }
if signer.PrivateKey.Encrypted { if signer.PrivateKey.Encrypted {
return error_.InvalidArgumentError("signing Entity's private key must be decrypted") return errors.InvalidArgumentError("signing Entity's private key must be decrypted")
} }
ident, ok := e.Identities[identity] ident, ok := e.Identities[identity]
if !ok { if !ok {
return error_.InvalidArgumentError("given identity string not found in Entity") return errors.InvalidArgumentError("given identity string not found in Entity")
} }
sig := &packet.Signature{ sig := &packet.Signature{
@ -535,7 +538,7 @@ func (e *Entity) SignIdentity(identity string, signer *Entity) error {
CreationTime: time.Now(), CreationTime: time.Now(),
IssuerKeyId: &signer.PrivateKey.KeyId, IssuerKeyId: &signer.PrivateKey.KeyId,
} }
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil { if err := sig.SignKey(rand.Reader, e.PrimaryKey, signer.PrivateKey); err != nil {
return err return err
} }
ident.Signatures = append(ident.Signatures, sig) ident.Signatures = append(ident.Signatures, sig)

View File

@ -7,7 +7,7 @@ package packet
import ( import (
"compress/flate" "compress/flate"
"compress/zlib" "compress/zlib"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"io" "io"
"strconv" "strconv"
) )
@ -31,7 +31,7 @@ func (c *Compressed) parse(r io.Reader) error {
case 2: case 2:
c.Body, err = zlib.NewReader(r) c.Body, err = zlib.NewReader(r)
default: default:
err = error_.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0]))) err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
} }
return err return err

View File

@ -6,7 +6,7 @@ package packet
import ( import (
"crypto/openpgp/elgamal" "crypto/openpgp/elgamal"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/binary" "encoding/binary"
@ -35,7 +35,7 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != encryptedKeyVersion { if buf[0] != encryptedKeyVersion {
return error_.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0]))) return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
} }
e.KeyId = binary.BigEndian.Uint64(buf[1:9]) e.KeyId = binary.BigEndian.Uint64(buf[1:9])
e.Algo = PublicKeyAlgorithm(buf[9]) e.Algo = PublicKeyAlgorithm(buf[9])
@ -77,7 +77,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
c2 := new(big.Int).SetBytes(e.encryptedMPI2) c2 := new(big.Int).SetBytes(e.encryptedMPI2)
b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2) b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
default: default:
err = error_.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) err = errors.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
} }
if err != nil { if err != nil {
@ -89,7 +89,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1]) expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
checksum := checksumKeyMaterial(e.Key) checksum := checksumKeyMaterial(e.Key)
if checksum != expectedChecksum { if checksum != expectedChecksum {
return error_.StructuralError("EncryptedKey checksum incorrect") return errors.StructuralError("EncryptedKey checksum incorrect")
} }
return nil return nil
@ -116,16 +116,16 @@ func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFu
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock) return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
return error_.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
} }
return error_.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
} }
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error { func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error {
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock) cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
if err != nil { if err != nil {
return error_.InvalidArgumentError("RSA encryption failed: " + err.Error()) return errors.InvalidArgumentError("RSA encryption failed: " + err.Error())
} }
packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText) packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
@ -144,7 +144,7 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error { func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error {
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock) c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
if err != nil { if err != nil {
return error_.InvalidArgumentError("ElGamal encryption failed: " + err.Error()) return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
} }
packetLen := 10 /* header length */ packetLen := 10 /* header length */

View File

@ -6,7 +6,7 @@ package packet
import ( import (
"crypto" "crypto"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"encoding/binary" "encoding/binary"
"io" "io"
@ -33,13 +33,13 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != onePassSignatureVersion { if buf[0] != onePassSignatureVersion {
err = error_.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0]))) err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
} }
var ok bool var ok bool
ops.Hash, ok = s2k.HashIdToHash(buf[2]) ops.Hash, ok = s2k.HashIdToHash(buf[2])
if !ok { if !ok {
return error_.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2]))) return errors.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
} }
ops.SigType = SignatureType(buf[1]) ops.SigType = SignatureType(buf[1])
@ -57,7 +57,7 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error {
var ok bool var ok bool
buf[2], ok = s2k.HashToHashId(ops.Hash) buf[2], ok = s2k.HashToHashId(ops.Hash)
if !ok { if !ok {
return error_.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash))) return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
} }
buf[3] = uint8(ops.PubKeyAlgo) buf[3] = uint8(ops.PubKeyAlgo)
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId) binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)

View File

@ -10,7 +10,7 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cast5" "crypto/cast5"
"crypto/cipher" "crypto/cipher"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"io" "io"
"math/big" "math/big"
) )
@ -162,7 +162,7 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
return return
} }
if buf[0]&0x80 == 0 { if buf[0]&0x80 == 0 {
err = error_.StructuralError("tag byte does not have MSB set") err = errors.StructuralError("tag byte does not have MSB set")
return return
} }
if buf[0]&0x40 == 0 { if buf[0]&0x40 == 0 {
@ -337,7 +337,7 @@ func Read(r io.Reader) (p Packet, err error) {
se.MDC = true se.MDC = true
p = se p = se
default: default:
err = error_.UnknownPacketTypeError(tag) err = errors.UnknownPacketTypeError(tag)
} }
if p != nil { if p != nil {
err = p.parse(contents) err = p.parse(contents)

View File

@ -6,7 +6,7 @@ package packet
import ( import (
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
@ -152,7 +152,7 @@ func TestReadHeader(t *testing.T) {
for i, test := range readHeaderTests { for i, test := range readHeaderTests {
tag, length, contents, err := readHeader(readerFromHex(test.hexInput)) tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
if test.structuralError { if test.structuralError {
if _, ok := err.(error_.StructuralError); ok { if _, ok := err.(errors.StructuralError); ok {
continue continue
} }
t.Errorf("%d: expected StructuralError, got:%s", i, err) t.Errorf("%d: expected StructuralError, got:%s", i, err)

View File

@ -9,7 +9,7 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/elgamal" "crypto/openpgp/elgamal"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
@ -28,14 +28,21 @@ type PrivateKey struct {
encryptedData []byte encryptedData []byte
cipher CipherFunction cipher CipherFunction
s2k func(out, in []byte) s2k func(out, in []byte)
PrivateKey interface{} // An *rsa.PrivateKey. PrivateKey interface{} // An *rsa.PrivateKey or *dsa.PrivateKey.
sha1Checksum bool sha1Checksum bool
iv []byte iv []byte
} }
func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey { func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey) *PrivateKey {
pk := new(PrivateKey) pk := new(PrivateKey)
pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey, isSubkey) pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
func NewDSAPrivateKey(currentTime time.Time, priv *dsa.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewDSAPublicKey(currentTime, &priv.PublicKey)
pk.PrivateKey = priv pk.PrivateKey = priv
return pk return pk
} }
@ -72,13 +79,13 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
pk.sha1Checksum = true pk.sha1Checksum = true
} }
default: default:
return error_.UnsupportedError("deprecated s2k function in private key") return errors.UnsupportedError("deprecated s2k function in private key")
} }
if pk.Encrypted { if pk.Encrypted {
blockSize := pk.cipher.blockSize() blockSize := pk.cipher.blockSize()
if blockSize == 0 { if blockSize == 0 {
return error_.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher))) return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
} }
pk.iv = make([]byte, blockSize) pk.iv = make([]byte, blockSize)
_, err = readFull(r, pk.iv) _, err = readFull(r, pk.iv)
@ -121,8 +128,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
switch priv := pk.PrivateKey.(type) { switch priv := pk.PrivateKey.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
err = serializeRSAPrivateKey(privateKeyBuf, priv) err = serializeRSAPrivateKey(privateKeyBuf, priv)
case *dsa.PrivateKey:
err = serializeDSAPrivateKey(privateKeyBuf, priv)
default: default:
err = error_.InvalidArgumentError("non-RSA private key") err = errors.InvalidArgumentError("unknown private key type")
} }
if err != nil { if err != nil {
return return
@ -172,6 +181,10 @@ func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) error {
return writeBig(w, priv.Precomputed.Qinv) return writeBig(w, priv.Precomputed.Qinv)
} }
func serializeDSAPrivateKey(w io.Writer, priv *dsa.PrivateKey) error {
return writeBig(w, priv.X)
}
// Decrypt decrypts an encrypted private key using a passphrase. // Decrypt decrypts an encrypted private key using a passphrase.
func (pk *PrivateKey) Decrypt(passphrase []byte) error { func (pk *PrivateKey) Decrypt(passphrase []byte) error {
if !pk.Encrypted { if !pk.Encrypted {
@ -188,18 +201,18 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
if pk.sha1Checksum { if pk.sha1Checksum {
if len(data) < sha1.Size { if len(data) < sha1.Size {
return error_.StructuralError("truncated private key data") return errors.StructuralError("truncated private key data")
} }
h := sha1.New() h := sha1.New()
h.Write(data[:len(data)-sha1.Size]) h.Write(data[:len(data)-sha1.Size])
sum := h.Sum(nil) sum := h.Sum(nil)
if !bytes.Equal(sum, data[len(data)-sha1.Size:]) { if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
return error_.StructuralError("private key checksum failure") return errors.StructuralError("private key checksum failure")
} }
data = data[:len(data)-sha1.Size] data = data[:len(data)-sha1.Size]
} else { } else {
if len(data) < 2 { if len(data) < 2 {
return error_.StructuralError("truncated private key data") return errors.StructuralError("truncated private key data")
} }
var sum uint16 var sum uint16
for i := 0; i < len(data)-2; i++ { for i := 0; i < len(data)-2; i++ {
@ -207,7 +220,7 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
} }
if data[len(data)-2] != uint8(sum>>8) || if data[len(data)-2] != uint8(sum>>8) ||
data[len(data)-1] != uint8(sum) { data[len(data)-1] != uint8(sum) {
return error_.StructuralError("private key checksum failure") return errors.StructuralError("private key checksum failure")
} }
data = data[:len(data)-2] data = data[:len(data)-2]
} }

View File

@ -7,7 +7,7 @@ package packet
import ( import (
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/elgamal" "crypto/openpgp/elgamal"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"encoding/binary" "encoding/binary"
@ -39,12 +39,11 @@ func fromBig(n *big.Int) parsedMPI {
} }
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey. // NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool) *PublicKey { func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey) *PublicKey {
pk := &PublicKey{ pk := &PublicKey{
CreationTime: creationTime, CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoRSA, PubKeyAlgo: PubKeyAlgoRSA,
PublicKey: pub, PublicKey: pub,
IsSubkey: isSubkey,
n: fromBig(pub.N), n: fromBig(pub.N),
e: fromBig(big.NewInt(int64(pub.E))), e: fromBig(big.NewInt(int64(pub.E))),
} }
@ -53,6 +52,22 @@ func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool)
return pk return pk
} }
// NewDSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewDSAPublicKey(creationTime time.Time, pub *dsa.PublicKey) *PublicKey {
pk := &PublicKey{
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoDSA,
PublicKey: pub,
p: fromBig(pub.P),
q: fromBig(pub.Q),
g: fromBig(pub.G),
y: fromBig(pub.Y),
}
pk.setFingerPrintAndKeyId()
return pk
}
func (pk *PublicKey) parse(r io.Reader) (err error) { func (pk *PublicKey) parse(r io.Reader) (err error) {
// RFC 4880, section 5.5.2 // RFC 4880, section 5.5.2
var buf [6]byte var buf [6]byte
@ -61,7 +76,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != 4 { if buf[0] != 4 {
return error_.UnsupportedError("public key version") return errors.UnsupportedError("public key version")
} }
pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0) pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5]) pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
@ -73,7 +88,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
err = pk.parseElGamal(r) err = pk.parseElGamal(r)
default: default:
err = error_.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
} }
if err != nil { if err != nil {
return return
@ -105,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err error) {
} }
if len(pk.e.bytes) > 3 { if len(pk.e.bytes) > 3 {
err = error_.UnsupportedError("large public exponent") err = errors.UnsupportedError("large public exponent")
return return
} }
rsa := &rsa.PublicKey{ rsa := &rsa.PublicKey{
@ -255,7 +270,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
return writeMPIs(w, pk.p, pk.g, pk.y) return writeMPIs(w, pk.p, pk.g, pk.y)
} }
return error_.InvalidArgumentError("bad public-key algorithm") return errors.InvalidArgumentError("bad public-key algorithm")
} }
// CanSign returns true iff this public key can generate signatures // CanSign returns true iff this public key can generate signatures
@ -267,18 +282,18 @@ func (pk *PublicKey) CanSign() bool {
// public key, of the data hashed into signed. signed is mutated by this call. // public key, of the data hashed into signed. signed is mutated by this call.
func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) { func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) {
if !pk.CanSign() { if !pk.CanSign() {
return error_.InvalidArgumentError("public key cannot generate signatures") return errors.InvalidArgumentError("public key cannot generate signatures")
} }
signed.Write(sig.HashSuffix) signed.Write(sig.HashSuffix)
hashBytes := signed.Sum(nil) hashBytes := signed.Sum(nil)
if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] { if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
return error_.SignatureError("hash tag doesn't match") return errors.SignatureError("hash tag doesn't match")
} }
if pk.PubKeyAlgo != sig.PubKeyAlgo { if pk.PubKeyAlgo != sig.PubKeyAlgo {
return error_.InvalidArgumentError("public key and signature use different algorithms") return errors.InvalidArgumentError("public key and signature use different algorithms")
} }
switch pk.PubKeyAlgo { switch pk.PubKeyAlgo {
@ -286,13 +301,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey) rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes) err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
if err != nil { if err != nil {
return error_.SignatureError("RSA verification failure") return errors.SignatureError("RSA verification failure")
} }
return nil return nil
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey) dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
// Need to truncate hashBytes to match FIPS 186-3 section 4.6.
subgroupSize := (dsaPublicKey.Q.BitLen() + 7) / 8
if len(hashBytes) > subgroupSize {
hashBytes = hashBytes[:subgroupSize]
}
if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) { if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
return error_.SignatureError("DSA verification failure") return errors.SignatureError("DSA verification failure")
} }
return nil return nil
default: default:
@ -306,7 +326,7 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err error) { func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err error) {
h = sig.Hash.New() h = sig.Hash.New()
if h == nil { if h == nil {
return nil, error_.UnsupportedError("hash function") return nil, errors.UnsupportedError("hash function")
} }
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
@ -332,7 +352,7 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err
func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err error) { func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err error) {
h = sig.Hash.New() h = sig.Hash.New()
if h == nil { if h == nil {
return nil, error_.UnsupportedError("hash function") return nil, errors.UnsupportedError("hash function")
} }
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4

View File

@ -5,7 +5,7 @@
package packet package packet
import ( import (
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"io" "io"
) )
@ -34,7 +34,7 @@ func (r *Reader) Next() (p Packet, err error) {
r.readers = r.readers[:len(r.readers)-1] r.readers = r.readers[:len(r.readers)-1]
continue continue
} }
if _, ok := err.(error_.UnknownPacketTypeError); !ok { if _, ok := err.(errors.UnknownPacketTypeError); !ok {
return nil, err return nil, err
} }
} }

View File

@ -7,9 +7,8 @@ package packet
import ( import (
"crypto" "crypto"
"crypto/dsa" "crypto/dsa"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/binary" "encoding/binary"
"hash" "hash"
@ -61,7 +60,7 @@ func (sig *Signature) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != 4 { if buf[0] != 4 {
err = error_.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0]))) err = errors.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0])))
return return
} }
@ -74,14 +73,14 @@ func (sig *Signature) parse(r io.Reader) (err error) {
switch sig.PubKeyAlgo { switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
default: default:
err = error_.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo))) err = errors.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
return return
} }
var ok bool var ok bool
sig.Hash, ok = s2k.HashIdToHash(buf[2]) sig.Hash, ok = s2k.HashIdToHash(buf[2])
if !ok { if !ok {
return error_.UnsupportedError("hash function " + strconv.Itoa(int(buf[2]))) return errors.UnsupportedError("hash function " + strconv.Itoa(int(buf[2])))
} }
hashedSubpacketsLength := int(buf[3])<<8 | int(buf[4]) hashedSubpacketsLength := int(buf[3])<<8 | int(buf[4])
@ -153,7 +152,7 @@ func parseSignatureSubpackets(sig *Signature, subpackets []byte, isHashed bool)
} }
if sig.CreationTime.IsZero() { if sig.CreationTime.IsZero() {
err = error_.StructuralError("no creation time in signature") err = errors.StructuralError("no creation time in signature")
} }
return return
@ -164,7 +163,7 @@ type signatureSubpacketType uint8
const ( const (
creationTimeSubpacket signatureSubpacketType = 2 creationTimeSubpacket signatureSubpacketType = 2
signatureExpirationSubpacket signatureSubpacketType = 3 signatureExpirationSubpacket signatureSubpacketType = 3
keyExpirySubpacket signatureSubpacketType = 9 keyExpirationSubpacket signatureSubpacketType = 9
prefSymmetricAlgosSubpacket signatureSubpacketType = 11 prefSymmetricAlgosSubpacket signatureSubpacketType = 11
issuerSubpacket signatureSubpacketType = 16 issuerSubpacket signatureSubpacketType = 16
prefHashAlgosSubpacket signatureSubpacketType = 21 prefHashAlgosSubpacket signatureSubpacketType = 21
@ -207,7 +206,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
rest = subpacket[length:] rest = subpacket[length:]
subpacket = subpacket[:length] subpacket = subpacket[:length]
if len(subpacket) == 0 { if len(subpacket) == 0 {
err = error_.StructuralError("zero length signature subpacket") err = errors.StructuralError("zero length signature subpacket")
return return
} }
packetType = signatureSubpacketType(subpacket[0] & 0x7f) packetType = signatureSubpacketType(subpacket[0] & 0x7f)
@ -217,37 +216,33 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
switch packetType { switch packetType {
case creationTimeSubpacket: case creationTimeSubpacket:
if !isHashed { if !isHashed {
err = error_.StructuralError("signature creation time in non-hashed area") err = errors.StructuralError("signature creation time in non-hashed area")
return return
} }
if len(subpacket) != 4 { if len(subpacket) != 4 {
err = error_.StructuralError("signature creation time not four bytes") err = errors.StructuralError("signature creation time not four bytes")
return return
} }
t := binary.BigEndian.Uint32(subpacket) t := binary.BigEndian.Uint32(subpacket)
if t == 0 { sig.CreationTime = time.Unix(int64(t), 0)
sig.CreationTime = time.Time{}
} else {
sig.CreationTime = time.Unix(int64(t), 0)
}
case signatureExpirationSubpacket: case signatureExpirationSubpacket:
// Signature expiration time, section 5.2.3.10 // Signature expiration time, section 5.2.3.10
if !isHashed { if !isHashed {
return return
} }
if len(subpacket) != 4 { if len(subpacket) != 4 {
err = error_.StructuralError("expiration subpacket with bad length") err = errors.StructuralError("expiration subpacket with bad length")
return return
} }
sig.SigLifetimeSecs = new(uint32) sig.SigLifetimeSecs = new(uint32)
*sig.SigLifetimeSecs = binary.BigEndian.Uint32(subpacket) *sig.SigLifetimeSecs = binary.BigEndian.Uint32(subpacket)
case keyExpirySubpacket: case keyExpirationSubpacket:
// Key expiration time, section 5.2.3.6 // Key expiration time, section 5.2.3.6
if !isHashed { if !isHashed {
return return
} }
if len(subpacket) != 4 { if len(subpacket) != 4 {
err = error_.StructuralError("key expiration subpacket with bad length") err = errors.StructuralError("key expiration subpacket with bad length")
return return
} }
sig.KeyLifetimeSecs = new(uint32) sig.KeyLifetimeSecs = new(uint32)
@ -262,7 +257,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
case issuerSubpacket: case issuerSubpacket:
// Issuer, section 5.2.3.5 // Issuer, section 5.2.3.5
if len(subpacket) != 8 { if len(subpacket) != 8 {
err = error_.StructuralError("issuer subpacket with bad length") err = errors.StructuralError("issuer subpacket with bad length")
return return
} }
sig.IssuerKeyId = new(uint64) sig.IssuerKeyId = new(uint64)
@ -287,7 +282,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
return return
} }
if len(subpacket) != 1 { if len(subpacket) != 1 {
err = error_.StructuralError("primary user id subpacket with bad length") err = errors.StructuralError("primary user id subpacket with bad length")
return return
} }
sig.IsPrimaryId = new(bool) sig.IsPrimaryId = new(bool)
@ -300,7 +295,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
return return
} }
if len(subpacket) == 0 { if len(subpacket) == 0 {
err = error_.StructuralError("empty key flags subpacket") err = errors.StructuralError("empty key flags subpacket")
return return
} }
sig.FlagsValid = true sig.FlagsValid = true
@ -319,14 +314,14 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
default: default:
if isCritical { if isCritical {
err = error_.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType))) err = errors.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
return return
} }
} }
return return
Truncated: Truncated:
err = error_.StructuralError("signature subpacket truncated") err = errors.StructuralError("signature subpacket truncated")
return return
} }
@ -401,7 +396,7 @@ func (sig *Signature) buildHashSuffix() (err error) {
sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash) sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash)
if !ok { if !ok {
sig.HashSuffix = nil sig.HashSuffix = nil
return error_.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash))) return errors.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
} }
sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8) sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
sig.HashSuffix[5] = byte(hashedSubpacketsLen) sig.HashSuffix[5] = byte(hashedSubpacketsLen)
@ -431,7 +426,7 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err error) {
// Sign signs a message with a private key. The hash, h, must contain // Sign signs a message with a private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function. // the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out. // On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) { func (sig *Signature) Sign(rand io.Reader, h hash.Hash, priv *PrivateKey) (err error) {
sig.outSubpackets = sig.buildSubpackets() sig.outSubpackets = sig.buildSubpackets()
digest, err := sig.signPrepareHash(h) digest, err := sig.signPrepareHash(h)
if err != nil { if err != nil {
@ -440,10 +435,17 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
switch priv.PubKeyAlgo { switch priv.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest) sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest)
sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes)) sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes))
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
r, s, err := dsa.Sign(rand.Reader, priv.PrivateKey.(*dsa.PrivateKey), digest) dsaPriv := priv.PrivateKey.(*dsa.PrivateKey)
// Need to truncate hashBytes to match FIPS 186-3 section 4.6.
subgroupSize := (dsaPriv.Q.BitLen() + 7) / 8
if len(digest) > subgroupSize {
digest = digest[:subgroupSize]
}
r, s, err := dsa.Sign(rand, dsaPriv, digest)
if err == nil { if err == nil {
sig.DSASigR.bytes = r.Bytes() sig.DSASigR.bytes = r.Bytes()
sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes)) sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes))
@ -451,7 +453,7 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes)) sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes))
} }
default: default:
err = error_.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo))) err = errors.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
} }
return return
@ -460,22 +462,22 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
// SignUserId computes a signature from priv, asserting that pub is a valid // SignUserId computes a signature from priv, asserting that pub is a valid
// key for the identity id. On success, the signature is stored in sig. Call // key for the identity id. On success, the signature is stored in sig. Call
// Serialize to write it out. // Serialize to write it out.
func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey) error { func (sig *Signature) SignUserId(rand io.Reader, id string, pub *PublicKey, priv *PrivateKey) error {
h, err := userIdSignatureHash(id, pub, sig) h, err := userIdSignatureHash(id, pub, sig)
if err != nil { if err != nil {
return nil return nil
} }
return sig.Sign(h, priv) return sig.Sign(rand, h, priv)
} }
// SignKey computes a signature from priv, asserting that pub is a subkey. On // SignKey computes a signature from priv, asserting that pub is a subkey. On
// success, the signature is stored in sig. Call Serialize to write it out. // success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey) error { func (sig *Signature) SignKey(rand io.Reader, pub *PublicKey, priv *PrivateKey) error {
h, err := keySignatureHash(&priv.PublicKey, pub, sig) h, err := keySignatureHash(&priv.PublicKey, pub, sig)
if err != nil { if err != nil {
return err return err
} }
return sig.Sign(h, priv) return sig.Sign(rand, h, priv)
} }
// Serialize marshals sig to w. SignRSA or SignDSA must have been called first. // Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
@ -484,7 +486,7 @@ func (sig *Signature) Serialize(w io.Writer) (err error) {
sig.outSubpackets = sig.rawSubpackets sig.outSubpackets = sig.rawSubpackets
} }
if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil { if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil {
return error_.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize") return errors.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
} }
sigLength := 0 sigLength := 0
@ -556,5 +558,54 @@ func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId}) subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId})
} }
if sig.SigLifetimeSecs != nil && *sig.SigLifetimeSecs != 0 {
sigLifetime := make([]byte, 4)
binary.BigEndian.PutUint32(sigLifetime, *sig.SigLifetimeSecs)
subpackets = append(subpackets, outputSubpacket{true, signatureExpirationSubpacket, true, sigLifetime})
}
// Key flags may only appear in self-signatures or certification signatures.
if sig.FlagsValid {
var flags byte
if sig.FlagCertify {
flags |= 1
}
if sig.FlagSign {
flags |= 2
}
if sig.FlagEncryptCommunications {
flags |= 4
}
if sig.FlagEncryptStorage {
flags |= 8
}
subpackets = append(subpackets, outputSubpacket{true, keyFlagsSubpacket, false, []byte{flags}})
}
// The following subpackets may only appear in self-signatures
if sig.KeyLifetimeSecs != nil && *sig.KeyLifetimeSecs != 0 {
keyLifetime := make([]byte, 4)
binary.BigEndian.PutUint32(keyLifetime, *sig.KeyLifetimeSecs)
subpackets = append(subpackets, outputSubpacket{true, keyExpirationSubpacket, true, keyLifetime})
}
if sig.IsPrimaryId != nil && *sig.IsPrimaryId {
subpackets = append(subpackets, outputSubpacket{true, primaryUserIdSubpacket, false, []byte{1}})
}
if len(sig.PreferredSymmetric) > 0 {
subpackets = append(subpackets, outputSubpacket{true, prefSymmetricAlgosSubpacket, false, sig.PreferredSymmetric})
}
if len(sig.PreferredHash) > 0 {
subpackets = append(subpackets, outputSubpacket{true, prefHashAlgosSubpacket, false, sig.PreferredHash})
}
if len(sig.PreferredCompression) > 0 {
subpackets = append(subpackets, outputSubpacket{true, prefCompressionSubpacket, false, sig.PreferredCompression})
}
return return
} }

View File

@ -7,7 +7,7 @@ package packet
import ( import (
"bytes" "bytes"
"crypto/cipher" "crypto/cipher"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"io" "io"
"strconv" "strconv"
@ -37,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != symmetricKeyEncryptedVersion { if buf[0] != symmetricKeyEncryptedVersion {
return error_.UnsupportedError("SymmetricKeyEncrypted version") return errors.UnsupportedError("SymmetricKeyEncrypted version")
} }
ske.CipherFunc = CipherFunction(buf[1]) ske.CipherFunc = CipherFunction(buf[1])
if ske.CipherFunc.KeySize() == 0 { if ske.CipherFunc.KeySize() == 0 {
return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1]))) return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
} }
ske.s2k, err = s2k.Parse(r) ske.s2k, err = s2k.Parse(r)
@ -60,7 +60,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
err = nil err = nil
if n != 0 { if n != 0 {
if n == maxSessionKeySizeInBytes { if n == maxSessionKeySizeInBytes {
return error_.UnsupportedError("oversized encrypted session key") return errors.UnsupportedError("oversized encrypted session key")
} }
ske.encryptedKey = encryptedKey[:n] ske.encryptedKey = encryptedKey[:n]
} }
@ -89,13 +89,13 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
c.XORKeyStream(ske.encryptedKey, ske.encryptedKey) c.XORKeyStream(ske.encryptedKey, ske.encryptedKey)
ske.CipherFunc = CipherFunction(ske.encryptedKey[0]) ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
if ske.CipherFunc.blockSize() == 0 { if ske.CipherFunc.blockSize() == 0 {
return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc))) return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
} }
ske.CipherFunc = CipherFunction(ske.encryptedKey[0]) ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
ske.Key = ske.encryptedKey[1:] ske.Key = ske.encryptedKey[1:]
if len(ske.Key)%ske.CipherFunc.blockSize() != 0 { if len(ske.Key)%ske.CipherFunc.blockSize() != 0 {
ske.Key = nil ske.Key = nil
return error_.StructuralError("length of decrypted key not a multiple of block size") return errors.StructuralError("length of decrypted key not a multiple of block size")
} }
} }
@ -110,7 +110,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err error) { func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err error) {
keySize := cipherFunc.KeySize() keySize := cipherFunc.KeySize()
if keySize == 0 { if keySize == 0 {
return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc))) return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
} }
s2kBuf := new(bytes.Buffer) s2kBuf := new(bytes.Buffer)

View File

@ -6,8 +6,7 @@ package packet
import ( import (
"crypto/cipher" "crypto/cipher"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/subtle" "crypto/subtle"
"hash" "hash"
@ -35,7 +34,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
return err return err
} }
if buf[0] != symmetricallyEncryptedVersion { if buf[0] != symmetricallyEncryptedVersion {
return error_.UnsupportedError("unknown SymmetricallyEncrypted version") return errors.UnsupportedError("unknown SymmetricallyEncrypted version")
} }
} }
se.contents = r se.contents = r
@ -48,10 +47,10 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) { func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) {
keySize := c.KeySize() keySize := c.KeySize()
if keySize == 0 { if keySize == 0 {
return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c))) return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
} }
if len(key) != keySize { if len(key) != keySize {
return nil, error_.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length") return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
} }
if se.prefix == nil { if se.prefix == nil {
@ -61,7 +60,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
return nil, err return nil, err
} }
} else if len(se.prefix) != c.blockSize()+2 { } else if len(se.prefix) != c.blockSize()+2 {
return nil, error_.InvalidArgumentError("can't try ciphers with different block lengths") return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
} }
ocfbResync := cipher.OCFBResync ocfbResync := cipher.OCFBResync
@ -72,7 +71,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
s := cipher.NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync) s := cipher.NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
if s == nil { if s == nil {
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
plaintext := cipher.StreamReader{S: s, R: se.contents} plaintext := cipher.StreamReader{S: s, R: se.contents}
@ -181,7 +180,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19
func (ser *seMDCReader) Close() error { func (ser *seMDCReader) Close() error {
if ser.error { if ser.error {
return error_.SignatureError("error during reading") return errors.SignatureError("error during reading")
} }
for !ser.eof { for !ser.eof {
@ -192,18 +191,18 @@ func (ser *seMDCReader) Close() error {
break break
} }
if err != nil { if err != nil {
return error_.SignatureError("error during reading") return errors.SignatureError("error during reading")
} }
} }
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size { if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
return error_.SignatureError("MDC packet not found") return errors.SignatureError("MDC packet not found")
} }
ser.h.Write(ser.trailer[:2]) ser.h.Write(ser.trailer[:2])
final := ser.h.Sum(nil) final := ser.h.Sum(nil)
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 { if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
return error_.SignatureError("hash mismatch") return errors.SignatureError("hash mismatch")
} }
return nil return nil
} }
@ -253,9 +252,9 @@ func (c noOpCloser) Close() error {
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet // SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
// to w and returns a WriteCloser to which the to-be-encrypted packets can be // to w and returns a WriteCloser to which the to-be-encrypted packets can be
// written. // written.
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err error) { func SerializeSymmetricallyEncrypted(w io.Writer, rand io.Reader, c CipherFunction, key []byte) (contents io.WriteCloser, err error) {
if c.KeySize() != len(key) { if c.KeySize() != len(key) {
return nil, error_.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length") return nil, errors.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
} }
writeCloser := noOpCloser{w} writeCloser := noOpCloser{w}
ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC) ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
@ -271,7 +270,7 @@ func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte)
block := c.new(key) block := c.new(key)
blockSize := block.BlockSize() blockSize := block.BlockSize()
iv := make([]byte, blockSize) iv := make([]byte, blockSize)
_, err = rand.Reader.Read(iv) _, err = rand.Read(iv)
if err != nil { if err != nil {
return return
} }

View File

@ -6,7 +6,8 @@ package packet
import ( import (
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"io" "io"
@ -70,7 +71,7 @@ func testMDCReader(t *testing.T) {
err = mdcReader.Close() err = mdcReader.Close()
if err == nil { if err == nil {
t.Error("corruption: no error") t.Error("corruption: no error")
} else if _, ok := err.(*error_.SignatureError); !ok { } else if _, ok := err.(*errors.SignatureError); !ok {
t.Errorf("corruption: expected SignatureError, got: %s", err) t.Errorf("corruption: expected SignatureError, got: %s", err)
} }
} }
@ -82,7 +83,7 @@ func TestSerialize(t *testing.T) {
c := CipherAES128 c := CipherAES128
key := make([]byte, c.KeySize()) key := make([]byte, c.KeySize())
w, err := SerializeSymmetricallyEncrypted(buf, c, key) w, err := SerializeSymmetricallyEncrypted(buf, rand.Reader, c, key)
if err != nil { if err != nil {
t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err) t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
return return

View File

@ -8,7 +8,7 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/packet" "crypto/openpgp/packet"
_ "crypto/sha256" _ "crypto/sha256"
"hash" "hash"
@ -27,7 +27,7 @@ func readArmored(r io.Reader, expectedType string) (body io.Reader, err error) {
} }
if block.Type != expectedType { if block.Type != expectedType {
return nil, error_.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type) return nil, errors.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
} }
return block.Body, nil return block.Body, nil
@ -130,7 +130,7 @@ ParsePackets:
case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature: case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature:
// This message isn't encrypted. // This message isn't encrypted.
if len(symKeys) != 0 || len(pubKeys) != 0 { if len(symKeys) != 0 || len(pubKeys) != 0 {
return nil, error_.StructuralError("key material not followed by encrypted message") return nil, errors.StructuralError("key material not followed by encrypted message")
} }
packets.Unread(p) packets.Unread(p)
return readSignedMessage(packets, nil, keyring) return readSignedMessage(packets, nil, keyring)
@ -161,7 +161,7 @@ FindKey:
continue continue
} }
decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key) decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key)
if err != nil && err != error_.KeyIncorrectError { if err != nil && err != errors.KeyIncorrectError {
return nil, err return nil, err
} }
if decrypted != nil { if decrypted != nil {
@ -179,11 +179,11 @@ FindKey:
} }
if len(candidates) == 0 && len(symKeys) == 0 { if len(candidates) == 0 && len(symKeys) == 0 {
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
if prompt == nil { if prompt == nil {
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
passphrase, err := prompt(candidates, len(symKeys) != 0) passphrase, err := prompt(candidates, len(symKeys) != 0)
@ -197,7 +197,7 @@ FindKey:
err = s.Decrypt(passphrase) err = s.Decrypt(passphrase)
if err == nil && !s.Encrypted { if err == nil && !s.Encrypted {
decrypted, err = se.Decrypt(s.CipherFunc, s.Key) decrypted, err = se.Decrypt(s.CipherFunc, s.Key)
if err != nil && err != error_.KeyIncorrectError { if err != nil && err != errors.KeyIncorrectError {
return nil, err return nil, err
} }
if decrypted != nil { if decrypted != nil {
@ -237,7 +237,7 @@ FindLiteralData:
packets.Push(p.Body) packets.Push(p.Body)
case *packet.OnePassSignature: case *packet.OnePassSignature:
if !p.IsLast { if !p.IsLast {
return nil, error_.UnsupportedError("nested signatures") return nil, errors.UnsupportedError("nested signatures")
} }
h, wrappedHash, err = hashForSignature(p.Hash, p.SigType) h, wrappedHash, err = hashForSignature(p.Hash, p.SigType)
@ -281,7 +281,7 @@ FindLiteralData:
func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) { func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) {
h := hashId.New() h := hashId.New()
if h == nil { if h == nil {
return nil, nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId))) return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
} }
switch sigType { switch sigType {
@ -291,7 +291,7 @@ func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Ha
return h, NewCanonicalTextHash(h), nil return h, NewCanonicalTextHash(h), nil
} }
return nil, nil, error_.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType))) return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
} }
// checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF // checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF
@ -333,7 +333,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) {
var ok bool var ok bool
if scr.md.Signature, ok = p.(*packet.Signature); !ok { if scr.md.Signature, ok = p.(*packet.Signature); !ok {
scr.md.SignatureError = error_.StructuralError("LiteralData not followed by Signature") scr.md.SignatureError = errors.StructuralError("LiteralData not followed by Signature")
return return
} }
@ -363,16 +363,16 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
sig, ok := p.(*packet.Signature) sig, ok := p.(*packet.Signature)
if !ok { if !ok {
return nil, error_.StructuralError("non signature packet found") return nil, errors.StructuralError("non signature packet found")
} }
if sig.IssuerKeyId == nil { if sig.IssuerKeyId == nil {
return nil, error_.StructuralError("signature doesn't have an issuer") return nil, errors.StructuralError("signature doesn't have an issuer")
} }
keys := keyring.KeysById(*sig.IssuerKeyId) keys := keyring.KeysById(*sig.IssuerKeyId)
if len(keys) == 0 { if len(keys) == 0 {
return nil, error_.UnknownIssuerError return nil, errors.UnknownIssuerError
} }
h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType) h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType)
@ -399,7 +399,7 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
return return
} }
return nil, error_.UnknownIssuerError return nil, errors.UnknownIssuerError
} }
// CheckArmoredDetachedSignature performs the same actions as // CheckArmoredDetachedSignature performs the same actions as

View File

@ -6,7 +6,8 @@ package openpgp
import ( import (
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
_ "crypto/sha512"
"encoding/hex" "encoding/hex"
"io" "io"
"io/ioutil" "io/ioutil"
@ -77,6 +78,15 @@ func TestReadDSAKey(t *testing.T) {
} }
} }
func TestDSAHashTruncatation(t *testing.T) {
// dsaKeyWithSHA512 was generated with GnuPG and --cert-digest-algo
// SHA512 in order to require DSA hash truncation to verify correctly.
_, err := ReadKeyRing(readerFromHex(dsaKeyWithSHA512))
if err != nil {
t.Error(err)
}
}
func TestGetKeyById(t *testing.T) { func TestGetKeyById(t *testing.T) {
kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex)) kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex))
@ -151,18 +161,18 @@ func TestSignedEncryptedMessage(t *testing.T) {
prompt := func(keys []Key, symmetric bool) ([]byte, error) { prompt := func(keys []Key, symmetric bool) ([]byte, error) {
if symmetric { if symmetric {
t.Errorf("prompt: message was marked as symmetrically encrypted") t.Errorf("prompt: message was marked as symmetrically encrypted")
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
if len(keys) == 0 { if len(keys) == 0 {
t.Error("prompt: no keys requested") t.Error("prompt: no keys requested")
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
err := keys[0].PrivateKey.Decrypt([]byte("passphrase")) err := keys[0].PrivateKey.Decrypt([]byte("passphrase"))
if err != nil { if err != nil {
t.Errorf("prompt: error decrypting key: %s", err) t.Errorf("prompt: error decrypting key: %s", err)
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
return nil, nil return nil, nil
@ -286,7 +296,7 @@ func TestReadingArmoredPrivateKey(t *testing.T) {
func TestNoArmoredData(t *testing.T) { func TestNoArmoredData(t *testing.T) {
_, err := ReadArmoredKeyRing(bytes.NewBufferString("foo")) _, err := ReadArmoredKeyRing(bytes.NewBufferString("foo"))
if _, ok := err.(error_.InvalidArgumentError); !ok { if _, ok := err.(errors.InvalidArgumentError); !ok {
t.Errorf("error was not an InvalidArgumentError: %s", err) t.Errorf("error was not an InvalidArgumentError: %s", err)
} }
} }
@ -358,3 +368,5 @@ AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL
VrM0m72/jnpKo04= VrM0m72/jnpKo04=
=zNCn =zNCn
-----END PGP PRIVATE KEY BLOCK-----` -----END PGP PRIVATE KEY BLOCK-----`
const dsaKeyWithSHA512 = `9901a2044f04b07f110400db244efecc7316553ee08d179972aab87bb1214de7692593fcf5b6feb1c80fba268722dd464748539b85b81d574cd2d7ad0ca2444de4d849b8756bad7768c486c83a824f9bba4af773d11742bdfb4ac3b89ef8cc9452d4aad31a37e4b630d33927bff68e879284a1672659b8b298222fc68f370f3e24dccacc4a862442b9438b00a0ea444a24088dc23e26df7daf8f43cba3bffc4fe703fe3d6cd7fdca199d54ed8ae501c30e3ec7871ea9cdd4cf63cfe6fc82281d70a5b8bb493f922cd99fba5f088935596af087c8d818d5ec4d0b9afa7f070b3d7c1dd32a84fca08d8280b4890c8da1dde334de8e3cad8450eed2a4a4fcc2db7b8e5528b869a74a7f0189e11ef097ef1253582348de072bb07a9fa8ab838e993cef0ee203ff49298723e2d1f549b00559f886cd417a41692ce58d0ac1307dc71d85a8af21b0cf6eaa14baf2922d3a70389bedf17cc514ba0febbd107675a372fe84b90162a9e88b14d4b1c6be855b96b33fb198c46f058568817780435b6936167ebb3724b680f32bf27382ada2e37a879b3d9de2abe0c3f399350afd1ad438883f4791e2e3b4184453412068617368207472756e636174696f6e207465737488620413110a002205024f04b07f021b03060b090807030206150802090a0b0416020301021e01021780000a0910ef20e0cefca131581318009e2bf3bf047a44d75a9bacd00161ee04d435522397009a03a60d51bd8a568c6c021c8d7cf1be8d990d6417b0020003`

View File

@ -8,7 +8,7 @@ package s2k
import ( import (
"crypto" "crypto"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"hash" "hash"
"io" "io"
"strconv" "strconv"
@ -89,11 +89,11 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
hash, ok := HashIdToHash(buf[1]) hash, ok := HashIdToHash(buf[1])
if !ok { if !ok {
return nil, error_.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1]))) return nil, errors.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
} }
h := hash.New() h := hash.New()
if h == nil { if h == nil {
return nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hash))) return nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
} }
switch buf[0] { switch buf[0] {
@ -123,7 +123,7 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
return f, nil return f, nil
} }
return nil, error_.UnsupportedError("S2K function") return nil, errors.UnsupportedError("S2K function")
} }
// Serialize salts and stretches the given passphrase and writes the resulting // Serialize salts and stretches the given passphrase and writes the resulting

View File

@ -7,7 +7,7 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/packet" "crypto/openpgp/packet"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rand" "crypto/rand"
@ -58,10 +58,10 @@ func armoredDetachSign(w io.Writer, signer *Entity, message io.Reader, sigType p
func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType) (err error) { func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType) (err error) {
if signer.PrivateKey == nil { if signer.PrivateKey == nil {
return error_.InvalidArgumentError("signing key doesn't have a private key") return errors.InvalidArgumentError("signing key doesn't have a private key")
} }
if signer.PrivateKey.Encrypted { if signer.PrivateKey.Encrypted {
return error_.InvalidArgumentError("signing key is encrypted") return errors.InvalidArgumentError("signing key is encrypted")
} }
sig := new(packet.Signature) sig := new(packet.Signature)
@ -77,7 +77,7 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
} }
io.Copy(wrappedHash, message) io.Copy(wrappedHash, message)
err = sig.Sign(h, signer.PrivateKey) err = sig.Sign(rand.Reader, h, signer.PrivateKey)
if err != nil { if err != nil {
return return
} }
@ -111,7 +111,7 @@ func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHi
if err != nil { if err != nil {
return return
} }
w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, packet.CipherAES128, key) w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, packet.CipherAES128, key)
if err != nil { if err != nil {
return return
} }
@ -156,7 +156,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
if signed != nil { if signed != nil {
signer = signed.signingKey().PrivateKey signer = signed.signingKey().PrivateKey
if signer == nil || signer.Encrypted { if signer == nil || signer.Encrypted {
return nil, error_.InvalidArgumentError("signing key must be decrypted") return nil, errors.InvalidArgumentError("signing key must be decrypted")
} }
} }
@ -183,7 +183,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
for i := range to { for i := range to {
encryptKeys[i] = to[i].encryptionKey() encryptKeys[i] = to[i].encryptionKey()
if encryptKeys[i].PublicKey == nil { if encryptKeys[i].PublicKey == nil {
return nil, error_.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys") return nil, errors.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
} }
sig := to[i].primaryIdentity().SelfSignature sig := to[i].primaryIdentity().SelfSignature
@ -201,7 +201,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
} }
if len(candidateCiphers) == 0 || len(candidateHashes) == 0 { if len(candidateCiphers) == 0 || len(candidateHashes) == 0 {
return nil, error_.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms") return nil, errors.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
} }
cipher := packet.CipherFunction(candidateCiphers[0]) cipher := packet.CipherFunction(candidateCiphers[0])
@ -217,7 +217,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
} }
} }
encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, cipher, symKey) encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, cipher, symKey)
if err != nil { if err != nil {
return return
} }
@ -287,7 +287,7 @@ func (s signatureWriter) Close() error {
IssuerKeyId: &s.signer.KeyId, IssuerKeyId: &s.signer.KeyId,
} }
if err := sig.Sign(s.h, s.signer); err != nil { if err := sig.Sign(rand.Reader, s.h, s.signer); err != nil {
return err return err
} }
if err := s.literalData.Close(); err != nil { if err := s.literalData.Close(); err != nil {

View File

@ -222,7 +222,7 @@ func TestEncryption(t *testing.T) {
if test.isSigned { if test.isSigned {
if md.SignatureError != nil { if md.SignatureError != nil {
t.Errorf("#%d: signature error: %s", i, err) t.Errorf("#%d: signature error: %s", i, md.SignatureError)
} }
if md.Signature == nil { if md.Signature == nil {
t.Error("signature missing") t.Error("signature missing")

View File

@ -111,6 +111,18 @@ type ConnectionState struct {
VerifiedChains [][]*x509.Certificate VerifiedChains [][]*x509.Certificate
} }
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
const (
NoClientCert ClientAuthType = iota
RequestClientCert
RequireAnyClientCert
VerifyClientCertIfGiven
RequireAndVerifyClientCert
)
// A Config structure is used to configure a TLS client or server. After one // A Config structure is used to configure a TLS client or server. After one
// has been passed to a TLS function it must not be modified. // has been passed to a TLS function it must not be modified.
type Config struct { type Config struct {
@ -120,7 +132,7 @@ type Config struct {
Rand io.Reader Rand io.Reader
// Time returns the current time as the number of seconds since the epoch. // Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses the system time.Seconds. // If Time is nil, TLS uses time.Now.
Time func() time.Time Time func() time.Time
// Certificates contains one or more certificate chains // Certificates contains one or more certificate chains
@ -148,11 +160,14 @@ type Config struct {
// hosting. // hosting.
ServerName string ServerName string
// AuthenticateClient controls whether a server will request a certificate // ClientAuth determines the server's policy for
// from the client. It does not require that the client send a // TLS Client Authentication. The default is NoClientCert.
// certificate nor does it require that the certificate sent be ClientAuth ClientAuthType
// anything more than self-signed.
AuthenticateClient bool // ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the // InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name. // server's certificate chain and host name.
@ -259,6 +274,11 @@ type Certificate struct {
// OCSPStaple contains an optional OCSP response which will be served // OCSPStaple contains an optional OCSP response which will be served
// to clients that request it. // to clients that request it.
OCSPStaple []byte OCSPStaple []byte
// Leaf is the parsed form of the leaf certificate, which may be
// initialized using x509.ParseCertificate to reduce per-handshake
// processing for TLS clients doing client authentication. If nil, the
// leaf certificate will be parsed as needed.
Leaf *x509.Certificate
} }
// A TLS record. // A TLS record.

View File

@ -31,7 +31,7 @@ func main() {
return return
} }
now := time.Seconds() now := time.Now()
template := x509.Certificate{ template := x509.Certificate{
SerialNumber: new(big.Int).SetInt64(0), SerialNumber: new(big.Int).SetInt64(0),
@ -39,8 +39,8 @@ func main() {
CommonName: *hostName, CommonName: *hostName,
Organization: []string{"Acme Co"}, Organization: []string{"Acme Co"},
}, },
NotBefore: time.SecondsToUTC(now - 300), NotBefore: now.Add(-5 * time.Minute).UTC(),
NotAfter: time.SecondsToUTC(now + 60*60*24*365), // valid for 1 year. NotAfter: now.AddDate(1, 0, 0).UTC(), // valid for 1 year.
SubjectKeyId: []byte{1, 2, 3, 4}, SubjectKeyId: []byte{1, 2, 3, 4},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,

View File

@ -5,12 +5,14 @@
package tls package tls
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rsa" "crypto/rsa"
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
"errors" "errors"
"io" "io"
"strconv"
) )
func (c *Conn) clientHandshake() error { func (c *Conn) clientHandshake() error {
@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error {
} }
} }
transmitCert := false var certToSend *Certificate
certReq, ok := msg.(*certificateRequestMsg) certReq, ok := msg.(*certificateRequestMsg)
if ok { if ok {
// We only accept certificates with RSA keys. // RFC 4346 on the certificateAuthorities field:
// A list of the distinguished names of acceptable certificate
// authorities. These distinguished names may specify a desired
// distinguished name for a root CA or for a subordinate CA;
// thus, this message can be used to describe both known roots
// and a desired authorization space. If the
// certificate_authorities list is empty then the client MAY
// send any certificate of the appropriate
// ClientCertificateType, unless there is some external
// arrangement to the contrary.
finishedHash.Write(certReq.marshal())
// For now, we only know how to sign challenges with RSA
rsaAvail := false rsaAvail := false
for _, certType := range certReq.certificateTypes { for _, certType := range certReq.certificateTypes {
if certType == certTypeRSASign { if certType == certTypeRSASign {
@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error {
} }
} }
// For now, only send a certificate back if the server gives us an // We need to search our list of client certs for one
// empty list of certificateAuthorities. // where SignatureAlgorithm is RSA and the Issuer is in
// // certReq.certificateAuthorities
// RFC 4346 on the certificateAuthorities field: findCert:
// A list of the distinguished names of acceptable certificate for i, cert := range c.config.Certificates {
// authorities. These distinguished names may specify a desired if !rsaAvail {
// distinguished name for a root CA or for a subordinate CA; thus, continue
// this message can be used to describe both known roots and a }
// desired authorization space. If the certificate_authorities
// list is empty then the client MAY send any certificate of the
// appropriate ClientCertificateType, unless there is some
// external arrangement to the contrary.
if rsaAvail && len(certReq.certificateAuthorities) == 0 {
transmitCert = true
}
finishedHash.Write(certReq.marshal()) leaf := cert.Leaf
if leaf == nil {
if leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
}
}
if leaf.PublicKeyAlgorithm != x509.RSA {
continue
}
if len(certReq.certificateAuthorities) == 0 {
// they gave us an empty list, so just take the
// first RSA cert from c.config.Certificates
certToSend = &cert
break
}
for _, ca := range certReq.certificateAuthorities {
if bytes.Equal(leaf.RawIssuer, ca) {
certToSend = &cert
break findCert
}
}
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error {
} }
finishedHash.Write(shd.marshal()) finishedHash.Write(shd.marshal())
var cert *x509.Certificate if certToSend != nil {
if transmitCert {
certMsg = new(certificateMsg) certMsg = new(certificateMsg)
if len(c.config.Certificates) > 0 { certMsg.certificates = certToSend.Certificate
cert, err = x509.ParseCertificate(c.config.Certificates[0].Certificate[0])
if err == nil && cert.PublicKeyAlgorithm == x509.RSA {
certMsg.certificates = c.config.Certificates[0].Certificate
} else {
cert = nil
}
}
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal())
} }
@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error {
c.writeRecord(recordTypeHandshake, ckx.marshal()) c.writeRecord(recordTypeHandshake, ckx.marshal())
} }
if cert != nil { if certToSend != nil {
certVerify := new(certificateVerifyMsg) certVerify := new(certificateVerifyMsg)
digest := make([]byte, 0, 36) digest := make([]byte, 0, 36)
digest = finishedHash.serverMD5.Sum(digest) digest = finishedHash.serverMD5.Sum(digest)

View File

@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
// See http://tools.ietf.org/html/rfc4346#section-7.4.4 // See http://tools.ietf.org/html/rfc4346#section-7.4.4
length := 1 + len(m.certificateTypes) + 2 length := 1 + len(m.certificateTypes) + 2
casLength := 0
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
length += 2 + len(ca) casLength += 2 + len(ca)
} }
length += casLength
x = make([]byte, 4+length) x = make([]byte, 4+length)
x[0] = typeCertificateRequest x[0] = typeCertificateRequest
@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
copy(x[5:], m.certificateTypes) copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):] y := x[5+len(m.certificateTypes):]
y[0] = uint8(casLength >> 8)
numCA := len(m.certificateAuthorities) y[1] = uint8(casLength)
y[0] = uint8(numCA >> 8)
y[1] = uint8(numCA)
y = y[2:] y = y[2:]
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8) y[0] = uint8(len(ca) >> 8)
@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
} }
m.raw = x m.raw = x
return return
} }
@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
} }
data = data[numCertTypes:] data = data[numCertTypes:]
if len(data) < 2 { if len(data) < 2 {
return false return false
} }
casLength := uint16(data[0])<<8 | uint16(data[1])
numCAs := uint16(data[0])<<16 | uint16(data[1])
data = data[2:] data = data[2:]
if len(data) < int(casLength) {
m.certificateAuthorities = make([][]byte, numCAs) return false
for i := uint16(0); i < numCAs; i++ {
if len(data) < 2 {
return false
}
caLen := uint16(data[0])<<16 | uint16(data[1])
data = data[2:]
if len(data) < int(caLen) {
return false
}
ca := make([]byte, caLen)
copy(ca, data)
m.certificateAuthorities[i] = ca
data = data[caLen:]
} }
cas := make([]byte, casLength)
copy(cas, data)
data = data[casLength:]
m.certificateAuthorities = nil
for len(cas) > 0 {
if len(cas) < 2 {
return false
}
caLen := uint16(cas[0])<<8 | uint16(cas[1])
cas = cas[2:]
if len(cas) < int(caLen) {
return false
}
m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
cas = cas[caLen:]
}
if len(data) > 0 { if len(data) > 0 {
return false return false
} }

View File

@ -150,14 +150,19 @@ FindCipherSuite:
c.writeRecord(recordTypeHandshake, skx.marshal()) c.writeRecord(recordTypeHandshake, skx.marshal())
} }
if config.AuthenticateClient { if config.ClientAuth >= RequestClientCert {
// Request a client certificate // Request a client certificate
certReq := new(certificateRequestMsg) certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{certTypeRSASign} certReq.certificateTypes = []byte{certTypeRSASign}
// An empty list of certificateAuthorities signals to // An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response // the client that it may send any certificate in response
// to our request. // to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if config.ClientCAs != nil {
certReq.certificateAuthorities = config.ClientCAs.Subjects()
}
finishedHash.Write(certReq.marshal()) finishedHash.Write(certReq.marshal())
c.writeRecord(recordTypeHandshake, certReq.marshal()) c.writeRecord(recordTypeHandshake, certReq.marshal())
} }
@ -166,52 +171,87 @@ FindCipherSuite:
finishedHash.Write(helloDone.marshal()) finishedHash.Write(helloDone.marshal())
c.writeRecord(recordTypeHandshake, helloDone.marshal()) c.writeRecord(recordTypeHandshake, helloDone.marshal())
var pub *rsa.PublicKey var pub *rsa.PublicKey // public key for client auth, if any
if config.AuthenticateClient {
// Get client certificate
msg, err = c.readHandshake()
if err != nil {
return err
}
certMsg, ok = msg.(*certificateMsg)
if !ok {
return c.sendAlert(alertUnexpectedMessage)
}
finishedHash.Write(certMsg.marshal())
certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("could not parse client's certificate: " + err.Error())
}
certs[i] = cert
}
// TODO(agl): do better validation of certs: max path length, name restrictions etc.
for i := 1; i < len(certs); i++ {
if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("could not validate certificate signature: " + err.Error())
}
}
if len(certs) > 0 {
key, ok := certs[0].PublicKey.(*rsa.PublicKey)
if !ok {
return c.sendAlert(alertUnsupportedCertificate)
}
pub = key
c.peerCertificates = certs
}
}
// Get client key exchange
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
return err return err
} }
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if config.ClientAuth >= RequestClientCert {
if certMsg, ok = msg.(*certificateMsg); !ok {
return c.sendAlert(alertHandshakeFailure)
}
finishedHash.Write(certMsg.marshal())
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
switch config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
}
certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates {
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
}
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to verify client's certificate: " + err.Error())
}
ok := false
for _, ku := range certs[0].ExtKeyUsage {
if ku == x509.ExtKeyUsageClientAuth {
ok = true
break
}
}
if !ok {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication")
}
c.verifiedChains = chains
}
if len(certs) > 0 {
if pub, ok = certs[0].PublicKey.(*rsa.PublicKey); !ok {
return c.sendAlert(alertUnsupportedCertificate)
}
c.peerCertificates = certs
}
msg, err = c.readHandshake()
if err != nil {
return err
}
}
// Get client key exchange
ckx, ok := msg.(*clientKeyExchangeMsg) ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok { if !ok {
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)

View File

@ -7,9 +7,12 @@ package tls
import ( import (
"bytes" "bytes"
"crypto/rsa" "crypto/rsa"
"crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/pem"
"flag" "flag"
"io" "io"
"log"
"math/big" "math/big"
"net" "net"
"strconv" "strconv"
@ -109,16 +112,18 @@ func TestClose(t *testing.T) {
} }
} }
func testServerScript(t *testing.T, name string, serverScript [][]byte, config *Config) { func testServerScript(t *testing.T, name string, serverScript [][]byte, config *Config, peers []*x509.Certificate) {
c, s := net.Pipe() c, s := net.Pipe()
srv := Server(s, config) srv := Server(s, config)
pchan := make(chan []*x509.Certificate, 1)
go func() { go func() {
srv.Write([]byte("hello, world\n")) srv.Write([]byte("hello, world\n"))
srv.Close() srv.Close()
s.Close() s.Close()
st := srv.ConnectionState()
pchan <- st.PeerCertificates
}() }()
defer c.Close()
for i, b := range serverScript { for i, b := range serverScript {
if i%2 == 0 { if i%2 == 0 {
c.Write(b) c.Write(b)
@ -133,34 +138,66 @@ func testServerScript(t *testing.T, name string, serverScript [][]byte, config *
t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", name, i, bb, b) t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", name, i, bb, b)
} }
} }
c.Close()
if peers != nil {
gotpeers := <-pchan
if len(peers) == len(gotpeers) {
for i, _ := range peers {
if !peers[i].Equal(gotpeers[i]) {
t.Fatalf("%s: mismatch on peer cert %d", name, i)
}
}
} else {
t.Fatalf("%s: mismatch on peer list length: %d (wanted) != %d (got)", name, len(peers), len(gotpeers))
}
}
} }
func TestHandshakeServerRC4(t *testing.T) { func TestHandshakeServerRC4(t *testing.T) {
testServerScript(t, "RC4", rc4ServerScript, testConfig) testServerScript(t, "RC4", rc4ServerScript, testConfig, nil)
} }
func TestHandshakeServer3DES(t *testing.T) { func TestHandshakeServer3DES(t *testing.T) {
des3Config := new(Config) des3Config := new(Config)
*des3Config = *testConfig *des3Config = *testConfig
des3Config.CipherSuites = []uint16{TLS_RSA_WITH_3DES_EDE_CBC_SHA} des3Config.CipherSuites = []uint16{TLS_RSA_WITH_3DES_EDE_CBC_SHA}
testServerScript(t, "3DES", des3ServerScript, des3Config) testServerScript(t, "3DES", des3ServerScript, des3Config, nil)
} }
func TestHandshakeServerAES(t *testing.T) { func TestHandshakeServerAES(t *testing.T) {
aesConfig := new(Config) aesConfig := new(Config)
*aesConfig = *testConfig *aesConfig = *testConfig
aesConfig.CipherSuites = []uint16{TLS_RSA_WITH_AES_128_CBC_SHA} aesConfig.CipherSuites = []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}
testServerScript(t, "AES", aesServerScript, aesConfig) testServerScript(t, "AES", aesServerScript, aesConfig, nil)
} }
func TestHandshakeServerSSLv3(t *testing.T) { func TestHandshakeServerSSLv3(t *testing.T) {
testServerScript(t, "SSLv3", sslv3ServerScript, testConfig) testServerScript(t, "SSLv3", sslv3ServerScript, testConfig, nil)
}
type clientauthTest struct {
name string
clientauth ClientAuthType
peers []*x509.Certificate
script [][]byte
}
func TestClientAuth(t *testing.T) {
for _, cat := range clientauthTests {
t.Log("running", cat.name)
cfg := new(Config)
*cfg = *testConfig
cfg.ClientAuth = cat.clientauth
testServerScript(t, cat.name, cat.script, cfg, cat.peers)
}
} }
var serve = flag.Bool("serve", false, "run a TLS server on :10443") var serve = flag.Bool("serve", false, "run a TLS server on :10443")
var testCipherSuites = flag.String("ciphersuites", var testCipherSuites = flag.String("ciphersuites",
"0x"+strconv.FormatInt(int64(TLS_RSA_WITH_RC4_128_SHA), 16), "0x"+strconv.FormatInt(int64(TLS_RSA_WITH_RC4_128_SHA), 16),
"cipher suites to accept in serving mode") "cipher suites to accept in serving mode")
var testClientAuth = flag.Int("clientauth", 0, "value for tls.Config.ClientAuth")
func TestRunServer(t *testing.T) { func TestRunServer(t *testing.T) {
if !*serve { if !*serve {
@ -177,6 +214,8 @@ func TestRunServer(t *testing.T) {
testConfig.CipherSuites[i] = uint16(suite) testConfig.CipherSuites[i] = uint16(suite)
} }
testConfig.ClientAuth = ClientAuthType(*testClientAuth)
l, err := Listen("tcp", ":10443", testConfig) l, err := Listen("tcp", ":10443", testConfig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -185,13 +224,23 @@ func TestRunServer(t *testing.T) {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
log.Printf("error from TLS handshake: %s", err)
break break
} }
_, err = c.Write([]byte("hello, world\n")) _, err = c.Write([]byte("hello, world\n"))
if err != nil { if err != nil {
t.Errorf("error from TLS: %s", err) log.Printf("error from TLS: %s", err)
break continue
} }
st := c.(*Conn).ConnectionState()
if len(st.PeerCertificates) > 0 {
log.Print("Handling request from client ", st.PeerCertificates[0].Subject.CommonName)
} else {
log.Print("Handling request from anon client")
}
c.Close() c.Close()
} }
} }
@ -221,6 +270,18 @@ var testPrivateKey = &rsa.PrivateKey{
}, },
} }
func loadPEMCert(in string) *x509.Certificate {
block, _ := pem.Decode([]byte(in))
if block.Type == "CERTIFICATE" && len(block.Headers) == 0 {
cert, err := x509.ParseCertificate(block.Bytes)
if err == nil {
return cert
}
panic("error parsing cert")
}
panic("error parsing PEM")
}
// Script of interaction with gnutls implementation. // Script of interaction with gnutls implementation.
// The values for this test are obtained by building and running in server mode: // The values for this test are obtained by building and running in server mode:
// % gotest -test.run "TestRunServer" -serve // % gotest -test.run "TestRunServer" -serve
@ -229,23 +290,22 @@ var testPrivateKey = &rsa.PrivateKey{
// % python parse-gnutls-cli-debug-log.py < /tmp/log // % python parse-gnutls-cli-debug-log.py < /tmp/log
var rc4ServerScript = [][]byte{ var rc4ServerScript = [][]byte{
{ {
0x16, 0x03, 0x02, 0x00, 0x7f, 0x01, 0x00, 0x00, 0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
0x7b, 0x03, 0x02, 0x4d, 0x08, 0x1f, 0x5a, 0x7a, 0x76, 0x03, 0x02, 0x4e, 0xdd, 0xe6, 0xa5, 0xf7,
0x0a, 0x92, 0x2f, 0xf0, 0x73, 0x16, 0x3a, 0x88, 0x00, 0x36, 0xf7, 0x83, 0xec, 0x93, 0x7c, 0xd2,
0x14, 0x85, 0x4c, 0x98, 0x15, 0x7b, 0x65, 0xe0, 0x4d, 0xe7, 0x7b, 0xf5, 0x4c, 0xf7, 0xe3, 0x86,
0x78, 0xd0, 0xed, 0xd0, 0xf3, 0x65, 0x20, 0xeb, 0xe8, 0xec, 0x3b, 0xbd, 0x2c, 0x9a, 0x3f, 0x57,
0x80, 0xd1, 0x0b, 0x00, 0x00, 0x34, 0x00, 0x33, 0xf0, 0xa4, 0xd4, 0x00, 0x00, 0x34, 0x00, 0x33,
0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16, 0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87, 0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91, 0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41, 0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05, 0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b, 0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
0x00, 0x8a, 0x01, 0x00, 0x00, 0x1e, 0x00, 0x09, 0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f, 0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, 0xff, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
0x01, 0x00, 0x01, 0x00,
}, },
{ {
@ -349,38 +409,46 @@ var rc4ServerScript = [][]byte{
{ {
0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00, 0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00,
0x82, 0x00, 0x80, 0x3c, 0x13, 0xd7, 0x12, 0xc1, 0x82, 0x00, 0x80, 0x39, 0xe2, 0x0f, 0x49, 0xa0,
0x6a, 0xf0, 0x3f, 0x8c, 0xa1, 0x35, 0x5d, 0xc5, 0xe6, 0xe4, 0x3b, 0x0c, 0x5f, 0xce, 0x39, 0x97,
0x89, 0x1e, 0x9e, 0xcd, 0x32, 0xc7, 0x9e, 0xe6, 0x6c, 0xb6, 0x41, 0xd9, 0xe1, 0x52, 0x8f, 0x43,
0xae, 0xd5, 0xf1, 0xbf, 0x70, 0xd7, 0xa9, 0xef, 0xb3, 0xc6, 0x4f, 0x9a, 0xe2, 0x1e, 0xb9, 0x3b,
0x2c, 0x4c, 0xf4, 0x22, 0xbc, 0x17, 0x17, 0xaa, 0xe3, 0x72, 0x17, 0x68, 0xb2, 0x0d, 0x7b, 0x71,
0x05, 0xf3, 0x9f, 0x80, 0xf2, 0xe9, 0x82, 0x2f, 0x33, 0x96, 0x5c, 0xf9, 0xfe, 0x18, 0x8f, 0x2f,
0x2a, 0x15, 0x54, 0x0d, 0x16, 0x0e, 0x77, 0x4c, 0x2b, 0x82, 0xec, 0x03, 0xf2, 0x16, 0xa8, 0xf8,
0x28, 0x3c, 0x03, 0x2d, 0x2d, 0xd7, 0xc8, 0x64, 0x39, 0xf9, 0xbb, 0x5a, 0xd3, 0x0c, 0xc1, 0x2a,
0xd9, 0x59, 0x4b, 0x1c, 0xf4, 0xde, 0xff, 0x2f, 0x52, 0xa1, 0x90, 0x20, 0x6b, 0x24, 0xc9, 0x55,
0xbc, 0x94, 0xaf, 0x18, 0x26, 0x37, 0xce, 0x4f, 0xee, 0x05, 0xd8, 0xb3, 0x43, 0x58, 0xf6, 0x7f,
0x84, 0x74, 0x2e, 0x45, 0x66, 0x7c, 0x0c, 0x54, 0x68, 0x2d, 0xb3, 0xd1, 0x1b, 0x30, 0xaa, 0xdf,
0x46, 0x36, 0x5f, 0x65, 0x21, 0x7b, 0x83, 0x8c, 0xfc, 0x85, 0xf1, 0xab, 0x14, 0x51, 0x91, 0x78,
0x6d, 0x76, 0xcd, 0x0d, 0x9f, 0xda, 0x1c, 0xa4, 0x29, 0x35, 0x65, 0xe0, 0x9c, 0xf6, 0xb7, 0x35,
0x6e, 0xfe, 0xb1, 0xf7, 0x09, 0x0d, 0xfb, 0x74, 0x33, 0xdb, 0x28, 0x93, 0x4d, 0x86, 0xbc, 0xfe,
0x66, 0x34, 0x99, 0x89, 0x7f, 0x5f, 0x77, 0x87, 0xaa, 0xd1, 0xc0, 0x2e, 0x4d, 0xec, 0xa2, 0x98,
0x4a, 0x66, 0x4b, 0xa9, 0x59, 0x57, 0xe3, 0x56, 0xca, 0x08, 0xb2, 0x91, 0x14, 0xde, 0x97, 0x3a,
0x0d, 0xdd, 0xd8, 0x14, 0x03, 0x01, 0x00, 0x01, 0xc4, 0x6b, 0x49, 0x14, 0x03, 0x01, 0x00, 0x01,
0x01, 0x16, 0x03, 0x01, 0x00, 0x24, 0xc0, 0x4e, 0x01, 0x16, 0x03, 0x01, 0x00, 0x24, 0x7a, 0xcb,
0xd3, 0x0f, 0xb5, 0xc0, 0x57, 0xa6, 0x18, 0x80, 0x3b, 0x0e, 0xbb, 0x7a, 0x56, 0x39, 0xaf, 0x83,
0x80, 0x6b, 0x49, 0xfe, 0xbd, 0x3a, 0x7a, 0x2c, 0xae, 0xfd, 0x25, 0xfd, 0x64, 0xb4, 0x0c, 0x0c,
0xef, 0x70, 0xb5, 0x1c, 0xd2, 0xdf, 0x5f, 0x78, 0x17, 0x46, 0x54, 0x2c, 0x6a, 0x07, 0x83, 0xc6,
0x5a, 0xd8, 0x4f, 0xa0, 0x95, 0xb4, 0xb3, 0xb5, 0x46, 0x08, 0x0b, 0xcd, 0x15, 0x53, 0xef, 0x40,
0xaa, 0x3b, 0x4e, 0x56,
}, },
{ {
0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03, 0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
0x01, 0x00, 0x24, 0x9d, 0xc9, 0xda, 0xdf, 0xeb, 0x01, 0x00, 0x24, 0xd3, 0x72, 0xeb, 0x29, 0xb9,
0xc8, 0xdb, 0xf8, 0x94, 0xa5, 0xef, 0xd5, 0xfc, 0x15, 0x29, 0xb5, 0xe5, 0xb7, 0xef, 0x5c, 0xb2,
0x89, 0x01, 0x64, 0x30, 0x77, 0x5a, 0x18, 0x4b, 0x9d, 0xf6, 0xc8, 0x47, 0xd6, 0xa0, 0x84, 0xf0,
0x16, 0x79, 0x9c, 0xf6, 0xf5, 0x09, 0x22, 0x12, 0x8c, 0xcb, 0xe6, 0xbe, 0xbc, 0xfb, 0x38, 0x90,
0x4c, 0x3e, 0xa8, 0x8e, 0x91, 0xa5, 0x24, 0x89, 0x60, 0xa2, 0xe8, 0xaa, 0xb3, 0x12, 0x17,
0x03, 0x01, 0x00, 0x21, 0x67, 0x4a, 0x3d, 0x31,
0x6c, 0x5a, 0x1c, 0xf9, 0x6e, 0xf1, 0xd8, 0x12,
0x0e, 0xb9, 0xfd, 0xfc, 0x66, 0x91, 0xd1, 0x1d,
0x6e, 0xe4, 0x55, 0xdd, 0x11, 0xb9, 0xb8, 0xa2,
0x65, 0xa1, 0x95, 0x64, 0x1c, 0x15, 0x03, 0x01,
0x00, 0x16, 0x9b, 0xa0, 0x24, 0xe3, 0xcb, 0xae,
0xad, 0x51, 0xb3, 0x63, 0x59, 0x78, 0x49, 0x24,
0x06, 0x6e, 0xee, 0x7a, 0xd7, 0x74, 0x53, 0x04,
}, },
} }
@ -878,3 +946,625 @@ var sslv3ServerScript = [][]byte{
0xaf, 0xd3, 0xb7, 0xa3, 0xcc, 0x4a, 0x1d, 0x2e, 0xaf, 0xd3, 0xb7, 0xa3, 0xcc, 0x4a, 0x1d, 0x2e,
}, },
} }
var clientauthTests = []clientauthTest{
// Server doesn't asks for cert
// gotest -test.run "TestRunServer" -serve -clientauth 0
// gnutls-cli --insecure --debug 100 -p 10443 localhost 2>&1 |
// python parse-gnutls-cli-debug-log.py
{"NoClientCert", NoClientCert, nil,
[][]byte{{
0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
0x76, 0x03, 0x02, 0x4e, 0xe0, 0x92, 0x5d, 0xcd,
0xfe, 0x0c, 0x69, 0xd4, 0x7d, 0x8e, 0xa6, 0x88,
0xde, 0x72, 0x04, 0x29, 0x6a, 0x4a, 0x16, 0x23,
0xd7, 0x8f, 0xbc, 0xfa, 0x80, 0x73, 0x2e, 0x12,
0xb7, 0x0b, 0x39, 0x00, 0x00, 0x34, 0x00, 0x33,
0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
},
{
0x16, 0x03, 0x01, 0x00, 0x2a, 0x02, 0x00, 0x00,
0x26, 0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x16,
0x03, 0x01, 0x02, 0xbe, 0x0b, 0x00, 0x02, 0xba,
0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82,
0x02, 0xb0, 0x30, 0x82, 0x02, 0x19, 0xa0, 0x03,
0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0,
0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x45, 0x31,
0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06,
0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53,
0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74,
0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55,
0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65,
0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64,
0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79,
0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d,
0x31, 0x30, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39,
0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31,
0x31, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39, 0x30,
0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b,
0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06,
0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f,
0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65,
0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04,
0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72,
0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20,
0x4c, 0x74, 0x64, 0x30, 0x81, 0x9f, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x81, 0x8d,
0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00,
0xbb, 0x79, 0xd6, 0xf5, 0x17, 0xb5, 0xe5, 0xbf,
0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b,
0x07, 0x43, 0x5a, 0xd0, 0x03, 0x2d, 0x8a, 0x7a,
0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65,
0x4c, 0x2c, 0x78, 0xb8, 0x23, 0x8c, 0xb5, 0xb4,
0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62,
0xa5, 0x2c, 0xa5, 0x33, 0xd6, 0xfe, 0x12, 0x5c,
0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58,
0x7b, 0x26, 0x3f, 0xb5, 0xcd, 0x04, 0xd3, 0xd0,
0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f,
0x5a, 0xbf, 0xef, 0x42, 0x71, 0x00, 0xfe, 0x18,
0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1,
0x04, 0x39, 0xc4, 0xa2, 0x2e, 0xdb, 0x51, 0xc9,
0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01,
0xcf, 0xaf, 0xb1, 0x1d, 0xb8, 0x71, 0x9a, 0x1d,
0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79,
0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x81, 0xa7,
0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0xad,
0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69,
0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e, 0x18,
0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d,
0x23, 0x04, 0x6e, 0x30, 0x6c, 0x80, 0x14, 0xb1,
0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb,
0x69, 0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e,
0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30,
0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13,
0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13,
0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e,
0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50,
0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x82, 0x09,
0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8,
0xca, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30,
0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81,
0x81, 0x00, 0x08, 0x6c, 0x45, 0x24, 0xc7, 0x6b,
0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0,
0x14, 0xd7, 0x87, 0x9d, 0x7a, 0x64, 0x75, 0xb5,
0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae,
0x12, 0x66, 0x1f, 0xeb, 0x4f, 0x38, 0xb3, 0x6e,
0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5,
0x25, 0x13, 0xb1, 0x18, 0x7a, 0x24, 0xfb, 0x30,
0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7,
0xd7, 0x31, 0x59, 0xdb, 0x95, 0xd3, 0x1d, 0x78,
0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d,
0x5a, 0x5f, 0x33, 0xc4, 0xb6, 0xd8, 0xc9, 0x75,
0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd,
0x98, 0x1f, 0x89, 0x20, 0x5f, 0xf2, 0xa0, 0x1c,
0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57,
0xe9, 0x70, 0xe8, 0x26, 0x6d, 0x71, 0x99, 0x9b,
0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7,
0xbd, 0xd9, 0x16, 0x03, 0x01, 0x00, 0x04, 0x0e,
0x00, 0x00, 0x00,
},
{
0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00,
0x82, 0x00, 0x80, 0x10, 0xe1, 0x00, 0x3d, 0x0a,
0x6b, 0x02, 0x7f, 0x97, 0xde, 0xfb, 0x65, 0x46,
0x1a, 0x50, 0x4e, 0x34, 0x9a, 0xae, 0x14, 0x7e,
0xec, 0xef, 0x85, 0x15, 0x3b, 0x39, 0xc2, 0x45,
0x04, 0x40, 0x92, 0x71, 0xd6, 0x7e, 0xf6, 0xfd,
0x4d, 0x84, 0xf7, 0xc4, 0x77, 0x99, 0x3d, 0xe2,
0xc3, 0x8d, 0xb0, 0x4c, 0x74, 0xc8, 0x51, 0xec,
0xb2, 0xe8, 0x6b, 0xa1, 0xd2, 0x4d, 0xd8, 0x61,
0x92, 0x7a, 0x24, 0x57, 0x44, 0x4f, 0xa2, 0x1e,
0x74, 0x0b, 0x06, 0x4b, 0x80, 0x34, 0x8b, 0xfe,
0xc2, 0x0e, 0xc1, 0xcd, 0xab, 0x0c, 0x3f, 0x54,
0xe2, 0x44, 0xe9, 0x6c, 0x2b, 0xba, 0x7b, 0x64,
0xf1, 0x93, 0x65, 0x75, 0xf2, 0x35, 0xff, 0x27,
0x03, 0xd5, 0x64, 0xe6, 0x8e, 0xe7, 0x7b, 0x56,
0xb6, 0x61, 0x73, 0xeb, 0xa2, 0xdc, 0xa4, 0x6e,
0x52, 0xac, 0xbc, 0xba, 0x11, 0xa3, 0xd2, 0x61,
0x4a, 0xe0, 0xbb, 0x14, 0x03, 0x01, 0x00, 0x01,
0x01, 0x16, 0x03, 0x01, 0x00, 0x24, 0xd2, 0x5a,
0x0c, 0x2a, 0x27, 0x96, 0xba, 0xa9, 0x67, 0xd2,
0x51, 0x68, 0x32, 0x68, 0x22, 0x1f, 0xb9, 0x27,
0x79, 0x59, 0x28, 0xdf, 0x38, 0x1f, 0x92, 0x21,
0x5d, 0x0f, 0xf4, 0xc0, 0xee, 0xb7, 0x10, 0x5a,
0xa9, 0x45,
},
{
0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
0x01, 0x00, 0x24, 0x13, 0x6f, 0x6c, 0x71, 0x83,
0x59, 0xcf, 0x32, 0x72, 0xe9, 0xce, 0xcc, 0x7a,
0x6c, 0xf0, 0x72, 0x39, 0x16, 0xae, 0x40, 0x61,
0xfa, 0x92, 0x4c, 0xe7, 0xf2, 0x1a, 0xd7, 0x0c,
0x84, 0x76, 0x6c, 0xe9, 0x11, 0x43, 0x19, 0x17,
0x03, 0x01, 0x00, 0x21, 0xc0, 0xa2, 0x13, 0x28,
0x94, 0x8c, 0x5c, 0xd6, 0x79, 0xb9, 0xfe, 0xae,
0x45, 0x4b, 0xc0, 0x7c, 0xae, 0x2d, 0xb4, 0x0d,
0x31, 0xc4, 0xad, 0x22, 0xd7, 0x1e, 0x99, 0x1c,
0x4c, 0x69, 0xab, 0x42, 0x61, 0x15, 0x03, 0x01,
0x00, 0x16, 0xe1, 0x0c, 0x67, 0xf3, 0xf4, 0xb9,
0x8e, 0x81, 0x8e, 0x01, 0xb8, 0xa0, 0x69, 0x8c,
0x03, 0x11, 0x43, 0x3e, 0xee, 0xb7, 0x4d, 0x69,
}}},
// Server asks for cert with empty CA list, client doesn't give it.
// gotest -test.run "TestRunServer" -serve -clientauth 1
// gnutls-cli --insecure --debug 100 -p 10443 localhost
{"RequestClientCert, none given", RequestClientCert, nil,
[][]byte{{
0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
0x76, 0x03, 0x02, 0x4e, 0xe0, 0x93, 0xe2, 0x47,
0x06, 0xa0, 0x61, 0x0c, 0x51, 0xdd, 0xf0, 0xef,
0xf4, 0x30, 0x72, 0xe1, 0xa6, 0x50, 0x68, 0x82,
0x3c, 0xfb, 0xcb, 0x72, 0x5e, 0x73, 0x9d, 0xda,
0x27, 0x35, 0x72, 0x00, 0x00, 0x34, 0x00, 0x33,
0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
},
{
0x16, 0x03, 0x01, 0x00, 0x2a, 0x02, 0x00, 0x00,
0x26, 0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x16,
0x03, 0x01, 0x02, 0xbe, 0x0b, 0x00, 0x02, 0xba,
0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82,
0x02, 0xb0, 0x30, 0x82, 0x02, 0x19, 0xa0, 0x03,
0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0,
0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x45, 0x31,
0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06,
0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53,
0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74,
0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55,
0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65,
0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64,
0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79,
0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d,
0x31, 0x30, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39,
0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31,
0x31, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39, 0x30,
0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b,
0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06,
0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f,
0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65,
0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04,
0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72,
0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20,
0x4c, 0x74, 0x64, 0x30, 0x81, 0x9f, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x81, 0x8d,
0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00,
0xbb, 0x79, 0xd6, 0xf5, 0x17, 0xb5, 0xe5, 0xbf,
0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b,
0x07, 0x43, 0x5a, 0xd0, 0x03, 0x2d, 0x8a, 0x7a,
0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65,
0x4c, 0x2c, 0x78, 0xb8, 0x23, 0x8c, 0xb5, 0xb4,
0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62,
0xa5, 0x2c, 0xa5, 0x33, 0xd6, 0xfe, 0x12, 0x5c,
0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58,
0x7b, 0x26, 0x3f, 0xb5, 0xcd, 0x04, 0xd3, 0xd0,
0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f,
0x5a, 0xbf, 0xef, 0x42, 0x71, 0x00, 0xfe, 0x18,
0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1,
0x04, 0x39, 0xc4, 0xa2, 0x2e, 0xdb, 0x51, 0xc9,
0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01,
0xcf, 0xaf, 0xb1, 0x1d, 0xb8, 0x71, 0x9a, 0x1d,
0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79,
0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x81, 0xa7,
0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0xad,
0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69,
0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e, 0x18,
0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d,
0x23, 0x04, 0x6e, 0x30, 0x6c, 0x80, 0x14, 0xb1,
0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb,
0x69, 0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e,
0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30,
0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13,
0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13,
0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e,
0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50,
0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x82, 0x09,
0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8,
0xca, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30,
0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81,
0x81, 0x00, 0x08, 0x6c, 0x45, 0x24, 0xc7, 0x6b,
0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0,
0x14, 0xd7, 0x87, 0x9d, 0x7a, 0x64, 0x75, 0xb5,
0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae,
0x12, 0x66, 0x1f, 0xeb, 0x4f, 0x38, 0xb3, 0x6e,
0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5,
0x25, 0x13, 0xb1, 0x18, 0x7a, 0x24, 0xfb, 0x30,
0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7,
0xd7, 0x31, 0x59, 0xdb, 0x95, 0xd3, 0x1d, 0x78,
0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d,
0x5a, 0x5f, 0x33, 0xc4, 0xb6, 0xd8, 0xc9, 0x75,
0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd,
0x98, 0x1f, 0x89, 0x20, 0x5f, 0xf2, 0xa0, 0x1c,
0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57,
0xe9, 0x70, 0xe8, 0x26, 0x6d, 0x71, 0x99, 0x9b,
0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7,
0xbd, 0xd9, 0x16, 0x03, 0x01, 0x00, 0x08, 0x0d,
0x00, 0x00, 0x04, 0x01, 0x01, 0x00, 0x00, 0x16,
0x03, 0x01, 0x00, 0x04, 0x0e, 0x00, 0x00, 0x00,
},
{
0x16, 0x03, 0x01, 0x00, 0x07, 0x0b, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x16, 0x03, 0x01, 0x00,
0x86, 0x10, 0x00, 0x00, 0x82, 0x00, 0x80, 0x64,
0x28, 0xb9, 0x3f, 0x48, 0xaf, 0x06, 0x22, 0x39,
0x56, 0xd8, 0x6f, 0x63, 0x5d, 0x03, 0x48, 0x63,
0x01, 0x13, 0xa2, 0xd6, 0x76, 0xc0, 0xab, 0xda,
0x25, 0x30, 0x75, 0x6c, 0xaa, 0xb4, 0xdc, 0x35,
0x72, 0xdc, 0xf2, 0x43, 0xe4, 0x1d, 0x82, 0xfb,
0x6c, 0x64, 0xe2, 0xa7, 0x8f, 0x32, 0x67, 0x6b,
0xcd, 0xd2, 0xb2, 0x36, 0x94, 0xbc, 0x6f, 0x46,
0x79, 0x29, 0x42, 0xe3, 0x1a, 0xbf, 0xfb, 0x41,
0xd5, 0xe3, 0xb4, 0x2a, 0xf6, 0x95, 0x6f, 0x0c,
0x87, 0xb9, 0x03, 0x18, 0xa1, 0xea, 0x4a, 0xe2,
0x2e, 0x0f, 0x50, 0x00, 0xc1, 0xe8, 0x8c, 0xc8,
0xa2, 0xf6, 0xa4, 0x05, 0xf4, 0x38, 0x3e, 0xd9,
0x6e, 0x63, 0x96, 0x0c, 0x34, 0x73, 0x90, 0x03,
0x55, 0xa6, 0x34, 0xb0, 0x5e, 0x8c, 0x48, 0x40,
0x25, 0x45, 0x84, 0xa6, 0x21, 0x3f, 0x81, 0x97,
0xa7, 0x11, 0x09, 0x14, 0x95, 0xa5, 0xe5, 0x14,
0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03, 0x01,
0x00, 0x24, 0x16, 0xaa, 0x01, 0x2c, 0xa8, 0xc1,
0x28, 0xaf, 0x35, 0xc1, 0xc1, 0xf3, 0x0a, 0x25,
0x66, 0x6e, 0x27, 0x11, 0xa3, 0xa4, 0xd9, 0xe9,
0xea, 0x15, 0x09, 0x9d, 0x28, 0xe3, 0x5b, 0x2b,
0xa6, 0x25, 0xa7, 0x14, 0x24, 0x3a,
},
{
0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
0x01, 0x00, 0x24, 0x9a, 0xa8, 0xd6, 0x77, 0x46,
0x45, 0x68, 0x9d, 0x5d, 0xa9, 0x68, 0x03, 0xe5,
0xaf, 0xe8, 0xc8, 0x21, 0xc5, 0xc6, 0xc1, 0x50,
0xe0, 0xd8, 0x52, 0xce, 0xa3, 0x4f, 0x2d, 0xf4,
0xe3, 0xa7, 0x7d, 0x35, 0x80, 0x84, 0x12, 0x17,
0x03, 0x01, 0x00, 0x21, 0x8a, 0x82, 0x0c, 0x54,
0x1b, 0xeb, 0x77, 0x90, 0x2c, 0x3e, 0xbc, 0xf0,
0x23, 0xcc, 0xa8, 0x9f, 0x25, 0x08, 0x12, 0xed,
0x43, 0xf1, 0xf9, 0x06, 0xad, 0xa9, 0x4b, 0x97,
0x82, 0xb7, 0xc4, 0x0b, 0x4c, 0x15, 0x03, 0x01,
0x00, 0x16, 0x05, 0x2d, 0x9d, 0x45, 0x03, 0xb7,
0xc2, 0xd1, 0xb5, 0x1a, 0x43, 0xcf, 0x1a, 0x37,
0xf4, 0x70, 0xcc, 0xb4, 0xed, 0x07, 0x76, 0x3a,
}}},
// Server asks for cert with empty CA list, client gives one
// gotest -test.run "TestRunServer" -serve -clientauth 1
// gnutls-cli --insecure --debug 100 -p 10443 localhost
{"RequestClientCert, client gives it", RequestClientCert,
[]*x509.Certificate{clicert},
[][]byte{{
0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
0x76, 0x03, 0x02, 0x4e, 0xe7, 0x44, 0xda, 0x58,
0x7d, 0x46, 0x4a, 0x48, 0x97, 0x9f, 0xe5, 0x91,
0x11, 0x64, 0xa7, 0x1e, 0x4d, 0xb7, 0xfe, 0x9b,
0xc6, 0x63, 0xf8, 0xa4, 0xb5, 0x0b, 0x18, 0xb5,
0xbd, 0x19, 0xb3, 0x00, 0x00, 0x34, 0x00, 0x33,
0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
},
{
0x16, 0x03, 0x01, 0x00, 0x2a, 0x02, 0x00, 0x00,
0x26, 0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x16,
0x03, 0x01, 0x02, 0xbe, 0x0b, 0x00, 0x02, 0xba,
0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82,
0x02, 0xb0, 0x30, 0x82, 0x02, 0x19, 0xa0, 0x03,
0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0,
0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x45, 0x31,
0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06,
0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53,
0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74,
0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55,
0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65,
0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64,
0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79,
0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d,
0x31, 0x30, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39,
0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31,
0x31, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39, 0x30,
0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b,
0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06,
0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f,
0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65,
0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04,
0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72,
0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20,
0x4c, 0x74, 0x64, 0x30, 0x81, 0x9f, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x81, 0x8d,
0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00,
0xbb, 0x79, 0xd6, 0xf5, 0x17, 0xb5, 0xe5, 0xbf,
0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b,
0x07, 0x43, 0x5a, 0xd0, 0x03, 0x2d, 0x8a, 0x7a,
0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65,
0x4c, 0x2c, 0x78, 0xb8, 0x23, 0x8c, 0xb5, 0xb4,
0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62,
0xa5, 0x2c, 0xa5, 0x33, 0xd6, 0xfe, 0x12, 0x5c,
0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58,
0x7b, 0x26, 0x3f, 0xb5, 0xcd, 0x04, 0xd3, 0xd0,
0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f,
0x5a, 0xbf, 0xef, 0x42, 0x71, 0x00, 0xfe, 0x18,
0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1,
0x04, 0x39, 0xc4, 0xa2, 0x2e, 0xdb, 0x51, 0xc9,
0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01,
0xcf, 0xaf, 0xb1, 0x1d, 0xb8, 0x71, 0x9a, 0x1d,
0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79,
0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x81, 0xa7,
0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0xad,
0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69,
0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e, 0x18,
0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d,
0x23, 0x04, 0x6e, 0x30, 0x6c, 0x80, 0x14, 0xb1,
0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb,
0x69, 0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e,
0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30,
0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13,
0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13,
0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e,
0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50,
0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x82, 0x09,
0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8,
0xca, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30,
0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81,
0x81, 0x00, 0x08, 0x6c, 0x45, 0x24, 0xc7, 0x6b,
0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0,
0x14, 0xd7, 0x87, 0x9d, 0x7a, 0x64, 0x75, 0xb5,
0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae,
0x12, 0x66, 0x1f, 0xeb, 0x4f, 0x38, 0xb3, 0x6e,
0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5,
0x25, 0x13, 0xb1, 0x18, 0x7a, 0x24, 0xfb, 0x30,
0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7,
0xd7, 0x31, 0x59, 0xdb, 0x95, 0xd3, 0x1d, 0x78,
0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d,
0x5a, 0x5f, 0x33, 0xc4, 0xb6, 0xd8, 0xc9, 0x75,
0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd,
0x98, 0x1f, 0x89, 0x20, 0x5f, 0xf2, 0xa0, 0x1c,
0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57,
0xe9, 0x70, 0xe8, 0x26, 0x6d, 0x71, 0x99, 0x9b,
0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7,
0xbd, 0xd9, 0x16, 0x03, 0x01, 0x00, 0x08, 0x0d,
0x00, 0x00, 0x04, 0x01, 0x01, 0x00, 0x00, 0x16,
0x03, 0x01, 0x00, 0x04, 0x0e, 0x00, 0x00, 0x00,
},
{
0x16, 0x03, 0x01, 0x01, 0xfb, 0x0b, 0x00, 0x01,
0xf7, 0x00, 0x01, 0xf4, 0x00, 0x01, 0xf1, 0x30,
0x82, 0x01, 0xed, 0x30, 0x82, 0x01, 0x58, 0xa0,
0x03, 0x02, 0x01, 0x02, 0x02, 0x01, 0x00, 0x30,
0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
0x0d, 0x01, 0x01, 0x05, 0x30, 0x26, 0x31, 0x10,
0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13,
0x07, 0x41, 0x63, 0x6d, 0x65, 0x20, 0x43, 0x6f,
0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04,
0x03, 0x13, 0x09, 0x31, 0x32, 0x37, 0x2e, 0x30,
0x2e, 0x30, 0x2e, 0x31, 0x30, 0x1e, 0x17, 0x0d,
0x31, 0x31, 0x31, 0x32, 0x30, 0x38, 0x30, 0x37,
0x35, 0x35, 0x31, 0x32, 0x5a, 0x17, 0x0d, 0x31,
0x32, 0x31, 0x32, 0x30, 0x37, 0x30, 0x38, 0x30,
0x30, 0x31, 0x32, 0x5a, 0x30, 0x26, 0x31, 0x10,
0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13,
0x07, 0x41, 0x63, 0x6d, 0x65, 0x20, 0x43, 0x6f,
0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04,
0x03, 0x13, 0x09, 0x31, 0x32, 0x37, 0x2e, 0x30,
0x2e, 0x30, 0x2e, 0x31, 0x30, 0x81, 0x9c, 0x30,
0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
0x0d, 0x01, 0x01, 0x01, 0x03, 0x81, 0x8c, 0x00,
0x30, 0x81, 0x88, 0x02, 0x81, 0x80, 0x4e, 0xd0,
0x7b, 0x31, 0xe3, 0x82, 0x64, 0xd9, 0x59, 0xc0,
0xc2, 0x87, 0xa4, 0x5e, 0x1e, 0x8b, 0x73, 0x33,
0xc7, 0x63, 0x53, 0xdf, 0x66, 0x92, 0x06, 0x84,
0xf6, 0x64, 0xd5, 0x8f, 0xe4, 0x36, 0xa7, 0x1d,
0x2b, 0xe8, 0xb3, 0x20, 0x36, 0x45, 0x23, 0xb5,
0xe3, 0x95, 0xae, 0xed, 0xe0, 0xf5, 0x20, 0x9c,
0x8d, 0x95, 0xdf, 0x7f, 0x5a, 0x12, 0xef, 0x87,
0xe4, 0x5b, 0x68, 0xe4, 0xe9, 0x0e, 0x74, 0xec,
0x04, 0x8a, 0x7f, 0xde, 0x93, 0x27, 0xc4, 0x01,
0x19, 0x7a, 0xbd, 0xf2, 0xdc, 0x3d, 0x14, 0xab,
0xd0, 0x54, 0xca, 0x21, 0x0c, 0xd0, 0x4d, 0x6e,
0x87, 0x2e, 0x5c, 0xc5, 0xd2, 0xbb, 0x4d, 0x4b,
0x4f, 0xce, 0xb6, 0x2c, 0xf7, 0x7e, 0x88, 0xec,
0x7c, 0xd7, 0x02, 0x91, 0x74, 0xa6, 0x1e, 0x0c,
0x1a, 0xda, 0xe3, 0x4a, 0x5a, 0x2e, 0xde, 0x13,
0x9c, 0x4c, 0x40, 0x88, 0x59, 0x93, 0x02, 0x03,
0x01, 0x00, 0x01, 0xa3, 0x32, 0x30, 0x30, 0x30,
0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01,
0xff, 0x04, 0x04, 0x03, 0x02, 0x00, 0xa0, 0x30,
0x0d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x06,
0x04, 0x04, 0x01, 0x02, 0x03, 0x04, 0x30, 0x0f,
0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x08, 0x30,
0x06, 0x80, 0x04, 0x01, 0x02, 0x03, 0x04, 0x30,
0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
0x0d, 0x01, 0x01, 0x05, 0x03, 0x81, 0x81, 0x00,
0x36, 0x1f, 0xb3, 0x7a, 0x0c, 0x75, 0xc9, 0x6e,
0x37, 0x46, 0x61, 0x2b, 0xd5, 0xbd, 0xc0, 0xa7,
0x4b, 0xcc, 0x46, 0x9a, 0x81, 0x58, 0x7c, 0x85,
0x79, 0x29, 0xc8, 0xc8, 0xc6, 0x67, 0xdd, 0x32,
0x56, 0x45, 0x2b, 0x75, 0xb6, 0xe9, 0x24, 0xa9,
0x50, 0x9a, 0xbe, 0x1f, 0x5a, 0xfa, 0x1a, 0x15,
0xd9, 0xcc, 0x55, 0x95, 0x72, 0x16, 0x83, 0xb9,
0xc2, 0xb6, 0x8f, 0xfd, 0x88, 0x8c, 0x38, 0x84,
0x1d, 0xab, 0x5d, 0x92, 0x31, 0x13, 0x4f, 0xfd,
0x83, 0x3b, 0xc6, 0x9d, 0xf1, 0x11, 0x62, 0xb6,
0x8b, 0xec, 0xab, 0x67, 0xbe, 0xc8, 0x64, 0xb0,
0x11, 0x50, 0x46, 0x58, 0x17, 0x6b, 0x99, 0x1c,
0xd3, 0x1d, 0xfc, 0x06, 0xf1, 0x0e, 0xe5, 0x96,
0xa8, 0x0c, 0xf9, 0x78, 0x20, 0xb7, 0x44, 0x18,
0x51, 0x8d, 0x10, 0x7e, 0x4f, 0x94, 0x67, 0xdf,
0xa3, 0x4e, 0x70, 0x73, 0x8e, 0x90, 0x91, 0x85,
0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00,
0x82, 0x00, 0x80, 0xa7, 0x2f, 0xed, 0xfa, 0xc2,
0xbd, 0x46, 0xa1, 0xf2, 0x69, 0xc5, 0x1d, 0xa1,
0x34, 0xd6, 0xd0, 0x84, 0xf5, 0x5d, 0x8c, 0x82,
0x8d, 0x98, 0x82, 0x9c, 0xd9, 0x07, 0xe0, 0xf7,
0x55, 0x49, 0x4d, 0xa1, 0x48, 0x59, 0x02, 0xd3,
0x84, 0x37, 0xaf, 0x01, 0xb3, 0x3a, 0xf4, 0xed,
0x99, 0xbe, 0x67, 0x36, 0x19, 0x55, 0xf3, 0xf9,
0xcb, 0x94, 0xe5, 0x7b, 0x8b, 0x77, 0xf2, 0x5f,
0x4c, 0xfe, 0x01, 0x1f, 0x7b, 0xd7, 0x23, 0x49,
0x0c, 0xcb, 0x6c, 0xb0, 0xe7, 0x77, 0xd6, 0xcf,
0xa8, 0x7d, 0xdb, 0xa7, 0x14, 0xe2, 0xf5, 0xf3,
0xff, 0xba, 0x23, 0xd2, 0x9a, 0x36, 0x14, 0x60,
0x2a, 0x91, 0x5d, 0x2b, 0x35, 0x3b, 0xb6, 0xdd,
0xcb, 0x6b, 0xdc, 0x18, 0xdc, 0x33, 0xb8, 0xb3,
0xc7, 0x27, 0x7e, 0xfc, 0xd2, 0xf7, 0x97, 0x90,
0x5e, 0x17, 0xac, 0x14, 0x8e, 0x0f, 0xca, 0xb5,
0x6f, 0xc9, 0x2d, 0x16, 0x03, 0x01, 0x00, 0x86,
0x0f, 0x00, 0x00, 0x82, 0x00, 0x80, 0x44, 0x7f,
0xa2, 0x59, 0x60, 0x0b, 0x5a, 0xc4, 0xaf, 0x1e,
0x60, 0xa5, 0x24, 0xea, 0xc1, 0xc3, 0x22, 0x21,
0x6b, 0x22, 0x8b, 0x2a, 0x11, 0x82, 0x68, 0x7d,
0xb9, 0xdd, 0x9c, 0x27, 0x4c, 0xc2, 0xc8, 0xa2,
0x8b, 0x6b, 0x77, 0x8d, 0x3a, 0x2b, 0x8d, 0x2f,
0x6a, 0x2b, 0x43, 0xd2, 0xd1, 0xc6, 0x41, 0x79,
0xa2, 0x4f, 0x2b, 0xc2, 0xf7, 0xb2, 0x10, 0xad,
0xa6, 0x01, 0x51, 0x51, 0x25, 0xe7, 0x58, 0x7a,
0xcf, 0x3b, 0xc4, 0x29, 0xb5, 0xe5, 0xa7, 0x83,
0xe6, 0xcb, 0x1e, 0xf3, 0x02, 0x0f, 0x53, 0x3b,
0xb5, 0x39, 0xef, 0x9c, 0x42, 0xe0, 0xa6, 0x9b,
0x2b, 0xdd, 0x60, 0xae, 0x0a, 0x73, 0x35, 0xbe,
0x26, 0x10, 0x1b, 0xe9, 0xe9, 0x61, 0xab, 0x20,
0xa5, 0x48, 0xc6, 0x60, 0xa6, 0x50, 0x3c, 0xfb,
0xa7, 0xca, 0xb0, 0x80, 0x95, 0x1e, 0xce, 0xc7,
0xbb, 0x68, 0x44, 0xdc, 0x0e, 0x0e, 0x14, 0x03,
0x01, 0x00, 0x01, 0x01, 0x16, 0x03, 0x01, 0x00,
0x24, 0xb6, 0xcd, 0x0c, 0x78, 0xfd, 0xd6, 0xff,
0xbe, 0x97, 0xd5, 0x0a, 0x7d, 0x4f, 0xa1, 0x03,
0x78, 0xc8, 0x61, 0x6f, 0xf2, 0x4b, 0xa8, 0x56,
0x4f, 0x3c, 0xa2, 0xd9, 0xd0, 0x20, 0x13, 0x1b,
0x8b, 0x36, 0xb7, 0x33, 0x9c,
},
{
0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
0x01, 0x00, 0x24, 0xa3, 0x43, 0x94, 0xe7, 0xdf,
0xb6, 0xc3, 0x03, 0x9f, 0xc1, 0x59, 0x0c, 0xc3,
0x13, 0xae, 0xed, 0xcf, 0xff, 0xf1, 0x80, 0xf3,
0x13, 0x63, 0x1c, 0xf0, 0xca, 0xad, 0x9e, 0x71,
0x46, 0x5f, 0x6b, 0xeb, 0x10, 0x3f, 0xe3, 0x17,
0x03, 0x01, 0x00, 0x21, 0xe9, 0x80, 0x95, 0x6e,
0x05, 0x55, 0x2f, 0xed, 0x4d, 0xde, 0x17, 0x3a,
0x32, 0x9b, 0x2a, 0x74, 0x30, 0x4f, 0xe0, 0x9f,
0x4e, 0xd3, 0x06, 0xbd, 0x3a, 0x43, 0x75, 0x8b,
0x5b, 0x9a, 0xd8, 0x2e, 0x56, 0x15, 0x03, 0x01,
0x00, 0x16, 0x53, 0xf5, 0xff, 0xe0, 0xa1, 0x6c,
0x33, 0xf4, 0x4e, 0x89, 0x68, 0xe1, 0xf7, 0x61,
0x13, 0xb3, 0x12, 0xa1, 0x8e, 0x5a, 0x7a, 0x02,
}}},
}
// cert.pem and key.pem were generated with generate_cert.go
// Thus, they have no ExtKeyUsage fields and trigger an error
// when verification is turned on.
var clicert = loadPEMCert(`
-----BEGIN CERTIFICATE-----
MIIB7TCCAVigAwIBAgIBADALBgkqhkiG9w0BAQUwJjEQMA4GA1UEChMHQWNtZSBD
bzESMBAGA1UEAxMJMTI3LjAuMC4xMB4XDTExMTIwODA3NTUxMloXDTEyMTIwNzA4
MDAxMlowJjEQMA4GA1UEChMHQWNtZSBDbzESMBAGA1UEAxMJMTI3LjAuMC4xMIGc
MAsGCSqGSIb3DQEBAQOBjAAwgYgCgYBO0Hsx44Jk2VnAwoekXh6LczPHY1PfZpIG
hPZk1Y/kNqcdK+izIDZFI7Xjla7t4PUgnI2V339aEu+H5Fto5OkOdOwEin/ekyfE
ARl6vfLcPRSr0FTKIQzQTW6HLlzF0rtNS0/Otiz3fojsfNcCkXSmHgwa2uNKWi7e
E5xMQIhZkwIDAQABozIwMDAOBgNVHQ8BAf8EBAMCAKAwDQYDVR0OBAYEBAECAwQw
DwYDVR0jBAgwBoAEAQIDBDALBgkqhkiG9w0BAQUDgYEANh+zegx1yW43RmEr1b3A
p0vMRpqBWHyFeSnIyMZn3TJWRSt1tukkqVCavh9a+hoV2cxVlXIWg7nCto/9iIw4
hB2rXZIxE0/9gzvGnfERYraL7KtnvshksBFQRlgXa5kc0x38BvEO5ZaoDPl4ILdE
GFGNEH5PlGffo05wc46QkYU=
-----END CERTIFICATE-----
`)
/* corresponding key.pem for cert.pem is:
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgE7QezHjgmTZWcDCh6ReHotzM8djU99mkgaE9mTVj+Q2px0r6LMg
NkUjteOVru3g9SCcjZXff1oS74fkW2jk6Q507ASKf96TJ8QBGXq98tw9FKvQVMoh
DNBNbocuXMXSu01LT862LPd+iOx81wKRdKYeDBra40paLt4TnExAiFmTAgMBAAEC
gYBxvXd8yNteFTns8A/2yomEMC4yeosJJSpp1CsN3BJ7g8/qTnrVPxBy+RU+qr63
t2WquaOu/cr5P8iEsa6lk20tf8pjKLNXeX0b1RTzK8rJLbS7nGzP3tvOhL096VtQ
dAo4ROEaro0TzYpHmpciSvxVIeEIAAdFDObDJPKqcJAxyQJBAJizfYgK8Gzx9fsx
hxp+VteCbVPg2euASH5Yv3K5LukRdKoSzHE2grUVQgN/LafC0eZibRanxHegYSr7
7qaswKUCQQCEIWor/X4XTMdVj3Oj+vpiw75y/S9gh682+myZL+d/02IEkwnB098P
RkKVpenBHyrGg0oeN5La7URILWKj7CPXAkBKo6F+d+phNjwIFoN1Xb/RA32w/D1I
saG9sF+UEhRt9AxUfW/U/tIQ9V0ZHHcSg1XaCM5Nvp934brdKdvTOKnJAkBD5h/3
Rybatlvg/fzBEaJFyq09zhngkxlZOUtBVTqzl17RVvY2orgH02U4HbCHy4phxOn7
qTdQRYlHRftgnWK1AkANibn9PRYJ7mJyJ9Dyj2QeNcSkSTzrt0tPvUMf4+meJymN
1Ntu5+S1DLLzfxlaljWG6ylW6DNxujCyuXIV2rvAMAA=
-----END RSA PRIVATE KEY-----
*/

View File

@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
// LoadX509KeyPair reads and parses a public/private key pair from a pair of // LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. // files. The files must contain PEM encoded data.
func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err error) { func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
certPEMBlock, err := ioutil.ReadFile(certFile) certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil { if err != nil {
return return

View File

@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
return return
} }
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
func (s *CertPool) Subjects() (res [][]byte) {
res = make([][]byte, len(s.certs))
for i, c := range s.certs {
res[i] = c.RawSubject
}
return
}

View File

@ -7,14 +7,14 @@ package gosym
import ( import (
"debug/elf" "debug/elf"
"os" "os"
"syscall" "runtime"
"testing" "testing"
) )
func dotest() bool { func dotest() bool {
// For now, only works on ELF platforms. // For now, only works on ELF platforms.
// TODO: convert to work with new go tool // TODO: convert to work with new go tool
return false && syscall.OS == "linux" && os.Getenv("GOARCH") == "amd64" return false && runtime.GOOS == "linux" && runtime.GOARCH == "amd64"
} }
func getTable(t *testing.T) *Table { func getTable(t *testing.T) *Table {

View File

@ -786,7 +786,8 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// Because Unmarshal uses the reflect package, the structs // Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names. // being written to must use upper case field names.
// //
// An ASN.1 INTEGER can be written to an int, int32 or int64. // An ASN.1 INTEGER can be written to an int, int32, int64,
// or *big.Int (from the math/big package).
// If the encoded value does not fit in the Go type, // If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error. // Unmarshal returns a parse error.
// //

View File

@ -6,6 +6,7 @@ package asn1
import ( import (
"bytes" "bytes"
"math/big"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -351,6 +352,10 @@ type TestElementsAfterString struct {
A, B int A, B int
} }
type TestBigInt struct {
X *big.Int
}
var unmarshalTestData = []struct { var unmarshalTestData = []struct {
in []byte in []byte
out interface{} out interface{}
@ -369,6 +374,7 @@ var unmarshalTestData = []struct {
{[]byte{0x01, 0x01, 0x00}, newBool(false)}, {[]byte{0x01, 0x01, 0x00}, newBool(false)},
{[]byte{0x01, 0x01, 0x01}, newBool(true)}, {[]byte{0x01, 0x01, 0x01}, newBool(true)},
{[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}}, {[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}},
{[]byte{0x30, 0x05, 0x02, 0x03, 0x12, 0x34, 0x56}, &TestBigInt{big.NewInt(0x123456)}},
} }
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {

View File

@ -7,6 +7,7 @@ package asn1
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"math/big"
"testing" "testing"
"time" "time"
) )
@ -20,6 +21,10 @@ type twoIntStruct struct {
B int B int
} }
type bigIntStruct struct {
A *big.Int
}
type nestedStruct struct { type nestedStruct struct {
A intStruct A intStruct
} }
@ -65,6 +70,7 @@ var marshalTests = []marshalTest{
{-128, "020180"}, {-128, "020180"},
{-129, "0202ff7f"}, {-129, "0202ff7f"},
{intStruct{64}, "3003020140"}, {intStruct{64}, "3003020140"},
{bigIntStruct{big.NewInt(0x123456)}, "30050203123456"},
{twoIntStruct{64, 65}, "3006020140020141"}, {twoIntStruct{64, 65}, "3006020140020141"},
{nestedStruct{intStruct{127}}, "3005300302017f"}, {nestedStruct{intStruct{127}}, "3005300302017f"},
{[]byte{1, 2, 3}, "0403010203"}, {[]byte{1, 2, 3}, "0403010203"},

View File

@ -1039,9 +1039,9 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
// Extract and compare element types. // Extract and compare element types.
var sw *sliceType var sw *sliceType
if tt, ok := builtinIdToType[fw]; ok { if tt, ok := builtinIdToType[fw]; ok {
sw = tt.(*sliceType) sw, _ = tt.(*sliceType)
} else { } else if wire != nil {
sw = dec.wireType[fw].SliceT sw = wire.SliceT
} }
elem := userType(t.Elem()).base elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)

View File

@ -678,3 +678,11 @@ func TestUnexportedChan(t *testing.T) {
t.Fatalf("error encoding unexported channel: %s", err) t.Fatalf("error encoding unexported channel: %s", err)
} }
} }
func TestSliceIncompatibility(t *testing.T) {
var in = []byte{1, 2, 3}
var out []int
if err := encAndDec(in, &out); err == nil {
t.Error("expected compatibility error")
}
}

View File

@ -10,6 +10,7 @@ package json
import ( import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"reflect" "reflect"
"runtime" "runtime"
"strconv" "strconv"
@ -538,7 +539,7 @@ func (d *decodeState) object(v reflect.Value) {
// Read value. // Read value.
if destring { if destring {
d.value(reflect.ValueOf(&d.tempstr)) d.value(reflect.ValueOf(&d.tempstr))
d.literalStore([]byte(d.tempstr), subv) d.literalStore([]byte(d.tempstr), subv, true)
} else { } else {
d.value(subv) d.value(subv)
} }
@ -571,11 +572,15 @@ func (d *decodeState) literal(v reflect.Value) {
d.off-- d.off--
d.scan.undo(op) d.scan.undo(op)
d.literalStore(d.data[start:d.off], v) d.literalStore(d.data[start:d.off], v, false)
} }
// literalStore decodes a literal stored in item into v. // literalStore decodes a literal stored in item into v.
func (d *decodeState) literalStore(item []byte, v reflect.Value) { //
// fromQuoted indicates whether this literal came from unwrapping a
// string from the ",string" struct tag option. this is used only to
// produce more helpful error messages.
func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) {
// Check for unmarshaler. // Check for unmarshaler.
wantptr := item[0] == 'n' // null wantptr := item[0] == 'n' // null
unmarshaler, pv := d.indirect(v, wantptr) unmarshaler, pv := d.indirect(v, wantptr)
@ -601,7 +606,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
value := c == 't' value := c == 't'
switch v.Kind() { switch v.Kind() {
default: default:
d.saveError(&UnmarshalTypeError{"bool", v.Type()}) if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{"bool", v.Type()})
}
case reflect.Bool: case reflect.Bool:
v.SetBool(value) v.SetBool(value)
case reflect.Interface: case reflect.Interface:
@ -611,7 +620,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
case '"': // string case '"': // string
s, ok := unquoteBytes(item) s, ok := unquoteBytes(item)
if !ok { if !ok {
d.error(errPhase) if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
} }
switch v.Kind() { switch v.Kind() {
default: default:
@ -636,12 +649,20 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
default: // number default: // number
if c != '-' && (c < '0' || c > '9') { if c != '-' && (c < '0' || c > '9') {
d.error(errPhase) if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
} }
s := string(item) s := string(item)
switch v.Kind() { switch v.Kind() {
default: default:
d.error(&UnmarshalTypeError{"number", v.Type()}) if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(&UnmarshalTypeError{"number", v.Type()})
}
case reflect.Interface: case reflect.Interface:
n, err := strconv.ParseFloat(s, 64) n, err := strconv.ParseFloat(s, 64)
if err != nil { if err != nil {

View File

@ -258,13 +258,10 @@ type wrongStringTest struct {
in, err string in, err string
} }
// TODO(bradfitz): as part of Issue 2331, fix these tests' expected
// error values to be helpful, rather than the confusing messages they
// are now.
var wrongStringTests = []wrongStringTest{ var wrongStringTests = []wrongStringTest{
{`{"result":"x"}`, "JSON decoder out of sync - data changing underfoot?"}, {`{"result":"x"}`, `json: invalid use of ,string struct tag, trying to unmarshal "x" into string`},
{`{"result":"foo"}`, "json: cannot unmarshal bool into Go value of type string"}, {`{"result":"foo"}`, `json: invalid use of ,string struct tag, trying to unmarshal "foo" into string`},
{`{"result":"123"}`, "json: cannot unmarshal number into Go value of type string"}, {`{"result":"123"}`, `json: invalid use of ,string struct tag, trying to unmarshal "123" into string`},
} }
// If people misuse the ,string modifier, the error message should be // If people misuse the ,string modifier, the error message should be

View File

@ -12,6 +12,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"math"
"reflect" "reflect"
"runtime" "runtime"
"sort" "sort"
@ -170,6 +171,15 @@ func (e *UnsupportedTypeError) Error() string {
return "json: unsupported type: " + e.Type.String() return "json: unsupported type: " + e.Type.String()
} }
type UnsupportedValueError struct {
Value reflect.Value
Str string
}
func (e *UnsupportedValueError) Error() string {
return "json: unsupported value: " + e.Str
}
type InvalidUTF8Error struct { type InvalidUTF8Error struct {
S string S string
} }
@ -290,7 +300,11 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
e.Write(b) e.Write(b)
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
b := strconv.AppendFloat(e.scratch[:0], v.Float(), 'g', -1, v.Type().Bits()) f := v.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits())})
}
b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, v.Type().Bits())
if quoted { if quoted {
writeString(e, string(b)) writeString(e, string(b))
} else { } else {

View File

@ -6,6 +6,7 @@ package json
import ( import (
"bytes" "bytes"
"math"
"reflect" "reflect"
"testing" "testing"
) )
@ -107,3 +108,21 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
t.Errorf(" got %s want %s", result, expect) t.Errorf(" got %s want %s", result, expect)
} }
} }
var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
}
func TestUnsupportedValues(t *testing.T) {
for _, v := range unsupportedValues {
if _, err := Marshal(v); err != nil {
if _, ok := err.(*UnsupportedValueError); !ok {
t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
}
} else {
t.Errorf("for %v, expected error", v)
}
}
}

View File

@ -5,6 +5,7 @@
package xml package xml
var atomValue = &Feed{ var atomValue = &Feed{
XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
Title: "Example Feed", Title: "Example Feed",
Link: []Link{{Href: "http://example.org/"}}, Link: []Link{{Href: "http://example.org/"}},
Updated: ParseTime("2003-12-13T18:30:02Z"), Updated: ParseTime("2003-12-13T18:30:02Z"),
@ -24,19 +25,19 @@ var atomValue = &Feed{
var atomXml = `` + var atomXml = `` +
`<feed xmlns="http://www.w3.org/2005/Atom">` + `<feed xmlns="http://www.w3.org/2005/Atom">` +
`<Title>Example Feed</Title>` + `<title>Example Feed</title>` +
`<Id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</Id>` + `<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` +
`<Link href="http://example.org/"></Link>` + `<link href="http://example.org/"></link>` +
`<Updated>2003-12-13T18:30:02Z</Updated>` + `<updated>2003-12-13T18:30:02Z</updated>` +
`<Author><Name>John Doe</Name><URI></URI><Email></Email></Author>` + `<author><name>John Doe</name><uri></uri><email></email></author>` +
`<Entry>` + `<entry>` +
`<Title>Atom-Powered Robots Run Amok</Title>` + `<title>Atom-Powered Robots Run Amok</title>` +
`<Id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</Id>` + `<id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</id>` +
`<Link href="http://example.org/2003/12/13/atom03"></Link>` + `<link href="http://example.org/2003/12/13/atom03"></link>` +
`<Updated>2003-12-13T18:30:02Z</Updated>` + `<updated>2003-12-13T18:30:02Z</updated>` +
`<Author><Name></Name><URI></URI><Email></Email></Author>` + `<author><name></name><uri></uri><email></email></author>` +
`<Summary>Some text.</Summary>` + `<summary>Some text.</summary>` +
`</Entry>` + `</entry>` +
`</feed>` `</feed>`
func ParseTime(str string) Time { func ParseTime(str string) Time {

View File

@ -1,124 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import "testing"
type C struct {
Name string
Open bool
}
type A struct {
XMLName Name `xml:"http://domain a"`
C
B B
FieldA string
}
type B struct {
XMLName Name `xml:"b"`
C
FieldB string
}
const _1a = `
<?xml version="1.0" encoding="UTF-8"?>
<a xmlns="http://domain">
<name>KmlFile</name>
<open>1</open>
<b>
<name>Absolute</name>
<open>0</open>
<fieldb>bar</fieldb>
</b>
<fielda>foo</fielda>
</a>
`
// Tests that embedded structs are marshalled.
func TestEmbedded1(t *testing.T) {
var a A
if e := Unmarshal(StringReader(_1a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.FieldA != "foo" {
t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.FieldA)
}
if a.Name != "KmlFile" {
t.Fatalf("Unmarshal: expected 'KmlFile' but found '%s'", a.Name)
}
if !a.Open {
t.Fatal("Unmarshal: expected 'true' but found otherwise")
}
if a.B.FieldB != "bar" {
t.Fatalf("Unmarshal: expected 'bar' but found '%s'", a.B.FieldB)
}
if a.B.Name != "Absolute" {
t.Fatalf("Unmarshal: expected 'Absolute' but found '%s'", a.B.Name)
}
if a.B.Open {
t.Fatal("Unmarshal: expected 'false' but found otherwise")
}
}
type A2 struct {
XMLName Name `xml:"http://domain a"`
XY string
Xy string
}
const _2a = `
<?xml version="1.0" encoding="UTF-8"?>
<a xmlns="http://domain">
<xy>foo</xy>
</a>
`
// Tests that conflicting field names get excluded.
func TestEmbedded2(t *testing.T) {
var a A2
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.XY != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.XY)
}
if a.Xy != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.Xy)
}
}
type A3 struct {
XMLName Name `xml:"http://domain a"`
xy string
}
// Tests that private fields are not set.
func TestEmbedded3(t *testing.T) {
var a A3
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.xy != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.xy)
}
}
type A4 struct {
XMLName Name `xml:"http://domain a"`
Any string
}
// Tests that private fields are not set.
func TestEmbedded4(t *testing.T) {
var a A4
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.Any != "foo" {
t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.Any)
}
}

View File

@ -6,6 +6,8 @@ package xml
import ( import (
"bufio" "bufio"
"bytes"
"fmt"
"io" "io"
"reflect" "reflect"
"strconv" "strconv"
@ -42,20 +44,26 @@ type printer struct {
// elements containing the data. // elements containing the data.
// //
// The name for the XML elements is taken from, in order of preference: // The name for the XML elements is taken from, in order of preference:
// - the tag on an XMLName field, if the data is a struct // - the tag on the XMLName field, if the data is a struct
// - the value of an XMLName field of type xml.Name // - the value of the XMLName field of type xml.Name
// - the tag of the struct field used to obtain the data // - the tag of the struct field used to obtain the data
// - the name of the struct field used to obtain the data // - the name of the struct field used to obtain the data
// - the name '???'. // - the name of the marshalled type
// //
// The XML element for a struct contains marshalled elements for each of the // The XML element for a struct contains marshalled elements for each of the
// exported fields of the struct, with these exceptions: // exported fields of the struct, with these exceptions:
// - the XMLName field, described above, is omitted. // - the XMLName field, described above, is omitted.
// - a field with tag "attr" becomes an attribute in the XML element. // - a field with tag "name,attr" becomes an attribute with
// - a field with tag "chardata" is written as character data, // the given name in the XML element.
// not as an XML element. // - a field with tag ",attr" becomes an attribute with the
// - a field with tag "innerxml" is written verbatim, // field name in the in the XML element.
// not subject to the usual marshalling procedure. // - a field with tag ",chardata" is written as character data,
// not as an XML element.
// - a field with tag ",innerxml" is written verbatim, not subject
// to the usual marshalling procedure.
// - a field with tag ",comment" is written as an XML comment, not
// subject to the usual marshalling procedure. It must not contain
// the "--" string within it.
// //
// If a field uses a tag "a>b>c", then the element c will be nested inside // If a field uses a tag "a>b>c", then the element c will be nested inside
// parent elements a and b. Fields that appear next to each other that name // parent elements a and b. Fields that appear next to each other that name
@ -63,17 +71,18 @@ type printer struct {
// //
// type Result struct { // type Result struct {
// XMLName xml.Name `xml:"result"` // XMLName xml.Name `xml:"result"`
// Id int `xml:"id,attr"`
// FirstName string `xml:"person>name>first"` // FirstName string `xml:"person>name>first"`
// LastName string `xml:"person>name>last"` // LastName string `xml:"person>name>last"`
// Age int `xml:"person>age"` // Age int `xml:"person>age"`
// } // }
// //
// xml.Marshal(w, &Result{FirstName: "John", LastName: "Doe", Age: 42}) // xml.Marshal(w, &Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
// //
// would be marshalled as: // would be marshalled as:
// //
// <result> // <result>
// <person> // <person id="13">
// <name> // <name>
// <first>John</first> // <first>John</first>
// <last>Doe</last> // <last>Doe</last>
@ -85,12 +94,12 @@ type printer struct {
// Marshal will return an error if asked to marshal a channel, function, or map. // Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(w io.Writer, v interface{}) (err error) { func Marshal(w io.Writer, v interface{}) (err error) {
p := &printer{bufio.NewWriter(w)} p := &printer{bufio.NewWriter(w)}
err = p.marshalValue(reflect.ValueOf(v), "???") err = p.marshalValue(reflect.ValueOf(v), nil)
p.Flush() p.Flush()
return err return err
} }
func (p *printer) marshalValue(val reflect.Value, name string) error { func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
if !val.IsValid() { if !val.IsValid() {
return nil return nil
} }
@ -115,58 +124,75 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
if val.IsNil() { if val.IsNil() {
return nil return nil
} }
return p.marshalValue(val.Elem(), name) return p.marshalValue(val.Elem(), finfo)
} }
// Slices and arrays iterate over the elements. They do not have an enclosing tag. // Slices and arrays iterate over the elements. They do not have an enclosing tag.
if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 { if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 {
for i, n := 0, val.Len(); i < n; i++ { for i, n := 0, val.Len(); i < n; i++ {
if err := p.marshalValue(val.Index(i), name); err != nil { if err := p.marshalValue(val.Index(i), finfo); err != nil {
return err return err
} }
} }
return nil return nil
} }
// Find XML name tinfo, err := getTypeInfo(typ)
xmlns := "" if err != nil {
if kind == reflect.Struct { return err
if f, ok := typ.FieldByName("XMLName"); ok { }
if tag := f.Tag.Get("xml"); tag != "" {
if i := strings.Index(tag, " "); i >= 0 { // Precedence for the XML element name is:
xmlns, name = tag[:i], tag[i+1:] // 1. XMLName field in underlying struct;
} else { // 2. field name/tag in the struct field; and
name = tag // 3. type name
} var xmlns, name string
} else if v, ok := val.FieldByIndex(f.Index).Interface().(Name); ok && v.Local != "" { if tinfo.xmlname != nil {
xmlns, name = v.Space, v.Local xmlname := tinfo.xmlname
} if xmlname.name != "" {
xmlns, name = xmlname.xmlns, xmlname.name
} else if v, ok := val.FieldByIndex(xmlname.idx).Interface().(Name); ok && v.Local != "" {
xmlns, name = v.Space, v.Local
}
}
if name == "" && finfo != nil {
xmlns, name = finfo.xmlns, finfo.name
}
if name == "" {
name = typ.Name()
if name == "" {
return &UnsupportedTypeError{typ}
} }
} }
p.WriteByte('<') p.WriteByte('<')
p.WriteString(name) p.WriteString(name)
// Attributes if xmlns != "" {
if kind == reflect.Struct { p.WriteString(` xmlns="`)
if len(xmlns) > 0 { // TODO: EscapeString, to avoid the allocation.
p.WriteString(` xmlns="`) Escape(p, []byte(xmlns))
Escape(p, []byte(xmlns)) p.WriteByte('"')
p.WriteByte('"') }
}
for i, n := 0, typ.NumField(); i < n; i++ { // Attributes
if f := typ.Field(i); f.PkgPath == "" && f.Tag.Get("xml") == "attr" { for i := range tinfo.fields {
if f.Type.Kind() == reflect.String { finfo := &tinfo.fields[i]
if str := val.Field(i).String(); str != "" { if finfo.flags&fAttr == 0 {
p.WriteByte(' ') continue
p.WriteString(strings.ToLower(f.Name)) }
p.WriteString(`="`) var str string
Escape(p, []byte(str)) if fv := val.FieldByIndex(finfo.idx); fv.Kind() == reflect.String {
p.WriteByte('"') str = fv.String()
} } else {
} str = fmt.Sprint(fv.Interface())
} }
if str != "" {
p.WriteByte(' ')
p.WriteString(finfo.name)
p.WriteString(`="`)
Escape(p, []byte(str))
p.WriteByte('"')
} }
} }
p.WriteByte('>') p.WriteByte('>')
@ -194,58 +220,9 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
bytes := val.Interface().([]byte) bytes := val.Interface().([]byte)
Escape(p, bytes) Escape(p, bytes)
case reflect.Struct: case reflect.Struct:
s := parentStack{printer: p} if err := p.marshalStruct(tinfo, val); err != nil {
for i, n := 0, val.NumField(); i < n; i++ { return err
if f := typ.Field(i); f.Name != "XMLName" && f.PkgPath == "" {
name := f.Name
vf := val.Field(i)
switch tag := f.Tag.Get("xml"); tag {
case "":
s.trim(nil)
case "chardata":
if tk := f.Type.Kind(); tk == reflect.String {
Escape(p, []byte(vf.String()))
} else if tk == reflect.Slice {
if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem)
}
}
continue
case "innerxml":
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case "attr":
continue
default:
parents := strings.Split(tag, ">")
if len(parents) == 1 {
parents, name = nil, tag
} else {
parents, name = parents[:len(parents)-1], parents[len(parents)-1]
if parents[0] == "" {
parents[0] = f.Name
}
}
s.trim(parents)
if !(vf.Kind() == reflect.Ptr || vf.Kind() == reflect.Interface) || !vf.IsNil() {
s.push(parents[len(s.stack):])
}
}
if err := p.marshalValue(vf, name); err != nil {
return err
}
}
} }
s.trim(nil)
default: default:
return &UnsupportedTypeError{typ} return &UnsupportedTypeError{typ}
} }
@ -258,6 +235,94 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
return nil return nil
} }
var ddBytes = []byte("--")
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
s := parentStack{printer: p}
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&(fAttr|fAny) != 0 {
continue
}
vf := val.FieldByIndex(finfo.idx)
switch finfo.flags & fMode {
case fCharData:
switch vf.Kind() {
case reflect.String:
Escape(p, []byte(vf.String()))
case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem)
}
}
continue
case fComment:
k := vf.Kind()
if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) {
return fmt.Errorf("xml: bad type for comment field of %s", val.Type())
}
if vf.Len() == 0 {
continue
}
p.WriteString("<!--")
dashDash := false
dashLast := false
switch k {
case reflect.String:
s := vf.String()
dashDash = strings.Index(s, "--") >= 0
dashLast = s[len(s)-1] == '-'
if !dashDash {
p.WriteString(s)
}
case reflect.Slice:
b := vf.Bytes()
dashDash = bytes.Index(b, ddBytes) >= 0
dashLast = b[len(b)-1] == '-'
if !dashDash {
p.Write(b)
}
default:
panic("can't happen")
}
if dashDash {
return fmt.Errorf(`xml: comments must not contain "--"`)
}
if dashLast {
// "--->" is invalid grammar. Make it "- -->"
p.WriteByte(' ')
}
p.WriteString("-->")
continue
case fInnerXml:
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case fElement:
s.trim(finfo.parents)
if len(finfo.parents) > len(s.stack) {
if vf.Kind() != reflect.Ptr && vf.Kind() != reflect.Interface || !vf.IsNil() {
s.push(finfo.parents[len(s.stack):])
}
}
}
if err := p.marshalValue(vf, finfo); err != nil {
return err
}
}
s.trim(nil)
return nil
}
type parentStack struct { type parentStack struct {
*printer *printer
stack []string stack []string

View File

@ -25,10 +25,10 @@ type Passenger struct {
} }
type Ship struct { type Ship struct {
XMLName Name `xml:"spaceship"` XMLName struct{} `xml:"spaceship"`
Name string `xml:"attr"` Name string `xml:"name,attr"`
Pilot string `xml:"attr"` Pilot string `xml:"pilot,attr"`
Drive DriveType `xml:"drive"` Drive DriveType `xml:"drive"`
Age uint `xml:"age"` Age uint `xml:"age"`
Passenger []*Passenger `xml:"passenger"` Passenger []*Passenger `xml:"passenger"`
@ -44,48 +44,50 @@ func (rx RawXML) MarshalXML() ([]byte, error) {
type NamedType string type NamedType string
type Port struct { type Port struct {
XMLName Name `xml:"port"` XMLName struct{} `xml:"port"`
Type string `xml:"attr"` Type string `xml:"type,attr"`
Number string `xml:"chardata"` Comment string `xml:",comment"`
Number string `xml:",chardata"`
} }
type Domain struct { type Domain struct {
XMLName Name `xml:"domain"` XMLName struct{} `xml:"domain"`
Country string `xml:"attr"` Country string `xml:",attr"`
Name []byte `xml:"chardata"` Name []byte `xml:",chardata"`
Comment []byte `xml:",comment"`
} }
type Book struct { type Book struct {
XMLName Name `xml:"book"` XMLName struct{} `xml:"book"`
Title string `xml:"chardata"` Title string `xml:",chardata"`
} }
type SecretAgent struct { type SecretAgent struct {
XMLName Name `xml:"agent"` XMLName struct{} `xml:"agent"`
Handle string `xml:"attr"` Handle string `xml:"handle,attr"`
Identity string Identity string
Obfuscate string `xml:"innerxml"` Obfuscate string `xml:",innerxml"`
} }
type NestedItems struct { type NestedItems struct {
XMLName Name `xml:"result"` XMLName struct{} `xml:"result"`
Items []string `xml:">item"` Items []string `xml:">item"`
Item1 []string `xml:"Items>item1"` Item1 []string `xml:"Items>item1"`
} }
type NestedOrder struct { type NestedOrder struct {
XMLName Name `xml:"result"` XMLName struct{} `xml:"result"`
Field1 string `xml:"parent>c"` Field1 string `xml:"parent>c"`
Field2 string `xml:"parent>b"` Field2 string `xml:"parent>b"`
Field3 string `xml:"parent>a"` Field3 string `xml:"parent>a"`
} }
type MixedNested struct { type MixedNested struct {
XMLName Name `xml:"result"` XMLName struct{} `xml:"result"`
A string `xml:"parent1>a"` A string `xml:"parent1>a"`
B string `xml:"b"` B string `xml:"b"`
C string `xml:"parent1>parent2>c"` C string `xml:"parent1>parent2>c"`
D string `xml:"parent1>d"` D string `xml:"parent1>d"`
} }
type NilTest struct { type NilTest struct {
@ -95,62 +97,165 @@ type NilTest struct {
} }
type Service struct { type Service struct {
XMLName Name `xml:"service"` XMLName struct{} `xml:"service"`
Domain *Domain `xml:"host>domain"` Domain *Domain `xml:"host>domain"`
Port *Port `xml:"host>port"` Port *Port `xml:"host>port"`
Extra1 interface{} Extra1 interface{}
Extra2 interface{} `xml:"host>extra2"` Extra2 interface{} `xml:"host>extra2"`
} }
var nilStruct *Ship var nilStruct *Ship
type EmbedA struct {
EmbedC
EmbedB EmbedB
FieldA string
}
type EmbedB struct {
FieldB string
EmbedC
}
type EmbedC struct {
FieldA1 string `xml:"FieldA>A1"`
FieldA2 string `xml:"FieldA>A2"`
FieldB string
FieldC string
}
type NameCasing struct {
XMLName struct{} `xml:"casing"`
Xy string
XY string
XyA string `xml:"Xy,attr"`
XYA string `xml:"XY,attr"`
}
type NamePrecedence struct {
XMLName Name `xml:"Parent"`
FromTag XMLNameWithoutTag `xml:"InTag"`
FromNameVal XMLNameWithoutTag
FromNameTag XMLNameWithTag
InFieldName string
}
type XMLNameWithTag struct {
XMLName Name `xml:"InXMLNameTag"`
Value string ",chardata"
}
type XMLNameWithoutTag struct {
XMLName Name
Value string ",chardata"
}
type AttrTest struct {
Int int `xml:",attr"`
Lower int `xml:"int,attr"`
Float float64 `xml:",attr"`
Uint8 uint8 `xml:",attr"`
Bool bool `xml:",attr"`
Str string `xml:",attr"`
}
type AnyTest struct {
XMLName struct{} `xml:"a"`
Nested string `xml:"nested>value"`
AnyField AnyHolder `xml:",any"`
}
type AnyHolder struct {
XMLName Name
XML string `xml:",innerxml"`
}
type RecurseA struct {
A string
B *RecurseB
}
type RecurseB struct {
A *RecurseA
B string
}
type Plain struct {
V interface{}
}
// Unless explicitly stated as such (or *Plain), all of the
// tests below are two-way tests. When introducing new tests,
// please try to make them two-way as well to ensure that
// marshalling and unmarshalling are as symmetrical as feasible.
var marshalTests = []struct { var marshalTests = []struct {
Value interface{} Value interface{}
ExpectXML string ExpectXML string
MarshalOnly bool
UnmarshalOnly bool
}{ }{
// Test nil marshals to nothing // Test nil marshals to nothing
{Value: nil, ExpectXML: ``}, {Value: nil, ExpectXML: ``, MarshalOnly: true},
{Value: nilStruct, ExpectXML: ``}, {Value: nilStruct, ExpectXML: ``, MarshalOnly: true},
// Test value types (no tag name, so ???) // Test value types
{Value: true, ExpectXML: `<???>true</???>`}, {Value: &Plain{true}, ExpectXML: `<Plain><V>true</V></Plain>`},
{Value: int(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{false}, ExpectXML: `<Plain><V>false</V></Plain>`},
{Value: int8(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{int(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: int16(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{int8(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: int32(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{int16(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: uint(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{int32(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: uint8(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{uint(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: uint16(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{uint8(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: uint32(42), ExpectXML: `<???>42</???>`}, {Value: &Plain{uint16(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: float32(1.25), ExpectXML: `<???>1.25</???>`}, {Value: &Plain{uint32(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
{Value: float64(1.25), ExpectXML: `<???>1.25</???>`}, {Value: &Plain{float32(1.25)}, ExpectXML: `<Plain><V>1.25</V></Plain>`},
{Value: uintptr(0xFFDD), ExpectXML: `<???>65501</???>`}, {Value: &Plain{float64(1.25)}, ExpectXML: `<Plain><V>1.25</V></Plain>`},
{Value: "gopher", ExpectXML: `<???>gopher</???>`}, {Value: &Plain{uintptr(0xFFDD)}, ExpectXML: `<Plain><V>65501</V></Plain>`},
{Value: []byte("gopher"), ExpectXML: `<???>gopher</???>`}, {Value: &Plain{"gopher"}, ExpectXML: `<Plain><V>gopher</V></Plain>`},
{Value: "</>", ExpectXML: `<???>&lt;/&gt;</???>`}, {Value: &Plain{[]byte("gopher")}, ExpectXML: `<Plain><V>gopher</V></Plain>`},
{Value: []byte("</>"), ExpectXML: `<???>&lt;/&gt;</???>`}, {Value: &Plain{"</>"}, ExpectXML: `<Plain><V>&lt;/&gt;</V></Plain>`},
{Value: [3]byte{'<', '/', '>'}, ExpectXML: `<???>&lt;/&gt;</???>`}, {Value: &Plain{[]byte("</>")}, ExpectXML: `<Plain><V>&lt;/&gt;</V></Plain>`},
{Value: NamedType("potato"), ExpectXML: `<???>potato</???>`}, {Value: &Plain{[3]byte{'<', '/', '>'}}, ExpectXML: `<Plain><V>&lt;/&gt;</V></Plain>`},
{Value: []int{1, 2, 3}, ExpectXML: `<???>1</???><???>2</???><???>3</???>`}, {Value: &Plain{NamedType("potato")}, ExpectXML: `<Plain><V>potato</V></Plain>`},
{Value: [3]int{1, 2, 3}, ExpectXML: `<???>1</???><???>2</???><???>3</???>`}, {Value: &Plain{[]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`},
{Value: &Plain{[3]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`},
// Test innerxml // Test innerxml
{Value: RawXML("</>"), ExpectXML: `</>`},
{ {
Value: &SecretAgent{ Value: &SecretAgent{
Handle: "007", Handle: "007",
Identity: "James Bond", Identity: "James Bond",
Obfuscate: "<redacted/>", Obfuscate: "<redacted/>",
}, },
//ExpectXML: `<agent handle="007"><redacted/></agent>`, ExpectXML: `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`,
ExpectXML: `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`, MarshalOnly: true,
},
{
Value: &SecretAgent{
Handle: "007",
Identity: "James Bond",
Obfuscate: "<Identity>James Bond</Identity><redacted/>",
},
ExpectXML: `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`,
UnmarshalOnly: true,
},
// Test marshaller interface
{
Value: RawXML("</>"),
ExpectXML: `</>`,
MarshalOnly: true,
}, },
// Test structs // Test structs
{Value: &Port{Type: "ssl", Number: "443"}, ExpectXML: `<port type="ssl">443</port>`}, {Value: &Port{Type: "ssl", Number: "443"}, ExpectXML: `<port type="ssl">443</port>`},
{Value: &Port{Number: "443"}, ExpectXML: `<port>443</port>`}, {Value: &Port{Number: "443"}, ExpectXML: `<port>443</port>`},
{Value: &Port{Type: "<unix>"}, ExpectXML: `<port type="&lt;unix&gt;"></port>`}, {Value: &Port{Type: "<unix>"}, ExpectXML: `<port type="&lt;unix&gt;"></port>`},
{Value: &Port{Number: "443", Comment: "https"}, ExpectXML: `<port><!--https-->443</port>`},
{Value: &Port{Number: "443", Comment: "add space-"}, ExpectXML: `<port><!--add space- -->443</port>`, MarshalOnly: true},
{Value: &Domain{Name: []byte("google.com&friends")}, ExpectXML: `<domain>google.com&amp;friends</domain>`}, {Value: &Domain{Name: []byte("google.com&friends")}, ExpectXML: `<domain>google.com&amp;friends</domain>`},
{Value: &Domain{Name: []byte("google.com"), Comment: []byte(" &friends ")}, ExpectXML: `<domain>google.com<!-- &friends --></domain>`},
{Value: &Book{Title: "Pride & Prejudice"}, ExpectXML: `<book>Pride &amp; Prejudice</book>`}, {Value: &Book{Title: "Pride & Prejudice"}, ExpectXML: `<book>Pride &amp; Prejudice</book>`},
{Value: atomValue, ExpectXML: atomXml}, {Value: atomValue, ExpectXML: atomXml},
{ {
@ -203,16 +308,25 @@ var marshalTests = []struct {
`</passenger>` + `</passenger>` +
`</spaceship>`, `</spaceship>`,
}, },
// Test a>b // Test a>b
{ {
Value: NestedItems{Items: []string{}, Item1: []string{}}, Value: &NestedItems{Items: nil, Item1: nil},
ExpectXML: `<result>` + ExpectXML: `<result>` +
`<Items>` + `<Items>` +
`</Items>` + `</Items>` +
`</result>`, `</result>`,
}, },
{ {
Value: NestedItems{Items: []string{}, Item1: []string{"A"}}, Value: &NestedItems{Items: []string{}, Item1: []string{}},
ExpectXML: `<result>` +
`<Items>` +
`</Items>` +
`</result>`,
MarshalOnly: true,
},
{
Value: &NestedItems{Items: nil, Item1: []string{"A"}},
ExpectXML: `<result>` + ExpectXML: `<result>` +
`<Items>` + `<Items>` +
`<item1>A</item1>` + `<item1>A</item1>` +
@ -220,7 +334,7 @@ var marshalTests = []struct {
`</result>`, `</result>`,
}, },
{ {
Value: NestedItems{Items: []string{"A", "B"}, Item1: []string{}}, Value: &NestedItems{Items: []string{"A", "B"}, Item1: nil},
ExpectXML: `<result>` + ExpectXML: `<result>` +
`<Items>` + `<Items>` +
`<item>A</item>` + `<item>A</item>` +
@ -229,7 +343,7 @@ var marshalTests = []struct {
`</result>`, `</result>`,
}, },
{ {
Value: NestedItems{Items: []string{"A", "B"}, Item1: []string{"C"}}, Value: &NestedItems{Items: []string{"A", "B"}, Item1: []string{"C"}},
ExpectXML: `<result>` + ExpectXML: `<result>` +
`<Items>` + `<Items>` +
`<item>A</item>` + `<item>A</item>` +
@ -239,7 +353,7 @@ var marshalTests = []struct {
`</result>`, `</result>`,
}, },
{ {
Value: NestedOrder{Field1: "C", Field2: "B", Field3: "A"}, Value: &NestedOrder{Field1: "C", Field2: "B", Field3: "A"},
ExpectXML: `<result>` + ExpectXML: `<result>` +
`<parent>` + `<parent>` +
`<c>C</c>` + `<c>C</c>` +
@ -249,16 +363,17 @@ var marshalTests = []struct {
`</result>`, `</result>`,
}, },
{ {
Value: NilTest{A: "A", B: nil, C: "C"}, Value: &NilTest{A: "A", B: nil, C: "C"},
ExpectXML: `<???>` + ExpectXML: `<NilTest>` +
`<parent1>` + `<parent1>` +
`<parent2><a>A</a></parent2>` + `<parent2><a>A</a></parent2>` +
`<parent2><c>C</c></parent2>` + `<parent2><c>C</c></parent2>` +
`</parent1>` + `</parent1>` +
`</???>`, `</NilTest>`,
MarshalOnly: true, // Uses interface{}
}, },
{ {
Value: MixedNested{A: "A", B: "B", C: "C", D: "D"}, Value: &MixedNested{A: "A", B: "B", C: "C", D: "D"},
ExpectXML: `<result>` + ExpectXML: `<result>` +
`<parent1><a>A</a></parent1>` + `<parent1><a>A</a></parent1>` +
`<b>B</b>` + `<b>B</b>` +
@ -269,32 +384,154 @@ var marshalTests = []struct {
`</result>`, `</result>`,
}, },
{ {
Value: Service{Port: &Port{Number: "80"}}, Value: &Service{Port: &Port{Number: "80"}},
ExpectXML: `<service><host><port>80</port></host></service>`, ExpectXML: `<service><host><port>80</port></host></service>`,
}, },
{ {
Value: Service{}, Value: &Service{},
ExpectXML: `<service></service>`, ExpectXML: `<service></service>`,
}, },
{ {
Value: Service{Port: &Port{Number: "80"}, Extra1: "A", Extra2: "B"}, Value: &Service{Port: &Port{Number: "80"}, Extra1: "A", Extra2: "B"},
ExpectXML: `<service>` + ExpectXML: `<service>` +
`<host><port>80</port></host>` + `<host><port>80</port></host>` +
`<Extra1>A</Extra1>` + `<Extra1>A</Extra1>` +
`<host><extra2>B</extra2></host>` + `<host><extra2>B</extra2></host>` +
`</service>`, `</service>`,
MarshalOnly: true,
}, },
{ {
Value: Service{Port: &Port{Number: "80"}, Extra2: "example"}, Value: &Service{Port: &Port{Number: "80"}, Extra2: "example"},
ExpectXML: `<service>` + ExpectXML: `<service>` +
`<host><port>80</port></host>` + `<host><port>80</port></host>` +
`<host><extra2>example</extra2></host>` + `<host><extra2>example</extra2></host>` +
`</service>`, `</service>`,
MarshalOnly: true,
},
// Test struct embedding
{
Value: &EmbedA{
EmbedC: EmbedC{
FieldA1: "", // Shadowed by A.A
FieldA2: "", // Shadowed by A.A
FieldB: "A.C.B",
FieldC: "A.C.C",
},
EmbedB: EmbedB{
FieldB: "A.B.B",
EmbedC: EmbedC{
FieldA1: "A.B.C.A1",
FieldA2: "A.B.C.A2",
FieldB: "", // Shadowed by A.B.B
FieldC: "A.B.C.C",
},
},
FieldA: "A.A",
},
ExpectXML: `<EmbedA>` +
`<FieldB>A.C.B</FieldB>` +
`<FieldC>A.C.C</FieldC>` +
`<EmbedB>` +
`<FieldB>A.B.B</FieldB>` +
`<FieldA>` +
`<A1>A.B.C.A1</A1>` +
`<A2>A.B.C.A2</A2>` +
`</FieldA>` +
`<FieldC>A.B.C.C</FieldC>` +
`</EmbedB>` +
`<FieldA>A.A</FieldA>` +
`</EmbedA>`,
},
// Test that name casing matters
{
Value: &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"},
ExpectXML: `<casing Xy="mixedA" XY="upperA"><Xy>mixed</Xy><XY>upper</XY></casing>`,
},
// Test the order in which the XML element name is chosen
{
Value: &NamePrecedence{
FromTag: XMLNameWithoutTag{Value: "A"},
FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "InXMLName"}, Value: "B"},
FromNameTag: XMLNameWithTag{Value: "C"},
InFieldName: "D",
},
ExpectXML: `<Parent>` +
`<InTag><Value>A</Value></InTag>` +
`<InXMLName><Value>B</Value></InXMLName>` +
`<InXMLNameTag><Value>C</Value></InXMLNameTag>` +
`<InFieldName>D</InFieldName>` +
`</Parent>`,
MarshalOnly: true,
},
{
Value: &NamePrecedence{
XMLName: Name{Local: "Parent"},
FromTag: XMLNameWithoutTag{XMLName: Name{Local: "InTag"}, Value: "A"},
FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "FromNameVal"}, Value: "B"},
FromNameTag: XMLNameWithTag{XMLName: Name{Local: "InXMLNameTag"}, Value: "C"},
InFieldName: "D",
},
ExpectXML: `<Parent>` +
`<InTag><Value>A</Value></InTag>` +
`<FromNameVal><Value>B</Value></FromNameVal>` +
`<InXMLNameTag><Value>C</Value></InXMLNameTag>` +
`<InFieldName>D</InFieldName>` +
`</Parent>`,
UnmarshalOnly: true,
},
// Test attributes
{
Value: &AttrTest{
Int: 8,
Lower: 9,
Float: 23.5,
Uint8: 255,
Bool: true,
Str: "s",
},
ExpectXML: `<AttrTest Int="8" int="9" Float="23.5" Uint8="255" Bool="true" Str="s"></AttrTest>`,
},
// Test ",any"
{
ExpectXML: `<a><nested><value>known</value></nested><other><sub>unknown</sub></other></a>`,
Value: &AnyTest{
Nested: "known",
AnyField: AnyHolder{
XMLName: Name{Local: "other"},
XML: "<sub>unknown</sub>",
},
},
UnmarshalOnly: true,
},
{
Value: &AnyTest{Nested: "known", AnyField: AnyHolder{XML: "<unknown/>"}},
ExpectXML: `<a><nested><value>known</value></nested></a>`,
MarshalOnly: true,
},
// Test recursive types.
{
Value: &RecurseA{
A: "a1",
B: &RecurseB{
A: &RecurseA{"a2", nil},
B: "b1",
},
},
ExpectXML: `<RecurseA><A>a1</A><B><A><A>a2</A></A><B>b1</B></B></RecurseA>`,
}, },
} }
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
for idx, test := range marshalTests { for idx, test := range marshalTests {
if test.UnmarshalOnly {
continue
}
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
err := Marshal(buf, test.Value) err := Marshal(buf, test.Value)
if err != nil { if err != nil {
@ -303,9 +540,9 @@ func TestMarshal(t *testing.T) {
} }
if got, want := buf.String(), test.ExpectXML; got != want { if got, want := buf.String(), test.ExpectXML; got != want {
if strings.Contains(want, "\n") { if strings.Contains(want, "\n") {
t.Errorf("#%d: marshal(%#v) - GOT:\n%s\nWANT:\n%s", idx, test.Value, got, want) t.Errorf("#%d: marshal(%#v):\nHAVE:\n%s\nWANT:\n%s", idx, test.Value, got, want)
} else { } else {
t.Errorf("#%d: marshal(%#v) = %#q want %#q", idx, test.Value, got, want) t.Errorf("#%d: marshal(%#v):\nhave %#q\nwant %#q", idx, test.Value, got, want)
} }
} }
} }
@ -334,6 +571,10 @@ var marshalErrorTests = []struct {
Err: "xml: unsupported type: map[*xml.Ship]bool", Err: "xml: unsupported type: map[*xml.Ship]bool",
Kind: reflect.Map, Kind: reflect.Map,
}, },
{
Value: &Domain{Comment: []byte("f--bar")},
Err: `xml: comments must not contain "--"`,
},
} }
func TestMarshalErrors(t *testing.T) { func TestMarshalErrors(t *testing.T) {
@ -341,10 +582,12 @@ func TestMarshalErrors(t *testing.T) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
err := Marshal(buf, test.Value) err := Marshal(buf, test.Value)
if err == nil || err.Error() != test.Err { if err == nil || err.Error() != test.Err {
t.Errorf("#%d: marshal(%#v) = [error] %q, want %q", idx, test.Value, err, test.Err) t.Errorf("#%d: marshal(%#v) = [error] %v, want %v", idx, test.Value, err, test.Err)
} }
if kind := err.(*UnsupportedTypeError).Type.Kind(); kind != test.Kind { if test.Kind != reflect.Invalid {
t.Errorf("#%d: marshal(%#v) = [error kind] %s, want %s", idx, test.Value, kind, test.Kind) if kind := err.(*UnsupportedTypeError).Type.Kind(); kind != test.Kind {
t.Errorf("#%d: marshal(%#v) = [error kind] %s, want %s", idx, test.Value, kind, test.Kind)
}
} }
} }
} }
@ -352,39 +595,20 @@ func TestMarshalErrors(t *testing.T) {
// Do invertibility testing on the various structures that we test // Do invertibility testing on the various structures that we test
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
for i, test := range marshalTests { for i, test := range marshalTests {
// Skip the nil pointers if test.MarshalOnly {
if i <= 1 { continue
continue }
} if _, ok := test.Value.(*Plain); ok {
var dest interface{}
switch test.Value.(type) {
case *Ship, Ship:
dest = &Ship{}
case *Port, Port:
dest = &Port{}
case *Domain, Domain:
dest = &Domain{}
case *Feed, Feed:
dest = &Feed{}
default:
continue continue
} }
vt := reflect.TypeOf(test.Value)
dest := reflect.New(vt.Elem()).Interface()
buffer := bytes.NewBufferString(test.ExpectXML) buffer := bytes.NewBufferString(test.ExpectXML)
err := Unmarshal(buffer, dest) err := Unmarshal(buffer, dest)
// Don't compare XMLNames
switch fix := dest.(type) { switch fix := dest.(type) {
case *Ship:
fix.XMLName = Name{}
case *Port:
fix.XMLName = Name{}
case *Domain:
fix.XMLName = Name{}
case *Feed: case *Feed:
fix.XMLName = Name{}
fix.Author.InnerXML = "" fix.Author.InnerXML = ""
for i := range fix.Entry { for i := range fix.Entry {
fix.Entry[i].Author.InnerXML = "" fix.Entry[i].Author.InnerXML = ""
@ -394,30 +618,23 @@ func TestUnmarshal(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("#%d: unexpected error: %#v", i, err) t.Errorf("#%d: unexpected error: %#v", i, err)
} else if got, want := dest, test.Value; !reflect.DeepEqual(got, want) { } else if got, want := dest, test.Value; !reflect.DeepEqual(got, want) {
t.Errorf("#%d: unmarshal(%q) = %#v, want %#v", i, test.ExpectXML, got, want) t.Errorf("#%d: unmarshal(%q):\nhave %#v\nwant %#v", i, test.ExpectXML, got, want)
} }
} }
} }
func BenchmarkMarshal(b *testing.B) { func BenchmarkMarshal(b *testing.B) {
idx := len(marshalTests) - 1
test := marshalTests[idx]
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Marshal(buf, test.Value) Marshal(buf, atomValue)
buf.Truncate(0) buf.Truncate(0)
} }
} }
func BenchmarkUnmarshal(b *testing.B) { func BenchmarkUnmarshal(b *testing.B) {
idx := len(marshalTests) - 1 xml := []byte(atomXml)
test := marshalTests[idx]
sm := &Ship{}
xml := []byte(test.ExpectXML)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
buffer := bytes.NewBuffer(xml) buffer := bytes.NewBuffer(xml)
Unmarshal(buffer, sm) Unmarshal(buffer, &Feed{})
} }
} }

View File

@ -7,13 +7,10 @@ package xml
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"unicode"
"unicode/utf8"
) )
// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: // BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
@ -31,7 +28,7 @@ import (
// For example, given these definitions: // For example, given these definitions:
// //
// type Email struct { // type Email struct {
// Where string `xml:"attr"` // Where string `xml:",attr"`
// Addr string // Addr string
// } // }
// //
@ -64,7 +61,8 @@ import (
// //
// via Unmarshal(r, &result) is equivalent to assigning // via Unmarshal(r, &result) is equivalent to assigning
// //
// r = Result{xml.Name{"", "result"}, // r = Result{
// xml.Name{Local: "result"},
// "Grace R. Emlin", // name // "Grace R. Emlin", // name
// "phone", // no phone given // "phone", // no phone given
// []Email{ // []Email{
@ -87,9 +85,9 @@ import (
// In the rules, the tag of a field refers to the value associated with the // In the rules, the tag of a field refers to the value associated with the
// key 'xml' in the struct field's tag (see the example above). // key 'xml' in the struct field's tag (see the example above).
// //
// * If the struct has a field of type []byte or string with tag "innerxml", // * If the struct has a field of type []byte or string with tag
// Unmarshal accumulates the raw XML nested inside the element // ",innerxml", Unmarshal accumulates the raw XML nested inside the
// in that field. The rest of the rules still apply. // element in that field. The rest of the rules still apply.
// //
// * If the struct has a field named XMLName of type xml.Name, // * If the struct has a field named XMLName of type xml.Name,
// Unmarshal records the element name in that field. // Unmarshal records the element name in that field.
@ -100,8 +98,9 @@ import (
// returns an error. // returns an error.
// //
// * If the XML element has an attribute whose name matches a // * If the XML element has an attribute whose name matches a
// struct field of type string with tag "attr", Unmarshal records // struct field name with an associated tag containing ",attr" or
// the attribute value in that field. // the explicit name in a struct field tag of the form "name,attr",
// Unmarshal records the attribute value in that field.
// //
// * If the XML element contains character data, that data is // * If the XML element contains character data, that data is
// accumulated in the first struct field that has tag "chardata". // accumulated in the first struct field that has tag "chardata".
@ -109,23 +108,30 @@ import (
// If there is no such field, the character data is discarded. // If there is no such field, the character data is discarded.
// //
// * If the XML element contains comments, they are accumulated in // * If the XML element contains comments, they are accumulated in
// the first struct field that has tag "comments". The struct // the first struct field that has tag ",comments". The struct
// field may have type []byte or string. If there is no such // field may have type []byte or string. If there is no such
// field, the comments are discarded. // field, the comments are discarded.
// //
// * If the XML element contains a sub-element whose name matches // * If the XML element contains a sub-element whose name matches
// the prefix of a tag formatted as "a>b>c", unmarshal // the prefix of a tag formatted as "a" or "a>b>c", unmarshal
// will descend into the XML structure looking for elements with the // will descend into the XML structure looking for elements with the
// given names, and will map the innermost elements to that struct field. // given names, and will map the innermost elements to that struct
// A tag starting with ">" is equivalent to one starting // field. A tag starting with ">" is equivalent to one starting
// with the field name followed by ">". // with the field name followed by ">".
// //
// * If the XML element contains a sub-element whose name // * If the XML element contains a sub-element whose name matches
// matches a field whose tag is neither "attr" nor "chardata", // a struct field's XMLName tag and the struct field has no
// Unmarshal maps the sub-element to that struct field. // explicit name tag as per the previous rule, unmarshal maps
// Otherwise, if the struct has a field named Any, unmarshal // the sub-element to that struct field.
//
// * If the XML element contains a sub-element whose name matches a
// field without any mode flags (",attr", ",chardata", etc), Unmarshal
// maps the sub-element to that struct field. // maps the sub-element to that struct field.
// //
// * If the XML element contains a sub-element that hasn't matched any
// of the above rules and the struct has a field with tag ",any",
// unmarshal maps the sub-element to that struct field.
//
// Unmarshal maps an XML element to a string or []byte by saving the // Unmarshal maps an XML element to a string or []byte by saving the
// concatenation of that element's character data in the string or // concatenation of that element's character data in the string or
// []byte. // []byte.
@ -169,18 +175,6 @@ type UnmarshalError string
func (e UnmarshalError) Error() string { return string(e) } func (e UnmarshalError) Error() string { return string(e) }
// A TagPathError represents an error in the unmarshalling process
// caused by the use of field tags with conflicting paths.
type TagPathError struct {
Struct reflect.Type
Field1, Tag1 string
Field2, Tag2 string
}
func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
}
// The Parser's Unmarshal method is like xml.Unmarshal // The Parser's Unmarshal method is like xml.Unmarshal
// except that it can be passed a pointer to the initial start element, // except that it can be passed a pointer to the initial start element,
// useful when a client reads some raw XML tokens itself // useful when a client reads some raw XML tokens itself
@ -195,26 +189,6 @@ func (p *Parser) Unmarshal(val interface{}, start *StartElement) error {
return p.unmarshal(v.Elem(), start) return p.unmarshal(v.Elem(), start)
} }
// fieldName strips invalid characters from an XML name
// to create a valid Go struct name. It also converts the
// name to lower case letters.
func fieldName(original string) string {
var i int
//remove leading underscores, without exhausting all characters
for i = 0; i < len(original)-1 && original[i] == '_'; i++ {
}
return strings.Map(
func(x rune) rune {
if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) {
return unicode.ToLower(x)
}
return -1
},
original[i:])
}
// Unmarshal a single XML element into val. // Unmarshal a single XML element into val.
func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error { func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
// Find start element if we need it. // Find start element if we need it.
@ -246,15 +220,22 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
saveXML reflect.Value saveXML reflect.Value
saveXMLIndex int saveXMLIndex int
saveXMLData []byte saveXMLData []byte
saveAny reflect.Value
sv reflect.Value sv reflect.Value
styp reflect.Type tinfo *typeInfo
fieldPaths map[string]pathInfo err error
) )
switch v := val; v.Kind() { switch v := val; v.Kind() {
default: default:
return errors.New("unknown type " + v.Type().String()) return errors.New("unknown type " + v.Type().String())
case reflect.Interface:
// TODO: For now, simply ignore the field. In the near
// future we may choose to unmarshal the start
// element on it, if not nil.
return p.Skip()
case reflect.Slice: case reflect.Slice:
typ := v.Type() typ := v.Type()
if typ.Elem().Kind() == reflect.Uint8 { if typ.Elem().Kind() == reflect.Uint8 {
@ -288,75 +269,69 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
saveData = v saveData = v
case reflect.Struct: case reflect.Struct:
if _, ok := v.Interface().(Name); ok {
v.Set(reflect.ValueOf(start.Name))
break
}
sv = v sv = v
typ := sv.Type() typ := sv.Type()
styp = typ tinfo, err = getTypeInfo(typ)
// Assign name. if err != nil {
if f, ok := typ.FieldByName("XMLName"); ok { return err
// Validate element name. }
if tag := f.Tag.Get("xml"); tag != "" {
ns := ""
i := strings.LastIndex(tag, " ")
if i >= 0 {
ns, tag = tag[0:i], tag[i+1:]
}
if tag != start.Name.Local {
return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">")
}
if ns != "" && ns != start.Name.Space {
e := "expected element <" + tag + "> in name space " + ns + " but have "
if start.Name.Space == "" {
e += "no name space"
} else {
e += start.Name.Space
}
return UnmarshalError(e)
}
}
// Save // Validate and assign element name.
v := sv.FieldByIndex(f.Index) if tinfo.xmlname != nil {
if _, ok := v.Interface().(Name); ok { finfo := tinfo.xmlname
v.Set(reflect.ValueOf(start.Name)) if finfo.name != "" && finfo.name != start.Name.Local {
return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">")
}
if finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have "
if start.Name.Space == "" {
e += "no name space"
} else {
e += start.Name.Space
}
return UnmarshalError(e)
}
fv := sv.FieldByIndex(finfo.idx)
if _, ok := fv.Interface().(Name); ok {
fv.Set(reflect.ValueOf(start.Name))
} }
} }
// Assign attributes. // Assign attributes.
// Also, determine whether we need to save character data or comments. // Also, determine whether we need to save character data or comments.
for i, n := 0, typ.NumField(); i < n; i++ { for i := range tinfo.fields {
f := typ.Field(i) finfo := &tinfo.fields[i]
switch f.Tag.Get("xml") { switch finfo.flags & fMode {
case "attr": case fAttr:
strv := sv.FieldByIndex(f.Index) strv := sv.FieldByIndex(finfo.idx)
// Look for attribute. // Look for attribute.
val := "" val := ""
k := strings.ToLower(f.Name)
for _, a := range start.Attr { for _, a := range start.Attr {
if fieldName(a.Name.Local) == k { if a.Name.Local == finfo.name {
val = a.Value val = a.Value
break break
} }
} }
copyValue(strv, []byte(val)) copyValue(strv, []byte(val))
case "comment": case fCharData:
if !saveComment.IsValid() {
saveComment = sv.FieldByIndex(f.Index)
}
case "chardata":
if !saveData.IsValid() { if !saveData.IsValid() {
saveData = sv.FieldByIndex(f.Index) saveData = sv.FieldByIndex(finfo.idx)
} }
case "innerxml": case fComment:
if !saveComment.IsValid() {
saveComment = sv.FieldByIndex(finfo.idx)
}
case fAny:
if !saveAny.IsValid() {
saveAny = sv.FieldByIndex(finfo.idx)
}
case fInnerXml:
if !saveXML.IsValid() { if !saveXML.IsValid() {
saveXML = sv.FieldByIndex(f.Index) saveXML = sv.FieldByIndex(finfo.idx)
if p.saved == nil { if p.saved == nil {
saveXMLIndex = 0 saveXMLIndex = 0
p.saved = new(bytes.Buffer) p.saved = new(bytes.Buffer)
@ -364,24 +339,6 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
saveXMLIndex = p.savedOffset() saveXMLIndex = p.savedOffset()
} }
} }
default:
if tag := f.Tag.Get("xml"); strings.Contains(tag, ">") {
if fieldPaths == nil {
fieldPaths = make(map[string]pathInfo)
}
path := strings.ToLower(tag)
if strings.HasPrefix(tag, ">") {
path = strings.ToLower(f.Name) + path
}
if strings.HasSuffix(tag, ">") {
path = path[:len(path)-1]
}
err := addFieldPath(sv, fieldPaths, path, f.Index)
if err != nil {
return err
}
}
} }
} }
} }
@ -400,44 +357,23 @@ Loop:
} }
switch t := tok.(type) { switch t := tok.(type) {
case StartElement: case StartElement:
// Sub-element. consumed := false
// Look up by tag name.
if sv.IsValid() { if sv.IsValid() {
k := fieldName(t.Name.Local) consumed, err = p.unmarshalPath(tinfo, sv, nil, &t)
if err != nil {
if fieldPaths != nil { return err
if _, found := fieldPaths[k]; found {
if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil {
return err
}
continue Loop
}
} }
if !consumed && saveAny.IsValid() {
match := func(s string) bool { consumed = true
// check if the name matches ignoring case if err := p.unmarshal(saveAny, &t); err != nil {
if strings.ToLower(s) != k {
return false
}
// now check that it's public
c, _ := utf8.DecodeRuneInString(s)
return unicode.IsUpper(c)
}
f, found := styp.FieldByNameFunc(match)
if !found { // fall back to mop-up field named "Any"
f, found = styp.FieldByName("Any")
}
if found {
if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
return err return err
} }
continue Loop
} }
} }
// Not saving sub-element but still have to skip over it. if !consumed {
if err := p.Skip(); err != nil { if err := p.Skip(); err != nil {
return err return err
}
} }
case EndElement: case EndElement:
@ -503,10 +439,10 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
return err == nil return err == nil
} }
// Save accumulated data and comments // Save accumulated data.
switch t := dst; t.Kind() { switch t := dst; t.Kind() {
case reflect.Invalid: case reflect.Invalid:
// Probably a comment, handled below // Probably a comment.
default: default:
return errors.New("cannot happen: unknown type " + t.Type().String()) return errors.New("cannot happen: unknown type " + t.Type().String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -538,70 +474,66 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
return nil return nil
} }
type pathInfo struct { // unmarshalPath walks down an XML structure looking for wanted
fieldIdx []int // paths, and calls unmarshal on them.
complete bool // The consumed result tells whether XML elements have been consumed
} // from the Parser until start's matching end element, or if it's
// still untouched because start is uninteresting for sv's fields.
func (p *Parser) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) {
recurse := false
Loop:
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) {
continue
}
for j := range parents {
if parents[j] != finfo.parents[j] {
continue Loop
}
}
if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
// It's a perfect match, unmarshal the field.
return true, p.unmarshal(sv.FieldByIndex(finfo.idx), start)
}
if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
// It's a prefix for the field. Break and recurse
// since it's not ok for one field path to be itself
// the prefix for another field path.
recurse = true
// addFieldPath takes an element path such as "a>b>c" and fills the // We can reuse the same slice as long as we
// paths map with all paths leading to it ("a", "a>b", and "a>b>c"). // don't try to append to it.
// It is okay for paths to share a common, shorter prefix but not ok parents = finfo.parents[:len(parents)+1]
// for one path to itself be a prefix of another.
func addFieldPath(sv reflect.Value, paths map[string]pathInfo, path string, fieldIdx []int) error {
if info, found := paths[path]; found {
return tagError(sv, info.fieldIdx, fieldIdx)
}
paths[path] = pathInfo{fieldIdx, true}
for {
i := strings.LastIndex(path, ">")
if i < 0 {
break break
} }
path = path[:i]
if info, found := paths[path]; found {
if info.complete {
return tagError(sv, info.fieldIdx, fieldIdx)
}
} else {
paths[path] = pathInfo{fieldIdx, false}
}
} }
return nil if !recurse {
// We have no business with this element.
} return false, nil
func tagError(sv reflect.Value, idx1 []int, idx2 []int) error {
t := sv.Type()
f1 := t.FieldByIndex(idx1)
f2 := t.FieldByIndex(idx2)
return &TagPathError{t, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
}
// unmarshalPaths walks down an XML structure looking for
// wanted paths, and calls unmarshal on them.
func (p *Parser) unmarshalPaths(sv reflect.Value, paths map[string]pathInfo, path string, start *StartElement) error {
if info, _ := paths[path]; info.complete {
return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start)
} }
// The element is not a perfect match for any field, but one
// or more fields have the path to this element as a parent
// prefix. Recurse and attempt to match these.
for { for {
tok, err := p.Token() var tok Token
tok, err = p.Token()
if err != nil { if err != nil {
return err return true, err
} }
switch t := tok.(type) { switch t := tok.(type) {
case StartElement: case StartElement:
k := path + ">" + fieldName(t.Name.Local) consumed2, err := p.unmarshalPath(tinfo, sv, parents, &t)
if _, found := paths[k]; found { if err != nil {
if err := p.unmarshalPaths(sv, paths, k, &t); err != nil { return true, err
return err
}
continue
} }
if err := p.Skip(); err != nil { if !consumed2 {
return err if err := p.Skip(); err != nil {
return true, err
}
} }
case EndElement: case EndElement:
return nil return true, nil
} }
} }
panic("unreachable") panic("unreachable")

View File

@ -6,6 +6,7 @@ package xml
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
) )
@ -13,7 +14,7 @@ import (
func TestUnmarshalFeed(t *testing.T) { func TestUnmarshalFeed(t *testing.T) {
var f Feed var f Feed
if err := Unmarshal(StringReader(atomFeedString), &f); err != nil { if err := Unmarshal(strings.NewReader(atomFeedString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
if !reflect.DeepEqual(f, atomFeed) { if !reflect.DeepEqual(f, atomFeed) {
@ -24,8 +25,8 @@ func TestUnmarshalFeed(t *testing.T) {
// hget http://codereview.appspot.com/rss/mine/rsc // hget http://codereview.appspot.com/rss/mine/rsc
const atomFeedString = ` const atomFeedString = `
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><li-nk href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></li-nk><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub <feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
</title><link hre-f="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html"> </title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
An attempt at adding pubsubhubbub support to Rietveld. An attempt at adding pubsubhubbub support to Rietveld.
http://code.google.com/p/pubsubhubbub http://code.google.com/p/pubsubhubbub
http://code.google.com/p/rietveld/issues/detail?id=155 http://code.google.com/p/rietveld/issues/detail?id=155
@ -78,39 +79,39 @@ not being used from outside intra_region_diff.py.
</summary></entry></feed> ` </summary></entry></feed> `
type Feed struct { type Feed struct {
XMLName Name `xml:"http://www.w3.org/2005/Atom feed"` XMLName Name `xml:"http://www.w3.org/2005/Atom feed"`
Title string Title string `xml:"title"`
Id string Id string `xml:"id"`
Link []Link Link []Link `xml:"link"`
Updated Time Updated Time `xml:"updated"`
Author Person Author Person `xml:"author"`
Entry []Entry Entry []Entry `xml:"entry"`
} }
type Entry struct { type Entry struct {
Title string Title string `xml:"title"`
Id string Id string `xml:"id"`
Link []Link Link []Link `xml:"link"`
Updated Time Updated Time `xml:"updated"`
Author Person Author Person `xml:"author"`
Summary Text Summary Text `xml:"summary"`
} }
type Link struct { type Link struct {
Rel string `xml:"attr"` Rel string `xml:"rel,attr"`
Href string `xml:"attr"` Href string `xml:"href,attr"`
} }
type Person struct { type Person struct {
Name string Name string `xml:"name"`
URI string URI string `xml:"uri"`
Email string Email string `xml:"email"`
InnerXML string `xml:"innerxml"` InnerXML string `xml:",innerxml"`
} }
type Text struct { type Text struct {
Type string `xml:"attr"` Type string `xml:"type,attr"`
Body string `xml:"chardata"` Body string `xml:",chardata"`
} }
type Time string type Time string
@ -213,44 +214,26 @@ not being used from outside intra_region_diff.py.
}, },
} }
type FieldNameTest struct {
in, out string
}
var FieldNameTests = []FieldNameTest{
{"Profile-Image", "profileimage"},
{"_score", "score"},
}
func TestFieldName(t *testing.T) {
for _, tt := range FieldNameTests {
a := fieldName(tt.in)
if a != tt.out {
t.Fatalf("have %#v\nwant %#v\n\n", a, tt.out)
}
}
}
const pathTestString = ` const pathTestString = `
<result> <Result>
<before>1</before> <Before>1</Before>
<items> <Items>
<item1> <Item1>
<value>A</value> <Value>A</Value>
</item1> </Item1>
<item2> <Item2>
<value>B</value> <Value>B</Value>
</item2> </Item2>
<Item1> <Item1>
<Value>C</Value> <Value>C</Value>
<Value>D</Value> <Value>D</Value>
</Item1> </Item1>
<_> <_>
<value>E</value> <Value>E</Value>
</_> </_>
</items> </Items>
<after>2</after> <After>2</After>
</result> </Result>
` `
type PathTestItem struct { type PathTestItem struct {
@ -258,18 +241,18 @@ type PathTestItem struct {
} }
type PathTestA struct { type PathTestA struct {
Items []PathTestItem `xml:">item1"` Items []PathTestItem `xml:">Item1"`
Before, After string Before, After string
} }
type PathTestB struct { type PathTestB struct {
Other []PathTestItem `xml:"items>Item1"` Other []PathTestItem `xml:"Items>Item1"`
Before, After string Before, After string
} }
type PathTestC struct { type PathTestC struct {
Values1 []string `xml:"items>item1>value"` Values1 []string `xml:"Items>Item1>Value"`
Values2 []string `xml:"items>item2>value"` Values2 []string `xml:"Items>Item2>Value"`
Before, After string Before, After string
} }
@ -278,12 +261,12 @@ type PathTestSet struct {
} }
type PathTestD struct { type PathTestD struct {
Other PathTestSet `xml:"items>"` Other PathTestSet `xml:"Items"`
Before, After string Before, After string
} }
type PathTestE struct { type PathTestE struct {
Underline string `xml:"items>_>value"` Underline string `xml:"Items>_>Value"`
Before, After string Before, After string
} }
@ -298,7 +281,7 @@ var pathTests = []interface{}{
func TestUnmarshalPaths(t *testing.T) { func TestUnmarshalPaths(t *testing.T) {
for _, pt := range pathTests { for _, pt := range pathTests {
v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() v := reflect.New(reflect.TypeOf(pt).Elem()).Interface()
if err := Unmarshal(StringReader(pathTestString), v); err != nil { if err := Unmarshal(strings.NewReader(pathTestString), v); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
if !reflect.DeepEqual(v, pt) { if !reflect.DeepEqual(v, pt) {
@ -310,7 +293,7 @@ func TestUnmarshalPaths(t *testing.T) {
type BadPathTestA struct { type BadPathTestA struct {
First string `xml:"items>item1"` First string `xml:"items>item1"`
Other string `xml:"items>item2"` Other string `xml:"items>item2"`
Second string `xml:"items>"` Second string `xml:"items"`
} }
type BadPathTestB struct { type BadPathTestB struct {
@ -319,81 +302,55 @@ type BadPathTestB struct {
Second string `xml:"items>item1>value"` Second string `xml:"items>item1>value"`
} }
type BadPathTestC struct {
First string
Second string `xml:"First"`
}
type BadPathTestD struct {
BadPathEmbeddedA
BadPathEmbeddedB
}
type BadPathEmbeddedA struct {
First string
}
type BadPathEmbeddedB struct {
Second string `xml:"First"`
}
var badPathTests = []struct { var badPathTests = []struct {
v, e interface{} v, e interface{}
}{ }{
{&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}},
{&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}},
{&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}},
{&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}},
} }
func TestUnmarshalBadPaths(t *testing.T) { func TestUnmarshalBadPaths(t *testing.T) {
for _, tt := range badPathTests { for _, tt := range badPathTests {
err := Unmarshal(StringReader(pathTestString), tt.v) err := Unmarshal(strings.NewReader(pathTestString), tt.v)
if !reflect.DeepEqual(err, tt.e) { if !reflect.DeepEqual(err, tt.e) {
t.Fatalf("Unmarshal with %#v didn't fail properly: %#v", tt.v, err) t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e)
} }
} }
} }
func TestUnmarshalAttrs(t *testing.T) {
var f AttrTest
if err := Unmarshal(StringReader(attrString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(f, attrStruct) {
t.Fatalf("have %#v\nwant %#v", f, attrStruct)
}
}
type AttrTest struct {
Test1 Test1
Test2 Test2
}
type Test1 struct {
Int int `xml:"attr"`
Float float64 `xml:"attr"`
Uint8 uint8 `xml:"attr"`
}
type Test2 struct {
Bool bool `xml:"attr"`
}
const attrString = `
<?xml version="1.0" charset="utf-8"?>
<attrtest>
<test1 int="8" float="23.5" uint8="255"/>
<test2 bool="true"/>
</attrtest>
`
var attrStruct = AttrTest{
Test1: Test1{
Int: 8,
Float: 23.5,
Uint8: 255,
},
Test2: Test2{
Bool: true,
},
}
// test data for TestUnmarshalWithoutNameType
const OK = "OK" const OK = "OK"
const withoutNameTypeData = ` const withoutNameTypeData = `
<?xml version="1.0" charset="utf-8"?> <?xml version="1.0" charset="utf-8"?>
<Test3 attr="OK" />` <Test3 Attr="OK" />`
type TestThree struct { type TestThree struct {
XMLName bool `xml:"Test3"` // XMLName field without an xml.Name type XMLName Name `xml:"Test3"`
Attr string `xml:"attr"` Attr string `xml:",attr"`
} }
func TestUnmarshalWithoutNameType(t *testing.T) { func TestUnmarshalWithoutNameType(t *testing.T) {
var x TestThree var x TestThree
if err := Unmarshal(StringReader(withoutNameTypeData), &x); err != nil { if err := Unmarshal(strings.NewReader(withoutNameTypeData), &x); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
if x.Attr != OK { if x.Attr != OK {

View File

@ -0,0 +1,321 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"fmt"
"reflect"
"strings"
"sync"
)
// typeInfo holds details for the xml representation of a type.
type typeInfo struct {
xmlname *fieldInfo
fields []fieldInfo
}
// fieldInfo holds details for the xml representation of a single field.
type fieldInfo struct {
idx []int
name string
xmlns string
flags fieldFlags
parents []string
}
type fieldFlags int
const (
fElement fieldFlags = 1 << iota
fAttr
fCharData
fInnerXml
fComment
fAny
// TODO:
//fIgnore
//fOmitEmpty
fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny
)
var tinfoMap = make(map[reflect.Type]*typeInfo)
var tinfoLock sync.RWMutex
// getTypeInfo returns the typeInfo structure with details necessary
// for marshalling and unmarshalling typ.
func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
tinfoLock.RLock()
tinfo, ok := tinfoMap[typ]
tinfoLock.RUnlock()
if ok {
return tinfo, nil
}
tinfo = &typeInfo{}
if typ.Kind() == reflect.Struct {
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
if f.PkgPath != "" {
continue // Private field
}
// For embedded structs, embed its fields.
if f.Anonymous {
if f.Type.Kind() != reflect.Struct {
continue
}
inner, err := getTypeInfo(f.Type)
if err != nil {
return nil, err
}
for _, finfo := range inner.fields {
finfo.idx = append([]int{i}, finfo.idx...)
if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
return nil, err
}
}
continue
}
finfo, err := structFieldInfo(typ, &f)
if err != nil {
return nil, err
}
if f.Name == "XMLName" {
tinfo.xmlname = finfo
continue
}
// Add the field if it doesn't conflict with other fields.
if err := addFieldInfo(typ, tinfo, finfo); err != nil {
return nil, err
}
}
}
tinfoLock.Lock()
tinfoMap[typ] = tinfo
tinfoLock.Unlock()
return tinfo, nil
}
// structFieldInfo builds and returns a fieldInfo for f.
func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
finfo := &fieldInfo{idx: f.Index}
// Split the tag from the xml namespace if necessary.
tag := f.Tag.Get("xml")
if i := strings.Index(tag, " "); i >= 0 {
finfo.xmlns, tag = tag[:i], tag[i+1:]
}
// Parse flags.
tokens := strings.Split(tag, ",")
if len(tokens) == 1 {
finfo.flags = fElement
} else {
tag = tokens[0]
for _, flag := range tokens[1:] {
switch flag {
case "attr":
finfo.flags |= fAttr
case "chardata":
finfo.flags |= fCharData
case "innerxml":
finfo.flags |= fInnerXml
case "comment":
finfo.flags |= fComment
case "any":
finfo.flags |= fAny
}
}
// Validate the flags used.
switch mode := finfo.flags & fMode; mode {
case 0:
finfo.flags |= fElement
case fAttr, fCharData, fInnerXml, fComment, fAny:
if f.Name != "XMLName" && (tag == "" || mode == fAttr) {
break
}
fallthrough
default:
// This will also catch multiple modes in a single field.
return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
}
// Use of xmlns without a name is not allowed.
if finfo.xmlns != "" && tag == "" {
return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
if f.Name == "XMLName" {
// The XMLName field records the XML element name. Don't
// process it as usual because its name should default to
// empty rather than to the field name.
finfo.name = tag
return finfo, nil
}
if tag == "" {
// If the name part of the tag is completely empty, get
// default from XMLName of underlying struct if feasible,
// or field name otherwise.
if xmlname := lookupXMLName(f.Type); xmlname != nil {
finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
} else {
finfo.name = f.Name
}
return finfo, nil
}
// Prepare field name and parents.
tokens = strings.Split(tag, ">")
if tokens[0] == "" {
tokens[0] = f.Name
}
if tokens[len(tokens)-1] == "" {
return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
}
finfo.name = tokens[len(tokens)-1]
if len(tokens) > 1 {
finfo.parents = tokens[:len(tokens)-1]
}
// If the field type has an XMLName field, the names must match
// so that the behavior of both marshalling and unmarshalling
// is straighforward and unambiguous.
if finfo.flags&fElement != 0 {
ftyp := f.Type
xmlname := lookupXMLName(ftyp)
if xmlname != nil && xmlname.name != finfo.name {
return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
finfo.name, typ, f.Name, xmlname.name, ftyp)
}
}
return finfo, nil
}
// lookupXMLName returns the fieldInfo for typ's XMLName field
// in case it exists and has a valid xml field tag, otherwise
// it returns nil.
func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return nil
}
for i, n := 0, typ.NumField(); i < n; i++ {
f := typ.Field(i)
if f.Name != "XMLName" {
continue
}
finfo, err := structFieldInfo(typ, &f)
if finfo.name != "" && err == nil {
return finfo
}
// Also consider errors as a non-existent field tag
// and let getTypeInfo itself report the error.
break
}
return nil
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}
// addFieldInfo adds finfo to tinfo.fields if there are no
// conflicts, or if conflicts arise from previous fields that were
// obtained from deeper embedded structures than finfo. In the latter
// case, the conflicting entries are dropped.
// A conflict occurs when the path (parent + name) to a field is
// itself a prefix of another path, or when two paths match exactly.
// It is okay for field paths to share a common, shorter prefix.
func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
var conflicts []int
Loop:
// First, figure all conflicts. Most working code will have none.
for i := range tinfo.fields {
oldf := &tinfo.fields[i]
if oldf.flags&fMode != newf.flags&fMode {
continue
}
minl := min(len(newf.parents), len(oldf.parents))
for p := 0; p < minl; p++ {
if oldf.parents[p] != newf.parents[p] {
continue Loop
}
}
if len(oldf.parents) > len(newf.parents) {
if oldf.parents[len(newf.parents)] == newf.name {
conflicts = append(conflicts, i)
}
} else if len(oldf.parents) < len(newf.parents) {
if newf.parents[len(oldf.parents)] == oldf.name {
conflicts = append(conflicts, i)
}
} else {
if newf.name == oldf.name {
conflicts = append(conflicts, i)
}
}
}
// Without conflicts, add the new field and return.
if conflicts == nil {
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// If any conflict is shallower, ignore the new field.
// This matches the Go field resolution on embedding.
for _, i := range conflicts {
if len(tinfo.fields[i].idx) < len(newf.idx) {
return nil
}
}
// Otherwise, if any of them is at the same depth level, it's an error.
for _, i := range conflicts {
oldf := &tinfo.fields[i]
if len(oldf.idx) == len(newf.idx) {
f1 := typ.FieldByIndex(oldf.idx)
f2 := typ.FieldByIndex(newf.idx)
return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
}
}
// Otherwise, the new field is shallower, and thus takes precedence,
// so drop the conflicting fields from tinfo and append the new one.
for c := len(conflicts) - 1; c >= 0; c-- {
i := conflicts[c]
copy(tinfo.fields[i:], tinfo.fields[i+1:])
tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
}
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// A TagPathError represents an error in the unmarshalling process
// caused by the use of field tags with conflicting paths.
type TagPathError struct {
Struct reflect.Type
Field1, Tag1 string
Field2, Tag2 string
}
func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
}

View File

@ -154,36 +154,8 @@ var xmlInput = []string{
"<t>cdata]]></t>", "<t>cdata]]></t>",
} }
type stringReader struct {
s string
off int
}
func (r *stringReader) Read(b []byte) (n int, err error) {
if r.off >= len(r.s) {
return 0, io.EOF
}
for r.off < len(r.s) && n < len(b) {
b[n] = r.s[r.off]
n++
r.off++
}
return
}
func (r *stringReader) ReadByte() (b byte, err error) {
if r.off >= len(r.s) {
return 0, io.EOF
}
b = r.s[r.off]
r.off++
return
}
func StringReader(s string) io.Reader { return &stringReader{s, 0} }
func TestRawToken(t *testing.T) { func TestRawToken(t *testing.T) {
p := NewParser(StringReader(testInput)) p := NewParser(strings.NewReader(testInput))
testRawToken(t, p, rawTokens) testRawToken(t, p, rawTokens)
} }
@ -207,7 +179,7 @@ func (d *downCaser) Read(p []byte) (int, error) {
func TestRawTokenAltEncoding(t *testing.T) { func TestRawTokenAltEncoding(t *testing.T) {
sawEncoding := "" sawEncoding := ""
p := NewParser(StringReader(testInputAltEncoding)) p := NewParser(strings.NewReader(testInputAltEncoding))
p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
sawEncoding = charset sawEncoding = charset
if charset != "x-testing-uppercase" { if charset != "x-testing-uppercase" {
@ -219,7 +191,7 @@ func TestRawTokenAltEncoding(t *testing.T) {
} }
func TestRawTokenAltEncodingNoConverter(t *testing.T) { func TestRawTokenAltEncodingNoConverter(t *testing.T) {
p := NewParser(StringReader(testInputAltEncoding)) p := NewParser(strings.NewReader(testInputAltEncoding))
token, err := p.RawToken() token, err := p.RawToken()
if token == nil { if token == nil {
t.Fatalf("expected a token on first RawToken call") t.Fatalf("expected a token on first RawToken call")
@ -286,7 +258,7 @@ var nestedDirectivesTokens = []Token{
} }
func TestNestedDirectives(t *testing.T) { func TestNestedDirectives(t *testing.T) {
p := NewParser(StringReader(nestedDirectivesInput)) p := NewParser(strings.NewReader(nestedDirectivesInput))
for i, want := range nestedDirectivesTokens { for i, want := range nestedDirectivesTokens {
have, err := p.Token() have, err := p.Token()
@ -300,7 +272,7 @@ func TestNestedDirectives(t *testing.T) {
} }
func TestToken(t *testing.T) { func TestToken(t *testing.T) {
p := NewParser(StringReader(testInput)) p := NewParser(strings.NewReader(testInput))
for i, want := range cookedTokens { for i, want := range cookedTokens {
have, err := p.Token() have, err := p.Token()
@ -315,7 +287,7 @@ func TestToken(t *testing.T) {
func TestSyntax(t *testing.T) { func TestSyntax(t *testing.T) {
for i := range xmlInput { for i := range xmlInput {
p := NewParser(StringReader(xmlInput[i])) p := NewParser(strings.NewReader(xmlInput[i]))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
@ -372,26 +344,26 @@ var all = allScalars{
var sixteen = "16" var sixteen = "16"
const testScalarsInput = `<allscalars> const testScalarsInput = `<allscalars>
<true1>true</true1> <True1>true</True1>
<true2>1</true2> <True2>1</True2>
<false1>false</false1> <False1>false</False1>
<false2>0</false2> <False2>0</False2>
<int>1</int> <Int>1</Int>
<int8>-2</int8> <Int8>-2</Int8>
<int16>3</int16> <Int16>3</Int16>
<int32>-4</int32> <Int32>-4</Int32>
<int64>5</int64> <Int64>5</Int64>
<uint>6</uint> <Uint>6</Uint>
<uint8>7</uint8> <Uint8>7</Uint8>
<uint16>8</uint16> <Uint16>8</Uint16>
<uint32>9</uint32> <Uint32>9</Uint32>
<uint64>10</uint64> <Uint64>10</Uint64>
<uintptr>11</uintptr> <Uintptr>11</Uintptr>
<float>12.0</float> <Float>12.0</Float>
<float32>13.0</float32> <Float32>13.0</Float32>
<float64>14.0</float64> <Float64>14.0</Float64>
<string>15</string> <String>15</String>
<ptrstring>16</ptrstring> <PtrString>16</PtrString>
</allscalars>` </allscalars>`
func TestAllScalars(t *testing.T) { func TestAllScalars(t *testing.T) {
@ -412,7 +384,7 @@ type item struct {
} }
func TestIssue569(t *testing.T) { func TestIssue569(t *testing.T) {
data := `<item><field_a>abcd</field_a></item>` data := `<item><Field_a>abcd</Field_a></item>`
var i item var i item
buf := bytes.NewBufferString(data) buf := bytes.NewBufferString(data)
err := Unmarshal(buf, &i) err := Unmarshal(buf, &i)
@ -424,7 +396,7 @@ func TestIssue569(t *testing.T) {
func TestUnquotedAttrs(t *testing.T) { func TestUnquotedAttrs(t *testing.T) {
data := "<tag attr=azAZ09:-_\t>" data := "<tag attr=azAZ09:-_\t>"
p := NewParser(StringReader(data)) p := NewParser(strings.NewReader(data))
p.Strict = false p.Strict = false
token, err := p.Token() token, err := p.Token()
if _, ok := err.(*SyntaxError); ok { if _, ok := err.(*SyntaxError); ok {
@ -450,7 +422,7 @@ func TestValuelessAttrs(t *testing.T) {
{"<input checked />", "input", "checked"}, {"<input checked />", "input", "checked"},
} }
for _, test := range tests { for _, test := range tests {
p := NewParser(StringReader(test[0])) p := NewParser(strings.NewReader(test[0]))
p.Strict = false p.Strict = false
token, err := p.Token() token, err := p.Token()
if _, ok := err.(*SyntaxError); ok { if _, ok := err.(*SyntaxError); ok {
@ -500,7 +472,7 @@ func TestCopyTokenStartElement(t *testing.T) {
func TestSyntaxErrorLineNum(t *testing.T) { func TestSyntaxErrorLineNum(t *testing.T) {
testInput := "<P>Foo<P>\n\n<P>Bar</>\n" testInput := "<P>Foo<P>\n\n<P>Bar</>\n"
p := NewParser(StringReader(testInput)) p := NewParser(strings.NewReader(testInput))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
@ -515,7 +487,7 @@ func TestSyntaxErrorLineNum(t *testing.T) {
func TestTrailingRawToken(t *testing.T) { func TestTrailingRawToken(t *testing.T) {
input := `<FOO></FOO> ` input := `<FOO></FOO> `
p := NewParser(StringReader(input)) p := NewParser(strings.NewReader(input))
var err error var err error
for _, err = p.RawToken(); err == nil; _, err = p.RawToken() { for _, err = p.RawToken(); err == nil; _, err = p.RawToken() {
} }
@ -526,7 +498,7 @@ func TestTrailingRawToken(t *testing.T) {
func TestTrailingToken(t *testing.T) { func TestTrailingToken(t *testing.T) {
input := `<FOO></FOO> ` input := `<FOO></FOO> `
p := NewParser(StringReader(input)) p := NewParser(strings.NewReader(input))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
@ -537,7 +509,7 @@ func TestTrailingToken(t *testing.T) {
func TestEntityInsideCDATA(t *testing.T) { func TestEntityInsideCDATA(t *testing.T) {
input := `<test><![CDATA[ &val=foo ]]></test>` input := `<test><![CDATA[ &val=foo ]]></test>`
p := NewParser(StringReader(input)) p := NewParser(strings.NewReader(input))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
@ -569,7 +541,7 @@ var characterTests = []struct {
func TestDisallowedCharacters(t *testing.T) { func TestDisallowedCharacters(t *testing.T) {
for i, tt := range characterTests { for i, tt := range characterTests {
p := NewParser(StringReader(tt.in)) p := NewParser(strings.NewReader(tt.in))
var err error var err error
for err == nil { for err == nil {

View File

@ -8,7 +8,7 @@ import "unicode/utf8"
type input interface { type input interface {
skipASCII(p int) int skipASCII(p int) int
skipNonStarter() int skipNonStarter(p int) int
appendSlice(buf []byte, s, e int) []byte appendSlice(buf []byte, s, e int) []byte
copySlice(buf []byte, s, e int) copySlice(buf []byte, s, e int)
charinfo(p int) (uint16, int) charinfo(p int) (uint16, int)
@ -25,8 +25,7 @@ func (s inputString) skipASCII(p int) int {
return p return p
} }
func (s inputString) skipNonStarter() int { func (s inputString) skipNonStarter(p int) int {
p := 0
for ; p < len(s) && !utf8.RuneStart(s[p]); p++ { for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
} }
return p return p
@ -71,8 +70,7 @@ func (s inputBytes) skipASCII(p int) int {
return p return p
} }
func (s inputBytes) skipNonStarter() int { func (s inputBytes) skipNonStarter(p int) int {
p := 0
for ; p < len(s) && !utf8.RuneStart(s[p]); p++ { for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
} }
return p return p

View File

@ -34,24 +34,28 @@ const (
// Bytes returns f(b). May return b if f(b) = b. // Bytes returns f(b). May return b if f(b) = b.
func (f Form) Bytes(b []byte) []byte { func (f Form) Bytes(b []byte) []byte {
n := f.QuickSpan(b) rb := reorderBuffer{}
rb.init(f, b)
n := quickSpan(&rb, 0)
if n == len(b) { if n == len(b) {
return b return b
} }
out := make([]byte, n, len(b)) out := make([]byte, n, len(b))
copy(out, b[0:n]) copy(out, b[0:n])
return f.Append(out, b[n:]...) return doAppend(&rb, out, n)
} }
// String returns f(s). // String returns f(s).
func (f Form) String(s string) string { func (f Form) String(s string) string {
n := f.QuickSpanString(s) rb := reorderBuffer{}
rb.initString(f, s)
n := quickSpan(&rb, 0)
if n == len(s) { if n == len(s) {
return s return s
} }
out := make([]byte, 0, len(s)) out := make([]byte, n, len(s))
copy(out, s[0:n]) copy(out, s[0:n])
return string(f.AppendString(out, s[n:])) return string(doAppend(&rb, out, n))
} }
// IsNormal returns true if b == f(b). // IsNormal returns true if b == f(b).
@ -122,23 +126,27 @@ func (f Form) IsNormalString(s string) bool {
// patchTail fixes a case where a rune may be incorrectly normalized // patchTail fixes a case where a rune may be incorrectly normalized
// if it is followed by illegal continuation bytes. It returns the // if it is followed by illegal continuation bytes. It returns the
// patched buffer and the number of trailing continuation bytes that // patched buffer and whether there were trailing continuation bytes.
// have been dropped. func patchTail(rb *reorderBuffer, buf []byte) ([]byte, bool) {
func patchTail(rb *reorderBuffer, buf []byte) ([]byte, int) {
info, p := lastRuneStart(&rb.f, buf) info, p := lastRuneStart(&rb.f, buf)
if p == -1 || info.size == 0 { if p == -1 || info.size == 0 {
return buf, 0 return buf, false
} }
end := p + int(info.size) end := p + int(info.size)
extra := len(buf) - end extra := len(buf) - end
if extra > 0 { if extra > 0 {
// Potentially allocating memory. However, this only
// happens with ill-formed UTF-8.
x := make([]byte, 0)
x = append(x, buf[len(buf)-extra:]...)
buf = decomposeToLastBoundary(rb, buf[:end]) buf = decomposeToLastBoundary(rb, buf[:end])
if rb.f.composing { if rb.f.composing {
rb.compose() rb.compose()
} }
return rb.flush(buf), extra buf = rb.flush(buf)
return append(buf, x...), true
} }
return buf, 0 return buf, false
} }
func appendQuick(rb *reorderBuffer, dst []byte, i int) ([]byte, int) { func appendQuick(rb *reorderBuffer, dst []byte, i int) ([]byte, int) {
@ -157,23 +165,23 @@ func (f Form) Append(out []byte, src ...byte) []byte {
} }
rb := reorderBuffer{} rb := reorderBuffer{}
rb.init(f, src) rb.init(f, src)
return doAppend(&rb, out) return doAppend(&rb, out, 0)
} }
func doAppend(rb *reorderBuffer, out []byte) []byte { func doAppend(rb *reorderBuffer, out []byte, p int) []byte {
src, n := rb.src, rb.nsrc src, n := rb.src, rb.nsrc
doMerge := len(out) > 0 doMerge := len(out) > 0
p := 0 if q := src.skipNonStarter(p); q > p {
if p = src.skipNonStarter(); p > 0 {
// Move leading non-starters to destination. // Move leading non-starters to destination.
out = src.appendSlice(out, 0, p) out = src.appendSlice(out, p, q)
buf, ndropped := patchTail(rb, out) buf, endsInError := patchTail(rb, out)
if ndropped > 0 { if endsInError {
out = src.appendSlice(buf, p-ndropped, p) out = buf
doMerge = false // no need to merge, ends with illegal UTF-8 doMerge = false // no need to merge, ends with illegal UTF-8
} else { } else {
out = decomposeToLastBoundary(rb, buf) // force decomposition out = decomposeToLastBoundary(rb, buf) // force decomposition
} }
p = q
} }
fd := &rb.f fd := &rb.f
if doMerge { if doMerge {
@ -217,7 +225,7 @@ func (f Form) AppendString(out []byte, src string) []byte {
} }
rb := reorderBuffer{} rb := reorderBuffer{}
rb.initString(f, src) rb.initString(f, src)
return doAppend(&rb, out) return doAppend(&rb, out, 0)
} }
// QuickSpan returns a boundary n such that b[0:n] == f(b[0:n]). // QuickSpan returns a boundary n such that b[0:n] == f(b[0:n]).
@ -225,7 +233,8 @@ func (f Form) AppendString(out []byte, src string) []byte {
func (f Form) QuickSpan(b []byte) int { func (f Form) QuickSpan(b []byte) int {
rb := reorderBuffer{} rb := reorderBuffer{}
rb.init(f, b) rb.init(f, b)
return quickSpan(&rb, 0) n := quickSpan(&rb, 0)
return n
} }
func quickSpan(rb *reorderBuffer, i int) int { func quickSpan(rb *reorderBuffer, i int) int {
@ -301,7 +310,7 @@ func (f Form) FirstBoundary(b []byte) int {
func firstBoundary(rb *reorderBuffer) int { func firstBoundary(rb *reorderBuffer) int {
src, nsrc := rb.src, rb.nsrc src, nsrc := rb.src, rb.nsrc
i := src.skipNonStarter() i := src.skipNonStarter(0)
if i >= nsrc { if i >= nsrc {
return -1 return -1
} }

View File

@ -253,7 +253,7 @@ var quickSpanNFDTests = []PositionTest{
{"\u0316\u0300cd", 6, ""}, {"\u0316\u0300cd", 6, ""},
{"\u043E\u0308b", 5, ""}, {"\u043E\u0308b", 5, ""},
// incorrectly ordered combining characters // incorrectly ordered combining characters
{"ab\u0300\u0316", 1, ""}, // TODO(mpvl): we could skip 'b' as well. {"ab\u0300\u0316", 1, ""}, // TODO: we could skip 'b' as well.
{"ab\u0300\u0316cd", 1, ""}, {"ab\u0300\u0316cd", 1, ""},
// Hangul // Hangul
{"같은", 0, ""}, {"같은", 0, ""},
@ -465,6 +465,7 @@ var appendTests = []AppendTest{
{"\u0300", "\xFC\x80\x80\x80\x80\x80\u0300", "\u0300\xFC\x80\x80\x80\x80\x80\u0300"}, {"\u0300", "\xFC\x80\x80\x80\x80\x80\u0300", "\u0300\xFC\x80\x80\x80\x80\x80\u0300"},
{"\xF8\x80\x80\x80\x80\u0300", "\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"}, {"\xF8\x80\x80\x80\x80\u0300", "\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
{"\xFC\x80\x80\x80\x80\x80\u0300", "\u0300", "\xFC\x80\x80\x80\x80\x80\u0300\u0300"}, {"\xFC\x80\x80\x80\x80\x80\u0300", "\u0300", "\xFC\x80\x80\x80\x80\x80\u0300\u0300"},
{"\xF8\x80\x80\x80", "\x80\u0300\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
} }
func appendF(f Form, out []byte, s string) []byte { func appendF(f Form, out []byte, s string) []byte {
@ -475,9 +476,23 @@ func appendStringF(f Form, out []byte, s string) []byte {
return f.AppendString(out, s) return f.AppendString(out, s)
} }
func bytesF(f Form, out []byte, s string) []byte {
buf := []byte{}
buf = append(buf, out...)
buf = append(buf, s...)
return f.Bytes(buf)
}
func stringF(f Form, out []byte, s string) []byte {
outs := string(out) + s
return []byte(f.String(outs))
}
func TestAppend(t *testing.T) { func TestAppend(t *testing.T) {
runAppendTests(t, "TestAppend", NFKC, appendF, appendTests) runAppendTests(t, "TestAppend", NFKC, appendF, appendTests)
runAppendTests(t, "TestAppendString", NFKC, appendStringF, appendTests) runAppendTests(t, "TestAppendString", NFKC, appendStringF, appendTests)
runAppendTests(t, "TestBytes", NFKC, bytesF, appendTests)
runAppendTests(t, "TestString", NFKC, stringF, appendTests)
} }
func doFormBenchmark(b *testing.B, f Form, s string) { func doFormBenchmark(b *testing.B, f Form, s string) {

View File

@ -27,7 +27,7 @@ func (w *normWriter) Write(data []byte) (n int, err error) {
} }
w.rb.src = inputBytes(data[:m]) w.rb.src = inputBytes(data[:m])
w.rb.nsrc = m w.rb.nsrc = m
w.buf = doAppend(&w.rb, w.buf) w.buf = doAppend(&w.rb, w.buf, 0)
data = data[m:] data = data[m:]
n += m n += m
@ -101,7 +101,7 @@ func (r *normReader) Read(p []byte) (int, error) {
r.rb.src = inputBytes(r.inbuf[0:n]) r.rb.src = inputBytes(r.inbuf[0:n])
r.rb.nsrc, r.err = n, err r.rb.nsrc, r.err = n, err
if n > 0 { if n > 0 {
r.outbuf = doAppend(&r.rb, r.outbuf) r.outbuf = doAppend(&r.rb, r.outbuf, 0)
} }
if err == io.EOF { if err == io.EOF {
r.lastBoundary = len(r.outbuf) r.lastBoundary = len(r.outbuf)

View File

@ -0,0 +1,18 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
)
type direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var Direct = direct{}
func (direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}

View File

@ -0,0 +1,140 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"strings"
)
// A PerHost directs connections to a default Dailer unless the hostname
// requested matches one of a number of exceptions.
type PerHost struct {
def, bypass Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
return &PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the network net through either
// defaultDialer or bypass.
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *PerHost) dialerForRequest(host string) Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone "example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifing hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a hostname
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddIP specifies an IP range that will use the bypass proxy. Note that this
// will only take effect if a literal IP address is dialed. A connection to a
// named host will never match.
func (p *PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a hostname that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}

View File

@ -0,0 +1,55 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"errors"
"net"
"reflect"
"testing"
)
type recordingProxy struct {
addrs []string
}
func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
r.addrs = append(r.addrs, addr)
return nil, errors.New("recordingProxy")
}
func TestPerHost(t *testing.T) {
var def, bypass recordingProxy
perHost := NewPerHost(&def, &bypass)
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
expectedDef := []string{
"example.com:123",
"1.2.3.4:123",
"[1001::]:123",
}
expectedBypass := []string{
"localhost:123",
"zone:123",
"foo.zone:123",
"127.0.0.1:123",
"10.1.2.3:123",
"[1000::]:123",
}
for _, addr := range expectedDef {
perHost.Dial("tcp", addr)
}
for _, addr := range expectedBypass {
perHost.Dial("tcp", addr)
}
if !reflect.DeepEqual(expectedDef, def.addrs) {
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
}
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
}
}

View File

@ -0,0 +1,98 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package proxy provides support for a variety of protocols to proxy network
// data.
package proxy
import (
"errors"
"net"
"net/url"
"os"
"strings"
)
// A Dialer is a means to establish a connection.
type Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type Auth struct {
User, Password string
}
// DefaultDialer returns the dialer specified by the proxy related variables in
// the environment.
func FromEnvironment() Dialer {
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
return Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return Direct
}
proxy, err := FromURL(proxyURL, Direct)
if err != nil {
return Direct
}
noProxy := os.Getenv("no_proxy")
if len(noProxy) == 0 {
return proxy
}
perHost := NewPerHost(proxy, Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
if proxySchemes == nil {
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
}
proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
var auth *Auth
if len(u.RawUserinfo) > 0 {
auth = new(Auth)
parts := strings.SplitN(u.RawUserinfo, ":", 1)
if len(parts) == 1 {
auth.User = parts[0]
} else if len(parts) >= 2 {
auth.User = parts[0]
auth.Password = parts[1]
}
}
switch u.Scheme {
case "socks5":
return SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxySchemes != nil {
if f, ok := proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}

View File

@ -0,0 +1,50 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"net/url"
"testing"
)
type testDialer struct {
network, addr string
}
func (t *testDialer) Dial(network, addr string) (net.Conn, error) {
t.network = network
t.addr = addr
return nil, t
}
func (t *testDialer) Error() string {
return "testDialer " + t.network + " " + t.addr
}
func TestFromURL(t *testing.T) {
u, err := url.Parse("socks5://user:password@1.2.3.4:5678")
if err != nil {
t.Fatalf("failed to parse URL: %s", err)
}
tp := &testDialer{}
proxy, err := FromURL(u, tp)
if err != nil {
t.Fatalf("FromURL failed: %s", err)
}
conn, err := proxy.Dial("tcp", "example.com:80")
if conn != nil {
t.Error("Dial unexpected didn't return an error")
}
if tp, ok := err.(*testDialer); ok {
if tp.network != "tcp" || tp.addr != "1.2.3.4:5678" {
t.Errorf("Dialer connected to wrong host. Wanted 1.2.3.4:5678, got: %v", tp)
}
} else {
t.Errorf("Unexpected error from Dial: %s", err)
}
}

View File

@ -0,0 +1,207 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"errors"
"io"
"net"
"strconv"
)
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928.
func SOCKS5(network, addr string, auth *Auth, forward Dialer) (Dialer, error) {
s := &socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type socks5 struct {
user, password string
network, addr string
forward Dialer
}
const socks5Version = 5
const (
socks5AuthNone = 0
socks5AuthPassword = 2
)
const socks5Connect = 1
const (
socks5IP4 = 1
socks5Domain = 3
socks5IP6 = 4
)
var socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the network net via the SOCKS5 proxy.
func (s *socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
break
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
closeConn := &conn
defer func() {
if closeConn != nil {
(*closeConn).Close()
}
}()
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return nil, errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2, /* num auth methods */ socks5AuthNone, socks5AuthPassword)
} else {
buf = append(buf, 1, /* num auth methods */ socks5AuthNone)
}
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
if buf[1] == socks5AuthPassword {
buf = buf[:0]
buf = append(buf, socks5Version)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, socks5Version, socks5Connect, 0 /* reserved */ )
if ip := net.ParseIP(host); ip != nil {
if len(ip) == 4 {
buf = append(buf, socks5IP4)
} else {
buf = append(buf, socks5IP6)
}
buf = append(buf, []byte(ip)...)
} else {
buf = append(buf, socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:4]); err != nil {
return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(socks5Errors) {
failure = socks5Errors[buf[1]]
}
if len(failure) > 0 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case socks5IP4:
bytesToDiscard = 4
case socks5IP6:
bytesToDiscard = 16
case socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err = io.ReadFull(conn, buf); err != nil {
return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
closeConn = nil
return conn, nil
}

View File

@ -8,8 +8,11 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"time"
) )
var someTime = time.Unix(123, 0)
type conversionTest struct { type conversionTest struct {
s, d interface{} // source and destination s, d interface{} // source and destination
@ -19,6 +22,7 @@ type conversionTest struct {
wantstr string wantstr string
wantf32 float32 wantf32 float32
wantf64 float64 wantf64 float64
wanttime time.Time
wantbool bool // used if d is of type *bool wantbool bool // used if d is of type *bool
wanterr string wanterr string
} }
@ -35,12 +39,14 @@ var (
scanbool bool scanbool bool
scanf32 float32 scanf32 float32
scanf64 float64 scanf64 float64
scantime time.Time
) )
var conversionTests = []conversionTest{ var conversionTests = []conversionTest{
// Exact conversions (destination pointer type matches source type) // Exact conversions (destination pointer type matches source type)
{s: "foo", d: &scanstr, wantstr: "foo"}, {s: "foo", d: &scanstr, wantstr: "foo"},
{s: 123, d: &scanint, wantint: 123}, {s: 123, d: &scanint, wantint: 123},
{s: someTime, d: &scantime, wanttime: someTime},
// To strings // To strings
{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 {
return *(ptr.(*float32)) return *(ptr.(*float32))
} }
func timeValue(ptr interface{}) time.Time {
return *(ptr.(*time.Time))
}
func TestConversions(t *testing.T) { func TestConversions(t *testing.T) {
for n, ct := range conversionTests { for n, ct := range conversionTests {
err := convertAssign(ct.d, ct.s) err := convertAssign(ct.d, ct.s)
@ -138,6 +148,9 @@ func TestConversions(t *testing.T) {
if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
errf("want bool %v, got %v", ct.wantbool, *bp) errf("want bool %v, got %v", ct.wantbool, *bp)
} }
if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
}
} }
} }

View File

@ -16,6 +16,7 @@
// nil // nil
// []byte // []byte
// string [*] everywhere except from Rows.Next. // string [*] everywhere except from Rows.Next.
// time.Time
// //
package driver package driver

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"time"
) )
// ValueConverter is the interface providing the ConvertValue method. // ValueConverter is the interface providing the ConvertValue method.
@ -39,7 +40,7 @@ type ValueConverter interface {
// 1 is true // 1 is true
// 0 is false, // 0 is false,
// other integers are an error // other integers are an error
// - for strings and []byte, same rules as strconv.Atob // - for strings and []byte, same rules as strconv.ParseBool
// - all other types are an error // - all other types are an error
var Bool boolType var Bool boolType
@ -143,9 +144,10 @@ func (stringType) ConvertValue(v interface{}) (interface{}, error) {
// bool // bool
// nil // nil
// []byte // []byte
// time.Time
// string // string
// //
// This is the ame list as IsScanSubsetType, with the addition of // This is the same list as IsScanSubsetType, with the addition of
// string. // string.
func IsParameterSubsetType(v interface{}) bool { func IsParameterSubsetType(v interface{}) bool {
if IsScanSubsetType(v) { if IsScanSubsetType(v) {
@ -165,6 +167,7 @@ func IsParameterSubsetType(v interface{}) bool {
// bool // bool
// nil // nil
// []byte // []byte
// time.Time
// //
// This is the same list as IsParameterSubsetType, without string. // This is the same list as IsParameterSubsetType, without string.
func IsScanSubsetType(v interface{}) bool { func IsScanSubsetType(v interface{}) bool {
@ -172,7 +175,7 @@ func IsScanSubsetType(v interface{}) bool {
return true return true
} }
switch v.(type) { switch v.(type) {
case int64, float64, []byte, bool: case int64, float64, []byte, bool, time.Time:
return true return true
} }
return false return false

View File

@ -7,6 +7,7 @@ package driver
import ( import (
"reflect" "reflect"
"testing" "testing"
"time"
) )
type valueConverterTest struct { type valueConverterTest struct {
@ -16,6 +17,8 @@ type valueConverterTest struct {
err string err string
} }
var now = time.Now()
var valueConverterTests = []valueConverterTest{ var valueConverterTests = []valueConverterTest{
{Bool, "true", true, ""}, {Bool, "true", true, ""},
{Bool, "True", true, ""}, {Bool, "True", true, ""},
@ -33,6 +36,7 @@ var valueConverterTests = []valueConverterTest{
{Bool, uint16(0), false, ""}, {Bool, uint16(0), false, ""},
{c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
{c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
{DefaultParameterConverter, now, now, ""},
} }
func TestValueConverters(t *testing.T) { func TestValueConverters(t *testing.T) {

View File

@ -12,6 +12,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"exp/sql/driver" "exp/sql/driver"
) )
@ -77,6 +78,17 @@ type fakeConn struct {
db *fakeDB // where to return ourselves to db *fakeDB // where to return ourselves to
currTx *fakeTx currTx *fakeTx
// Stats for tests:
mu sync.Mutex
stmtsMade int
stmtsClosed int
}
func (c *fakeConn) incrStat(v *int) {
c.mu.Lock()
*v++
c.mu.Unlock()
} }
type fakeTx struct { type fakeTx struct {
@ -110,25 +122,34 @@ func init() {
// Supports dsn forms: // Supports dsn forms:
// <dbname> // <dbname>
// <dbname>;wipe // <dbname>;<opts> (no currently supported options)
func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
d.mu.Lock()
defer d.mu.Unlock()
d.openCount++
if d.dbs == nil {
d.dbs = make(map[string]*fakeDB)
}
parts := strings.Split(dsn, ";") parts := strings.Split(dsn, ";")
if len(parts) < 1 { if len(parts) < 1 {
return nil, errors.New("fakedb: no database name") return nil, errors.New("fakedb: no database name")
} }
name := parts[0] name := parts[0]
db := d.getDB(name)
d.mu.Lock()
d.openCount++
d.mu.Unlock()
return &fakeConn{db: db}, nil
}
func (d *fakeDriver) getDB(name string) *fakeDB {
d.mu.Lock()
defer d.mu.Unlock()
if d.dbs == nil {
d.dbs = make(map[string]*fakeDB)
}
db, ok := d.dbs[name] db, ok := d.dbs[name]
if !ok { if !ok {
db = &fakeDB{name: name} db = &fakeDB{name: name}
d.dbs[name] = db d.dbs[name] = db
} }
return &fakeConn{db: db}, nil return db
} }
func (db *fakeDB) wipe() { func (db *fakeDB) wipe() {
@ -200,7 +221,7 @@ func (c *fakeConn) Close() error {
func checkSubsetTypes(args []interface{}) error { func checkSubsetTypes(args []interface{}) error {
for n, arg := range args { for n, arg := range args {
switch arg.(type) { switch arg.(type) {
case int64, float64, bool, nil, []byte, string: case int64, float64, bool, nil, []byte, string, time.Time:
default: default:
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
} }
@ -297,6 +318,8 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
switch ctype { switch ctype {
case "string": case "string":
subsetVal = []byte(value) subsetVal = []byte(value)
case "blob":
subsetVal = []byte(value)
case "int32": case "int32":
i, err := strconv.Atoi(value) i, err := strconv.Atoi(value)
if err != nil { if err != nil {
@ -327,6 +350,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
cmd := parts[0] cmd := parts[0]
parts = parts[1:] parts = parts[1:]
stmt := &fakeStmt{q: query, c: c, cmd: cmd} stmt := &fakeStmt{q: query, c: c, cmd: cmd}
c.incrStat(&c.stmtsMade)
switch cmd { switch cmd {
case "WIPE": case "WIPE":
// Nothing // Nothing
@ -347,7 +371,10 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
} }
func (s *fakeStmt) Close() error { func (s *fakeStmt) Close() error {
s.closed = true if !s.closed {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
}
return nil return nil
} }
@ -501,9 +528,19 @@ type rowsCursor struct {
pos int pos int
rows []*row rows []*row
closed bool closed bool
// a clone of slices to give out to clients, indexed by the
// the original slice's first byte address. we clone them
// just so we're able to corrupt them on close.
bytesClone map[*byte][]byte
} }
func (rc *rowsCursor) Close() error { func (rc *rowsCursor) Close() error {
if !rc.closed {
for _, bs := range rc.bytesClone {
bs[0] = 255 // first byte corrupted
}
}
rc.closed = true rc.closed = true
return nil return nil
} }
@ -528,6 +565,19 @@ func (rc *rowsCursor) Next(dest []interface{}) error {
// for ease of drivers, and to prevent drivers from // for ease of drivers, and to prevent drivers from
// messing up conversions or doing them differently. // messing up conversions or doing them differently.
dest[i] = v dest[i] = v
if bs, ok := v.([]byte); ok {
if rc.bytesClone == nil {
rc.bytesClone = make(map[*byte][]byte)
}
clone, ok := rc.bytesClone[&bs[0]]
if !ok {
clone = make([]byte, len(bs))
copy(clone, bs)
rc.bytesClone[&bs[0]] = clone
}
dest[i] = clone
}
} }
return nil return nil
} }
@ -540,6 +590,8 @@ func converterForType(typ string) driver.ValueConverter {
return driver.Int32 return driver.Int32
case "string": case "string":
return driver.String return driver.String
case "datetime":
return driver.DefaultParameterConverter
} }
panic("invalid fakedb column type of " + typ) panic("invalid fakedb column type of " + typ)
} }

View File

@ -243,8 +243,13 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer stmt.Close() rows, err := stmt.Query(args...)
return stmt.Query(args...) if err != nil {
stmt.Close()
return nil, err
}
rows.closeStmt = stmt
return rows, nil
} }
// QueryRow executes a query that is expected to return at most one row. // QueryRow executes a query that is expected to return at most one row.
@ -549,8 +554,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
// statement, a function to call to release the connection, and a // statement, a function to call to release the connection, and a
// statement bound to that connection. // statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
if s.stickyErr != nil { if err = s.stickyErr; err != nil {
return nil, nil, nil, s.stickyErr return
} }
s.mu.Lock() s.mu.Lock()
if s.closed { if s.closed {
@ -706,9 +711,10 @@ type Rows struct {
releaseConn func() releaseConn func()
rowsi driver.Rows rowsi driver.Rows
closed bool closed bool
lastcols []interface{} lastcols []interface{}
lasterr error lasterr error
closeStmt *Stmt // if non-nil, statement to Close on close
} }
// Next prepares the next result row for reading with the Scan method. // Next prepares the next result row for reading with the Scan method.
@ -726,6 +732,9 @@ func (rs *Rows) Next() bool {
rs.lastcols = make([]interface{}, len(rs.rowsi.Columns())) rs.lastcols = make([]interface{}, len(rs.rowsi.Columns()))
} }
rs.lasterr = rs.rowsi.Next(rs.lastcols) rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr == io.EOF {
rs.Close()
}
return rs.lasterr == nil return rs.lasterr == nil
} }
@ -786,6 +795,9 @@ func (rs *Rows) Close() error {
rs.closed = true rs.closed = true
err := rs.rowsi.Close() err := rs.rowsi.Close()
rs.releaseConn() rs.releaseConn()
if rs.closeStmt != nil {
rs.closeStmt.Close()
}
return err return err
} }
@ -800,10 +812,6 @@ type Row struct {
// pointed at by dest. If more than one row matches the query, // pointed at by dest. If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches // Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns ErrNoRows. // the query, Scan returns ErrNoRows.
//
// If dest contains pointers to []byte, the slices should not be
// modified and should only be considered valid until the next call to
// Next or Scan.
func (r *Row) Scan(dest ...interface{}) error { func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil { if r.err != nil {
return r.err return r.err
@ -812,7 +820,33 @@ func (r *Row) Scan(dest ...interface{}) error {
if !r.rows.Next() { if !r.rows.Next() {
return ErrNoRows return ErrNoRows
} }
return r.rows.Scan(dest...) err := r.rows.Scan(dest...)
if err != nil {
return err
}
// TODO(bradfitz): for now we need to defensively clone all
// []byte that the driver returned, since we're about to close
// the Rows in our defer, when we return from this function.
// the contract with the driver.Next(...) interface is that it
// can return slices into read-only temporary memory that's
// only valid until the next Scan/Close. But the TODO is that
// for a lot of drivers, this copy will be unnecessary. We
// should provide an optional interface for drivers to
// implement to say, "don't worry, the []bytes that I return
// from Next will not be modified again." (for instance, if
// they were obtained from the network anyway) But for now we
// don't care.
for _, dp := range dest {
b, ok := dp.(*[]byte)
if !ok {
continue
}
clone := make([]byte, len(*b))
copy(clone, *b)
*b = clone
}
return nil
} }
// A Result summarizes an executed SQL command. // A Result summarizes an executed SQL command.

View File

@ -8,10 +8,15 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"time"
) )
const fakeDBName = "foo"
var chrisBirthday = time.Unix(123456789, 0)
func newTestDB(t *testing.T, name string) *DB { func newTestDB(t *testing.T, name string) *DB {
db, err := Open("test", "foo") db, err := Open("test", fakeDBName)
if err != nil { if err != nil {
t.Fatalf("Open: %v", err) t.Fatalf("Open: %v", err)
} }
@ -19,10 +24,10 @@ func newTestDB(t *testing.T, name string) *DB {
t.Fatalf("exec wipe: %v", err) t.Fatalf("exec wipe: %v", err)
} }
if name == "people" { if name == "people" {
exec(t, db, "CREATE|people|name=string,age=int32,dead=bool") exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
exec(t, db, "INSERT|people|name=Alice,age=?", 1) exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
exec(t, db, "INSERT|people|name=Bob,age=?", 2) exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
exec(t, db, "INSERT|people|name=Chris,age=?", 3) exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
} }
return db return db
} }
@ -73,6 +78,12 @@ func TestQuery(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Logf(" got: %#v\nwant: %#v", got, want) t.Logf(" got: %#v\nwant: %#v", got, want)
} }
// And verify that the final rows.Next() call, which hit EOF,
// also closed the rows connection.
if n := len(db.freeConn); n != 1 {
t.Errorf("free conns after query hitting EOF = %d; want 1", n)
}
} }
func TestRowsColumns(t *testing.T) { func TestRowsColumns(t *testing.T) {
@ -97,12 +108,18 @@ func TestQueryRow(t *testing.T) {
defer closeDB(t, db) defer closeDB(t, db)
var name string var name string
var age int var age int
var birthday time.Time
err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age) err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") { if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
t.Errorf("expected error from wrong number of arguments; actually got: %v", err) t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
} }
err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
if err != nil || !birthday.Equal(chrisBirthday) {
t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
}
err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name) err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
if err != nil { if err != nil {
t.Fatalf("age QueryRow+Scan: %v", err) t.Fatalf("age QueryRow+Scan: %v", err)
@ -124,6 +141,16 @@ func TestQueryRow(t *testing.T) {
if age != 1 { if age != 1 {
t.Errorf("expected age 1, got %d", age) t.Errorf("expected age 1, got %d", age)
} }
var photo []byte
err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
if err != nil {
t.Fatalf("photo QueryRow+Scan: %v", err)
}
want := []byte("APHOTO")
if !reflect.DeepEqual(photo, want) {
t.Errorf("photo = %q; want %q", photo, want)
}
} }
func TestStatementErrorAfterClose(t *testing.T) { func TestStatementErrorAfterClose(t *testing.T) {
@ -258,3 +285,21 @@ func TestIssue2542Deadlock(t *testing.T) {
} }
} }
} }
func TestQueryRowClosingStmt(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
var name string
var age int
err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
if err != nil {
t.Fatal(err)
}
if len(db.freeConn) != 1 {
t.Fatalf("expected 1 free conn")
}
fakeConn := db.freeConn[0].(*fakeConn)
if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
t.Logf("statement close mismatch: made %d, closed %d", made, closed)
}
}

View File

@ -420,27 +420,37 @@ type chanWriter struct {
} }
// Write writes data to the remote process's standard input. // Write writes data to the remote process's standard input.
func (w *chanWriter) Write(data []byte) (n int, err error) { func (w *chanWriter) Write(data []byte) (written int, err error) {
for { for len(data) > 0 {
if w.rwin == 0 { for w.rwin < 1 {
win, ok := <-w.win win, ok := <-w.win
if !ok { if !ok {
return 0, io.EOF return 0, io.EOF
} }
w.rwin += win w.rwin += win
continue
} }
n := min(len(data), w.rwin)
peersId := w.clientChan.peersId peersId := w.clientChan.peersId
n = len(data) packet := []byte{
packet := make([]byte, 0, 9+n) msgChannelData,
packet = append(packet, msgChannelData, byte(peersId >> 24), byte(peersId >> 16), byte(peersId >> 8), byte(peersId),
byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId), byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) }
err = w.clientChan.writePacket(append(packet, data...)) if err = w.clientChan.writePacket(append(packet, data[:n]...)); err != nil {
break
}
data = data[n:]
w.rwin -= n w.rwin -= n
return written += n
} }
panic("unreachable") return
}
func min(a, b int) int {
if a < b {
return a
}
return b
} }
func (w *chanWriter) Close() error { func (w *chanWriter) Close() error {

View File

@ -14,7 +14,7 @@ others.
An SSH server is represented by a ServerConfig, which holds certificate An SSH server is represented by a ServerConfig, which holds certificate
details and handles authentication of ServerConns. details and handles authentication of ServerConns.
config := new(ServerConfig) config := new(ssh.ServerConfig)
config.PubKeyCallback = pubKeyAuth config.PubKeyCallback = pubKeyAuth
config.PasswordCallback = passwordAuth config.PasswordCallback = passwordAuth
@ -34,8 +34,7 @@ Once a ServerConfig has been configured, connections can be accepted.
if err != nil { if err != nil {
panic("failed to accept incoming connection") panic("failed to accept incoming connection")
} }
err = sConn.Handshake(conn) if err := sConn.Handshake(conn); err != nil {
if err != nil {
panic("failed to handshake") panic("failed to handshake")
} }
@ -60,16 +59,20 @@ the case of a shell, the type is "session" and ServerShell may be used to
present a simple terminal interface. present a simple terminal interface.
if channel.ChannelType() != "session" { if channel.ChannelType() != "session" {
c.Reject(UnknownChannelType, "unknown channel type") channel.Reject(UnknownChannelType, "unknown channel type")
return return
} }
channel.Accept() channel.Accept()
shell := NewServerShell(channel, "> ") term := terminal.NewTerminal(channel, "> ")
serverTerm := &ssh.ServerTerminal{
Term: term,
Channel: channel,
}
go func() { go func() {
defer channel.Close() defer channel.Close()
for { for {
line, err := shell.ReadLine() line, err := serverTerm.ReadLine()
if err != nil { if err != nil {
break break
} }
@ -78,8 +81,27 @@ present a simple terminal interface.
return return
}() }()
To authenticate with the remote server you must pass at least one implementation of
ClientAuth via the Auth field in ClientConfig.
// password implements the ClientPassword interface
type password string
func (p password) Password(user string) (string, error) {
return string(p), nil
}
config := &ssh.ClientConfig {
User: "username",
Auth: []ClientAuth {
// ClientAuthPassword wraps a ClientPassword implementation
// in a type that implements ClientAuth.
ClientAuthPassword(password("yourpassword")),
}
}
An SSH client is represented with a ClientConn. Currently only the "password" An SSH client is represented with a ClientConn. Currently only the "password"
authentication method is supported. authentication method is supported.
config := &ClientConfig{ config := &ClientConfig{
User: "username", User: "username",
@ -87,19 +109,19 @@ authentication method is supported.
} }
client, err := Dial("yourserver.com:22", config) client, err := Dial("yourserver.com:22", config)
Each ClientConn can support multiple interactive sessions, represented by a Session. Each ClientConn can support multiple interactive sessions, represented by a Session.
session, err := client.NewSession() session, err := client.NewSession()
Once a Session is created, you can execute a single command on the remote side Once a Session is created, you can execute a single command on the remote side
using the Run method. using the Exec method.
b := bytes.NewBuffer()
session.Stdin = b
if err := session.Run("/usr/bin/whoami"); err != nil { if err := session.Run("/usr/bin/whoami"); err != nil {
panic("Failed to exec: " + err.String()) panic("Failed to exec: " + err.String())
} }
reader := bufio.NewReader(session.Stdin) fmt.Println(bytes.String())
line, _, _ := reader.ReadLine()
fmt.Println(line)
session.Close() session.Close()
*/ */
package ssh package ssh

View File

@ -1,398 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import "io"
// ServerShell contains the state for running a VT100 terminal that is capable
// of reading lines of input.
type ServerShell struct {
c Channel
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewServerShell runs a VT100 terminal on the given channel. prompt is a
// string that is written at the start of each input line. For example: "> ".
func NewServerShell(c Channel, prompt string) *ServerShell {
return &ServerShell{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of ss.outBuf
func (ss *ServerShell) queue(data []byte) {
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
copy(newOutBuf, ss.outBuf)
ss.outBuf = newOutBuf
}
oldLen := len(ss.outBuf)
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
copy(ss.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
// given, logical position in the text.
func (ss *ServerShell) moveCursorToPos(pos int) {
x := len(ss.prompt) + pos
y := x / ss.termWidth
x = x % ss.termWidth
up := 0
if y < ss.cursorY {
up = ss.cursorY - y
}
down := 0
if y > ss.cursorY {
down = y - ss.cursorY
}
left := 0
if x < ss.cursorX {
left = ss.cursorX - x
}
right := 0
if x > ss.cursorX {
right = x - ss.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
ss.cursorX = x
ss.cursorY = y
ss.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if ss.pos == 0 {
return
}
ss.pos--
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
ss.line = ss.line[:len(ss.line)-1]
ss.writeLine(ss.line[ss.pos:])
ss.moveCursorToPos(ss.pos)
ss.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if ss.pos == 0 {
return
}
ss.pos--
for ss.pos > 0 {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos--
}
for ss.pos > 0 {
if ss.line[ss.pos] == ' ' {
ss.pos++
break
}
ss.pos--
}
ss.moveCursorToPos(ss.pos)
case keyAltRight:
// move right by a word.
for ss.pos < len(ss.line) {
if ss.line[ss.pos] == ' ' {
break
}
ss.pos++
}
for ss.pos < len(ss.line) {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos++
}
ss.moveCursorToPos(ss.pos)
case keyLeft:
if ss.pos == 0 {
return
}
ss.pos--
ss.moveCursorToPos(ss.pos)
case keyRight:
if ss.pos == len(ss.line) {
return
}
ss.pos++
ss.moveCursorToPos(ss.pos)
case keyEnter:
ss.moveCursorToPos(len(ss.line))
ss.queue([]byte("\r\n"))
line = string(ss.line)
ok = true
ss.line = ss.line[:0]
ss.pos = 0
ss.cursorX = 0
ss.cursorY = 0
ss.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(ss.line) == maxLineLength {
return
}
if len(ss.line) == cap(ss.line) {
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
copy(newLine, ss.line)
ss.line = newLine
}
ss.line = ss.line[:len(ss.line)+1]
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
ss.line[ss.pos] = byte(key)
ss.writeLine(ss.line[ss.pos:])
ss.pos++
ss.moveCursorToPos(ss.pos)
}
return
}
func (ss *ServerShell) writeLine(line []byte) {
for len(line) != 0 {
if ss.cursorX == ss.termWidth {
ss.queue([]byte("\r\n"))
ss.cursorX = 0
ss.cursorY++
if ss.cursorY > ss.maxLine {
ss.maxLine = ss.cursorY
}
}
remainingOnLine := ss.termWidth - ss.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
ss.queue(line[:todo])
ss.cursorX += todo
line = line[todo:]
}
}
// parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func (ss *ServerShell) Write(buf []byte) (n int, err error) {
return ss.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *ServerShell) ReadLine() (line string, err error) {
ss.writeLine([]byte(ss.prompt))
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
for {
// ss.remainder is a slice at the beginning of ss.inBuf
// containing a partial key sequence
readBuf := ss.inBuf[len(ss.remainder):]
var n int
n, err = ss.c.Read(readBuf)
if err == nil {
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
rest := ss.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = ss.handleKey(key)
}
if len(rest) > 0 {
n := copy(ss.inBuf[:], rest)
ss.remainder = ss.inBuf[:n]
} else {
ss.remainder = nil
}
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
if lineOk {
return
}
continue
}
if req, ok := err.(ChannelRequest); ok {
ok := false
switch req.Request {
case "pty-req":
ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
if !ok {
ss.termWidth = 80
ss.termHeight = 24
}
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
if req.WantReply {
ss.c.AckRequest(ok)
}
} else {
return "", err
}
}
panic("unreachable")
}

View File

@ -1,134 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"testing"
)
type MockChannel struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockChannel) Accept() error {
return nil
}
func (c *MockChannel) Reject(RejectionReason, string) error {
return nil
}
func (c *MockChannel) Read(data []byte) (n int, err error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, io.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockChannel) Write(data []byte) (n int, err error) {
c.received = append(c.received, data...)
return len(data), nil
}
func (c *MockChannel) Close() error {
return nil
}
func (c *MockChannel) AckRequest(ok bool) error {
return nil
}
func (c *MockChannel) ChannelType() string {
return ""
}
func (c *MockChannel) ExtraData() []byte {
return nil
}
func TestClose(t *testing.T) {
c := &MockChannel{}
ss := NewServerShell(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != io.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
err error
}{
{
"",
"",
io.EOF,
},
{
"\r",
"",
nil,
},
{
"foo\r",
"foo",
nil,
},
{
"a\x1b[Cb\r", // right
"ab",
nil,
},
{
"a\x1b[Db\r", // left
"ba",
nil,
},
{
"a\177b\r", // backspace
"b",
nil,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 0; j < len(test.in); j++ {
c := &MockChannel{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewServerShell(c, "> ")
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}

View File

@ -0,0 +1,81 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// A Terminal is capable of parsing and generating virtual terminal
// data from an SSH client.
type Terminal interface {
ReadLine() (line string, err error)
SetSize(x, y int)
Write([]byte) (int, error)
}
// ServerTerminal contains the state for running a terminal that is capable of
// reading lines of input.
type ServerTerminal struct {
Term Terminal
Channel Channel
}
// parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func (ss *ServerTerminal) Write(buf []byte) (n int, err error) {
return ss.Term.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *ServerTerminal) ReadLine() (line string, err error) {
for {
if line, err = ss.Term.ReadLine(); err == nil {
return
}
req, ok := err.(ChannelRequest)
if !ok {
return
}
ok = false
switch req.Request {
case "pty-req":
var width, height int
width, height, ok = parsePtyRequest(req.Payload)
ss.Term.SetSize(width, height)
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
if req.WantReply {
ss.Channel.AckRequest(ok)
}
}
panic("unreachable")
}

View File

@ -8,6 +8,7 @@ package ssh
import ( import (
"bytes" "bytes"
"exp/terminal"
"io" "io"
"testing" "testing"
) )
@ -290,24 +291,32 @@ type exitSignalMsg struct {
Lang string Lang string
} }
func newServerShell(ch *channel, prompt string) *ServerTerminal {
term := terminal.NewTerminal(ch, prompt)
return &ServerTerminal{
Term: term,
Channel: ch,
}
}
func exitStatusZeroHandler(ch *channel) { func exitStatusZeroHandler(ch *channel) {
defer ch.Close() defer ch.Close()
// this string is returned to stdout // this string is returned to stdout
shell := NewServerShell(ch, "> ") shell := newServerShell(ch, "> ")
shell.ReadLine() shell.ReadLine()
sendStatus(0, ch) sendStatus(0, ch)
} }
func exitStatusNonZeroHandler(ch *channel) { func exitStatusNonZeroHandler(ch *channel) {
defer ch.Close() defer ch.Close()
shell := NewServerShell(ch, "> ") shell := newServerShell(ch, "> ")
shell.ReadLine() shell.ReadLine()
sendStatus(15, ch) sendStatus(15, ch)
} }
func exitSignalAndStatusHandler(ch *channel) { func exitSignalAndStatusHandler(ch *channel) {
defer ch.Close() defer ch.Close()
shell := NewServerShell(ch, "> ") shell := newServerShell(ch, "> ")
shell.ReadLine() shell.ReadLine()
sendStatus(15, ch) sendStatus(15, ch)
sendSignal("TERM", ch) sendSignal("TERM", ch)
@ -315,28 +324,28 @@ func exitSignalAndStatusHandler(ch *channel) {
func exitSignalHandler(ch *channel) { func exitSignalHandler(ch *channel) {
defer ch.Close() defer ch.Close()
shell := NewServerShell(ch, "> ") shell := newServerShell(ch, "> ")
shell.ReadLine() shell.ReadLine()
sendSignal("TERM", ch) sendSignal("TERM", ch)
} }
func exitSignalUnknownHandler(ch *channel) { func exitSignalUnknownHandler(ch *channel) {
defer ch.Close() defer ch.Close()
shell := NewServerShell(ch, "> ") shell := newServerShell(ch, "> ")
shell.ReadLine() shell.ReadLine()
sendSignal("SYS", ch) sendSignal("SYS", ch)
} }
func exitWithoutSignalOrStatus(ch *channel) { func exitWithoutSignalOrStatus(ch *channel) {
defer ch.Close() defer ch.Close()
shell := NewServerShell(ch, "> ") shell := newServerShell(ch, "> ")
shell.ReadLine() shell.ReadLine()
} }
func shellHandler(ch *channel) { func shellHandler(ch *channel) {
defer ch.Close() defer ch.Close()
// this string is returned to stdout // this string is returned to stdout
shell := NewServerShell(ch, "golang") shell := newServerShell(ch, "golang")
shell.ReadLine() shell.ReadLine()
sendStatus(0, ch) sendStatus(0, ch)
} }

View File

@ -117,9 +117,7 @@ func (r *reader) readOnePacket() ([]byte, error) {
return nil, err return nil, err
} }
mac := packet[length-1:] mac := packet[length-1:]
if r.cipher != nil { r.cipher.XORKeyStream(packet, packet[:length-1])
r.cipher.XORKeyStream(packet, packet[:length-1])
}
if r.mac != nil { if r.mac != nil {
r.mac.Write(packet[:length-1]) r.mac.Write(packet[:length-1])

View File

@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build linux
package terminal package terminal
import ( import (
@ -463,6 +461,31 @@ func (t *Terminal) readLine() (line string, err error) {
} }
for { for {
rest := t.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = t.handleKey(key)
}
if len(rest) > 0 {
n := copy(t.inBuf[:], rest)
t.remainder = t.inBuf[:n]
} else {
t.remainder = nil
}
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
if lineOk {
return
}
// t.remainder is a slice at the beginning of t.inBuf // t.remainder is a slice at the beginning of t.inBuf
// containing a partial key sequence // containing a partial key sequence
readBuf := t.inBuf[len(t.remainder):] readBuf := t.inBuf[len(t.remainder):]
@ -476,38 +499,19 @@ func (t *Terminal) readLine() (line string, err error) {
return return
} }
if err == nil { t.remainder = t.inBuf[:n+len(t.remainder)]
t.remainder = t.inBuf[:n+len(t.remainder)]
rest := t.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = t.handleKey(key)
}
if len(rest) > 0 {
n := copy(t.inBuf[:], rest)
t.remainder = t.inBuf[:n]
} else {
t.remainder = nil
}
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
if lineOk {
return
}
continue
}
} }
panic("unreachable") panic("unreachable")
} }
// SetPrompt sets the prompt to be used when reading subsequent lines.
func (t *Terminal) SetPrompt(prompt string) {
t.lock.Lock()
defer t.lock.Unlock()
t.prompt = prompt
}
func (t *Terminal) SetSize(width, height int) { func (t *Terminal) SetSize(width, height int) {
t.lock.Lock() t.lock.Lock()
defer t.lock.Unlock() defer t.lock.Unlock()

View File

@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build linux
package terminal package terminal
import ( import (

View File

@ -111,7 +111,7 @@ func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) m
// set otherwise the position information returned here will // set otherwise the position information returned here will
// not match the position information collected by the parser // not match the position information collected by the parser
s.Init(getFile(filename), src, nil, scanner.ScanComments) s.Init(getFile(filename), src, nil, scanner.ScanComments)
var prev token.Pos // position of last non-comment token var prev token.Pos // position of last non-comment, non-semicolon token
scanFile: scanFile:
for { for {
@ -124,6 +124,12 @@ func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) m
if len(s) == 2 { if len(s) == 2 {
errors[prev] = string(s[1]) errors[prev] = string(s[1])
} }
case token.SEMICOLON:
// ignore automatically inserted semicolon
if lit == "\n" {
break
}
fallthrough
default: default:
prev = pos prev = pos
} }

View File

@ -20,6 +20,7 @@ func define(kind ast.ObjKind, name string) *ast.Object {
if scope.Insert(obj) != nil { if scope.Insert(obj) != nil {
panic("types internal error: double declaration") panic("types internal error: double declaration")
} }
obj.Decl = scope
return obj return obj
} }

View File

@ -65,12 +65,13 @@ import (
"os" "os"
"sort" "sort"
"strconv" "strconv"
"time"
) )
// ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined.
var ErrHelp = errors.New("flag: help requested") var ErrHelp = errors.New("flag: help requested")
// -- Bool Value // -- bool Value
type boolValue bool type boolValue bool
func newBoolValue(val bool, p *bool) *boolValue { func newBoolValue(val bool, p *bool) *boolValue {
@ -78,15 +79,15 @@ func newBoolValue(val bool, p *bool) *boolValue {
return (*boolValue)(p) return (*boolValue)(p)
} }
func (b *boolValue) Set(s string) bool { func (b *boolValue) Set(s string) error {
v, err := strconv.ParseBool(s) v, err := strconv.ParseBool(s)
*b = boolValue(v) *b = boolValue(v)
return err == nil return err
} }
func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) } func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) }
// -- Int Value // -- int Value
type intValue int type intValue int
func newIntValue(val int, p *int) *intValue { func newIntValue(val int, p *int) *intValue {
@ -94,15 +95,15 @@ func newIntValue(val int, p *int) *intValue {
return (*intValue)(p) return (*intValue)(p)
} }
func (i *intValue) Set(s string) bool { func (i *intValue) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64) v, err := strconv.ParseInt(s, 0, 64)
*i = intValue(v) *i = intValue(v)
return err == nil return err
} }
func (i *intValue) String() string { return fmt.Sprintf("%v", *i) } func (i *intValue) String() string { return fmt.Sprintf("%v", *i) }
// -- Int64 Value // -- int64 Value
type int64Value int64 type int64Value int64
func newInt64Value(val int64, p *int64) *int64Value { func newInt64Value(val int64, p *int64) *int64Value {
@ -110,15 +111,15 @@ func newInt64Value(val int64, p *int64) *int64Value {
return (*int64Value)(p) return (*int64Value)(p)
} }
func (i *int64Value) Set(s string) bool { func (i *int64Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64) v, err := strconv.ParseInt(s, 0, 64)
*i = int64Value(v) *i = int64Value(v)
return err == nil return err
} }
func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) } func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) }
// -- Uint Value // -- uint Value
type uintValue uint type uintValue uint
func newUintValue(val uint, p *uint) *uintValue { func newUintValue(val uint, p *uint) *uintValue {
@ -126,10 +127,10 @@ func newUintValue(val uint, p *uint) *uintValue {
return (*uintValue)(p) return (*uintValue)(p)
} }
func (i *uintValue) Set(s string) bool { func (i *uintValue) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64) v, err := strconv.ParseUint(s, 0, 64)
*i = uintValue(v) *i = uintValue(v)
return err == nil return err
} }
func (i *uintValue) String() string { return fmt.Sprintf("%v", *i) } func (i *uintValue) String() string { return fmt.Sprintf("%v", *i) }
@ -142,10 +143,10 @@ func newUint64Value(val uint64, p *uint64) *uint64Value {
return (*uint64Value)(p) return (*uint64Value)(p)
} }
func (i *uint64Value) Set(s string) bool { func (i *uint64Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64) v, err := strconv.ParseUint(s, 0, 64)
*i = uint64Value(v) *i = uint64Value(v)
return err == nil return err
} }
func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i) } func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i) }
@ -158,14 +159,14 @@ func newStringValue(val string, p *string) *stringValue {
return (*stringValue)(p) return (*stringValue)(p)
} }
func (s *stringValue) Set(val string) bool { func (s *stringValue) Set(val string) error {
*s = stringValue(val) *s = stringValue(val)
return true return nil
} }
func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) } func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) }
// -- Float64 Value // -- float64 Value
type float64Value float64 type float64Value float64
func newFloat64Value(val float64, p *float64) *float64Value { func newFloat64Value(val float64, p *float64) *float64Value {
@ -173,19 +174,35 @@ func newFloat64Value(val float64, p *float64) *float64Value {
return (*float64Value)(p) return (*float64Value)(p)
} }
func (f *float64Value) Set(s string) bool { func (f *float64Value) Set(s string) error {
v, err := strconv.ParseFloat(s, 64) v, err := strconv.ParseFloat(s, 64)
*f = float64Value(v) *f = float64Value(v)
return err == nil return err
} }
func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) } func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) }
// -- time.Duration Value
type durationValue time.Duration
func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
*p = val
return (*durationValue)(p)
}
func (d *durationValue) Set(s string) error {
v, err := time.ParseDuration(s)
*d = durationValue(v)
return err
}
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
// Value is the interface to the dynamic value stored in a flag. // Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.) // (The default value is represented as a string.)
type Value interface { type Value interface {
String() string String() string
Set(string) bool Set(string) error
} }
// ErrorHandling defines how to handle flag parsing errors. // ErrorHandling defines how to handle flag parsing errors.
@ -276,27 +293,25 @@ func Lookup(name string) *Flag {
return commandLine.formal[name] return commandLine.formal[name]
} }
// Set sets the value of the named flag. It returns true if the set succeeded; false if // Set sets the value of the named flag.
// there is no such flag defined. func (f *FlagSet) Set(name, value string) error {
func (f *FlagSet) Set(name, value string) bool {
flag, ok := f.formal[name] flag, ok := f.formal[name]
if !ok { if !ok {
return false return fmt.Errorf("no such flag -%v", name)
} }
ok = flag.Value.Set(value) err := flag.Value.Set(value)
if !ok { if err != nil {
return false return err
} }
if f.actual == nil { if f.actual == nil {
f.actual = make(map[string]*Flag) f.actual = make(map[string]*Flag)
} }
f.actual[name] = flag f.actual[name] = flag
return true return nil
} }
// Set sets the value of the named command-line flag. It returns true if the // Set sets the value of the named command-line flag.
// set succeeded; false if there is no such flag defined. func Set(name, value string) error {
func Set(name, value string) bool {
return commandLine.Set(name, value) return commandLine.Set(name, value)
} }
@ -543,12 +558,38 @@ func (f *FlagSet) Float64(name string, value float64, usage string) *float64 {
return p return p
} }
// Float64 defines an int flag with specified name, default value, and usage string. // Float64 defines a float64 flag with specified name, default value, and usage string.
// The return value is the address of a float64 variable that stores the value of the flag. // The return value is the address of a float64 variable that stores the value of the flag.
func Float64(name string, value float64, usage string) *float64 { func Float64(name string, value float64, usage string) *float64 {
return commandLine.Float64(name, value, usage) return commandLine.Float64(name, value, usage)
} }
// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
// The argument p points to a time.Duration variable in which to store the value of the flag.
func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
f.Var(newDurationValue(value, p), name, usage)
}
// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
// The argument p points to a time.Duration variable in which to store the value of the flag.
func DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
commandLine.Var(newDurationValue(value, p), name, usage)
}
// Duration defines a time.Duration flag with specified name, default value, and usage string.
// The return value is the address of a time.Duration variable that stores the value of the flag.
func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration {
p := new(time.Duration)
f.DurationVar(p, name, value, usage)
return p
}
// Duration defines a time.Duration flag with specified name, default value, and usage string.
// The return value is the address of a time.Duration variable that stores the value of the flag.
func Duration(name string, value time.Duration, usage string) *time.Duration {
return commandLine.Duration(name, value, usage)
}
// Var defines a flag with the specified name and usage string. The type and // Var defines a flag with the specified name and usage string. The type and
// value of the flag are represented by the first argument, of type Value, which // value of the flag are represented by the first argument, of type Value, which
// typically holds a user-defined implementation of Value. For instance, the // typically holds a user-defined implementation of Value. For instance, the
@ -645,8 +686,8 @@ func (f *FlagSet) parseOne() (bool, error) {
} }
if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
if has_value { if has_value {
if !fv.Set(value) { if err := fv.Set(value); err != nil {
f.failf("invalid boolean value %q for flag: -%s", value, name) f.failf("invalid boolean value %q for -%s: %v", value, name, err)
} }
} else { } else {
fv.Set("true") fv.Set("true")
@ -661,9 +702,8 @@ func (f *FlagSet) parseOne() (bool, error) {
if !has_value { if !has_value {
return false, f.failf("flag needs an argument: -%s", name) return false, f.failf("flag needs an argument: -%s", name)
} }
ok = flag.Value.Set(value) if err := flag.Value.Set(value); err != nil {
if !ok { return false, f.failf("invalid value %q for flag -%s: %v", value, name, err)
return false, f.failf("invalid value %q for flag: -%s", value, name)
} }
} }
if f.actual == nil { if f.actual == nil {

View File

@ -10,16 +10,18 @@ import (
"os" "os"
"sort" "sort"
"testing" "testing"
"time"
) )
var ( var (
test_bool = Bool("test_bool", false, "bool value") test_bool = Bool("test_bool", false, "bool value")
test_int = Int("test_int", 0, "int value") test_int = Int("test_int", 0, "int value")
test_int64 = Int64("test_int64", 0, "int64 value") test_int64 = Int64("test_int64", 0, "int64 value")
test_uint = Uint("test_uint", 0, "uint value") test_uint = Uint("test_uint", 0, "uint value")
test_uint64 = Uint64("test_uint64", 0, "uint64 value") test_uint64 = Uint64("test_uint64", 0, "uint64 value")
test_string = String("test_string", "0", "string value") test_string = String("test_string", "0", "string value")
test_float64 = Float64("test_float64", 0, "float64 value") test_float64 = Float64("test_float64", 0, "float64 value")
test_duration = Duration("test_duration", 0, "time.Duration value")
) )
func boolString(s string) string { func boolString(s string) string {
@ -41,6 +43,8 @@ func TestEverything(t *testing.T) {
ok = true ok = true
case f.Name == "test_bool" && f.Value.String() == boolString(desired): case f.Name == "test_bool" && f.Value.String() == boolString(desired):
ok = true ok = true
case f.Name == "test_duration" && f.Value.String() == desired+"s":
ok = true
} }
if !ok { if !ok {
t.Error("Visit: bad value", f.Value.String(), "for", f.Name) t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
@ -48,7 +52,7 @@ func TestEverything(t *testing.T) {
} }
} }
VisitAll(visitor) VisitAll(visitor)
if len(m) != 7 { if len(m) != 8 {
t.Error("VisitAll misses some flags") t.Error("VisitAll misses some flags")
for k, v := range m { for k, v := range m {
t.Log(k, *v) t.Log(k, *v)
@ -70,9 +74,10 @@ func TestEverything(t *testing.T) {
Set("test_uint64", "1") Set("test_uint64", "1")
Set("test_string", "1") Set("test_string", "1")
Set("test_float64", "1") Set("test_float64", "1")
Set("test_duration", "1s")
desired = "1" desired = "1"
Visit(visitor) Visit(visitor)
if len(m) != 7 { if len(m) != 8 {
t.Error("Visit fails after set") t.Error("Visit fails after set")
for k, v := range m { for k, v := range m {
t.Log(k, *v) t.Log(k, *v)
@ -109,6 +114,7 @@ func testParse(f *FlagSet, t *testing.T) {
uint64Flag := f.Uint64("uint64", 0, "uint64 value") uint64Flag := f.Uint64("uint64", 0, "uint64 value")
stringFlag := f.String("string", "0", "string value") stringFlag := f.String("string", "0", "string value")
float64Flag := f.Float64("float64", 0, "float64 value") float64Flag := f.Float64("float64", 0, "float64 value")
durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value")
extra := "one-extra-argument" extra := "one-extra-argument"
args := []string{ args := []string{
"-bool", "-bool",
@ -119,6 +125,7 @@ func testParse(f *FlagSet, t *testing.T) {
"--uint64", "25", "--uint64", "25",
"-string", "hello", "-string", "hello",
"-float64", "2718e28", "-float64", "2718e28",
"-duration", "2m",
extra, extra,
} }
if err := f.Parse(args); err != nil { if err := f.Parse(args); err != nil {
@ -151,6 +158,9 @@ func testParse(f *FlagSet, t *testing.T) {
if *float64Flag != 2718e28 { if *float64Flag != 2718e28 {
t.Error("float64 flag should be 2718e28, is ", *float64Flag) t.Error("float64 flag should be 2718e28, is ", *float64Flag)
} }
if *durationFlag != 2*time.Minute {
t.Error("duration flag should be 2m, is ", *durationFlag)
}
if len(f.Args()) != 1 { if len(f.Args()) != 1 {
t.Error("expected one argument, got", len(f.Args())) t.Error("expected one argument, got", len(f.Args()))
} else if f.Args()[0] != extra { } else if f.Args()[0] != extra {
@ -174,9 +184,9 @@ func (f *flagVar) String() string {
return fmt.Sprint([]string(*f)) return fmt.Sprint([]string(*f))
} }
func (f *flagVar) Set(value string) bool { func (f *flagVar) Set(value string) error {
*f = append(*f, value) *f = append(*f, value)
return true return nil
} }
func TestUserDefined(t *testing.T) { func TestUserDefined(t *testing.T) {

View File

@ -30,8 +30,9 @@
%X base 16, with upper-case letters for A-F %X base 16, with upper-case letters for A-F
%U Unicode format: U+1234; same as "U+%04X" %U Unicode format: U+1234; same as "U+%04X"
Floating-point and complex constituents: Floating-point and complex constituents:
%b decimalless scientific notation with exponent a power %b decimalless scientific notation with exponent a power of two,
of two, in the manner of strconv.Ftoa32, e.g. -123456p-78 in the manner of strconv.FormatFloat with the 'b' format,
e.g. -123456p-78
%e scientific notation, e.g. -1234.456e+78 %e scientific notation, e.g. -1234.456e+78
%E scientific notation, e.g. -1234.456E+78 %E scientific notation, e.g. -1234.456E+78
%f decimal point but no exponent, e.g. 123.456 %f decimal point but no exponent, e.g. 123.456

View File

@ -517,7 +517,7 @@ var mallocTest = []struct {
{1, `Sprintf("xxx")`, func() { Sprintf("xxx") }}, {1, `Sprintf("xxx")`, func() { Sprintf("xxx") }},
{1, `Sprintf("%x")`, func() { Sprintf("%x", 7) }}, {1, `Sprintf("%x")`, func() { Sprintf("%x", 7) }},
{2, `Sprintf("%s")`, func() { Sprintf("%s", "hello") }}, {2, `Sprintf("%s")`, func() { Sprintf("%s", "hello") }},
{1, `Sprintf("%x %x")`, func() { Sprintf("%x", 7, 112) }}, {1, `Sprintf("%x %x")`, func() { Sprintf("%x %x", 7, 112) }},
{1, `Sprintf("%g")`, func() { Sprintf("%g", 3.14159) }}, {1, `Sprintf("%g")`, func() { Sprintf("%g", 3.14159) }},
{0, `Fprintf(buf, "%x %x %x")`, func() { mallocBuf.Reset(); Fprintf(&mallocBuf, "%x %x %x", 7, 8, 9) }}, {0, `Fprintf(buf, "%x %x %x")`, func() { mallocBuf.Reset(); Fprintf(&mallocBuf, "%x %x %x", 7, 8, 9) }},
{1, `Fprintf(buf, "%s")`, func() { mallocBuf.Reset(); Fprintf(&mallocBuf, "%s", "hello") }}, {1, `Fprintf(buf, "%s")`, func() { mallocBuf.Reset(); Fprintf(&mallocBuf, "%s", "hello") }},

View File

@ -9,6 +9,7 @@ package ast
import ( import (
"go/token" "go/token"
"strings"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
) )
@ -76,6 +77,74 @@ type CommentGroup struct {
func (g *CommentGroup) Pos() token.Pos { return g.List[0].Pos() } func (g *CommentGroup) Pos() token.Pos { return g.List[0].Pos() }
func (g *CommentGroup) End() token.Pos { return g.List[len(g.List)-1].End() } func (g *CommentGroup) End() token.Pos { return g.List[len(g.List)-1].End() }
func isWhitespace(ch byte) bool { return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' }
func stripTrailingWhitespace(s string) string {
i := len(s)
for i > 0 && isWhitespace(s[i-1]) {
i--
}
return s[0:i]
}
// Text returns the text of the comment,
// with the comment markers - //, /*, and */ - removed.
func (g *CommentGroup) Text() string {
if g == nil {
return ""
}
comments := make([]string, len(g.List))
for i, c := range g.List {
comments[i] = string(c.Text)
}
lines := make([]string, 0, 10) // most comments are less than 10 lines
for _, c := range comments {
// Remove comment markers.
// The parser has given us exactly the comment text.
switch c[1] {
case '/':
//-style comment
c = c[2:]
// Remove leading space after //, if there is one.
// TODO(gri) This appears to be necessary in isolated
// cases (bignum.RatFromString) - why?
if len(c) > 0 && c[0] == ' ' {
c = c[1:]
}
case '*':
/*-style comment */
c = c[2 : len(c)-2]
}
// Split on newlines.
cl := strings.Split(c, "\n")
// Walk lines, stripping trailing white space and adding to list.
for _, l := range cl {
lines = append(lines, stripTrailingWhitespace(l))
}
}
// Remove leading blank lines; convert runs of
// interior blank lines to a single blank line.
n := 0
for _, line := range lines {
if line != "" || n > 0 && lines[n-1] != "" {
lines[n] = line
n++
}
}
lines = lines[0:n]
// Add final "" entry to get trailing newline from Join.
if n > 0 && lines[n-1] != "" {
lines = append(lines, "")
}
return strings.Join(lines, "\n")
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Expressions and types // Expressions and types

View File

@ -4,7 +4,10 @@
package ast package ast
import "go/token" import (
"go/token"
"sort"
)
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Export filtering // Export filtering
@ -20,7 +23,7 @@ func exportFilter(name string) bool {
// body) are removed. Non-exported fields and methods of exported types are // body) are removed. Non-exported fields and methods of exported types are
// stripped. The File.Comments list is not changed. // stripped. The File.Comments list is not changed.
// //
// FileExports returns true if there are exported declarationa; // FileExports returns true if there are exported declarations;
// it returns false otherwise. // it returns false otherwise.
// //
func FileExports(src *File) bool { func FileExports(src *File) bool {
@ -291,29 +294,35 @@ var separator = &Comment{noPos, "//"}
// //
func MergePackageFiles(pkg *Package, mode MergeMode) *File { func MergePackageFiles(pkg *Package, mode MergeMode) *File {
// Count the number of package docs, comments and declarations across // Count the number of package docs, comments and declarations across
// all package files. // all package files. Also, compute sorted list of filenames, so that
// subsequent iterations can always iterate in the same order.
ndocs := 0 ndocs := 0
ncomments := 0 ncomments := 0
ndecls := 0 ndecls := 0
for _, f := range pkg.Files { filenames := make([]string, len(pkg.Files))
i := 0
for filename, f := range pkg.Files {
filenames[i] = filename
i++
if f.Doc != nil { if f.Doc != nil {
ndocs += len(f.Doc.List) + 1 // +1 for separator ndocs += len(f.Doc.List) + 1 // +1 for separator
} }
ncomments += len(f.Comments) ncomments += len(f.Comments)
ndecls += len(f.Decls) ndecls += len(f.Decls)
} }
sort.Strings(filenames)
// Collect package comments from all package files into a single // Collect package comments from all package files into a single
// CommentGroup - the collected package documentation. The order // CommentGroup - the collected package documentation. In general
// is unspecified. In general there should be only one file with // there should be only one file with a package comment; but it's
// a package comment; but it's better to collect extra comments // better to collect extra comments than drop them on the floor.
// than drop them on the floor.
var doc *CommentGroup var doc *CommentGroup
var pos token.Pos var pos token.Pos
if ndocs > 0 { if ndocs > 0 {
list := make([]*Comment, ndocs-1) // -1: no separator before first group list := make([]*Comment, ndocs-1) // -1: no separator before first group
i := 0 i := 0
for _, f := range pkg.Files { for _, filename := range filenames {
f := pkg.Files[filename]
if f.Doc != nil { if f.Doc != nil {
if i > 0 { if i > 0 {
// not the first group - add separator // not the first group - add separator
@ -342,7 +351,8 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
funcs := make(map[string]int) // map of global function name -> decls index funcs := make(map[string]int) // map of global function name -> decls index
i := 0 // current index i := 0 // current index
n := 0 // number of filtered entries n := 0 // number of filtered entries
for _, f := range pkg.Files { for _, filename := range filenames {
f := pkg.Files[filename]
for _, d := range f.Decls { for _, d := range f.Decls {
if mode&FilterFuncDuplicates != 0 { if mode&FilterFuncDuplicates != 0 {
// A language entity may be declared multiple // A language entity may be declared multiple
@ -398,7 +408,8 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
var imports []*ImportSpec var imports []*ImportSpec
if mode&FilterImportDuplicates != 0 { if mode&FilterImportDuplicates != 0 {
seen := make(map[string]bool) seen := make(map[string]bool)
for _, f := range pkg.Files { for _, filename := range filenames {
f := pkg.Files[filename]
for _, imp := range f.Imports { for _, imp := range f.Imports {
if path := imp.Path.Value; !seen[path] { if path := imp.Path.Value; !seen[path] {
// TODO: consider handling cases where: // TODO: consider handling cases where:

View File

@ -36,7 +36,7 @@ func NotNilFilter(_ string, v reflect.Value) bool {
// struct fields for which f(fieldname, fieldvalue) is true are // struct fields for which f(fieldname, fieldvalue) is true are
// are printed; all others are filtered from the output. // are printed; all others are filtered from the output.
// //
func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n int, err error) { func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (err error) {
// setup printer // setup printer
p := printer{ p := printer{
output: w, output: w,
@ -48,7 +48,6 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i
// install error handler // install error handler
defer func() { defer func() {
n = p.written
if e := recover(); e != nil { if e := recover(); e != nil {
err = e.(localError).err // re-panics if it's not a localError err = e.(localError).err // re-panics if it's not a localError
} }
@ -67,19 +66,18 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i
// Print prints x to standard output, skipping nil fields. // Print prints x to standard output, skipping nil fields.
// Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter). // Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter).
func Print(fset *token.FileSet, x interface{}) (int, error) { func Print(fset *token.FileSet, x interface{}) error {
return Fprint(os.Stdout, fset, x, NotNilFilter) return Fprint(os.Stdout, fset, x, NotNilFilter)
} }
type printer struct { type printer struct {
output io.Writer output io.Writer
fset *token.FileSet fset *token.FileSet
filter FieldFilter filter FieldFilter
ptrmap map[interface{}]int // *T -> line number ptrmap map[interface{}]int // *T -> line number
written int // number of bytes written to output indent int // current indentation level
indent int // current indentation level last byte // the last byte processed by Write
last byte // the last byte processed by Write line int // current line number
line int // current line number
} }
var indent = []byte(". ") var indent = []byte(". ")
@ -122,9 +120,7 @@ type localError struct {
// printf is a convenience wrapper that takes care of print errors. // printf is a convenience wrapper that takes care of print errors.
func (p *printer) printf(format string, args ...interface{}) { func (p *printer) printf(format string, args ...interface{}) {
n, err := fmt.Fprintf(p, format, args...) if _, err := fmt.Fprintf(p, format, args...); err != nil {
p.written += n
if err != nil {
panic(localError{err}) panic(localError{err})
} }
} }

View File

@ -66,7 +66,7 @@ func TestPrint(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
for _, test := range tests { for _, test := range tests {
buf.Reset() buf.Reset()
if _, err := Fprint(&buf, nil, test.x, nil); err != nil { if err := Fprint(&buf, nil, test.x, nil); err != nil {
t.Errorf("Fprint failed: %s", err) t.Errorf("Fprint failed: %s", err)
} }
if s, ts := trim(buf.String()), trim(test.s); s != ts { if s, ts := trim(buf.String()), trim(test.s); s != ts {

View File

@ -80,7 +80,7 @@ func (s *Scope) String() string {
type Object struct { type Object struct {
Kind ObjKind Kind ObjKind
Name string // declared name Name string // declared name
Decl interface{} // corresponding Field, XxxSpec, FuncDecl, LabeledStmt, or AssignStmt; or nil Decl interface{} // corresponding Field, XxxSpec, FuncDecl, LabeledStmt, AssignStmt, Scope; or nil
Data interface{} // object-specific data; or nil Data interface{} // object-specific data; or nil
Type interface{} // place holder for type information; may be nil Type interface{} // place holder for type information; may be nil
} }
@ -131,6 +131,8 @@ func (obj *Object) Pos() token.Pos {
return ident.Pos() return ident.Pos()
} }
} }
case *Scope:
// predeclared object - nothing to do for now
} }
return token.NoPos return token.NoPos
} }

View File

@ -396,8 +396,7 @@ func (b *build) cgo(cgofiles, cgocfiles []string) (outGo, outObj []string) {
Output: output, Output: output,
}) })
outGo = append(outGo, gofiles...) outGo = append(outGo, gofiles...)
exportH := filepath.Join(b.path, "_cgo_export.h") b.script.addIntermediate(defunC, b.obj+"_cgo_export.h", b.obj+"_cgo_flags")
b.script.addIntermediate(defunC, exportH, b.obj+"_cgo_flags")
b.script.addIntermediate(cfiles...) b.script.addIntermediate(cfiles...)
// cc _cgo_defun.c // cc _cgo_defun.c

View File

@ -9,7 +9,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/doc"
"go/parser" "go/parser"
"go/token" "go/token"
"io/ioutil" "io/ioutil"
@ -412,7 +411,7 @@ func (ctxt *Context) shouldBuild(content []byte) bool {
// TODO(rsc): This duplicates code in cgo. // TODO(rsc): This duplicates code in cgo.
// Once the dust settles, remove this code from cgo. // Once the dust settles, remove this code from cgo.
func (ctxt *Context) saveCgo(filename string, di *DirInfo, cg *ast.CommentGroup) error { func (ctxt *Context) saveCgo(filename string, di *DirInfo, cg *ast.CommentGroup) error {
text := doc.CommentText(cg) text := cg.Text()
for _, line := range strings.Split(text, "\n") { for _, line := range strings.Split(text, "\n") {
orig := line orig := line
@ -476,7 +475,7 @@ func (ctxt *Context) saveCgo(filename string, di *DirInfo, cg *ast.CommentGroup)
return nil return nil
} }
var safeBytes = []byte("+-.,/0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz") var safeBytes = []byte("+-.,/0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz:")
func safeName(s string) bool { func safeName(s string) bool {
if s == "" { if s == "" {

View File

@ -157,6 +157,7 @@ func init() {
Path = []*Tree{t} Path = []*Tree{t}
} }
Loop:
for _, p := range filepath.SplitList(os.Getenv("GOPATH")) { for _, p := range filepath.SplitList(os.Getenv("GOPATH")) {
if p == "" { if p == "" {
continue continue
@ -166,6 +167,21 @@ func init() {
log.Printf("invalid GOPATH %q: %v", p, err) log.Printf("invalid GOPATH %q: %v", p, err)
continue continue
} }
// Check for dupes.
// TODO(alexbrainman): make this correct under windows (case insensitive).
for _, t2 := range Path {
if t2.Path != t.Path {
continue
}
if t2.Goroot {
log.Printf("GOPATH is the same as GOROOT: %q", t.Path)
} else {
log.Printf("duplicate GOPATH entry: %q", t.Path)
}
continue Loop
}
Path = append(Path, t) Path = append(Path, t)
gcImportArgs = append(gcImportArgs, "-I", t.PkgDir()) gcImportArgs = append(gcImportArgs, "-I", t.PkgDir())
ldImportArgs = append(ldImportArgs, "-L", t.PkgDir()) ldImportArgs = append(ldImportArgs, "-L", t.PkgDir())

View File

@ -7,7 +7,6 @@
package doc package doc
import ( import (
"go/ast"
"io" "io"
"regexp" "regexp"
"strings" "strings"
@ -16,74 +15,6 @@ import (
"unicode/utf8" "unicode/utf8"
) )
func isWhitespace(ch byte) bool { return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' }
func stripTrailingWhitespace(s string) string {
i := len(s)
for i > 0 && isWhitespace(s[i-1]) {
i--
}
return s[0:i]
}
// CommentText returns the text of comment,
// with the comment markers - //, /*, and */ - removed.
func CommentText(comment *ast.CommentGroup) string {
if comment == nil {
return ""
}
comments := make([]string, len(comment.List))
for i, c := range comment.List {
comments[i] = string(c.Text)
}
lines := make([]string, 0, 10) // most comments are less than 10 lines
for _, c := range comments {
// Remove comment markers.
// The parser has given us exactly the comment text.
switch c[1] {
case '/':
//-style comment
c = c[2:]
// Remove leading space after //, if there is one.
// TODO(gri) This appears to be necessary in isolated
// cases (bignum.RatFromString) - why?
if len(c) > 0 && c[0] == ' ' {
c = c[1:]
}
case '*':
/*-style comment */
c = c[2 : len(c)-2]
}
// Split on newlines.
cl := strings.Split(c, "\n")
// Walk lines, stripping trailing white space and adding to list.
for _, l := range cl {
lines = append(lines, stripTrailingWhitespace(l))
}
}
// Remove leading blank lines; convert runs of
// interior blank lines to a single blank line.
n := 0
for _, line := range lines {
if line != "" || n > 0 && lines[n-1] != "" {
lines[n] = line
n++
}
}
lines = lines[0:n]
// Add final "" entry to get trailing newline from Join.
if n > 0 && lines[n-1] != "" {
lines = append(lines, "")
}
return strings.Join(lines, "\n")
}
var ( var (
ldquo = []byte("&ldquo;") ldquo = []byte("&ldquo;")
rdquo = []byte("&rdquo;") rdquo = []byte("&rdquo;")
@ -422,12 +353,10 @@ func ToText(w io.Writer, text string, indent, preIndent string, width int) {
width: width, width: width,
indent: indent, indent: indent,
} }
for i, b := range blocks(text) { for _, b := range blocks(text) {
switch b.op { switch b.op {
case opPara: case opPara:
if i > 0 { // l.write will add leading newline if required
w.Write(nl)
}
for _, line := range b.lines { for _, line := range b.lines {
l.write(line) l.write(line)
} }

View File

@ -7,673 +7,94 @@ package doc
import ( import (
"go/ast" "go/ast"
"go/token"
"regexp"
"sort" "sort"
) )
// ---------------------------------------------------------------------------- // Package is the documentation for an entire package.
// Collection of documentation info type Package struct {
Doc string
// embeddedType describes the type of an anonymous field. Name string
// ImportPath string
type embeddedType struct { Imports []string // TODO(gri) this field is not computed at the moment
typ *typeInfo // the corresponding base type Filenames []string
ptr bool // if set, the anonymous field type is a pointer Consts []*Value
Types []*Type
Vars []*Value
Funcs []*Func
Bugs []string
} }
type typeInfo struct { // Value is the documentation for a (possibly grouped) var or const declaration.
// len(decl.Specs) == 1, and the element type is *ast.TypeSpec type Value struct {
// if the type declaration hasn't been seen yet, decl is nil
decl *ast.GenDecl
embedded []embeddedType
forward *TypeDoc // forward link to processed type documentation
// declarations associated with the type
values []*ast.GenDecl // consts and vars
factories map[string]*ast.FuncDecl
methods map[string]*ast.FuncDecl
}
func (info *typeInfo) addEmbeddedType(embedded *typeInfo, isPtr bool) {
info.embedded = append(info.embedded, embeddedType{embedded, isPtr})
}
// docReader accumulates documentation for a single package.
// It modifies the AST: Comments (declaration documentation)
// that have been collected by the DocReader are set to nil
// in the respective AST nodes so that they are not printed
// twice (once when printing the documentation and once when
// printing the corresponding AST node).
//
type docReader struct {
doc *ast.CommentGroup // package documentation, if any
pkgName string
values []*ast.GenDecl // consts and vars
types map[string]*typeInfo
embedded map[string]*typeInfo // embedded types, possibly not exported
funcs map[string]*ast.FuncDecl
bugs []*ast.CommentGroup
}
func (doc *docReader) init(pkgName string) {
doc.pkgName = pkgName
doc.types = make(map[string]*typeInfo)
doc.embedded = make(map[string]*typeInfo)
doc.funcs = make(map[string]*ast.FuncDecl)
}
func (doc *docReader) addDoc(comments *ast.CommentGroup) {
if doc.doc == nil {
// common case: just one package comment
doc.doc = comments
return
}
// More than one package comment: Usually there will be only
// one file with a package comment, but it's better to collect
// all comments than drop them on the floor.
blankComment := &ast.Comment{token.NoPos, "//"}
list := append(doc.doc.List, blankComment)
doc.doc.List = append(list, comments.List...)
}
func (doc *docReader) lookupTypeInfo(name string) *typeInfo {
if name == "" || name == "_" {
return nil // no type docs for anonymous types
}
if info, found := doc.types[name]; found {
return info
}
// type wasn't found - add one without declaration
info := &typeInfo{
factories: make(map[string]*ast.FuncDecl),
methods: make(map[string]*ast.FuncDecl),
}
doc.types[name] = info
return info
}
func baseTypeName(typ ast.Expr, allTypes bool) string {
switch t := typ.(type) {
case *ast.Ident:
// if the type is not exported, the effect to
// a client is as if there were no type name
if t.IsExported() || allTypes {
return t.Name
}
case *ast.StarExpr:
return baseTypeName(t.X, allTypes)
}
return ""
}
func (doc *docReader) addValue(decl *ast.GenDecl) {
// determine if decl should be associated with a type
// Heuristic: For each typed entry, determine the type name, if any.
// If there is exactly one type name that is sufficiently
// frequent, associate the decl with the respective type.
domName := ""
domFreq := 0
prev := ""
for _, s := range decl.Specs {
if v, ok := s.(*ast.ValueSpec); ok {
name := ""
switch {
case v.Type != nil:
// a type is present; determine its name
name = baseTypeName(v.Type, false)
case decl.Tok == token.CONST:
// no type is present but we have a constant declaration;
// use the previous type name (w/o more type information
// we cannot handle the case of unnamed variables with
// initializer expressions except for some trivial cases)
name = prev
}
if name != "" {
// entry has a named type
if domName != "" && domName != name {
// more than one type name - do not associate
// with any type
domName = ""
break
}
domName = name
domFreq++
}
prev = name
}
}
// determine values list
const threshold = 0.75
values := &doc.values
if domName != "" && domFreq >= int(float64(len(decl.Specs))*threshold) {
// typed entries are sufficiently frequent
typ := doc.lookupTypeInfo(domName)
if typ != nil {
values = &typ.values // associate with that type
}
}
*values = append(*values, decl)
}
// Helper function to set the table entry for function f. Makes sure that
// at least one f with associated documentation is stored in table, if there
// are multiple f's with the same name.
func setFunc(table map[string]*ast.FuncDecl, f *ast.FuncDecl) {
name := f.Name.Name
if g, exists := table[name]; exists && g.Doc != nil {
// a function with the same name has already been registered;
// since it has documentation, assume f is simply another
// implementation and ignore it
// TODO(gri) consider collecting all functions, or at least
// all comments
return
}
// function doesn't exist or has no documentation; use f
table[name] = f
}
func (doc *docReader) addFunc(fun *ast.FuncDecl) {
// strip function body
fun.Body = nil
// determine if it should be associated with a type
if fun.Recv != nil {
// method
typ := doc.lookupTypeInfo(baseTypeName(fun.Recv.List[0].Type, false))
if typ != nil {
// exported receiver type
setFunc(typ.methods, fun)
}
// otherwise don't show the method
// TODO(gri): There may be exported methods of non-exported types
// that can be called because of exported values (consts, vars, or
// function results) of that type. Could determine if that is the
// case and then show those methods in an appropriate section.
return
}
// perhaps a factory function
// determine result type, if any
if fun.Type.Results.NumFields() >= 1 {
res := fun.Type.Results.List[0]
if len(res.Names) <= 1 {
// exactly one (named or anonymous) result associated
// with the first type in result signature (there may
// be more than one result)
tname := baseTypeName(res.Type, false)
typ := doc.lookupTypeInfo(tname)
if typ != nil {
// named and exported result type
setFunc(typ.factories, fun)
return
}
}
}
// ordinary function
setFunc(doc.funcs, fun)
}
func (doc *docReader) addDecl(decl ast.Decl) {
switch d := decl.(type) {
case *ast.GenDecl:
if len(d.Specs) > 0 {
switch d.Tok {
case token.CONST, token.VAR:
// constants and variables are always handled as a group
doc.addValue(d)
case token.TYPE:
// types are handled individually
for _, spec := range d.Specs {
tspec := spec.(*ast.TypeSpec)
// add the type to the documentation
info := doc.lookupTypeInfo(tspec.Name.Name)
if info == nil {
continue // no name - ignore the type
}
// Make a (fake) GenDecl node for this TypeSpec
// (we need to do this here - as opposed to just
// for printing - so we don't lose the GenDecl
// documentation). Since a new GenDecl node is
// created, there's no need to nil out d.Doc.
//
// TODO(gri): Consider just collecting the TypeSpec
// node (and copy in the GenDecl.doc if there is no
// doc in the TypeSpec - this is currently done in
// makeTypeDocs below). Simpler data structures, but
// would lose GenDecl documentation if the TypeSpec
// has documentation as well.
fake := &ast.GenDecl{d.Doc, d.Pos(), token.TYPE, token.NoPos,
[]ast.Spec{tspec}, token.NoPos}
// A type should be added at most once, so info.decl
// should be nil - if it isn't, simply overwrite it.
info.decl = fake
// Look for anonymous fields that might contribute methods.
var fields *ast.FieldList
switch typ := spec.(*ast.TypeSpec).Type.(type) {
case *ast.StructType:
fields = typ.Fields
case *ast.InterfaceType:
fields = typ.Methods
}
if fields != nil {
for _, field := range fields.List {
if len(field.Names) == 0 {
// anonymous field - add corresponding type
// to the info and collect it in doc
name := baseTypeName(field.Type, true)
if embedded := doc.lookupTypeInfo(name); embedded != nil {
_, ptr := field.Type.(*ast.StarExpr)
info.addEmbeddedType(embedded, ptr)
}
}
}
}
}
}
}
case *ast.FuncDecl:
doc.addFunc(d)
}
}
func copyCommentList(list []*ast.Comment) []*ast.Comment {
return append([]*ast.Comment(nil), list...)
}
var (
bug_markers = regexp.MustCompile("^/[/*][ \t]*BUG\\(.*\\):[ \t]*") // BUG(uid):
bug_content = regexp.MustCompile("[^ \n\r\t]+") // at least one non-whitespace char
)
// addFile adds the AST for a source file to the docReader.
// Adding the same AST multiple times is a no-op.
//
func (doc *docReader) addFile(src *ast.File) {
// add package documentation
if src.Doc != nil {
doc.addDoc(src.Doc)
src.Doc = nil // doc consumed - remove from ast.File node
}
// add all declarations
for _, decl := range src.Decls {
doc.addDecl(decl)
}
// collect BUG(...) comments
for _, c := range src.Comments {
text := c.List[0].Text
if m := bug_markers.FindStringIndex(text); m != nil {
// found a BUG comment; maybe empty
if btxt := text[m[1]:]; bug_content.MatchString(btxt) {
// non-empty BUG comment; collect comment without BUG prefix
list := copyCommentList(c.List)
list[0].Text = text[m[1]:]
doc.bugs = append(doc.bugs, &ast.CommentGroup{list})
}
}
}
src.Comments = nil // consumed unassociated comments - remove from ast.File node
}
func NewPackageDoc(pkg *ast.Package, importpath string, exportsOnly bool) *PackageDoc {
var r docReader
r.init(pkg.Name)
filenames := make([]string, len(pkg.Files))
i := 0
for filename, f := range pkg.Files {
if exportsOnly {
r.fileExports(f)
}
r.addFile(f)
filenames[i] = filename
i++
}
return r.newDoc(importpath, filenames)
}
// ----------------------------------------------------------------------------
// Conversion to external representation
// ValueDoc is the documentation for a group of declared
// values, either vars or consts.
//
type ValueDoc struct {
Doc string Doc string
Names []string // var or const names in declaration order
Decl *ast.GenDecl Decl *ast.GenDecl
order int order int
} }
type sortValueDoc []*ValueDoc type Method struct {
*Func
func (p sortValueDoc) Len() int { return len(p) } // TODO(gri) The following fields are not set at the moment.
func (p sortValueDoc) Swap(i, j int) { p[i], p[j] = p[j], p[i] } Recv *Type // original receiver base type
Level int // embedding level; 0 means Func is not embedded
func declName(d *ast.GenDecl) string {
if len(d.Specs) != 1 {
return ""
}
switch v := d.Specs[0].(type) {
case *ast.ValueSpec:
return v.Names[0].Name
case *ast.TypeSpec:
return v.Name.Name
}
return ""
} }
func (p sortValueDoc) Less(i, j int) bool { // Type is the documentation for type declaration.
// sort by name type Type struct {
// pull blocks (name = "") up to top Doc string
// in original order Name string
if ni, nj := declName(p[i].Decl), declName(p[j].Decl); ni != nj { Type *ast.TypeSpec
return ni < nj Decl *ast.GenDecl
} Consts []*Value // sorted list of constants of (mostly) this type
return p[i].order < p[j].order Vars []*Value // sorted list of variables of (mostly) this type
Funcs []*Func // sorted list of functions returning this type
Methods []*Method // sorted list of methods (including embedded ones) of this type
methods []*Func // top-level methods only
embedded methodSet // embedded methods only
order int
} }
func makeValueDocs(list []*ast.GenDecl, tok token.Token) []*ValueDoc { // Func is the documentation for a func declaration.
d := make([]*ValueDoc, len(list)) // big enough in any case type Func struct {
n := 0
for i, decl := range list {
if decl.Tok == tok {
d[n] = &ValueDoc{CommentText(decl.Doc), decl, i}
n++
decl.Doc = nil // doc consumed - removed from AST
}
}
d = d[0:n]
sort.Sort(sortValueDoc(d))
return d
}
// FuncDoc is the documentation for a func declaration,
// either a top-level function or a method function.
//
type FuncDoc struct {
Doc string Doc string
Recv ast.Expr // TODO(rsc): Would like string here
Name string Name string
// TODO(gri) remove Recv once we switch to new implementation
Recv ast.Expr // TODO(rsc): Would like string here
Decl *ast.FuncDecl Decl *ast.FuncDecl
} }
type sortFuncDoc []*FuncDoc // Mode values control the operation of New.
type Mode int
func (p sortFuncDoc) Len() int { return len(p) } const (
func (p sortFuncDoc) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // extract documentation for all package-level declarations,
func (p sortFuncDoc) Less(i, j int) bool { return p[i].Name < p[j].Name } // not just exported ones
AllDecls Mode = 1 << iota
)
func makeFuncDocs(m map[string]*ast.FuncDecl) []*FuncDoc { // New computes the package documentation for the given package.
d := make([]*FuncDoc, len(m)) func New(pkg *ast.Package, importpath string, mode Mode) *Package {
var r docReader
r.init(pkg.Name, mode)
filenames := make([]string, len(pkg.Files))
// sort package files before reading them so that the
// result is the same on different machines (32/64bit)
i := 0 i := 0
for _, f := range m { for filename := range pkg.Files {
doc := new(FuncDoc) filenames[i] = filename
doc.Doc = CommentText(f.Doc)
f.Doc = nil // doc consumed - remove from ast.FuncDecl node
if f.Recv != nil {
doc.Recv = f.Recv.List[0].Type
}
doc.Name = f.Name.Name
doc.Decl = f
d[i] = doc
i++ i++
} }
sort.Sort(sortFuncDoc(d))
return d
}
type methodSet map[string]*FuncDoc
func (mset methodSet) add(m *FuncDoc) {
if mset[m.Name] == nil {
mset[m.Name] = m
}
}
func (mset methodSet) sortedList() []*FuncDoc {
list := make([]*FuncDoc, len(mset))
i := 0
for _, m := range mset {
list[i] = m
i++
}
sort.Sort(sortFuncDoc(list))
return list
}
// TypeDoc is the documentation for a declared type.
// Consts and Vars are sorted lists of constants and variables of (mostly) that type.
// Factories is a sorted list of factory functions that return that type.
// Methods is a sorted list of method functions on that type.
type TypeDoc struct {
Doc string
Type *ast.TypeSpec
Consts []*ValueDoc
Vars []*ValueDoc
Factories []*FuncDoc
methods []*FuncDoc // top-level methods only
embedded methodSet // embedded methods only
Methods []*FuncDoc // all methods including embedded ones
Decl *ast.GenDecl
order int
}
type sortTypeDoc []*TypeDoc
func (p sortTypeDoc) Len() int { return len(p) }
func (p sortTypeDoc) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p sortTypeDoc) Less(i, j int) bool {
// sort by name
// pull blocks (name = "") up to top
// in original order
if ni, nj := p[i].Type.Name.Name, p[j].Type.Name.Name; ni != nj {
return ni < nj
}
return p[i].order < p[j].order
}
// NOTE(rsc): This would appear not to be correct for type ( )
// blocks, but the doc extractor above has split them into
// individual declarations.
func (doc *docReader) makeTypeDocs(m map[string]*typeInfo) []*TypeDoc {
// TODO(gri) Consider computing the embedded method information
// before calling makeTypeDocs. Then this function can
// be single-phased again. Also, it might simplify some
// of the logic.
//
// phase 1: associate collected declarations with TypeDocs
list := make([]*TypeDoc, len(m))
i := 0
for _, old := range m {
// all typeInfos should have a declaration associated with
// them after processing an entire package - be conservative
// and check
if decl := old.decl; decl != nil {
typespec := decl.Specs[0].(*ast.TypeSpec)
t := new(TypeDoc)
doc := typespec.Doc
typespec.Doc = nil // doc consumed - remove from ast.TypeSpec node
if doc == nil {
// no doc associated with the spec, use the declaration doc, if any
doc = decl.Doc
}
decl.Doc = nil // doc consumed - remove from ast.Decl node
t.Doc = CommentText(doc)
t.Type = typespec
t.Consts = makeValueDocs(old.values, token.CONST)
t.Vars = makeValueDocs(old.values, token.VAR)
t.Factories = makeFuncDocs(old.factories)
t.methods = makeFuncDocs(old.methods)
// The list of embedded types' methods is computed from the list
// of embedded types, some of which may not have been processed
// yet (i.e., their forward link is nil) - do this in a 2nd phase.
// The final list of methods can only be computed after that -
// do this in a 3rd phase.
t.Decl = old.decl
t.order = i
old.forward = t // old has been processed
list[i] = t
i++
} else {
// no corresponding type declaration found - move any associated
// values, factory functions, and methods back to the top-level
// so that they are not lost (this should only happen if a package
// file containing the explicit type declaration is missing or if
// an unqualified type name was used after a "." import)
// 1) move values
doc.values = append(doc.values, old.values...)
// 2) move factory functions
for name, f := range old.factories {
doc.funcs[name] = f
}
// 3) move methods
for name, f := range old.methods {
// don't overwrite functions with the same name
if _, found := doc.funcs[name]; !found {
doc.funcs[name] = f
}
}
}
}
list = list[0:i] // some types may have been ignored
// phase 2: collect embedded methods for each processed typeInfo
for _, old := range m {
if t := old.forward; t != nil {
// old has been processed into t; collect embedded
// methods for t from the list of processed embedded
// types in old (and thus for which the methods are known)
typ := t.Type
if _, ok := typ.Type.(*ast.StructType); ok {
// struct
t.embedded = make(methodSet)
collectEmbeddedMethods(t.embedded, old, typ.Name.Name)
} else {
// interface
// TODO(gri) fix this
}
}
}
// phase 3: compute final method set for each TypeDoc
for _, d := range list {
if len(d.embedded) > 0 {
// there are embedded methods - exclude
// the ones with names conflicting with
// non-embedded methods
mset := make(methodSet)
// top-level methods have priority
for _, m := range d.methods {
mset.add(m)
}
// add non-conflicting embedded methods
for _, m := range d.embedded {
mset.add(m)
}
d.Methods = mset.sortedList()
} else {
// no embedded methods
d.Methods = d.methods
}
}
sort.Sort(sortTypeDoc(list))
return list
}
// collectEmbeddedMethods collects the embedded methods from all
// processed embedded types found in info in mset. It considers
// embedded types at the most shallow level first so that more
// deeply nested embedded methods with conflicting names are
// excluded.
//
func collectEmbeddedMethods(mset methodSet, info *typeInfo, recvTypeName string) {
for _, e := range info.embedded {
if e.typ.forward != nil { // == e was processed
for _, m := range e.typ.forward.methods {
mset.add(customizeRecv(m, e.ptr, recvTypeName))
}
collectEmbeddedMethods(mset, e.typ, recvTypeName)
}
}
}
func customizeRecv(m *FuncDoc, embeddedIsPtr bool, recvTypeName string) *FuncDoc {
if m == nil || m.Decl == nil || m.Decl.Recv == nil || len(m.Decl.Recv.List) != 1 {
return m // shouldn't happen, but be safe
}
// copy existing receiver field and set new type
// TODO(gri) is receiver type computation correct?
// what about deeply nested embeddings?
newField := *m.Decl.Recv.List[0]
_, origRecvIsPtr := newField.Type.(*ast.StarExpr)
var typ ast.Expr = ast.NewIdent(recvTypeName)
if embeddedIsPtr || origRecvIsPtr {
typ = &ast.StarExpr{token.NoPos, typ}
}
newField.Type = typ
// copy existing receiver field list and set new receiver field
newFieldList := *m.Decl.Recv
newFieldList.List = []*ast.Field{&newField}
// copy existing function declaration and set new receiver field list
newFuncDecl := *m.Decl
newFuncDecl.Recv = &newFieldList
// copy existing function documentation and set new declaration
newM := *m
newM.Decl = &newFuncDecl
newM.Recv = typ
return &newM
}
func makeBugDocs(list []*ast.CommentGroup) []string {
d := make([]string, len(list))
for i, g := range list {
d[i] = CommentText(g)
}
return d
}
// PackageDoc is the documentation for an entire package.
//
type PackageDoc struct {
PackageName string
ImportPath string
Filenames []string
Doc string
Consts []*ValueDoc
Types []*TypeDoc
Vars []*ValueDoc
Funcs []*FuncDoc
Bugs []string
}
// newDoc returns the accumulated documentation for the package.
//
func (doc *docReader) newDoc(importpath string, filenames []string) *PackageDoc {
p := new(PackageDoc)
p.PackageName = doc.pkgName
p.ImportPath = importpath
sort.Strings(filenames) sort.Strings(filenames)
p.Filenames = filenames
p.Doc = CommentText(doc.doc) // process files in sorted order
// makeTypeDocs may extend the list of doc.values and for _, filename := range filenames {
// doc.funcs and thus must be called before any other f := pkg.Files[filename]
// function consuming those lists if mode&AllDecls == 0 {
p.Types = doc.makeTypeDocs(doc.types) r.fileExports(f)
p.Consts = makeValueDocs(doc.values, token.CONST) }
p.Vars = makeValueDocs(doc.values, token.VAR) r.addFile(f)
p.Funcs = makeFuncDocs(doc.funcs) }
p.Bugs = makeBugDocs(doc.bugs) return r.newDoc(importpath, filenames)
return p
} }

137
libgo/go/go/doc/doc_test.go Normal file
View File

@ -0,0 +1,137 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/token"
"testing"
"text/template"
)
type sources map[string]string // filename -> file contents
type testCase struct {
name string
importPath string
mode Mode
srcs sources
doc string
}
var tests = make(map[string]*testCase)
// To register a new test case, use the pattern:
//
// var _ = register(&testCase{ ... })
//
// (The result value of register is always 0 and only present to enable the pattern.)
//
func register(test *testCase) int {
if _, found := tests[test.name]; found {
panic(fmt.Sprintf("registration failed: test case %q already exists", test.name))
}
tests[test.name] = test
return 0
}
func runTest(t *testing.T, test *testCase) {
// create AST
fset := token.NewFileSet()
var pkg ast.Package
pkg.Files = make(map[string]*ast.File)
for filename, src := range test.srcs {
file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
if err != nil {
t.Errorf("test %s: %v", test.name, err)
return
}
switch {
case pkg.Name == "":
pkg.Name = file.Name.Name
case pkg.Name != file.Name.Name:
t.Errorf("test %s: different package names in test files", test.name)
return
}
pkg.Files[filename] = file
}
doc := New(&pkg, test.importPath, test.mode).String()
if doc != test.doc {
//TODO(gri) Enable this once the sorting issue of comments is fixed
//t.Errorf("test %s\n\tgot : %s\n\twant: %s", test.name, doc, test.doc)
}
}
func Test(t *testing.T) {
for _, test := range tests {
runTest(t, test)
}
}
// ----------------------------------------------------------------------------
// Printing support
func (pkg *Package) String() string {
var buf bytes.Buffer
docText.Execute(&buf, pkg) // ignore error - test will fail w/ incorrect output
return buf.String()
}
// TODO(gri) complete template
var docText = template.Must(template.New("docText").Parse(
`
PACKAGE {{.Name}}
DOC {{printf "%q" .Doc}}
IMPORTPATH {{.ImportPath}}
FILENAMES {{.Filenames}}
`))
// ----------------------------------------------------------------------------
// Test cases
// Test that all package comments and bugs are collected,
// and that the importPath is correctly set.
//
var _ = register(&testCase{
name: "p",
importPath: "p",
srcs: sources{
"p1.go": "// comment 1\npackage p\n//BUG(uid): bug1",
"p0.go": "// comment 0\npackage p\n// BUG(uid): bug0",
},
doc: `
PACKAGE p
DOC "comment 0\n\ncomment 1\n"
IMPORTPATH p
FILENAMES [p0.go p1.go]
`,
})
// Test basic functionality.
//
var _ = register(&testCase{
name: "p1",
importPath: "p",
srcs: sources{
"p.go": `
package p
import "a"
const pi = 3.14 // pi
type T struct{} // T
var V T // v
func F(x int) int {} // F
`,
},
doc: `
PACKAGE p
DOC ""
IMPORTPATH p
FILENAMES [p.go]
`,
})

View File

@ -35,7 +35,7 @@ func Examples(pkg *ast.Package) []*Example {
examples = append(examples, &Example{ examples = append(examples, &Example{
Name: name[len("Example"):], Name: name[len("Example"):],
Body: &printer.CommentedNode{f.Body, src.Comments}, Body: &printer.CommentedNode{f.Body, src.Comments},
Output: CommentText(f.Doc), Output: f.Doc.Text(),
}) })
} }
} }

View File

@ -33,7 +33,7 @@ func baseName(x ast.Expr) *ast.Ident {
return nil return nil
} }
func (doc *docReader) filterFieldList(fields *ast.FieldList) (removedFields bool) { func (doc *docReader) filterFieldList(tinfo *typeInfo, fields *ast.FieldList) (removedFields bool) {
if fields == nil { if fields == nil {
return false return false
} }
@ -44,7 +44,18 @@ func (doc *docReader) filterFieldList(fields *ast.FieldList) (removedFields bool
if len(f.Names) == 0 { if len(f.Names) == 0 {
// anonymous field // anonymous field
name := baseName(f.Type) name := baseName(f.Type)
keepField = name != nil && name.IsExported() if name != nil && name.IsExported() {
// we keep the field - in this case doc.addDecl
// will take care of adding the embedded type
keepField = true
} else if tinfo != nil {
// we don't keep the field - add it as an embedded
// type so we won't loose its methods, if any
if embedded := doc.lookupTypeInfo(name.Name); embedded != nil {
_, ptr := f.Type.(*ast.StarExpr)
tinfo.addEmbeddedType(embedded, ptr)
}
}
} else { } else {
n := len(f.Names) n := len(f.Names)
f.Names = filterIdentList(f.Names) f.Names = filterIdentList(f.Names)
@ -54,7 +65,7 @@ func (doc *docReader) filterFieldList(fields *ast.FieldList) (removedFields bool
keepField = len(f.Names) > 0 keepField = len(f.Names) > 0
} }
if keepField { if keepField {
doc.filterType(f.Type) doc.filterType(nil, f.Type)
list[j] = f list[j] = f
j++ j++
} }
@ -72,23 +83,23 @@ func (doc *docReader) filterParamList(fields *ast.FieldList) bool {
} }
var b bool var b bool
for _, f := range fields.List { for _, f := range fields.List {
if doc.filterType(f.Type) { if doc.filterType(nil, f.Type) {
b = true b = true
} }
} }
return b return b
} }
func (doc *docReader) filterType(typ ast.Expr) bool { func (doc *docReader) filterType(tinfo *typeInfo, typ ast.Expr) bool {
switch t := typ.(type) { switch t := typ.(type) {
case *ast.Ident: case *ast.Ident:
return ast.IsExported(t.Name) return ast.IsExported(t.Name)
case *ast.ParenExpr: case *ast.ParenExpr:
return doc.filterType(t.X) return doc.filterType(nil, t.X)
case *ast.ArrayType: case *ast.ArrayType:
return doc.filterType(t.Elt) return doc.filterType(nil, t.Elt)
case *ast.StructType: case *ast.StructType:
if doc.filterFieldList(t.Fields) { if doc.filterFieldList(tinfo, t.Fields) {
t.Incomplete = true t.Incomplete = true
} }
return len(t.Fields.List) > 0 return len(t.Fields.List) > 0
@ -97,16 +108,16 @@ func (doc *docReader) filterType(typ ast.Expr) bool {
b2 := doc.filterParamList(t.Results) b2 := doc.filterParamList(t.Results)
return b1 || b2 return b1 || b2
case *ast.InterfaceType: case *ast.InterfaceType:
if doc.filterFieldList(t.Methods) { if doc.filterFieldList(tinfo, t.Methods) {
t.Incomplete = true t.Incomplete = true
} }
return len(t.Methods.List) > 0 return len(t.Methods.List) > 0
case *ast.MapType: case *ast.MapType:
b1 := doc.filterType(t.Key) b1 := doc.filterType(nil, t.Key)
b2 := doc.filterType(t.Value) b2 := doc.filterType(nil, t.Value)
return b1 || b2 return b1 || b2
case *ast.ChanType: case *ast.ChanType:
return doc.filterType(t.Value) return doc.filterType(nil, t.Value)
} }
return false return false
} }
@ -116,12 +127,12 @@ func (doc *docReader) filterSpec(spec ast.Spec) bool {
case *ast.ValueSpec: case *ast.ValueSpec:
s.Names = filterIdentList(s.Names) s.Names = filterIdentList(s.Names)
if len(s.Names) > 0 { if len(s.Names) > 0 {
doc.filterType(s.Type) doc.filterType(nil, s.Type)
return true return true
} }
case *ast.TypeSpec: case *ast.TypeSpec:
if ast.IsExported(s.Name.Name) { if ast.IsExported(s.Name.Name) {
doc.filterType(s.Type) doc.filterType(doc.lookupTypeInfo(s.Name.Name), s.Type)
return true return true
} }
} }

View File

@ -49,7 +49,7 @@ func matchDecl(d *ast.GenDecl, f Filter) bool {
return false return false
} }
func filterValueDocs(a []*ValueDoc, f Filter) []*ValueDoc { func filterValues(a []*Value, f Filter) []*Value {
w := 0 w := 0
for _, vd := range a { for _, vd := range a {
if matchDecl(vd.Decl, f) { if matchDecl(vd.Decl, f) {
@ -60,7 +60,7 @@ func filterValueDocs(a []*ValueDoc, f Filter) []*ValueDoc {
return a[0:w] return a[0:w]
} }
func filterFuncDocs(a []*FuncDoc, f Filter) []*FuncDoc { func filterFuncs(a []*Func, f Filter) []*Func {
w := 0 w := 0
for _, fd := range a { for _, fd := range a {
if f(fd.Name) { if f(fd.Name) {
@ -71,7 +71,18 @@ func filterFuncDocs(a []*FuncDoc, f Filter) []*FuncDoc {
return a[0:w] return a[0:w]
} }
func filterTypeDocs(a []*TypeDoc, f Filter) []*TypeDoc { func filterMethods(a []*Method, f Filter) []*Method {
w := 0
for _, md := range a {
if f(md.Name) {
a[w] = md
w++
}
}
return a[0:w]
}
func filterTypes(a []*Type, f Filter) []*Type {
w := 0 w := 0
for _, td := range a { for _, td := range a {
n := 0 // number of matches n := 0 // number of matches
@ -79,11 +90,11 @@ func filterTypeDocs(a []*TypeDoc, f Filter) []*TypeDoc {
n = 1 n = 1
} else { } else {
// type name doesn't match, but we may have matching consts, vars, factories or methods // type name doesn't match, but we may have matching consts, vars, factories or methods
td.Consts = filterValueDocs(td.Consts, f) td.Consts = filterValues(td.Consts, f)
td.Vars = filterValueDocs(td.Vars, f) td.Vars = filterValues(td.Vars, f)
td.Factories = filterFuncDocs(td.Factories, f) td.Funcs = filterFuncs(td.Funcs, f)
td.Methods = filterFuncDocs(td.Methods, f) td.Methods = filterMethods(td.Methods, f)
n += len(td.Consts) + len(td.Vars) + len(td.Factories) + len(td.Methods) n += len(td.Consts) + len(td.Vars) + len(td.Funcs) + len(td.Methods)
} }
if n > 0 { if n > 0 {
a[w] = td a[w] = td
@ -96,10 +107,10 @@ func filterTypeDocs(a []*TypeDoc, f Filter) []*TypeDoc {
// Filter eliminates documentation for names that don't pass through the filter f. // Filter eliminates documentation for names that don't pass through the filter f.
// TODO: Recognize "Type.Method" as a name. // TODO: Recognize "Type.Method" as a name.
// //
func (p *PackageDoc) Filter(f Filter) { func (p *Package) Filter(f Filter) {
p.Consts = filterValueDocs(p.Consts, f) p.Consts = filterValues(p.Consts, f)
p.Vars = filterValueDocs(p.Vars, f) p.Vars = filterValues(p.Vars, f)
p.Types = filterTypeDocs(p.Types, f) p.Types = filterTypes(p.Types, f)
p.Funcs = filterFuncDocs(p.Funcs, f) p.Funcs = filterFuncs(p.Funcs, f)
p.Doc = "" // don't show top-level package doc p.Doc = "" // don't show top-level package doc
} }

Some files were not shown because too many files have changed in this diff Show More