Update Go library to r60.

From-SVN: r178910
This commit is contained in:
Ian Lance Taylor 2011-09-16 15:47:21 +00:00
parent 5548ca3540
commit adb0401dac
718 changed files with 58911 additions and 30469 deletions

View File

@ -806,7 +806,7 @@ proc go-gc-tests { } {
$status $name $status $name
} else { } else {
verbose -log $comp_output verbose -log $comp_output
fali $name fail $name
} }
file delete $ofile1 $ofile2 $output_file file delete $ofile1 $ofile2 $output_file
set runtests $hold_runtests set runtests $hold_runtests

View File

@ -54,8 +54,8 @@ func run(t *template.Template, a interface{}, out io.Writer) {
} }
} }
type arg struct{ type arg struct {
def bool def bool
nreset int nreset int
} }
@ -135,181 +135,180 @@ func main() {
} }
` `
func parse(s string) *template.Template { func parse(name, s string) *template.Template {
t := template.New(nil) t, err := template.New(name).Parse(s)
t.SetDelims("〈", "〉") if err != nil {
if err := t.Parse(s); err != nil { panic(fmt.Sprintf("%q: %s", name, err))
panic(s)
} }
return t return t
} }
var recv = parse(` var recv = parse("recv", `
# Send n, receive it one way or another into x, check that they match. {{/* Send n, receive it one way or another into x, check that they match. */}}
c <- n c <- n
.section Maybe {{if .Maybe}}
x = <-c x = <-c
.or {{else}}
select { select {
# Blocking or non-blocking, before the receive. {{/* Blocking or non-blocking, before the receive. */}}
# The compiler implements two-case select where one is default with custom code, {{/* The compiler implements two-case select where one is default with custom code, */}}
# so test the default branch both before and after the send. {{/* so test the default branch both before and after the send. */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. {{/* Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. */}}
.section Maybe {{if .Maybe}}
case x = <-c: case x = <-c:
.or.section Maybe {{else}}{{if .Maybe}}
case *f(&x) = <-c: case *f(&x) = <-c:
.or.section Maybe {{else}}{{if .Maybe}}
case y := <-c: case y := <-c:
x = y x = y
.or.section Maybe {{else}}{{if .Maybe}}
case i = <-c: case i = <-c:
x = i.(int) x = i.(int)
.or {{else}}
case m[13] = <-c: case m[13] = <-c:
x = m[13] x = m[13]
.end.end.end.end {{end}}{{end}}{{end}}{{end}}
# Blocking or non-blocking again, after the receive. {{/* Blocking or non-blocking again, after the receive. */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Dummy send, receive to keep compiler from optimizing select. {{/* Dummy send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case dummy <- 1: case dummy <- 1:
panic("dummy send") panic("dummy send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-dummy: case <-dummy:
panic("dummy receive") panic("dummy receive")
.end {{end}}
# Nil channel send, receive to keep compiler from optimizing select. {{/* Nil channel send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case nilch <- 1: case nilch <- 1:
panic("nilch send") panic("nilch send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-nilch: case <-nilch:
panic("nilch recv") panic("nilch recv")
.end {{end}}
} }
.end {{end}}
if x != n { if x != n {
die(x) die(x)
} }
n++ n++
`) `)
var recvOrder = parse(` var recvOrder = parse("recvOrder", `
# Send n, receive it one way or another into x, check that they match. {{/* Send n, receive it one way or another into x, check that they match. */}}
# Check order of operations along the way by calling functions that check {{/* Check order of operations along the way by calling functions that check */}}
# that the argument sequence is strictly increasing. {{/* that the argument sequence is strictly increasing. */}}
order = 0 order = 0
c <- n c <- n
.section Maybe {{if .Maybe}}
# Outside of select, left-to-right rule applies. {{/* Outside of select, left-to-right rule applies. */}}
# (Inside select, assignment waits until case is chosen, {{/* (Inside select, assignment waits until case is chosen, */}}
# so right hand side happens before anything on left hand side. {{/* so right hand side happens before anything on left hand side. */}}
*fp(&x, 1) = <-fc(c, 2) *fp(&x, 1) = <-fc(c, 2)
.or.section Maybe {{else}}{{if .Maybe}}
m[fn(13, 1)] = <-fc(c, 2) m[fn(13, 1)] = <-fc(c, 2)
x = m[13] x = m[13]
.or {{else}}
select { select {
# Blocking or non-blocking, before the receive. {{/* Blocking or non-blocking, before the receive. */}}
# The compiler implements two-case select where one is default with custom code, {{/* The compiler implements two-case select where one is default with custom code, */}}
# so test the default branch both before and after the send. {{/* so test the default branch both before and after the send. */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. {{/* Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. */}}
.section Maybe {{if .Maybe}}
case *fp(&x, 100) = <-fc(c, 1): case *fp(&x, 100) = <-fc(c, 1):
.or.section Maybe {{else}}{{if .Maybe}}
case y := <-fc(c, 1): case y := <-fc(c, 1):
x = y x = y
.or.section Maybe {{else}}{{if .Maybe}}
case i = <-fc(c, 1): case i = <-fc(c, 1):
x = i.(int) x = i.(int)
.or {{else}}
case m[fn(13, 100)] = <-fc(c, 1): case m[fn(13, 100)] = <-fc(c, 1):
x = m[13] x = m[13]
.end.end.end {{end}}{{end}}{{end}}
# Blocking or non-blocking again, after the receive. {{/* Blocking or non-blocking again, after the receive. */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Dummy send, receive to keep compiler from optimizing select. {{/* Dummy send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case fc(dummy, 2) <- fn(1, 3): case fc(dummy, 2) <- fn(1, 3):
panic("dummy send") panic("dummy send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-fc(dummy, 4): case <-fc(dummy, 4):
panic("dummy receive") panic("dummy receive")
.end {{end}}
# Nil channel send, receive to keep compiler from optimizing select. {{/* Nil channel send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case fc(nilch, 5) <- fn(1, 6): case fc(nilch, 5) <- fn(1, 6):
panic("nilch send") panic("nilch send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-fc(nilch, 7): case <-fc(nilch, 7):
panic("nilch recv") panic("nilch recv")
.end {{end}}
} }
.end.end {{end}}{{end}}
if x != n { if x != n {
die(x) die(x)
} }
n++ n++
`) `)
var send = parse(` var send = parse("send", `
# Send n one way or another, receive it into x, check that they match. {{/* Send n one way or another, receive it into x, check that they match. */}}
.section Maybe {{if .Maybe}}
c <- n c <- n
.or {{else}}
select { select {
# Blocking or non-blocking, before the receive (same reason as in recv). {{/* Blocking or non-blocking, before the receive (same reason as in recv). */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Send c <- n. No real special cases here, because no values come back {{/* Send c <- n. No real special cases here, because no values come back */}}
# from the send operation. {{/* from the send operation. */}}
case c <- n: case c <- n:
# Blocking or non-blocking. {{/* Blocking or non-blocking. */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Dummy send, receive to keep compiler from optimizing select. {{/* Dummy send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case dummy <- 1: case dummy <- 1:
panic("dummy send") panic("dummy send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-dummy: case <-dummy:
panic("dummy receive") panic("dummy receive")
.end {{end}}
# Nil channel send, receive to keep compiler from optimizing select. {{/* Nil channel send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case nilch <- 1: case nilch <- 1:
panic("nilch send") panic("nilch send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-nilch: case <-nilch:
panic("nilch recv") panic("nilch recv")
.end {{end}}
} }
.end {{end}}
x = <-c x = <-c
if x != n { if x != n {
die(x) die(x)
@ -317,48 +316,48 @@ var send = parse(`
n++ n++
`) `)
var sendOrder = parse(` var sendOrder = parse("sendOrder", `
# Send n one way or another, receive it into x, check that they match. {{/* Send n one way or another, receive it into x, check that they match. */}}
# Check order of operations along the way by calling functions that check {{/* Check order of operations along the way by calling functions that check */}}
# that the argument sequence is strictly increasing. {{/* that the argument sequence is strictly increasing. */}}
order = 0 order = 0
.section Maybe {{if .Maybe}}
fc(c, 1) <- fn(n, 2) fc(c, 1) <- fn(n, 2)
.or {{else}}
select { select {
# Blocking or non-blocking, before the receive (same reason as in recv). {{/* Blocking or non-blocking, before the receive (same reason as in recv). */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Send c <- n. No real special cases here, because no values come back {{/* Send c <- n. No real special cases here, because no values come back */}}
# from the send operation. {{/* from the send operation. */}}
case fc(c, 1) <- fn(n, 2): case fc(c, 1) <- fn(n, 2):
# Blocking or non-blocking. {{/* Blocking or non-blocking. */}}
.section MaybeDefault {{if .MaybeDefault}}
default: default:
panic("nonblock") panic("nonblock")
.end {{end}}
# Dummy send, receive to keep compiler from optimizing select. {{/* Dummy send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case fc(dummy, 3) <- fn(1, 4): case fc(dummy, 3) <- fn(1, 4):
panic("dummy send") panic("dummy send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-fc(dummy, 5): case <-fc(dummy, 5):
panic("dummy receive") panic("dummy receive")
.end {{end}}
# Nil channel send, receive to keep compiler from optimizing select. {{/* Nil channel send, receive to keep compiler from optimizing select. */}}
.section Maybe {{if .Maybe}}
case fc(nilch, 6) <- fn(1, 7): case fc(nilch, 6) <- fn(1, 7):
panic("nilch send") panic("nilch send")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-fc(nilch, 8): case <-fc(nilch, 8):
panic("nilch recv") panic("nilch recv")
.end {{end}}
} }
.end {{end}}
x = <-c x = <-c
if x != n { if x != n {
die(x) die(x)
@ -366,49 +365,49 @@ var sendOrder = parse(`
n++ n++
`) `)
var nonblock = parse(` var nonblock = parse("nonblock", `
x = n x = n
# Test various combinations of non-blocking operations. {{/* Test various combinations of non-blocking operations. */}}
# Receive assignments must not edit or even attempt to compute the address of the lhs. {{/* Receive assignments must not edit or even attempt to compute the address of the lhs. */}}
select { select {
.section MaybeDefault {{if .MaybeDefault}}
default: default:
.end {{end}}
.section Maybe {{if .Maybe}}
case dummy <- 1: case dummy <- 1:
panic("dummy <- 1") panic("dummy <- 1")
.end {{end}}
.section Maybe {{if .Maybe}}
case nilch <- 1: case nilch <- 1:
panic("nilch <- 1") panic("nilch <- 1")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-dummy: case <-dummy:
panic("<-dummy") panic("<-dummy")
.end {{end}}
.section Maybe {{if .Maybe}}
case x = <-dummy: case x = <-dummy:
panic("<-dummy x") panic("<-dummy x")
.end {{end}}
.section Maybe {{if .Maybe}}
case **(**int)(nil) = <-dummy: case **(**int)(nil) = <-dummy:
panic("<-dummy (and didn't crash saving result!)") panic("<-dummy (and didn't crash saving result!)")
.end {{end}}
.section Maybe {{if .Maybe}}
case <-nilch: case <-nilch:
panic("<-nilch") panic("<-nilch")
.end {{end}}
.section Maybe {{if .Maybe}}
case x = <-nilch: case x = <-nilch:
panic("<-nilch x") panic("<-nilch x")
.end {{end}}
.section Maybe {{if .Maybe}}
case **(**int)(nil) = <-nilch: case **(**int)(nil) = <-nilch:
panic("<-nilch (and didn't crash saving result!)") panic("<-nilch (and didn't crash saving result!)")
.end {{end}}
.section MustDefault {{if .MustDefault}}
default: default:
.end {{end}}
} }
if x != n { if x != n {
die(x) die(x)
@ -466,7 +465,7 @@ func next() bool {
} }
// increment last choice sequence // increment last choice sequence
cp = len(choices)-1 cp = len(choices) - 1
for cp >= 0 && choices[cp].i == choices[cp].n-1 { for cp >= 0 && choices[cp].i == choices[cp].n-1 {
cp-- cp--
} }
@ -479,4 +478,3 @@ func next() bool {
cp = 0 cp = 0
return true return true
} }

View File

@ -38,7 +38,7 @@ func Listen(x, y string) (T, string) {
} }
func (t T) Addr() os.Error { func (t T) Addr() os.Error {
return os.ErrorString("stringer") return os.NewError("stringer")
} }
func (t T) Accept() (int, string) { func (t T) Accept() (int, string) {
@ -49,4 +49,3 @@ func Dial(x, y, z string) (int, string) {
global <- 1 global <- 1
return 0, "" return 0, ""
} }

View File

@ -18,6 +18,7 @@ var chatty = flag.Bool("v", false, "chatty")
var oldsys uint64 var oldsys uint64
func bigger() { func bigger() {
runtime.UpdateMemStats()
if st := runtime.MemStats; oldsys < st.Sys { if st := runtime.MemStats; oldsys < st.Sys {
oldsys = st.Sys oldsys = st.Sys
if *chatty { if *chatty {
@ -31,7 +32,7 @@ func bigger() {
} }
func main() { func main() {
runtime.GC() // clean up garbage from init runtime.GC() // clean up garbage from init
runtime.MemProfileRate = 0 // disable profiler runtime.MemProfileRate = 0 // disable profiler
runtime.MemStats.Alloc = 0 // ignore stacks runtime.MemStats.Alloc = 0 // ignore stacks
flag.Parse() flag.Parse()
@ -45,8 +46,10 @@ func main() {
panic("fail") panic("fail")
} }
b := runtime.Alloc(uintptr(j)) b := runtime.Alloc(uintptr(j))
runtime.UpdateMemStats()
during := runtime.MemStats.Alloc during := runtime.MemStats.Alloc
runtime.Free(b) runtime.Free(b)
runtime.UpdateMemStats()
if a := runtime.MemStats.Alloc; a != 0 { if a := runtime.MemStats.Alloc; a != 0 {
println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)") println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)")
panic("fail") panic("fail")

View File

@ -42,6 +42,7 @@ func AllocAndFree(size, count int) {
if *chatty { if *chatty {
fmt.Printf("size=%d count=%d ...\n", size, count) fmt.Printf("size=%d count=%d ...\n", size, count)
} }
runtime.UpdateMemStats()
n1 := stats.Alloc n1 := stats.Alloc
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
b[i] = runtime.Alloc(uintptr(size)) b[i] = runtime.Alloc(uintptr(size))
@ -50,11 +51,13 @@ func AllocAndFree(size, count int) {
println("lookup failed: got", base, n, "for", b[i]) println("lookup failed: got", base, n, "for", b[i])
panic("fail") panic("fail")
} }
if runtime.MemStats.Sys > 1e9 { runtime.UpdateMemStats()
if stats.Sys > 1e9 {
println("too much memory allocated") println("too much memory allocated")
panic("fail") panic("fail")
} }
} }
runtime.UpdateMemStats()
n2 := stats.Alloc n2 := stats.Alloc
if *chatty { if *chatty {
fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats) fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats)
@ -72,6 +75,7 @@ func AllocAndFree(size, count int) {
panic("fail") panic("fail")
} }
runtime.Free(b[i]) runtime.Free(b[i])
runtime.UpdateMemStats()
if stats.Alloc != uint64(alloc-n) { if stats.Alloc != uint64(alloc-n) {
println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n) println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n)
panic("fail") panic("fail")
@ -81,6 +85,7 @@ func AllocAndFree(size, count int) {
panic("fail") panic("fail")
} }
} }
runtime.UpdateMemStats()
n4 := stats.Alloc n4 := stats.Alloc
if *chatty { if *chatty {

View File

@ -1,4 +1,4 @@
aea0ba6e5935 504f4e9b079c
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -12,12 +12,24 @@
/* Define to 1 if you have the <inttypes.h> header file. */ /* Define to 1 if you have the <inttypes.h> header file. */
#undef HAVE_INTTYPES_H #undef HAVE_INTTYPES_H
/* Define to 1 if you have the <linux/filter.h> header file. */
#undef HAVE_LINUX_FILTER_H
/* Define to 1 if you have the <linux/netlink.h> header file. */
#undef HAVE_LINUX_NETLINK_H
/* Define to 1 if you have the <linux/rtnetlink.h> header file. */
#undef HAVE_LINUX_RTNETLINK_H
/* Define to 1 if you have the <memory.h> header file. */ /* Define to 1 if you have the <memory.h> header file. */
#undef HAVE_MEMORY_H #undef HAVE_MEMORY_H
/* Define to 1 if you have the `mincore' function. */ /* Define to 1 if you have the `mincore' function. */
#undef HAVE_MINCORE #undef HAVE_MINCORE
/* Define to 1 if you have the <net/if.h> header file. */
#undef HAVE_NET_IF_H
/* Define to 1 if the system has the type `off64_t'. */ /* Define to 1 if the system has the type `off64_t'. */
#undef HAVE_OFF64_T #undef HAVE_OFF64_T
@ -71,6 +83,9 @@
/* Define to 1 if you have the <sys/select.h> header file. */ /* Define to 1 if you have the <sys/select.h> header file. */
#undef HAVE_SYS_SELECT_H #undef HAVE_SYS_SELECT_H
/* Define to 1 if you have the <sys/socket.h> header file. */
#undef HAVE_SYS_SOCKET_H
/* Define to 1 if you have the <sys/stat.h> header file. */ /* Define to 1 if you have the <sys/stat.h> header file. */
#undef HAVE_SYS_STAT_H #undef HAVE_SYS_STAT_H

33
libgo/configure vendored
View File

@ -617,7 +617,6 @@ USING_SPLIT_STACK_FALSE
USING_SPLIT_STACK_TRUE USING_SPLIT_STACK_TRUE
SPLIT_STACK SPLIT_STACK
OSCFLAGS OSCFLAGS
GO_DEBUG_PROC_REGS_OS_ARCH_FILE
GO_SYSCALLS_SYSCALL_OS_ARCH_FILE GO_SYSCALLS_SYSCALL_OS_ARCH_FILE
GOARCH GOARCH
LIBGO_IS_X86_64_FALSE LIBGO_IS_X86_64_FALSE
@ -10914,7 +10913,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2 lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF cat > conftest.$ac_ext <<_LT_EOF
#line 10917 "configure" #line 10916 "configure"
#include "confdefs.h" #include "confdefs.h"
#if HAVE_DLFCN_H #if HAVE_DLFCN_H
@ -11020,7 +11019,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2 lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF cat > conftest.$ac_ext <<_LT_EOF
#line 11023 "configure" #line 11022 "configure"
#include "confdefs.h" #include "confdefs.h"
#if HAVE_DLFCN_H #if HAVE_DLFCN_H
@ -13558,12 +13557,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
fi fi
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
fi
case "$target" in case "$target" in
mips-sgi-irix6.5*) mips-sgi-irix6.5*)
# IRIX 6 needs _XOPEN_SOURCE=500 for the XPG5 version of struct # IRIX 6 needs _XOPEN_SOURCE=500 for the XPG5 version of struct
@ -14252,7 +14245,7 @@ no)
;; ;;
esac esac
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h
do : do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
@ -14266,6 +14259,26 @@ fi
done done
for ac_header in linux/filter.h linux/netlink.h linux/rtnetlink.h
do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
"
eval as_val=\$$as_ac_Header
if test "x$as_val" = x""yes; then :
cat >>confdefs.h <<_ACEOF
#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1
_ACEOF
fi
done
if test "$ac_cv_header_sys_mman_h" = yes; then if test "$ac_cv_header_sys_mman_h" = yes; then
HAVE_SYS_MMAN_H_TRUE= HAVE_SYS_MMAN_H_TRUE=
HAVE_SYS_MMAN_H_FALSE='#' HAVE_SYS_MMAN_H_FALSE='#'

View File

@ -255,12 +255,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
fi fi
AC_SUBST(GO_SYSCALLS_SYSCALL_OS_ARCH_FILE) AC_SUBST(GO_SYSCALLS_SYSCALL_OS_ARCH_FILE)
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
fi
AC_SUBST(GO_DEBUG_PROC_REGS_OS_ARCH_FILE)
dnl Some targets need special flags to build sysinfo.go. dnl Some targets need special flags to build sysinfo.go.
case "$target" in case "$target" in
mips-sgi-irix6.5*) mips-sgi-irix6.5*)
@ -431,7 +425,14 @@ no)
;; ;;
esac esac
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h) AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h)
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
])
AM_CONDITIONAL(HAVE_SYS_MMAN_H, test "$ac_cv_header_sys_mman_h" = yes) AM_CONDITIONAL(HAVE_SYS_MMAN_H, test "$ac_cv_header_sys_mman_h" = yes)
AC_CHECK_FUNCS(srandom random strerror_r strsignal wait4 mincore setenv) AC_CHECK_FUNCS(srandom random strerror_r strsignal wait4 mincore setenv)

View File

@ -16,7 +16,7 @@ import (
) )
var ( var (
HeaderError os.Error = os.ErrorString("invalid tar header") HeaderError = os.NewError("invalid tar header")
) )
// A Reader provides sequential access to the contents of a tar archive. // A Reader provides sequential access to the contents of a tar archive.

View File

@ -178,7 +178,6 @@ func TestPartialRead(t *testing.T) {
} }
} }
func TestIncrementalRead(t *testing.T) { func TestIncrementalRead(t *testing.T) {
test := gnuTarTest test := gnuTarTest
f, err := os.Open(test.file) f, err := os.Open(test.file)

View File

@ -2,18 +2,10 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/*
Package zip provides support for reading ZIP archives.
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
This package does not support ZIP64 or disk spanning.
*/
package zip package zip
import ( import (
"bufio" "bufio"
"bytes"
"compress/flate" "compress/flate"
"hash" "hash"
"hash/crc32" "hash/crc32"
@ -24,9 +16,9 @@ import (
) )
var ( var (
FormatError = os.NewError("not a valid zip file") FormatError = os.NewError("zip: not a valid zip file")
UnsupportedMethod = os.NewError("unsupported compression algorithm") UnsupportedMethod = os.NewError("zip: unsupported compression algorithm")
ChecksumError = os.NewError("checksum error") ChecksumError = os.NewError("zip: checksum error")
) )
type Reader struct { type Reader struct {
@ -44,15 +36,14 @@ type File struct {
FileHeader FileHeader
zipr io.ReaderAt zipr io.ReaderAt
zipsize int64 zipsize int64
headerOffset uint32 headerOffset int64
bodyOffset int64
} }
func (f *File) hasDataDescriptor() bool { func (f *File) hasDataDescriptor() bool {
return f.Flags&0x8 != 0 return f.Flags&0x8 != 0
} }
// OpenReader will open the Zip file specified by name and return a ReaderCloser. // OpenReader will open the Zip file specified by name and return a ReadCloser.
func OpenReader(name string) (*ReadCloser, os.Error) { func OpenReader(name string) (*ReadCloser, os.Error) {
f, err := os.Open(name) f, err := os.Open(name)
if err != nil { if err != nil {
@ -87,18 +78,33 @@ func (z *Reader) init(r io.ReaderAt, size int64) os.Error {
return err return err
} }
z.r = r z.r = r
z.File = make([]*File, end.directoryRecords) z.File = make([]*File, 0, end.directoryRecords)
z.Comment = end.comment z.Comment = end.comment
rs := io.NewSectionReader(r, 0, size) rs := io.NewSectionReader(r, 0, size)
if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil { if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil {
return err return err
} }
buf := bufio.NewReader(rs) buf := bufio.NewReader(rs)
for i := range z.File {
z.File[i] = &File{zipr: r, zipsize: size} // The count of files inside a zip is truncated to fit in a uint16.
if err := readDirectoryHeader(z.File[i], buf); err != nil { // Gloss over this by reading headers until we encounter
// a bad one, and then only report a FormatError or UnexpectedEOF if
// the file count modulo 65536 is incorrect.
for {
f := &File{zipr: r, zipsize: size}
err = readDirectoryHeader(f, buf)
if err == FormatError || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return err return err
} }
z.File = append(z.File, f)
}
if uint16(len(z.File)) != end.directoryRecords {
// Return the readDirectoryHeader error if we read
// the wrong number of directory entries.
return err
} }
return nil return nil
} }
@ -109,31 +115,22 @@ func (rc *ReadCloser) Close() os.Error {
} }
// Open returns a ReadCloser that provides access to the File's contents. // Open returns a ReadCloser that provides access to the File's contents.
// It is safe to Open and Read from files concurrently.
func (f *File) Open() (rc io.ReadCloser, err os.Error) { func (f *File) Open() (rc io.ReadCloser, err os.Error) {
off := int64(f.headerOffset) bodyOffset, err := f.findBodyOffset()
if f.bodyOffset == 0 { if err != nil {
r := io.NewSectionReader(f.zipr, off, f.zipsize-off) return
if err = readFileHeader(f, r); err != nil {
return
}
if f.bodyOffset, err = r.Seek(0, os.SEEK_CUR); err != nil {
return
}
} }
size := int64(f.CompressedSize) size := int64(f.CompressedSize)
if f.hasDataDescriptor() { if size == 0 && f.hasDataDescriptor() {
if size == 0 { // permit SectionReader to see the rest of the file
// permit SectionReader to see the rest of the file size = f.zipsize - (f.headerOffset + bodyOffset)
size = f.zipsize - (off + f.bodyOffset)
} else {
size += dataDescriptorLen
}
} }
r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size) r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
switch f.Method { switch f.Method {
case 0: // store (no compression) case Store: // (no compression)
rc = ioutil.NopCloser(r) rc = ioutil.NopCloser(r)
case 8: // DEFLATE case Deflate:
rc = flate.NewReader(r) rc = flate.NewReader(r)
default: default:
err = UnsupportedMethod err = UnsupportedMethod
@ -170,90 +167,102 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
func (r *checksumReader) Close() os.Error { return r.rc.Close() } func (r *checksumReader) Close() os.Error { return r.rc.Close() }
func readFileHeader(f *File, r io.Reader) (err os.Error) { func readFileHeader(f *File, r io.Reader) os.Error {
defer func() { var b [fileHeaderLen]byte
if rerr, ok := recover().(os.Error); ok { if _, err := io.ReadFull(r, b[:]); err != nil {
err = rerr return err
} }
}() c := binary.LittleEndian
var ( if sig := c.Uint32(b[:4]); sig != fileHeaderSignature {
signature uint32
filenameLength uint16
extraLength uint16
)
read(r, &signature)
if signature != fileHeaderSignature {
return FormatError return FormatError
} }
read(r, &f.ReaderVersion) f.ReaderVersion = c.Uint16(b[4:6])
read(r, &f.Flags) f.Flags = c.Uint16(b[6:8])
read(r, &f.Method) f.Method = c.Uint16(b[8:10])
read(r, &f.ModifiedTime) f.ModifiedTime = c.Uint16(b[10:12])
read(r, &f.ModifiedDate) f.ModifiedDate = c.Uint16(b[12:14])
read(r, &f.CRC32) f.CRC32 = c.Uint32(b[14:18])
read(r, &f.CompressedSize) f.CompressedSize = c.Uint32(b[18:22])
read(r, &f.UncompressedSize) f.UncompressedSize = c.Uint32(b[22:26])
read(r, &filenameLength) filenameLen := int(c.Uint16(b[26:28]))
read(r, &extraLength) extraLen := int(c.Uint16(b[28:30]))
f.Name = string(readByteSlice(r, filenameLength)) d := make([]byte, filenameLen+extraLen)
f.Extra = readByteSlice(r, extraLength) if _, err := io.ReadFull(r, d); err != nil {
return return err
}
f.Name = string(d[:filenameLen])
f.Extra = d[filenameLen:]
return nil
} }
func readDirectoryHeader(f *File, r io.Reader) (err os.Error) { // findBodyOffset does the minimum work to verify the file has a header
defer func() { // and returns the file body offset.
if rerr, ok := recover().(os.Error); ok { func (f *File) findBodyOffset() (int64, os.Error) {
err = rerr r := io.NewSectionReader(f.zipr, f.headerOffset, f.zipsize-f.headerOffset)
} var b [fileHeaderLen]byte
}() if _, err := io.ReadFull(r, b[:]); err != nil {
var ( return 0, err
signature uint32 }
filenameLength uint16 c := binary.LittleEndian
extraLength uint16 if sig := c.Uint32(b[:4]); sig != fileHeaderSignature {
commentLength uint16 return 0, FormatError
startDiskNumber uint16 // unused }
internalAttributes uint16 // unused filenameLen := int(c.Uint16(b[26:28]))
externalAttributes uint32 // unused extraLen := int(c.Uint16(b[28:30]))
) return int64(fileHeaderLen + filenameLen + extraLen), nil
read(r, &signature) }
if signature != directoryHeaderSignature {
// readDirectoryHeader attempts to read a directory header from r.
// It returns io.ErrUnexpectedEOF if it cannot read a complete header,
// and FormatError if it doesn't find a valid header signature.
func readDirectoryHeader(f *File, r io.Reader) os.Error {
var b [directoryHeaderLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
c := binary.LittleEndian
if sig := c.Uint32(b[:4]); sig != directoryHeaderSignature {
return FormatError return FormatError
} }
read(r, &f.CreatorVersion) f.CreatorVersion = c.Uint16(b[4:6])
read(r, &f.ReaderVersion) f.ReaderVersion = c.Uint16(b[6:8])
read(r, &f.Flags) f.Flags = c.Uint16(b[8:10])
read(r, &f.Method) f.Method = c.Uint16(b[10:12])
read(r, &f.ModifiedTime) f.ModifiedTime = c.Uint16(b[12:14])
read(r, &f.ModifiedDate) f.ModifiedDate = c.Uint16(b[14:16])
read(r, &f.CRC32) f.CRC32 = c.Uint32(b[16:20])
read(r, &f.CompressedSize) f.CompressedSize = c.Uint32(b[20:24])
read(r, &f.UncompressedSize) f.UncompressedSize = c.Uint32(b[24:28])
read(r, &filenameLength) filenameLen := int(c.Uint16(b[28:30]))
read(r, &extraLength) extraLen := int(c.Uint16(b[30:32]))
read(r, &commentLength) commentLen := int(c.Uint16(b[32:34]))
read(r, &startDiskNumber) // startDiskNumber := c.Uint16(b[34:36]) // Unused
read(r, &internalAttributes) // internalAttributes := c.Uint16(b[36:38]) // Unused
read(r, &externalAttributes) // externalAttributes := c.Uint32(b[38:42]) // Unused
read(r, &f.headerOffset) f.headerOffset = int64(c.Uint32(b[42:46]))
f.Name = string(readByteSlice(r, filenameLength)) d := make([]byte, filenameLen+extraLen+commentLen)
f.Extra = readByteSlice(r, extraLength) if _, err := io.ReadFull(r, d); err != nil {
f.Comment = string(readByteSlice(r, commentLength)) return err
return }
f.Name = string(d[:filenameLen])
f.Extra = d[filenameLen : filenameLen+extraLen]
f.Comment = string(d[filenameLen+extraLen:])
return nil
} }
func readDataDescriptor(r io.Reader, f *File) (err os.Error) { func readDataDescriptor(r io.Reader, f *File) os.Error {
defer func() { var b [dataDescriptorLen]byte
if rerr, ok := recover().(os.Error); ok { if _, err := io.ReadFull(r, b[:]); err != nil {
err = rerr return err
} }
}() c := binary.LittleEndian
read(r, &f.CRC32) f.CRC32 = c.Uint32(b[:4])
read(r, &f.CompressedSize) f.CompressedSize = c.Uint32(b[4:8])
read(r, &f.UncompressedSize) f.UncompressedSize = c.Uint32(b[8:12])
return return nil
} }
func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error) { func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, err os.Error) {
// look for directoryEndSignature in the last 1k, then in the last 65k // look for directoryEndSignature in the last 1k, then in the last 65k
var b []byte var b []byte
for i, bLen := range []int64{1024, 65 * 1024} { for i, bLen := range []int64{1024, 65 * 1024} {
@ -274,53 +283,29 @@ func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error)
} }
// read header into struct // read header into struct
defer func() { c := binary.LittleEndian
if rerr, ok := recover().(os.Error); ok { d := new(directoryEnd)
err = rerr d.diskNbr = c.Uint16(b[4:6])
d = nil d.dirDiskNbr = c.Uint16(b[6:8])
} d.dirRecordsThisDisk = c.Uint16(b[8:10])
}() d.directoryRecords = c.Uint16(b[10:12])
br := bytes.NewBuffer(b[4:]) // skip over signature d.directorySize = c.Uint32(b[12:16])
d = new(directoryEnd) d.directoryOffset = c.Uint32(b[16:20])
read(br, &d.diskNbr) d.commentLen = c.Uint16(b[20:22])
read(br, &d.dirDiskNbr) d.comment = string(b[22 : 22+int(d.commentLen)])
read(br, &d.dirRecordsThisDisk)
read(br, &d.directoryRecords)
read(br, &d.directorySize)
read(br, &d.directoryOffset)
read(br, &d.commentLen)
d.comment = string(readByteSlice(br, d.commentLen))
return d, nil return d, nil
} }
func findSignatureInBlock(b []byte) int { func findSignatureInBlock(b []byte) int {
const minSize = 4 + 2 + 2 + 2 + 2 + 4 + 4 + 2 // fixed part of header for i := len(b) - directoryEndLen; i >= 0; i-- {
for i := len(b) - minSize; i >= 0; i-- {
// defined from directoryEndSignature in struct.go // defined from directoryEndSignature in struct.go
if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 { if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 {
// n is length of comment // n is length of comment
n := int(b[i+minSize-2]) | int(b[i+minSize-1])<<8 n := int(b[i+directoryEndLen-2]) | int(b[i+directoryEndLen-1])<<8
if n+minSize+i == len(b) { if n+directoryEndLen+i == len(b) {
return i return i
} }
} }
} }
return -1 return -1
} }
func read(r io.Reader, data interface{}) {
if err := binary.Read(r, binary.LittleEndian, data); err != nil {
panic(err)
}
}
func readByteSlice(r io.Reader, l uint16) []byte {
b := make([]byte, l)
if l == 0 {
return b
}
if _, err := io.ReadFull(r, b); err != nil {
panic(err)
}
return b
}

View File

@ -11,6 +11,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time"
) )
type ZipTest struct { type ZipTest struct {
@ -24,8 +25,19 @@ type ZipTestFile struct {
Name string Name string
Content []byte // if blank, will attempt to compare against File Content []byte // if blank, will attempt to compare against File
File string // name of file to compare to (relative to testdata/) File string // name of file to compare to (relative to testdata/)
Mtime string // modified time in format "mm-dd-yy hh:mm:ss"
} }
// Caution: The Mtime values found for the test files should correspond to
// the values listed with unzip -l <zipfile>. However, the values
// listed by unzip appear to be off by some hours. When creating
// fresh test files and testing them, this issue is not present.
// The test files were created in Sydney, so there might be a time
// zone issue. The time zone information does have to be encoded
// somewhere, because otherwise unzip -l could not provide a different
// time from what the archive/zip package provides, but there appears
// to be no documentation about this.
var tests = []ZipTest{ var tests = []ZipTest{
{ {
Name: "test.zip", Name: "test.zip",
@ -34,10 +46,12 @@ var tests = []ZipTest{
{ {
Name: "test.txt", Name: "test.txt",
Content: []byte("This is a test text file.\n"), Content: []byte("This is a test text file.\n"),
Mtime: "09-05-10 12:12:02",
}, },
{ {
Name: "gophercolor16x16.png", Name: "gophercolor16x16.png",
File: "gophercolor16x16.png", File: "gophercolor16x16.png",
Mtime: "09-05-10 15:52:58",
}, },
}, },
}, },
@ -45,8 +59,9 @@ var tests = []ZipTest{
Name: "r.zip", Name: "r.zip",
File: []ZipTestFile{ File: []ZipTestFile{
{ {
Name: "r/r.zip", Name: "r/r.zip",
File: "r.zip", File: "r.zip",
Mtime: "03-04-10 00:24:16",
}, },
}, },
}, },
@ -58,6 +73,7 @@ var tests = []ZipTest{
{ {
Name: "filename", Name: "filename",
Content: []byte("This is a test textfile.\n"), Content: []byte("This is a test textfile.\n"),
Mtime: "02-02-11 13:06:20",
}, },
}, },
}, },
@ -136,18 +152,36 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
if f.Name != ft.Name { if f.Name != ft.Name {
t.Errorf("name=%q, want %q", f.Name, ft.Name) t.Errorf("name=%q, want %q", f.Name, ft.Name)
} }
mtime, err := time.Parse("01-02-06 15:04:05", ft.Mtime)
if err != nil {
t.Error(err)
return
}
if got, want := f.Mtime_ns()/1e9, mtime.Seconds(); got != want {
t.Errorf("%s: mtime=%s (%d); want %s (%d)", f.Name, time.SecondsToUTC(got), got, mtime, want)
}
size0 := f.UncompressedSize
var b bytes.Buffer var b bytes.Buffer
r, err := f.Open() r, err := f.Open()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
if size1 := f.UncompressedSize; size0 != size1 {
t.Errorf("file %q changed f.UncompressedSize from %d to %d", f.Name, size0, size1)
}
_, err = io.Copy(&b, r) _, err = io.Copy(&b, r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
r.Close() r.Close()
var c []byte var c []byte
if len(ft.Content) != 0 { if len(ft.Content) != 0 {
c = ft.Content c = ft.Content
@ -155,10 +189,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
t.Error(err) t.Error(err)
return return
} }
if b.Len() != len(c) { if b.Len() != len(c) {
t.Errorf("%s: len=%d, want %d", f.Name, b.Len(), len(c)) t.Errorf("%s: len=%d, want %d", f.Name, b.Len(), len(c))
return return
} }
for i, b := range b.Bytes() { for i, b := range b.Bytes() {
if b != c[i] { if b != c[i] {
t.Errorf("%s: content[%d]=%q want %q", f.Name, i, b, c[i]) t.Errorf("%s: content[%d]=%q want %q", f.Name, i, b, c[i])

View File

@ -1,9 +1,32 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zip provides support for reading and writing ZIP archives.
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
This package does not support ZIP64 or disk spanning.
*/
package zip package zip
import "os"
import "time"
// Compression methods.
const (
Store uint16 = 0
Deflate uint16 = 8
)
const ( const (
fileHeaderSignature = 0x04034b50 fileHeaderSignature = 0x04034b50
directoryHeaderSignature = 0x02014b50 directoryHeaderSignature = 0x02014b50
directoryEndSignature = 0x06054b50 directoryEndSignature = 0x06054b50
fileHeaderLen = 30 // + filename + extra
directoryHeaderLen = 46 // + filename + extra + comment
directoryEndLen = 22 // + comment
dataDescriptorLen = 12 dataDescriptorLen = 12
) )
@ -13,8 +36,8 @@ type FileHeader struct {
ReaderVersion uint16 ReaderVersion uint16
Flags uint16 Flags uint16
Method uint16 Method uint16
ModifiedTime uint16 ModifiedTime uint16 // MS-DOS time
ModifiedDate uint16 ModifiedDate uint16 // MS-DOS date
CRC32 uint32 CRC32 uint32
CompressedSize uint32 CompressedSize uint32
UncompressedSize uint32 UncompressedSize uint32
@ -32,3 +55,37 @@ type directoryEnd struct {
commentLen uint16 commentLen uint16
comment string comment string
} }
func recoverError(err *os.Error) {
if e := recover(); e != nil {
if osErr, ok := e.(os.Error); ok {
*err = osErr
return
}
panic(e)
}
}
// msDosTimeToTime converts an MS-DOS date and time into a time.Time.
// The resolution is 2s.
// See: http://msdn.microsoft.com/en-us/library/ms724247(v=VS.85).aspx
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
return time.Time{
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
Year: int64(dosDate>>9 + 1980),
Month: int(dosDate >> 5 & 0xf),
Day: int(dosDate & 0x1f),
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
Hour: int(dosTime >> 11),
Minute: int(dosTime >> 5 & 0x3f),
Second: int(dosTime & 0x1f * 2),
}
}
// Mtime_ns returns the modified time in ns since epoch.
// The resolution is 2s.
func (h *FileHeader) Mtime_ns() int64 {
t := msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
return t.Seconds() * 1e9
}

View File

@ -0,0 +1,244 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"compress/flate"
"encoding/binary"
"hash"
"hash/crc32"
"io"
"os"
)
// TODO(adg): support zip file comments
// TODO(adg): support specifying deflate level
// Writer implements a zip file writer.
type Writer struct {
*countWriter
dir []*header
last *fileWriter
closed bool
}
type header struct {
*FileHeader
offset uint32
}
// NewWriter returns a new Writer writing a zip file to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{countWriter: &countWriter{w: bufio.NewWriter(w)}}
}
// Close finishes writing the zip file by writing the central directory.
// It does not (and can not) close the underlying writer.
func (w *Writer) Close() (err os.Error) {
if w.last != nil && !w.last.closed {
if err = w.last.close(); err != nil {
return
}
w.last = nil
}
if w.closed {
return os.NewError("zip: writer closed twice")
}
w.closed = true
defer recoverError(&err)
// write central directory
start := w.count
for _, h := range w.dir {
write(w, uint32(directoryHeaderSignature))
write(w, h.CreatorVersion)
write(w, h.ReaderVersion)
write(w, h.Flags)
write(w, h.Method)
write(w, h.ModifiedTime)
write(w, h.ModifiedDate)
write(w, h.CRC32)
write(w, h.CompressedSize)
write(w, h.UncompressedSize)
write(w, uint16(len(h.Name)))
write(w, uint16(len(h.Extra)))
write(w, uint16(len(h.Comment)))
write(w, uint16(0)) // disk number start
write(w, uint16(0)) // internal file attributes
write(w, uint32(0)) // external file attributes
write(w, h.offset)
writeBytes(w, []byte(h.Name))
writeBytes(w, h.Extra)
writeBytes(w, []byte(h.Comment))
}
end := w.count
// write end record
write(w, uint32(directoryEndSignature))
write(w, uint16(0)) // disk number
write(w, uint16(0)) // disk number where directory starts
write(w, uint16(len(w.dir))) // number of entries this disk
write(w, uint16(len(w.dir))) // number of entries total
write(w, uint32(end-start)) // size of directory
write(w, uint32(start)) // start of directory
write(w, uint16(0)) // size of comment
return w.w.(*bufio.Writer).Flush()
}
// Create adds a file to the zip file using the provided name.
// It returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) Create(name string) (io.Writer, os.Error) {
header := &FileHeader{
Name: name,
Method: Deflate,
}
return w.CreateHeader(header)
}
// CreateHeader adds a file to the zip file using the provided FileHeader
// for the file metadata.
// It returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, os.Error) {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return nil, err
}
}
fh.Flags |= 0x8 // we will write a data descriptor
fh.CreatorVersion = 0x14
fh.ReaderVersion = 0x14
fw := &fileWriter{
zipw: w,
compCount: &countWriter{w: w},
crc32: crc32.NewIEEE(),
}
switch fh.Method {
case Store:
fw.comp = nopCloser{fw.compCount}
case Deflate:
fw.comp = flate.NewWriter(fw.compCount, 5)
default:
return nil, UnsupportedMethod
}
fw.rawCount = &countWriter{w: fw.comp}
h := &header{
FileHeader: fh,
offset: uint32(w.count),
}
w.dir = append(w.dir, h)
fw.header = h
if err := writeHeader(w, fh); err != nil {
return nil, err
}
w.last = fw
return fw, nil
}
func writeHeader(w io.Writer, h *FileHeader) (err os.Error) {
defer recoverError(&err)
write(w, uint32(fileHeaderSignature))
write(w, h.ReaderVersion)
write(w, h.Flags)
write(w, h.Method)
write(w, h.ModifiedTime)
write(w, h.ModifiedDate)
write(w, h.CRC32)
write(w, h.CompressedSize)
write(w, h.UncompressedSize)
write(w, uint16(len(h.Name)))
write(w, uint16(len(h.Extra)))
writeBytes(w, []byte(h.Name))
writeBytes(w, h.Extra)
return nil
}
type fileWriter struct {
*header
zipw io.Writer
rawCount *countWriter
comp io.WriteCloser
compCount *countWriter
crc32 hash.Hash32
closed bool
}
func (w *fileWriter) Write(p []byte) (int, os.Error) {
if w.closed {
return 0, os.NewError("zip: write to closed file")
}
w.crc32.Write(p)
return w.rawCount.Write(p)
}
func (w *fileWriter) close() (err os.Error) {
if w.closed {
return os.NewError("zip: file closed twice")
}
w.closed = true
if err = w.comp.Close(); err != nil {
return
}
// update FileHeader
fh := w.header.FileHeader
fh.CRC32 = w.crc32.Sum32()
fh.CompressedSize = uint32(w.compCount.count)
fh.UncompressedSize = uint32(w.rawCount.count)
// write data descriptor
defer recoverError(&err)
write(w.zipw, fh.CRC32)
write(w.zipw, fh.CompressedSize)
write(w.zipw, fh.UncompressedSize)
return nil
}
type countWriter struct {
w io.Writer
count int64
}
func (w *countWriter) Write(p []byte) (int, os.Error) {
n, err := w.w.Write(p)
w.count += int64(n)
return n, err
}
type nopCloser struct {
io.Writer
}
func (w nopCloser) Close() os.Error {
return nil
}
func write(w io.Writer, data interface{}) {
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
panic(err)
}
}
func writeBytes(w io.Writer, b []byte) {
n, err := w.Write(b)
if err != nil {
panic(err)
}
if n != len(b) {
panic(io.ErrShortWrite)
}
}

View File

@ -0,0 +1,73 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bytes"
"io/ioutil"
"rand"
"testing"
)
// TODO(adg): a more sophisticated test suite
const testString = "Rabbits, guinea pigs, gophers, marsupial rats, and quolls."
func TestWriter(t *testing.T) {
largeData := make([]byte, 1<<17)
for i := range largeData {
largeData[i] = byte(rand.Int())
}
// write a zip file
buf := new(bytes.Buffer)
w := NewWriter(buf)
testCreate(t, w, "foo", []byte(testString), Store)
testCreate(t, w, "bar", largeData, Deflate)
if err := w.Close(); err != nil {
t.Fatal(err)
}
// read it back
r, err := NewReader(sliceReaderAt(buf.Bytes()), int64(buf.Len()))
if err != nil {
t.Fatal(err)
}
testReadFile(t, r.File[0], []byte(testString))
testReadFile(t, r.File[1], largeData)
}
func testCreate(t *testing.T, w *Writer, name string, data []byte, method uint16) {
header := &FileHeader{
Name: name,
Method: method,
}
f, err := w.CreateHeader(header)
if err != nil {
t.Fatal(err)
}
_, err = f.Write(data)
if err != nil {
t.Fatal(err)
}
}
func testReadFile(t *testing.T, f *File, data []byte) {
rc, err := f.Open()
if err != nil {
t.Fatal("opening:", err)
}
b, err := ioutil.ReadAll(rc)
if err != nil {
t.Fatal("reading:", err)
}
err = rc.Close()
if err != nil {
t.Fatal("closing:", err)
}
if !bytes.Equal(b, data) {
t.Errorf("File contents %q, want %q", b, data)
}
}

View File

@ -0,0 +1,57 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests that involve both reading and writing.
package zip
import (
"bytes"
"fmt"
"os"
"testing"
)
type stringReaderAt string
func (s stringReaderAt) ReadAt(p []byte, off int64) (n int, err os.Error) {
if off >= int64(len(s)) {
return 0, os.EOF
}
n = copy(p, s[off:])
return
}
func TestOver65kFiles(t *testing.T) {
if testing.Short() {
t.Logf("slow test; skipping")
return
}
buf := new(bytes.Buffer)
w := NewWriter(buf)
const nFiles = (1 << 16) + 42
for i := 0; i < nFiles; i++ {
_, err := w.Create(fmt.Sprintf("%d.dat", i))
if err != nil {
t.Fatalf("creating file %d: %v", i, err)
}
}
if err := w.Close(); err != nil {
t.Fatalf("Writer.Close: %v", err)
}
rat := stringReaderAt(buf.String())
zr, err := NewReader(rat, int64(len(rat)))
if err != nil {
t.Fatalf("NewReader: %v", err)
}
if got := len(zr.File); got != nFiles {
t.Fatalf("File contains %d files, want %d", got, nFiles)
}
for i := 0; i < nFiles; i++ {
want := fmt.Sprintf("%d.dat", i)
if zr.File[i].Name != want {
t.Fatalf("File(%d) = %q, want %q", i, zr.File[i].Name, want)
}
}
}

View File

@ -20,6 +20,7 @@ package asn1
// everything by any means. // everything by any means.
import ( import (
"big"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -88,6 +89,27 @@ func parseInt(bytes []byte) (int, os.Error) {
return int(ret64), nil return int(ret64), nil
} }
var bigOne = big.NewInt(1)
// parseBigInt treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseBigInt(bytes []byte) *big.Int {
ret := new(big.Int)
if len(bytes) > 0 && bytes[0]&0x80 == 0x80 {
// This is a negative number.
notBytes := make([]byte, len(bytes))
for i := range notBytes {
notBytes[i] = ^bytes[i]
}
ret.SetBytes(notBytes)
ret.Add(ret, bigOne)
ret.Neg(ret)
return ret
}
ret.SetBytes(bytes)
return ret
}
// BIT STRING // BIT STRING
// BitString is the structure to use when you want an ASN.1 BIT STRING type. A // BitString is the structure to use when you want an ASN.1 BIT STRING type. A
@ -127,7 +149,7 @@ func (b BitString) RightAlign() []byte {
return a return a
} }
// parseBitString parses an ASN.1 bit string from the given byte array and returns it. // parseBitString parses an ASN.1 bit string from the given byte slice and returns it.
func parseBitString(bytes []byte) (ret BitString, err os.Error) { func parseBitString(bytes []byte) (ret BitString, err os.Error) {
if len(bytes) == 0 { if len(bytes) == 0 {
err = SyntaxError{"zero length BIT STRING"} err = SyntaxError{"zero length BIT STRING"}
@ -164,9 +186,9 @@ func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
return true return true
} }
// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and // parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
// returns it. An object identifer is a sequence of variable length integers // returns it. An object identifier is a sequence of variable length integers
// that are assigned in a hierarachy. // that are assigned in a hierarchy.
func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) { func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
if len(bytes) == 0 { if len(bytes) == 0 {
err = SyntaxError{"zero length OBJECT IDENTIFIER"} err = SyntaxError{"zero length OBJECT IDENTIFIER"}
@ -198,14 +220,13 @@ func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
// An Enumerated is represented as a plain int. // An Enumerated is represented as a plain int.
type Enumerated int type Enumerated int
// FLAG // FLAG
// A Flag accepts any data and is set to true if present. // A Flag accepts any data and is set to true if present.
type Flag bool type Flag bool
// parseBase128Int parses a base-128 encoded int from the given offset in the // parseBase128Int parses a base-128 encoded int from the given offset in the
// given byte array. It returns the value and the new offset. // given byte slice. It returns the value and the new offset.
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) { func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
offset = initOffset offset = initOffset
for shifted := 0; offset < len(bytes); shifted++ { for shifted := 0; offset < len(bytes); shifted++ {
@ -237,7 +258,7 @@ func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
return return
} }
// parseGeneralizedTime parses the GeneralizedTime from the given byte array // parseGeneralizedTime parses the GeneralizedTime from the given byte slice
// and returns the resulting time. // and returns the resulting time.
func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) { func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
return time.Parse("20060102150405Z0700", string(bytes)) return time.Parse("20060102150405Z0700", string(bytes))
@ -269,7 +290,7 @@ func isPrintable(b byte) bool {
b == ':' || b == ':' ||
b == '=' || b == '=' ||
b == '?' || b == '?' ||
// This is techincally not allowed in a PrintableString. // This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't // However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it. // always use the correct string type so we permit it.
b == '*' b == '*'
@ -278,7 +299,7 @@ func isPrintable(b byte) bool {
// IA5String // IA5String
// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given // parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
// byte array and returns it. // byte slice and returns it.
func parseIA5String(bytes []byte) (ret string, err os.Error) { func parseIA5String(bytes []byte) (ret string, err os.Error) {
for _, b := range bytes { for _, b := range bytes {
if b >= 0x80 { if b >= 0x80 {
@ -293,11 +314,19 @@ func parseIA5String(bytes []byte) (ret string, err os.Error) {
// T61String // T61String
// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given // parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
// byte array and returns it. // byte slice and returns it.
func parseT61String(bytes []byte) (ret string, err os.Error) { func parseT61String(bytes []byte) (ret string, err os.Error) {
return string(bytes), nil return string(bytes), nil
} }
// UTF8String
// parseUTF8String parses a ASN.1 UTF8String (raw UTF-8) from the given byte
// array and returns it.
func parseUTF8String(bytes []byte) (ret string, err os.Error) {
return string(bytes), nil
}
// A RawValue represents an undecoded ASN.1 object. // A RawValue represents an undecoded ASN.1 object.
type RawValue struct { type RawValue struct {
Class, Tag int Class, Tag int
@ -314,7 +343,7 @@ type RawContent []byte
// Tagging // Tagging
// parseTagAndLength parses an ASN.1 tag and length pair from the given offset // parseTagAndLength parses an ASN.1 tag and length pair from the given offset
// into a byte array. It returns the parsed data and the new offset. SET and // into a byte slice. It returns the parsed data and the new offset. SET and
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we // SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
// don't distinguish between ordered and unordered objects in this code. // don't distinguish between ordered and unordered objects in this code.
func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) { func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
@ -371,7 +400,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i
} }
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse // parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
// a number of ASN.1 values from the given byte array and returns them as a // a number of ASN.1 values from the given byte slice and returns them as a
// slice of Go values of the given type. // slice of Go values of the given type.
func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) { func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
expectedTag, compoundType, ok := getUniversalType(elemType) expectedTag, compoundType, ok := getUniversalType(elemType)
@ -425,6 +454,7 @@ var (
timeType = reflect.TypeOf(&time.Time{}) timeType = reflect.TypeOf(&time.Time{})
rawValueType = reflect.TypeOf(RawValue{}) rawValueType = reflect.TypeOf(RawValue{})
rawContentsType = reflect.TypeOf(RawContent(nil)) rawContentsType = reflect.TypeOf(RawContent(nil))
bigIntType = reflect.TypeOf(new(big.Int))
) )
// invalidLength returns true iff offset + length > sliceLength, or if the // invalidLength returns true iff offset + length > sliceLength, or if the
@ -433,7 +463,7 @@ func invalidLength(offset, length, sliceLength int) bool {
return offset+length < offset || offset+length > sliceLength return offset+length < offset || offset+length > sliceLength
} }
// parseField is the main parsing function. Given a byte array and an offset // parseField is the main parsing function. Given a byte slice and an offset
// into the array, it will try to parse a suitable ASN.1 value out and store it // into the array, it will try to parse a suitable ASN.1 value out and store it
// in the given Value. // in the given Value.
func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) { func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
@ -550,16 +580,15 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
} }
// Special case for strings: PrintableString and IA5String both map to // Special case for strings: all the ASN.1 string types map to the Go
// the Go type string. getUniversalType returns the tag for // type string. getUniversalType returns the tag for PrintableString
// PrintableString when it sees a string so, if we see an IA5String on // when it sees a string, so if we see a different string type on the
// the wire, we change the universal type to match. // wire, we change the universal type to match.
if universalTag == tagPrintableString && t.tag == tagIA5String { if universalTag == tagPrintableString {
universalTag = tagIA5String switch t.tag {
} case tagIA5String, tagGeneralString, tagT61String, tagUTF8String:
// Likewise for GeneralString universalTag = t.tag
if universalTag == tagPrintableString && t.tag == tagGeneralString { }
universalTag = tagGeneralString
} }
// Special case for time: UTCTime and GeneralizedTime both map to the // Special case for time: UTCTime and GeneralizedTime both map to the
@ -639,6 +668,10 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
case flagType: case flagType:
v.SetBool(true) v.SetBool(true)
return return
case bigIntType:
parsedInt := parseBigInt(innerBytes)
v.Set(reflect.ValueOf(parsedInt))
return
} }
switch val := v; val.Kind() { switch val := v; val.Kind() {
case reflect.Bool: case reflect.Bool:
@ -648,23 +681,21 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
err = err1 err = err1
return return
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32:
switch val.Type().Kind() { parsedInt, err1 := parseInt(innerBytes)
case reflect.Int: if err1 == nil {
parsedInt, err1 := parseInt(innerBytes) val.SetInt(int64(parsedInt))
if err1 == nil {
val.SetInt(int64(parsedInt))
}
err = err1
return
case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
return
} }
err = err1
return
case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
return
// TODO(dfc) Add support for the remaining integer types
case reflect.Struct: case reflect.Struct:
structType := fieldType structType := fieldType
@ -680,7 +711,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
if i == 0 && field.Type == rawContentsType { if i == 0 && field.Type == rawContentsType {
continue continue
} }
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag)) innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1")))
if err != nil { if err != nil {
return return
} }
@ -711,6 +742,8 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
v, err = parseIA5String(innerBytes) v, err = parseIA5String(innerBytes)
case tagT61String: case tagT61String:
v, err = parseT61String(innerBytes) v, err = parseT61String(innerBytes)
case tagUTF8String:
v, err = parseUTF8String(innerBytes)
case tagGeneralString: case tagGeneralString:
// GeneralString is specified in ISO-2022/ECMA-35, // GeneralString is specified in ISO-2022/ECMA-35,
// A brief review suggests that it includes structures // A brief review suggests that it includes structures
@ -725,7 +758,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
return return
} }
err = StructuralError{"unknown Go type"} err = StructuralError{"unsupported: " + v.Type().String()}
return return
} }
@ -752,7 +785,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// Because Unmarshal uses the reflect package, the structs // Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names. // being written to must use upper case field names.
// //
// An ASN.1 INTEGER can be written to an int or int64. // An ASN.1 INTEGER can be written to an int, int32 or int64.
// If the encoded value does not fit in the Go type, // If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error. // Unmarshal returns a parse error.
// //

View File

@ -42,6 +42,64 @@ func TestParseInt64(t *testing.T) {
} }
} }
type int32Test struct {
in []byte
ok bool
out int32
}
var int32TestData = []int32Test{
{[]byte{0x00}, true, 0},
{[]byte{0x7f}, true, 127},
{[]byte{0x00, 0x80}, true, 128},
{[]byte{0x01, 0x00}, true, 256},
{[]byte{0x80}, true, -128},
{[]byte{0xff, 0x7f}, true, -129},
{[]byte{0xff, 0xff, 0xff, 0xff}, true, -1},
{[]byte{0xff}, true, -1},
{[]byte{0x80, 0x00, 0x00, 0x00}, true, -2147483648},
{[]byte{0x80, 0x00, 0x00, 0x00, 0x00}, false, 0},
}
func TestParseInt32(t *testing.T) {
for i, test := range int32TestData {
ret, err := parseInt(test.in)
if (err == nil) != test.ok {
t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok)
}
if test.ok && int32(ret) != test.out {
t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out)
}
}
}
var bigIntTests = []struct {
in []byte
base10 string
}{
{[]byte{0xff}, "-1"},
{[]byte{0x00}, "0"},
{[]byte{0x01}, "1"},
{[]byte{0x00, 0xff}, "255"},
{[]byte{0xff, 0x00}, "-256"},
{[]byte{0x01, 0x00}, "256"},
}
func TestParseBigInt(t *testing.T) {
for i, test := range bigIntTests {
ret := parseBigInt(test.in)
if ret.String() != test.base10 {
t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
}
fw := newForkableWriter()
marshalBigInt(fw, ret)
result := fw.Bytes()
if !bytes.Equal(result, test.in) {
t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
}
}
}
type bitStringTest struct { type bitStringTest struct {
in []byte in []byte
ok bool ok bool
@ -148,10 +206,10 @@ type timeTest struct {
} }
var utcTestData = []timeTest{ var utcTestData = []timeTest{
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}}, {"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, -7 * 60 * 60, ""}},
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}}, {"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, 7*60*60 + 30*60, ""}},
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, "UTC"}}, {"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, 0, "UTC"}},
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, "UTC"}}, {"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, 0, "UTC"}},
{"a10506234540Z", false, nil}, {"a10506234540Z", false, nil},
{"91a506234540Z", false, nil}, {"91a506234540Z", false, nil},
{"9105a6234540Z", false, nil}, {"9105a6234540Z", false, nil},
@ -177,10 +235,10 @@ func TestUTCTime(t *testing.T) {
} }
var generalizedTimeTestData = []timeTest{ var generalizedTimeTestData = []timeTest{
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, "UTC"}}, {"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 0, "UTC"}},
{"20100102030405", false, nil}, {"20100102030405", false, nil},
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 6*60*60 + 7*60, ""}}, {"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 6*60*60 + 7*60, ""}},
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, -6*60*60 - 7*60, ""}}, {"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, -6*60*60 - 7*60, ""}},
} }
func TestGeneralizedTime(t *testing.T) { func TestGeneralizedTime(t *testing.T) {
@ -272,11 +330,11 @@ type TestObjectIdentifierStruct struct {
} }
type TestContextSpecificTags struct { type TestContextSpecificTags struct {
A int "tag:1" A int `asn1:"tag:1"`
} }
type TestContextSpecificTags2 struct { type TestContextSpecificTags2 struct {
A int "explicit,tag:1" A int `asn1:"explicit,tag:1"`
B int B int
} }
@ -326,7 +384,7 @@ type Certificate struct {
} }
type TBSCertificate struct { type TBSCertificate struct {
Version int "optional,explicit,default:0,tag:0" Version int `asn1:"optional,explicit,default:0,tag:0"`
SerialNumber RawValue SerialNumber RawValue
SignatureAlgorithm AlgorithmIdentifier SignatureAlgorithm AlgorithmIdentifier
Issuer RDNSequence Issuer RDNSequence

View File

@ -10,7 +10,7 @@ import (
"strings" "strings"
) )
// ASN.1 objects have metadata preceeding them: // ASN.1 objects have metadata preceding them:
// the tag: the type of the object // the tag: the type of the object
// a flag denoting if this object is compound or not // a flag denoting if this object is compound or not
// the class type: the namespace of the tag // the class type: the namespace of the tag
@ -25,6 +25,7 @@ const (
tagOctetString = 4 tagOctetString = 4
tagOID = 6 tagOID = 6
tagEnum = 10 tagEnum = 10
tagUTF8String = 12
tagSequence = 16 tagSequence = 16
tagSet = 17 tagSet = 17
tagPrintableString = 19 tagPrintableString = 19
@ -83,7 +84,7 @@ type fieldParameters struct {
// parseFieldParameters will parse it into a fieldParameters structure, // parseFieldParameters will parse it into a fieldParameters structure,
// ignoring unknown parts of the string. // ignoring unknown parts of the string.
func parseFieldParameters(str string) (ret fieldParameters) { func parseFieldParameters(str string) (ret fieldParameters) {
for _, part := range strings.Split(str, ",", -1) { for _, part := range strings.Split(str, ",") {
switch { switch {
case part == "optional": case part == "optional":
ret.optional = true ret.optional = true
@ -132,6 +133,8 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
return tagUTCTime, false, true return tagUTCTime, false, true
case enumeratedType: case enumeratedType:
return tagEnum, false, true return tagEnum, false, true
case bigIntType:
return tagInteger, false, true
} }
switch t.Kind() { switch t.Kind() {
case reflect.Bool: case reflect.Bool:

View File

@ -5,6 +5,7 @@
package asn1 package asn1
import ( import (
"big"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
@ -125,6 +126,43 @@ func int64Length(i int64) (numBytes int) {
return return
} }
func marshalBigInt(out *forkableWriter, n *big.Int) (err os.Error) {
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
err = out.WriteByte(0xff)
if err != nil {
return
}
}
_, err = out.Write(bytes)
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
err = out.WriteByte(0x00)
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
err = out.WriteByte(0)
if err != nil {
return
}
}
_, err = out.Write(bytes)
}
return
}
func marshalLength(out *forkableWriter, i int) (err os.Error) { func marshalLength(out *forkableWriter, i int) (err os.Error) {
n := lengthLength(i) n := lengthLength(i)
@ -334,6 +372,8 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
return marshalBitString(out, value.Interface().(BitString)) return marshalBitString(out, value.Interface().(BitString))
case objectIdentifierType: case objectIdentifierType:
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
case bigIntType:
return marshalBigInt(out, value.Interface().(*big.Int))
} }
switch v := value; v.Kind() { switch v := value; v.Kind() {
@ -351,7 +391,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
startingField := 0 startingField := 0
// If the first element of the structure is a non-empty // If the first element of the structure is a non-empty
// RawContents, then we don't bother serialising the rest. // RawContents, then we don't bother serializing the rest.
if t.NumField() > 0 && t.Field(0).Type == rawContentsType { if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
s := v.Field(0) s := v.Field(0)
if s.Len() > 0 { if s.Len() > 0 {
@ -361,7 +401,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
} }
/* The RawContents will contain the tag and /* The RawContents will contain the tag and
* length fields but we'll also be writing * length fields but we'll also be writing
* those outselves, so we strip them out of * those ourselves, so we strip them out of
* bytes */ * bytes */
_, err = out.Write(stripTagAndLength(bytes)) _, err = out.Write(stripTagAndLength(bytes))
return return
@ -373,7 +413,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
for i := startingField; i < t.NumField(); i++ { for i := startingField; i < t.NumField(); i++ {
var pre *forkableWriter var pre *forkableWriter
pre, out = out.fork() pre, out = out.fork()
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag)) err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
if err != nil { if err != nil {
return return
} }
@ -418,6 +458,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return marshalField(out, v.Elem(), params) return marshalField(out, v.Elem(), params)
} }
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
if v.Type() == rawValueType { if v.Type() == rawValueType {
rv := v.Interface().(RawValue) rv := v.Interface().(RawValue)
err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}) err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
@ -428,10 +472,6 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return return
} }
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
tag, isCompound, ok := getUniversalType(v.Type()) tag, isCompound, ok := getUniversalType(v.Type())
if !ok { if !ok {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}

View File

@ -30,19 +30,23 @@ type rawContentsStruct struct {
} }
type implicitTagTest struct { type implicitTagTest struct {
A int "implicit,tag:5" A int `asn1:"implicit,tag:5"`
} }
type explicitTagTest struct { type explicitTagTest struct {
A int "explicit,tag:5" A int `asn1:"explicit,tag:5"`
} }
type ia5StringTest struct { type ia5StringTest struct {
A string "ia5" A string `asn1:"ia5"`
} }
type printableStringTest struct { type printableStringTest struct {
A string "printable" A string `asn1:"printable"`
}
type optionalRawValueTest struct {
A RawValue `asn1:"optional"`
} }
type testSET []int type testSET []int
@ -102,6 +106,7 @@ var marshalTests = []marshalTest{
"7878787878787878787878787878787878787878787878787878787878787878", "7878787878787878787878787878787878787878787878787878787878787878",
}, },
{ia5StringTest{"test"}, "3006160474657374"}, {ia5StringTest{"test"}, "3006160474657374"},
{optionalRawValueTest{}, "3000"},
{printableStringTest{"test"}, "3006130474657374"}, {printableStringTest{"test"}, "3006130474657374"},
{printableStringTest{"test*"}, "30071305746573742a"}, {printableStringTest{"test*"}, "30071305746573742a"},
{rawContentsStruct{nil, 64}, "3003020140"}, {rawContentsStruct{nil, 64}, "3003020140"},

View File

@ -27,7 +27,6 @@ const (
_M2 = _B2 - 1 // half digit mask _M2 = _B2 - 1 // half digit mask
) )
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Elementary operations on words // Elementary operations on words
// //
@ -43,7 +42,6 @@ func addWW_g(x, y, c Word) (z1, z0 Word) {
return return
} }
// z1<<_W + z0 = x-y-c, with c == 0 or 1 // z1<<_W + z0 = x-y-c, with c == 0 or 1
func subWW_g(x, y, c Word) (z1, z0 Word) { func subWW_g(x, y, c Word) (z1, z0 Word) {
yc := y + c yc := y + c
@ -54,7 +52,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) {
return return
} }
// z1<<_W + z0 = x*y // z1<<_W + z0 = x*y
func mulWW(x, y Word) (z1, z0 Word) { return mulWW_g(x, y) } func mulWW(x, y Word) (z1, z0 Word) { return mulWW_g(x, y) }
// Adapted from Warren, Hacker's Delight, p. 132. // Adapted from Warren, Hacker's Delight, p. 132.
@ -73,7 +70,6 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
return return
} }
// z1<<_W + z0 = x*y + c // z1<<_W + z0 = x*y + c
func mulAddWWW_g(x, y, c Word) (z1, z0 Word) { func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
z1, zz0 := mulWW(x, y) z1, zz0 := mulWW(x, y)
@ -83,7 +79,6 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
return return
} }
// Length of x in bits. // Length of x in bits.
func bitLen(x Word) (n int) { func bitLen(x Word) (n int) {
for ; x >= 0x100; x >>= 8 { for ; x >= 0x100; x >>= 8 {
@ -95,7 +90,6 @@ func bitLen(x Word) (n int) {
return return
} }
// log2 computes the integer binary logarithm of x. // log2 computes the integer binary logarithm of x.
// The result is the integer n for which 2^n <= x < 2^(n+1). // The result is the integer n for which 2^n <= x < 2^(n+1).
// If x == 0, the result is -1. // If x == 0, the result is -1.
@ -103,13 +97,11 @@ func log2(x Word) int {
return bitLen(x) - 1 return bitLen(x) - 1
} }
// Number of leading zeros in x. // Number of leading zeros in x.
func leadingZeros(x Word) uint { func leadingZeros(x Word) uint {
return uint(_W - bitLen(x)) return uint(_W - bitLen(x))
} }
// q = (u1<<_W + u0 - r)/y // q = (u1<<_W + u0 - r)/y
func divWW(x1, x0, y Word) (q, r Word) { return divWW_g(x1, x0, y) } func divWW(x1, x0, y Word) (q, r Word) { return divWW_g(x1, x0, y) }
// Adapted from Warren, Hacker's Delight, p. 152. // Adapted from Warren, Hacker's Delight, p. 152.
@ -155,7 +147,6 @@ again2:
return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s
} }
func addVV(z, x, y []Word) (c Word) { return addVV_g(z, x, y) } func addVV(z, x, y []Word) (c Word) { return addVV_g(z, x, y) }
func addVV_g(z, x, y []Word) (c Word) { func addVV_g(z, x, y []Word) (c Word) {
for i := range z { for i := range z {
@ -164,7 +155,6 @@ func addVV_g(z, x, y []Word) (c Word) {
return return
} }
func subVV(z, x, y []Word) (c Word) { return subVV_g(z, x, y) } func subVV(z, x, y []Word) (c Word) { return subVV_g(z, x, y) }
func subVV_g(z, x, y []Word) (c Word) { func subVV_g(z, x, y []Word) (c Word) {
for i := range z { for i := range z {
@ -173,7 +163,6 @@ func subVV_g(z, x, y []Word) (c Word) {
return return
} }
func addVW(z, x []Word, y Word) (c Word) { return addVW_g(z, x, y) } func addVW(z, x []Word, y Word) (c Word) { return addVW_g(z, x, y) }
func addVW_g(z, x []Word, y Word) (c Word) { func addVW_g(z, x []Word, y Word) (c Word) {
c = y c = y
@ -183,7 +172,6 @@ func addVW_g(z, x []Word, y Word) (c Word) {
return return
} }
func subVW(z, x []Word, y Word) (c Word) { return subVW_g(z, x, y) } func subVW(z, x []Word, y Word) (c Word) { return subVW_g(z, x, y) }
func subVW_g(z, x []Word, y Word) (c Word) { func subVW_g(z, x []Word, y Word) (c Word) {
c = y c = y
@ -193,9 +181,8 @@ func subVW_g(z, x []Word, y Word) (c Word) {
return return
} }
func shlVU(z, x []Word, s uint) (c Word) { return shlVU_g(z, x, s) }
func shlVW(z, x []Word, s Word) (c Word) { return shlVW_g(z, x, s) } func shlVU_g(z, x []Word, s uint) (c Word) {
func shlVW_g(z, x []Word, s Word) (c Word) {
if n := len(z); n > 0 { if n := len(z); n > 0 {
ŝ := _W - s ŝ := _W - s
w1 := x[n-1] w1 := x[n-1]
@ -210,9 +197,8 @@ func shlVW_g(z, x []Word, s Word) (c Word) {
return return
} }
func shrVU(z, x []Word, s uint) (c Word) { return shrVU_g(z, x, s) }
func shrVW(z, x []Word, s Word) (c Word) { return shrVW_g(z, x, s) } func shrVU_g(z, x []Word, s uint) (c Word) {
func shrVW_g(z, x []Word, s Word) (c Word) {
if n := len(z); n > 0 { if n := len(z); n > 0 {
ŝ := _W - s ŝ := _W - s
w1 := x[0] w1 := x[0]
@ -227,7 +213,6 @@ func shrVW_g(z, x []Word, s Word) (c Word) {
return return
} }
func mulAddVWW(z, x []Word, y, r Word) (c Word) { return mulAddVWW_g(z, x, y, r) } func mulAddVWW(z, x []Word, y, r Word) (c Word) { return mulAddVWW_g(z, x, y, r) }
func mulAddVWW_g(z, x []Word, y, r Word) (c Word) { func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
c = r c = r
@ -237,7 +222,6 @@ func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
return return
} }
func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVW_g(z, x, y) } func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVW_g(z, x, y) }
func addMulVVW_g(z, x []Word, y Word) (c Word) { func addMulVVW_g(z, x []Word, y Word) (c Word) {
for i := range z { for i := range z {
@ -248,7 +232,6 @@ func addMulVVW_g(z, x []Word, y Word) (c Word) {
return return
} }
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { return divWVW_g(z, xn, x, y) } func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { return divWVW_g(z, xn, x, y) }
func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) { func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) {
r = xn r = xn

View File

@ -11,8 +11,8 @@ func addVV(z, x, y []Word) (c Word)
func subVV(z, x, y []Word) (c Word) func subVV(z, x, y []Word) (c Word)
func addVW(z, x []Word, y Word) (c Word) func addVW(z, x []Word, y Word) (c Word)
func subVW(z, x []Word, y Word) (c Word) func subVW(z, x []Word, y Word) (c Word)
func shlVW(z, x []Word, s Word) (c Word) func shlVU(z, x []Word, s uint) (c Word)
func shrVW(z, x []Word, s Word) (c Word) func shrVU(z, x []Word, s uint) (c Word)
func mulAddVWW(z, x []Word, y, r Word) (c Word) func mulAddVWW(z, x []Word, y, r Word) (c Word)
func addMulVVW(z, x []Word, y Word) (c Word) func addMulVVW(z, x []Word, y Word) (c Word)
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) func divWVW(z []Word, xn Word, x []Word, y Word) (r Word)

View File

@ -6,7 +6,6 @@ package big
import "testing" import "testing"
type funWW func(x, y, c Word) (z1, z0 Word) type funWW func(x, y, c Word) (z1, z0 Word)
type argWW struct { type argWW struct {
x, y, c, z1, z0 Word x, y, c, z1, z0 Word
@ -26,7 +25,6 @@ var sumWW = []argWW{
{_M, _M, 1, 1, _M}, {_M, _M, 1, 1, _M},
} }
func testFunWW(t *testing.T, msg string, f funWW, a argWW) { func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
z1, z0 := f(a.x, a.y, a.c) z1, z0 := f(a.x, a.y, a.c)
if z1 != a.z1 || z0 != a.z0 { if z1 != a.z1 || z0 != a.z0 {
@ -34,7 +32,6 @@ func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
} }
} }
func TestFunWW(t *testing.T) { func TestFunWW(t *testing.T) {
for _, a := range sumWW { for _, a := range sumWW {
arg := a arg := a
@ -51,7 +48,6 @@ func TestFunWW(t *testing.T) {
} }
} }
type funVV func(z, x, y []Word) (c Word) type funVV func(z, x, y []Word) (c Word)
type argVV struct { type argVV struct {
z, x, y nat z, x, y nat
@ -70,7 +66,6 @@ var sumVV = []argVV{
{nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1}, {nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1},
} }
func testFunVV(t *testing.T, msg string, f funVV, a argVV) { func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
z := make(nat, len(a.z)) z := make(nat, len(a.z))
c := f(z, a.x, a.y) c := f(z, a.x, a.y)
@ -85,7 +80,6 @@ func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
} }
} }
func TestFunVV(t *testing.T) { func TestFunVV(t *testing.T) {
for _, a := range sumVV { for _, a := range sumVV {
arg := a arg := a
@ -106,7 +100,6 @@ func TestFunVV(t *testing.T) {
} }
} }
type funVW func(z, x []Word, y Word) (c Word) type funVW func(z, x []Word, y Word) (c Word)
type argVW struct { type argVW struct {
z, x nat z, x nat
@ -169,7 +162,6 @@ var rshVW = []argVW{
{nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M}, {nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M},
} }
func testFunVW(t *testing.T, msg string, f funVW, a argVW) { func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
z := make(nat, len(a.z)) z := make(nat, len(a.z))
c := f(z, a.x, a.y) c := f(z, a.x, a.y)
@ -184,6 +176,11 @@ func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
} }
} }
func makeFunVW(f func(z, x []Word, s uint) (c Word)) funVW {
return func(z, x []Word, s Word) (c Word) {
return f(z, x, uint(s))
}
}
func TestFunVW(t *testing.T) { func TestFunVW(t *testing.T) {
for _, a := range sumVW { for _, a := range sumVW {
@ -196,20 +193,23 @@ func TestFunVW(t *testing.T) {
testFunVW(t, "subVW", subVW, arg) testFunVW(t, "subVW", subVW, arg)
} }
shlVW_g := makeFunVW(shlVU_g)
shlVW := makeFunVW(shlVU)
for _, a := range lshVW { for _, a := range lshVW {
arg := a arg := a
testFunVW(t, "shlVW_g", shlVW_g, arg) testFunVW(t, "shlVU_g", shlVW_g, arg)
testFunVW(t, "shlVW", shlVW, arg) testFunVW(t, "shlVU", shlVW, arg)
} }
shrVW_g := makeFunVW(shrVU_g)
shrVW := makeFunVW(shrVU)
for _, a := range rshVW { for _, a := range rshVW {
arg := a arg := a
testFunVW(t, "shrVW_g", shrVW_g, arg) testFunVW(t, "shrVU_g", shrVW_g, arg)
testFunVW(t, "shrVW", shrVW, arg) testFunVW(t, "shrVU", shrVW, arg)
} }
} }
type funVWW func(z, x []Word, y, r Word) (c Word) type funVWW func(z, x []Word, y, r Word) (c Word)
type argVWW struct { type argVWW struct {
z, x nat z, x nat
@ -243,7 +243,6 @@ var prodVWW = []argVWW{
{nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)}, {nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)},
} }
func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) { func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
z := make(nat, len(a.z)) z := make(nat, len(a.z))
c := f(z, a.x, a.y, a.r) c := f(z, a.x, a.y, a.r)
@ -258,7 +257,6 @@ func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
} }
} }
// TODO(gri) mulAddVWW and divWVW are symmetric operations but // TODO(gri) mulAddVWW and divWVW are symmetric operations but
// their signature is not symmetric. Try to unify. // their signature is not symmetric. Try to unify.
@ -285,7 +283,6 @@ func testFunWVW(t *testing.T, msg string, f funWVW, a argWVW) {
} }
} }
func TestFunVWW(t *testing.T) { func TestFunVWW(t *testing.T) {
for _, a := range prodVWW { for _, a := range prodVWW {
arg := a arg := a
@ -300,7 +297,6 @@ func TestFunVWW(t *testing.T) {
} }
} }
var mulWWTests = []struct { var mulWWTests = []struct {
x, y Word x, y Word
q, r Word q, r Word
@ -309,7 +305,6 @@ var mulWWTests = []struct {
// 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4}, // 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4},
} }
func TestMulWW(t *testing.T) { func TestMulWW(t *testing.T) {
for i, test := range mulWWTests { for i, test := range mulWWTests {
q, r := mulWW_g(test.x, test.y) q, r := mulWW_g(test.x, test.y)
@ -319,7 +314,6 @@ func TestMulWW(t *testing.T) {
} }
} }
var mulAddWWWTests = []struct { var mulAddWWWTests = []struct {
x, y, c Word x, y, c Word
q, r Word q, r Word
@ -331,7 +325,6 @@ var mulAddWWWTests = []struct {
{_M, _M, _M, _M, 0}, {_M, _M, _M, _M, 0},
} }
func TestMulAddWWW(t *testing.T) { func TestMulAddWWW(t *testing.T) {
for i, test := range mulAddWWWTests { for i, test := range mulAddWWWTests {
q, r := mulAddWWW_g(test.x, test.y, test.c) q, r := mulAddWWW_g(test.x, test.y, test.c)

View File

@ -19,10 +19,8 @@ import (
"time" "time"
) )
var calibrate = flag.Bool("calibrate", false, "run calibration test") var calibrate = flag.Bool("calibrate", false, "run calibration test")
// measure returns the time to run f // measure returns the time to run f
func measure(f func()) int64 { func measure(f func()) int64 {
const N = 100 const N = 100
@ -34,7 +32,6 @@ func measure(f func()) int64 {
return (stop - start) / N return (stop - start) / N
} }
func computeThresholds() { func computeThresholds() {
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n") fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
fmt.Printf("(run repeatedly for good results)\n") fmt.Printf("(run repeatedly for good results)\n")
@ -84,7 +81,6 @@ func computeThresholds() {
} }
} }
func TestCalibrate(t *testing.T) { func TestCalibrate(t *testing.T) {
if *calibrate { if *calibrate {
computeThresholds() computeThresholds()

View File

@ -13,13 +13,11 @@ import (
"testing" "testing"
) )
type matrix struct { type matrix struct {
n, m int n, m int
a []*Rat a []*Rat
} }
func (a *matrix) at(i, j int) *Rat { func (a *matrix) at(i, j int) *Rat {
if !(0 <= i && i < a.n && 0 <= j && j < a.m) { if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
panic("index out of range") panic("index out of range")
@ -27,7 +25,6 @@ func (a *matrix) at(i, j int) *Rat {
return a.a[i*a.m+j] return a.a[i*a.m+j]
} }
func (a *matrix) set(i, j int, x *Rat) { func (a *matrix) set(i, j int, x *Rat) {
if !(0 <= i && i < a.n && 0 <= j && j < a.m) { if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
panic("index out of range") panic("index out of range")
@ -35,7 +32,6 @@ func (a *matrix) set(i, j int, x *Rat) {
a.a[i*a.m+j] = x a.a[i*a.m+j] = x
} }
func newMatrix(n, m int) *matrix { func newMatrix(n, m int) *matrix {
if !(0 <= n && 0 <= m) { if !(0 <= n && 0 <= m) {
panic("illegal matrix") panic("illegal matrix")
@ -47,7 +43,6 @@ func newMatrix(n, m int) *matrix {
return a return a
} }
func newUnit(n int) *matrix { func newUnit(n int) *matrix {
a := newMatrix(n, n) a := newMatrix(n, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@ -62,7 +57,6 @@ func newUnit(n int) *matrix {
return a return a
} }
func newHilbert(n int) *matrix { func newHilbert(n int) *matrix {
a := newMatrix(n, n) a := newMatrix(n, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@ -73,7 +67,6 @@ func newHilbert(n int) *matrix {
return a return a
} }
func newInverseHilbert(n int) *matrix { func newInverseHilbert(n int) *matrix {
a := newMatrix(n, n) a := newMatrix(n, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@ -98,7 +91,6 @@ func newInverseHilbert(n int) *matrix {
return a return a
} }
func (a *matrix) mul(b *matrix) *matrix { func (a *matrix) mul(b *matrix) *matrix {
if a.m != b.n { if a.m != b.n {
panic("illegal matrix multiply") panic("illegal matrix multiply")
@ -116,7 +108,6 @@ func (a *matrix) mul(b *matrix) *matrix {
return c return c
} }
func (a *matrix) eql(b *matrix) bool { func (a *matrix) eql(b *matrix) bool {
if a.n != b.n || a.m != b.m { if a.n != b.n || a.m != b.m {
return false return false
@ -131,7 +122,6 @@ func (a *matrix) eql(b *matrix) bool {
return true return true
} }
func (a *matrix) String() string { func (a *matrix) String() string {
s := "" s := ""
for i := 0; i < a.n; i++ { for i := 0; i < a.n; i++ {
@ -143,7 +133,6 @@ func (a *matrix) String() string {
return s return s
} }
func doHilbert(t *testing.T, n int) { func doHilbert(t *testing.T, n int) {
a := newHilbert(n) a := newHilbert(n)
b := newInverseHilbert(n) b := newInverseHilbert(n)
@ -160,12 +149,10 @@ func doHilbert(t *testing.T, n int) {
} }
} }
func TestHilbert(t *testing.T) { func TestHilbert(t *testing.T) {
doHilbert(t, 10) doHilbert(t, 10)
} }
func BenchmarkHilbert(b *testing.B) { func BenchmarkHilbert(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
doHilbert(nil, 10) doHilbert(nil, 10)

View File

@ -8,8 +8,10 @@ package big
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"rand" "rand"
"strings"
) )
// An Int represents a signed multi-precision integer. // An Int represents a signed multi-precision integer.
@ -19,10 +21,8 @@ type Int struct {
abs nat // absolute value of the integer abs nat // absolute value of the integer
} }
var intOne = &Int{false, natOne} var intOne = &Int{false, natOne}
// Sign returns: // Sign returns:
// //
// -1 if x < 0 // -1 if x < 0
@ -39,7 +39,6 @@ func (x *Int) Sign() int {
return 1 return 1
} }
// SetInt64 sets z to x and returns z. // SetInt64 sets z to x and returns z.
func (z *Int) SetInt64(x int64) *Int { func (z *Int) SetInt64(x int64) *Int {
neg := false neg := false
@ -52,13 +51,11 @@ func (z *Int) SetInt64(x int64) *Int {
return z return z
} }
// NewInt allocates and returns a new Int set to x. // NewInt allocates and returns a new Int set to x.
func NewInt(x int64) *Int { func NewInt(x int64) *Int {
return new(Int).SetInt64(x) return new(Int).SetInt64(x)
} }
// Set sets z to x and returns z. // Set sets z to x and returns z.
func (z *Int) Set(x *Int) *Int { func (z *Int) Set(x *Int) *Int {
z.abs = z.abs.set(x.abs) z.abs = z.abs.set(x.abs)
@ -66,7 +63,6 @@ func (z *Int) Set(x *Int) *Int {
return z return z
} }
// Abs sets z to |x| (the absolute value of x) and returns z. // Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Int) Abs(x *Int) *Int { func (z *Int) Abs(x *Int) *Int {
z.abs = z.abs.set(x.abs) z.abs = z.abs.set(x.abs)
@ -74,7 +70,6 @@ func (z *Int) Abs(x *Int) *Int {
return z return z
} }
// Neg sets z to -x and returns z. // Neg sets z to -x and returns z.
func (z *Int) Neg(x *Int) *Int { func (z *Int) Neg(x *Int) *Int {
z.abs = z.abs.set(x.abs) z.abs = z.abs.set(x.abs)
@ -82,7 +77,6 @@ func (z *Int) Neg(x *Int) *Int {
return z return z
} }
// Add sets z to the sum x+y and returns z. // Add sets z to the sum x+y and returns z.
func (z *Int) Add(x, y *Int) *Int { func (z *Int) Add(x, y *Int) *Int {
neg := x.neg neg := x.neg
@ -104,7 +98,6 @@ func (z *Int) Add(x, y *Int) *Int {
return z return z
} }
// Sub sets z to the difference x-y and returns z. // Sub sets z to the difference x-y and returns z.
func (z *Int) Sub(x, y *Int) *Int { func (z *Int) Sub(x, y *Int) *Int {
neg := x.neg neg := x.neg
@ -126,7 +119,6 @@ func (z *Int) Sub(x, y *Int) *Int {
return z return z
} }
// Mul sets z to the product x*y and returns z. // Mul sets z to the product x*y and returns z.
func (z *Int) Mul(x, y *Int) *Int { func (z *Int) Mul(x, y *Int) *Int {
// x * y == x * y // x * y == x * y
@ -138,7 +130,6 @@ func (z *Int) Mul(x, y *Int) *Int {
return z return z
} }
// MulRange sets z to the product of all integers // MulRange sets z to the product of all integers
// in the range [a, b] inclusively and returns z. // in the range [a, b] inclusively and returns z.
// If a > b (empty range), the result is 1. // If a > b (empty range), the result is 1.
@ -162,7 +153,6 @@ func (z *Int) MulRange(a, b int64) *Int {
return z return z
} }
// Binomial sets z to the binomial coefficient of (n, k) and returns z. // Binomial sets z to the binomial coefficient of (n, k) and returns z.
func (z *Int) Binomial(n, k int64) *Int { func (z *Int) Binomial(n, k int64) *Int {
var a, b Int var a, b Int
@ -171,7 +161,6 @@ func (z *Int) Binomial(n, k int64) *Int {
return z.Quo(&a, &b) return z.Quo(&a, &b)
} }
// Quo sets z to the quotient x/y for y != 0 and returns z. // Quo sets z to the quotient x/y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// See QuoRem for more details. // See QuoRem for more details.
@ -181,7 +170,6 @@ func (z *Int) Quo(x, y *Int) *Int {
return z return z
} }
// Rem sets z to the remainder x%y for y != 0 and returns z. // Rem sets z to the remainder x%y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// See QuoRem for more details. // See QuoRem for more details.
@ -191,7 +179,6 @@ func (z *Int) Rem(x, y *Int) *Int {
return z return z
} }
// QuoRem sets z to the quotient x/y and r to the remainder x%y // QuoRem sets z to the quotient x/y and r to the remainder x%y
// and returns the pair (z, r) for y != 0. // and returns the pair (z, r) for y != 0.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
@ -209,7 +196,6 @@ func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
return z, r return z, r
} }
// Div sets z to the quotient x/y for y != 0 and returns z. // Div sets z to the quotient x/y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// See DivMod for more details. // See DivMod for more details.
@ -227,7 +213,6 @@ func (z *Int) Div(x, y *Int) *Int {
return z return z
} }
// Mod sets z to the modulus x%y for y != 0 and returns z. // Mod sets z to the modulus x%y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// See DivMod for more details. // See DivMod for more details.
@ -248,7 +233,6 @@ func (z *Int) Mod(x, y *Int) *Int {
return z return z
} }
// DivMod sets z to the quotient x div y and m to the modulus x mod y // DivMod sets z to the quotient x div y and m to the modulus x mod y
// and returns the pair (z, m) for y != 0. // and returns the pair (z, m) for y != 0.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
@ -281,7 +265,6 @@ func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
return z, m return z, m
} }
// Cmp compares x and y and returns: // Cmp compares x and y and returns:
// //
// -1 if x < y // -1 if x < y
@ -307,49 +290,197 @@ func (x *Int) Cmp(y *Int) (r int) {
return return
} }
func (x *Int) String() string { func (x *Int) String() string {
s := "" switch {
if x.neg { case x == nil:
s = "-" return "<nil>"
case x.neg:
return "-" + x.abs.decimalString()
} }
return s + x.abs.string(10) return x.abs.decimalString()
} }
func charset(ch int) string {
func fmtbase(ch int) int {
switch ch { switch ch {
case 'b': case 'b':
return 2 return lowercaseDigits[0:2]
case 'o': case 'o':
return 8 return lowercaseDigits[0:8]
case 'd': case 'd', 's', 'v':
return 10 return lowercaseDigits[0:10]
case 'x': case 'x':
return 16 return lowercaseDigits[0:16]
case 'X':
return uppercaseDigits[0:16]
} }
return 10 return "" // unknown format
} }
// write count copies of text to s
func writeMultiple(s fmt.State, text string, count int) {
if len(text) > 0 {
b := []byte(text)
for ; count > 0; count-- {
s.Write(b)
}
}
}
// Format is a support routine for fmt.Formatter. It accepts // Format is a support routine for fmt.Formatter. It accepts
// the formats 'b' (binary), 'o' (octal), 'd' (decimal) and // the formats 'b' (binary), 'o' (octal), 'd' (decimal), 'x'
// 'x' (hexadecimal). // (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
// Also supported are the full suite of package fmt's format
// verbs for integral types, including '+', '-', and ' '
// for sign control, '#' for leading zero in octal and for
// hexadecimal, a leading "0x" or "0X" for "%#x" and "%#X"
// respectively, specification of minimum digits precision,
// output field width, space or zero padding, and left or
// right justification.
// //
func (x *Int) Format(s fmt.State, ch int) { func (x *Int) Format(s fmt.State, ch int) {
if x == nil { cs := charset(ch)
// special cases
switch {
case cs == "":
// unknown format
fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String())
return
case x == nil:
fmt.Fprint(s, "<nil>") fmt.Fprint(s, "<nil>")
return return
} }
if x.neg {
fmt.Fprint(s, "-") // determine sign character
sign := ""
switch {
case x.neg:
sign = "-"
case s.Flag('+'): // supersedes ' ' when both specified
sign = "+"
case s.Flag(' '):
sign = " "
} }
fmt.Fprint(s, x.abs.string(fmtbase(ch)))
// determine prefix characters for indicating output base
prefix := ""
if s.Flag('#') {
switch ch {
case 'o': // octal
prefix = "0"
case 'x': // hexadecimal
prefix = "0x"
case 'X':
prefix = "0X"
}
}
// determine digits with base set by len(cs) and digit characters from cs
digits := x.abs.string(cs)
// number of characters for the three classes of number padding
var left int // space characters to left of digits for right justification ("%8d")
var zeroes int // zero characters (actually cs[0]) as left-most digits ("%.8d")
var right int // space characters to right of digits for left justification ("%-8d")
// determine number padding from precision: the least number of digits to output
precision, precisionSet := s.Precision()
if precisionSet {
switch {
case len(digits) < precision:
zeroes = precision - len(digits) // count of zero padding
case digits == "0" && precision == 0:
return // print nothing if zero value (x == 0) and zero precision ("." or ".0")
}
}
// determine field pad from width: the least number of characters to output
length := len(sign) + len(prefix) + zeroes + len(digits)
if width, widthSet := s.Width(); widthSet && length < width { // pad as specified
switch d := width - length; {
case s.Flag('-'):
// pad on the right with spaces; supersedes '0' when both specified
right = d
case s.Flag('0') && !precisionSet:
// pad with zeroes unless precision also specified
zeroes = d
default:
// pad on the left with spaces
left = d
}
}
// print number as [left pad][sign][prefix][zero pad][digits][right pad]
writeMultiple(s, " ", left)
writeMultiple(s, sign, 1)
writeMultiple(s, prefix, 1)
writeMultiple(s, "0", zeroes)
writeMultiple(s, digits, 1)
writeMultiple(s, " ", right)
} }
// scan sets z to the integer value corresponding to the longest possible prefix
// read from r representing a signed integer number in a given conversion base.
// It returns z, the actual conversion base used, and an error, if any. In the
// error case, the value of z is undefined. The syntax follows the syntax of
// integer literals in Go.
//
// The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
//
func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) {
// determine sign
ch, _, err := r.ReadRune()
if err != nil {
return z, 0, err
}
neg := false
switch ch {
case '-':
neg = true
case '+': // nothing to do
default:
r.UnreadRune()
}
// Int64 returns the int64 representation of z. // determine mantissa
// If z cannot be represented in an int64, the result is undefined. z.abs, base, err = z.abs.scan(r, base)
if err != nil {
return z, base, err
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z, base, nil
}
// Scan is a support routine for fmt.Scanner; it sets z to the value of
// the scanned number. It accepts the formats 'b' (binary), 'o' (octal),
// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
func (z *Int) Scan(s fmt.ScanState, ch int) os.Error {
s.SkipSpace() // skip leading space characters
base := 0
switch ch {
case 'b':
base = 2
case 'o':
base = 8
case 'd':
base = 10
case 'x', 'X':
base = 16
case 's', 'v':
// let scan determine the base
default:
return os.NewError("Int.Scan: invalid verb")
}
_, _, err := z.scan(s, base)
return err
}
// Int64 returns the int64 representation of x.
// If x cannot be represented in an int64, the result is undefined.
func (x *Int) Int64() int64 { func (x *Int) Int64() int64 {
if len(x.abs) == 0 { if len(x.abs) == 0 {
return 0 return 0
@ -364,40 +495,25 @@ func (x *Int) Int64() int64 {
return v return v
} }
// SetString sets z to the value of s, interpreted in the given base, // SetString sets z to the value of s, interpreted in the given base,
// and returns z and a boolean indicating success. If SetString fails, // and returns z and a boolean indicating success. If SetString fails,
// the value of z is undefined. // the value of z is undefined.
// //
// If the base argument is 0, the string prefix determines the actual // The base argument must be 0 or a value from 2 through MaxBase. If the base
// conversion base. A prefix of ``0x'' or ``0X'' selects base 16; the // is 0, the string prefix determines the actual conversion base. A prefix of
// ``0'' prefix selects base 8, and a ``0b'' or ``0B'' prefix selects // ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// base 2. Otherwise the selected base is 10. // ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
// //
func (z *Int) SetString(s string, base int) (*Int, bool) { func (z *Int) SetString(s string, base int) (*Int, bool) {
if len(s) == 0 || base < 0 || base == 1 || 16 < base { r := strings.NewReader(s)
_, _, err := z.scan(r, base)
if err != nil {
return z, false return z, false
} }
_, _, err = r.ReadRune()
neg := s[0] == '-' return z, err == os.EOF // err == os.EOF => scan consumed all of s
if neg || s[0] == '+' {
s = s[1:]
if len(s) == 0 {
return z, false
}
}
var scanned int
z.abs, _, scanned = z.abs.scan(s, base)
if scanned != len(s) {
return z, false
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z, true
} }
// SetBytes interprets buf as the bytes of a big-endian unsigned // SetBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z. // integer, sets z to that value, and returns z.
func (z *Int) SetBytes(buf []byte) *Int { func (z *Int) SetBytes(buf []byte) *Int {
@ -406,21 +522,18 @@ func (z *Int) SetBytes(buf []byte) *Int {
return z return z
} }
// Bytes returns the absolute value of z as a big-endian byte slice. // Bytes returns the absolute value of z as a big-endian byte slice.
func (z *Int) Bytes() []byte { func (z *Int) Bytes() []byte {
buf := make([]byte, len(z.abs)*_S) buf := make([]byte, len(z.abs)*_S)
return buf[z.abs.bytes(buf):] return buf[z.abs.bytes(buf):]
} }
// BitLen returns the length of the absolute value of z in bits. // BitLen returns the length of the absolute value of z in bits.
// The bit length of 0 is 0. // The bit length of 0 is 0.
func (z *Int) BitLen() int { func (z *Int) BitLen() int {
return z.abs.bitLen() return z.abs.bitLen()
} }
// Exp sets z = x**y mod m. If m is nil, z = x**y. // Exp sets z = x**y mod m. If m is nil, z = x**y.
// See Knuth, volume 2, section 4.6.3. // See Knuth, volume 2, section 4.6.3.
func (z *Int) Exp(x, y, m *Int) *Int { func (z *Int) Exp(x, y, m *Int) *Int {
@ -441,7 +554,6 @@ func (z *Int) Exp(x, y, m *Int) *Int {
return z return z
} }
// GcdInt sets d to the greatest common divisor of a and b, which must be // GcdInt sets d to the greatest common divisor of a and b, which must be
// positive numbers. // positive numbers.
// If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y. // If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y.
@ -500,7 +612,6 @@ func GcdInt(d, x, y, a, b *Int) {
*d = *A *d = *A
} }
// ProbablyPrime performs n Miller-Rabin tests to check whether z is prime. // ProbablyPrime performs n Miller-Rabin tests to check whether z is prime.
// If it returns true, z is prime with probability 1 - 1/4^n. // If it returns true, z is prime with probability 1 - 1/4^n.
// If it returns false, z is not prime. // If it returns false, z is not prime.
@ -508,7 +619,6 @@ func ProbablyPrime(z *Int, n int) bool {
return !z.neg && z.abs.probablyPrime(n) return !z.neg && z.abs.probablyPrime(n)
} }
// Rand sets z to a pseudo-random number in [0, n) and returns z. // Rand sets z to a pseudo-random number in [0, n) and returns z.
func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int { func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
z.neg = false z.neg = false
@ -520,7 +630,6 @@ func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
return z return z
} }
// ModInverse sets z to the multiplicative inverse of g in the group /p (where // ModInverse sets z to the multiplicative inverse of g in the group /p (where
// p is a prime) and returns z. // p is a prime) and returns z.
func (z *Int) ModInverse(g, p *Int) *Int { func (z *Int) ModInverse(g, p *Int) *Int {
@ -534,7 +643,6 @@ func (z *Int) ModInverse(g, p *Int) *Int {
return z return z
} }
// Lsh sets z = x << n and returns z. // Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int { func (z *Int) Lsh(x *Int, n uint) *Int {
z.abs = z.abs.shl(x.abs, n) z.abs = z.abs.shl(x.abs, n)
@ -542,7 +650,6 @@ func (z *Int) Lsh(x *Int, n uint) *Int {
return z return z
} }
// Rsh sets z = x >> n and returns z. // Rsh sets z = x >> n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int { func (z *Int) Rsh(x *Int, n uint) *Int {
if x.neg { if x.neg {
@ -559,6 +666,39 @@ func (z *Int) Rsh(x *Int, n uint) *Int {
return z return z
} }
// Bit returns the value of the i'th bit of z. That is, it
// returns (z>>i)&1. The bit index i must be >= 0.
func (z *Int) Bit(i int) uint {
if i < 0 {
panic("negative bit index")
}
if z.neg {
t := nat{}.sub(z.abs, natOne)
return t.bit(uint(i)) ^ 1
}
return z.abs.bit(uint(i))
}
// SetBit sets the i'th bit of z to bit and returns z.
// That is, if bit is 1 SetBit sets z = x | (1 << i);
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
// SetBit will panic.
func (z *Int) SetBit(x *Int, i int, b uint) *Int {
if i < 0 {
panic("negative bit index")
}
if x.neg {
t := z.abs.sub(x.abs, natOne)
t = t.setBit(t, uint(i), b^1)
z.abs = t.add(t, natOne)
z.neg = len(z.abs) > 0
return z
}
z.abs = z.abs.setBit(x.abs, uint(i), b)
z.neg = false
return z
}
// And sets z = x & y and returns z. // And sets z = x & y and returns z.
func (z *Int) And(x, y *Int) *Int { func (z *Int) And(x, y *Int) *Int {
@ -590,7 +730,6 @@ func (z *Int) And(x, y *Int) *Int {
return z return z
} }
// AndNot sets z = x &^ y and returns z. // AndNot sets z = x &^ y and returns z.
func (z *Int) AndNot(x, y *Int) *Int { func (z *Int) AndNot(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
@ -624,7 +763,6 @@ func (z *Int) AndNot(x, y *Int) *Int {
return z return z
} }
// Or sets z = x | y and returns z. // Or sets z = x | y and returns z.
func (z *Int) Or(x, y *Int) *Int { func (z *Int) Or(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
@ -655,7 +793,6 @@ func (z *Int) Or(x, y *Int) *Int {
return z return z
} }
// Xor sets z = x ^ y and returns z. // Xor sets z = x ^ y and returns z.
func (z *Int) Xor(x, y *Int) *Int { func (z *Int) Xor(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
@ -686,7 +823,6 @@ func (z *Int) Xor(x, y *Int) *Int {
return z return z
} }
// Not sets z = ^x and returns z. // Not sets z = ^x and returns z.
func (z *Int) Not(x *Int) *Int { func (z *Int) Not(x *Int) *Int {
if x.neg { if x.neg {
@ -702,15 +838,14 @@ func (z *Int) Not(x *Int) *Int {
return z return z
} }
// Gob codec version. Permits backward-compatible changes to the encoding. // Gob codec version. Permits backward-compatible changes to the encoding.
const version byte = 1 const intGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface. // GobEncode implements the gob.GobEncoder interface.
func (z *Int) GobEncode() ([]byte, os.Error) { func (z *Int) GobEncode() ([]byte, os.Error) {
buf := make([]byte, len(z.abs)*_S+1) // extra byte for version and sign bit buf := make([]byte, 1+len(z.abs)*_S) // extra byte for version and sign bit
i := z.abs.bytes(buf) - 1 // i >= 0 i := z.abs.bytes(buf) - 1 // i >= 0
b := version << 1 // make space for sign bit b := intGobVersion << 1 // make space for sign bit
if z.neg { if z.neg {
b |= 1 b |= 1
} }
@ -718,14 +853,13 @@ func (z *Int) GobEncode() ([]byte, os.Error) {
return buf[i:], nil return buf[i:], nil
} }
// GobDecode implements the gob.GobDecoder interface. // GobDecode implements the gob.GobDecoder interface.
func (z *Int) GobDecode(buf []byte) os.Error { func (z *Int) GobDecode(buf []byte) os.Error {
if len(buf) == 0 { if len(buf) == 0 {
return os.NewError("Int.GobDecode: no data") return os.NewError("Int.GobDecode: no data")
} }
b := buf[0] b := buf[0]
if b>>1 != version { if b>>1 != intGobVersion {
return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1)) return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1))
} }
z.neg = b&1 != 0 z.neg = b&1 != 0

View File

@ -13,7 +13,6 @@ import (
"testing/quick" "testing/quick"
) )
func isNormalized(x *Int) bool { func isNormalized(x *Int) bool {
if len(x.abs) == 0 { if len(x.abs) == 0 {
return !x.neg return !x.neg
@ -22,13 +21,11 @@ func isNormalized(x *Int) bool {
return x.abs[len(x.abs)-1] != 0 return x.abs[len(x.abs)-1] != 0
} }
type funZZ func(z, x, y *Int) *Int type funZZ func(z, x, y *Int) *Int
type argZZ struct { type argZZ struct {
z, x, y *Int z, x, y *Int
} }
var sumZZ = []argZZ{ var sumZZ = []argZZ{
{NewInt(0), NewInt(0), NewInt(0)}, {NewInt(0), NewInt(0), NewInt(0)},
{NewInt(1), NewInt(1), NewInt(0)}, {NewInt(1), NewInt(1), NewInt(0)},
@ -38,7 +35,6 @@ var sumZZ = []argZZ{
{NewInt(-1111111110), NewInt(-123456789), NewInt(-987654321)}, {NewInt(-1111111110), NewInt(-123456789), NewInt(-987654321)},
} }
var prodZZ = []argZZ{ var prodZZ = []argZZ{
{NewInt(0), NewInt(0), NewInt(0)}, {NewInt(0), NewInt(0), NewInt(0)},
{NewInt(0), NewInt(1), NewInt(0)}, {NewInt(0), NewInt(1), NewInt(0)},
@ -47,7 +43,6 @@ var prodZZ = []argZZ{
// TODO(gri) add larger products // TODO(gri) add larger products
} }
func TestSignZ(t *testing.T) { func TestSignZ(t *testing.T) {
var zero Int var zero Int
for _, a := range sumZZ { for _, a := range sumZZ {
@ -59,7 +54,6 @@ func TestSignZ(t *testing.T) {
} }
} }
func TestSetZ(t *testing.T) { func TestSetZ(t *testing.T) {
for _, a := range sumZZ { for _, a := range sumZZ {
var z Int var z Int
@ -73,7 +67,6 @@ func TestSetZ(t *testing.T) {
} }
} }
func TestAbsZ(t *testing.T) { func TestAbsZ(t *testing.T) {
var zero Int var zero Int
for _, a := range sumZZ { for _, a := range sumZZ {
@ -90,7 +83,6 @@ func TestAbsZ(t *testing.T) {
} }
} }
func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) { func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
var z Int var z Int
f(&z, a.x, a.y) f(&z, a.x, a.y)
@ -102,7 +94,6 @@ func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
} }
} }
func TestSumZZ(t *testing.T) { func TestSumZZ(t *testing.T) {
AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) } AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) }
SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) } SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) }
@ -121,7 +112,6 @@ func TestSumZZ(t *testing.T) {
} }
} }
func TestProdZZ(t *testing.T) { func TestProdZZ(t *testing.T) {
MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) } MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) }
for _, a := range prodZZ { for _, a := range prodZZ {
@ -133,7 +123,6 @@ func TestProdZZ(t *testing.T) {
} }
} }
// mulBytes returns x*y via grade school multiplication. Both inputs // mulBytes returns x*y via grade school multiplication. Both inputs
// and the result are assumed to be in big-endian representation (to // and the result are assumed to be in big-endian representation (to
// match the semantics of Int.Bytes and Int.SetBytes). // match the semantics of Int.Bytes and Int.SetBytes).
@ -166,7 +155,6 @@ func mulBytes(x, y []byte) []byte {
return z[i:] return z[i:]
} }
func checkMul(a, b []byte) bool { func checkMul(a, b []byte) bool {
var x, y, z1 Int var x, y, z1 Int
x.SetBytes(a) x.SetBytes(a)
@ -179,14 +167,12 @@ func checkMul(a, b []byte) bool {
return z1.Cmp(&z2) == 0 return z1.Cmp(&z2) == 0
} }
func TestMul(t *testing.T) { func TestMul(t *testing.T) {
if err := quick.Check(checkMul, nil); err != nil { if err := quick.Check(checkMul, nil); err != nil {
t.Error(err) t.Error(err)
} }
} }
var mulRangesZ = []struct { var mulRangesZ = []struct {
a, b int64 a, b int64
prod string prod string
@ -212,7 +198,6 @@ var mulRangesZ = []struct {
}, },
} }
func TestMulRangeZ(t *testing.T) { func TestMulRangeZ(t *testing.T) {
var tmp Int var tmp Int
// test entirely positive ranges // test entirely positive ranges
@ -231,7 +216,6 @@ func TestMulRangeZ(t *testing.T) {
} }
} }
var stringTests = []struct { var stringTests = []struct {
in string in string
out string out string
@ -280,7 +264,6 @@ var stringTests = []struct {
{"1001010111", "1001010111", 2, 0x257, true}, {"1001010111", "1001010111", 2, 0x257, true},
} }
func format(base int) string { func format(base int) string {
switch base { switch base {
case 2: case 2:
@ -293,7 +276,6 @@ func format(base int) string {
return "%d" return "%d"
} }
func TestGetString(t *testing.T) { func TestGetString(t *testing.T) {
z := new(Int) z := new(Int)
for i, test := range stringTests { for i, test := range stringTests {
@ -316,7 +298,6 @@ func TestGetString(t *testing.T) {
} }
} }
func TestSetString(t *testing.T) { func TestSetString(t *testing.T) {
tmp := new(Int) tmp := new(Int)
for i, test := range stringTests { for i, test := range stringTests {
@ -347,6 +328,212 @@ func TestSetString(t *testing.T) {
} }
} }
var formatTests = []struct {
input string
format string
output string
}{
{"<nil>", "%x", "<nil>"},
{"<nil>", "%#x", "<nil>"},
{"<nil>", "%#y", "%!y(big.Int=<nil>)"},
{"10", "%b", "1010"},
{"10", "%o", "12"},
{"10", "%d", "10"},
{"10", "%v", "10"},
{"10", "%x", "a"},
{"10", "%X", "A"},
{"-10", "%X", "-A"},
{"10", "%y", "%!y(big.Int=10)"},
{"-10", "%y", "%!y(big.Int=-10)"},
{"10", "%#b", "1010"},
{"10", "%#o", "012"},
{"10", "%#d", "10"},
{"10", "%#v", "10"},
{"10", "%#x", "0xa"},
{"10", "%#X", "0XA"},
{"-10", "%#X", "-0XA"},
{"10", "%#y", "%!y(big.Int=10)"},
{"-10", "%#y", "%!y(big.Int=-10)"},
{"1234", "%d", "1234"},
{"1234", "%3d", "1234"},
{"1234", "%4d", "1234"},
{"-1234", "%d", "-1234"},
{"1234", "% 5d", " 1234"},
{"1234", "%+5d", "+1234"},
{"1234", "%-5d", "1234 "},
{"1234", "%x", "4d2"},
{"1234", "%X", "4D2"},
{"-1234", "%3x", "-4d2"},
{"-1234", "%4x", "-4d2"},
{"-1234", "%5x", " -4d2"},
{"-1234", "%-5x", "-4d2 "},
{"1234", "%03d", "1234"},
{"1234", "%04d", "1234"},
{"1234", "%05d", "01234"},
{"1234", "%06d", "001234"},
{"-1234", "%06d", "-01234"},
{"1234", "%+06d", "+01234"},
{"1234", "% 06d", " 01234"},
{"1234", "%-6d", "1234 "},
{"1234", "%-06d", "1234 "},
{"-1234", "%-06d", "-1234 "},
{"1234", "%.3d", "1234"},
{"1234", "%.4d", "1234"},
{"1234", "%.5d", "01234"},
{"1234", "%.6d", "001234"},
{"-1234", "%.3d", "-1234"},
{"-1234", "%.4d", "-1234"},
{"-1234", "%.5d", "-01234"},
{"-1234", "%.6d", "-001234"},
{"1234", "%8.3d", " 1234"},
{"1234", "%8.4d", " 1234"},
{"1234", "%8.5d", " 01234"},
{"1234", "%8.6d", " 001234"},
{"-1234", "%8.3d", " -1234"},
{"-1234", "%8.4d", " -1234"},
{"-1234", "%8.5d", " -01234"},
{"-1234", "%8.6d", " -001234"},
{"1234", "%+8.3d", " +1234"},
{"1234", "%+8.4d", " +1234"},
{"1234", "%+8.5d", " +01234"},
{"1234", "%+8.6d", " +001234"},
{"-1234", "%+8.3d", " -1234"},
{"-1234", "%+8.4d", " -1234"},
{"-1234", "%+8.5d", " -01234"},
{"-1234", "%+8.6d", " -001234"},
{"1234", "% 8.3d", " 1234"},
{"1234", "% 8.4d", " 1234"},
{"1234", "% 8.5d", " 01234"},
{"1234", "% 8.6d", " 001234"},
{"-1234", "% 8.3d", " -1234"},
{"-1234", "% 8.4d", " -1234"},
{"-1234", "% 8.5d", " -01234"},
{"-1234", "% 8.6d", " -001234"},
{"1234", "%.3x", "4d2"},
{"1234", "%.4x", "04d2"},
{"1234", "%.5x", "004d2"},
{"1234", "%.6x", "0004d2"},
{"-1234", "%.3x", "-4d2"},
{"-1234", "%.4x", "-04d2"},
{"-1234", "%.5x", "-004d2"},
{"-1234", "%.6x", "-0004d2"},
{"1234", "%8.3x", " 4d2"},
{"1234", "%8.4x", " 04d2"},
{"1234", "%8.5x", " 004d2"},
{"1234", "%8.6x", " 0004d2"},
{"-1234", "%8.3x", " -4d2"},
{"-1234", "%8.4x", " -04d2"},
{"-1234", "%8.5x", " -004d2"},
{"-1234", "%8.6x", " -0004d2"},
{"1234", "%+8.3x", " +4d2"},
{"1234", "%+8.4x", " +04d2"},
{"1234", "%+8.5x", " +004d2"},
{"1234", "%+8.6x", " +0004d2"},
{"-1234", "%+8.3x", " -4d2"},
{"-1234", "%+8.4x", " -04d2"},
{"-1234", "%+8.5x", " -004d2"},
{"-1234", "%+8.6x", " -0004d2"},
{"1234", "% 8.3x", " 4d2"},
{"1234", "% 8.4x", " 04d2"},
{"1234", "% 8.5x", " 004d2"},
{"1234", "% 8.6x", " 0004d2"},
{"1234", "% 8.7x", " 00004d2"},
{"1234", "% 8.8x", " 000004d2"},
{"-1234", "% 8.3x", " -4d2"},
{"-1234", "% 8.4x", " -04d2"},
{"-1234", "% 8.5x", " -004d2"},
{"-1234", "% 8.6x", " -0004d2"},
{"-1234", "% 8.7x", "-00004d2"},
{"-1234", "% 8.8x", "-000004d2"},
{"1234", "%-8.3d", "1234 "},
{"1234", "%-8.4d", "1234 "},
{"1234", "%-8.5d", "01234 "},
{"1234", "%-8.6d", "001234 "},
{"1234", "%-8.7d", "0001234 "},
{"1234", "%-8.8d", "00001234"},
{"-1234", "%-8.3d", "-1234 "},
{"-1234", "%-8.4d", "-1234 "},
{"-1234", "%-8.5d", "-01234 "},
{"-1234", "%-8.6d", "-001234 "},
{"-1234", "%-8.7d", "-0001234"},
{"-1234", "%-8.8d", "-00001234"},
{"16777215", "%b", "111111111111111111111111"}, // 2**24 - 1
{"0", "%.d", ""},
{"0", "%.0d", ""},
{"0", "%3.d", ""},
}
func TestFormat(t *testing.T) {
for i, test := range formatTests {
var x *Int
if test.input != "<nil>" {
var ok bool
x, ok = new(Int).SetString(test.input, 0)
if !ok {
t.Errorf("#%d failed reading input %s", i, test.input)
}
}
output := fmt.Sprintf(test.format, x)
if output != test.output {
t.Errorf("#%d got %q; want %q, {%q, %q, %q}", i, output, test.output, test.input, test.format, test.output)
}
}
}
var scanTests = []struct {
input string
format string
output string
remaining int
}{
{"1010", "%b", "10", 0},
{"0b1010", "%v", "10", 0},
{"12", "%o", "10", 0},
{"012", "%v", "10", 0},
{"10", "%d", "10", 0},
{"10", "%v", "10", 0},
{"a", "%x", "10", 0},
{"0xa", "%v", "10", 0},
{"A", "%X", "10", 0},
{"-A", "%X", "-10", 0},
{"+0b1011001", "%v", "89", 0},
{"0xA", "%v", "10", 0},
{"0 ", "%v", "0", 1},
{"2+3", "%v", "2", 2},
{"0XABC 12", "%v", "2748", 3},
}
func TestScan(t *testing.T) {
var buf bytes.Buffer
for i, test := range scanTests {
x := new(Int)
buf.Reset()
buf.WriteString(test.input)
if _, err := fmt.Fscanf(&buf, test.format, x); err != nil {
t.Errorf("#%d error: %s", i, err.String())
}
if x.String() != test.output {
t.Errorf("#%d got %s; want %s", i, x.String(), test.output)
}
if buf.Len() != test.remaining {
t.Errorf("#%d got %d bytes remaining; want %d", i, buf.Len(), test.remaining)
}
}
}
// Examples from the Go Language Spec, section "Arithmetic operators" // Examples from the Go Language Spec, section "Arithmetic operators"
var divisionSignsTests = []struct { var divisionSignsTests = []struct {
@ -362,7 +549,6 @@ var divisionSignsTests = []struct {
{8, 4, 2, 0, 2, 0}, {8, 4, 2, 0, 2, 0},
} }
func TestDivisionSigns(t *testing.T) { func TestDivisionSigns(t *testing.T) {
for i, test := range divisionSignsTests { for i, test := range divisionSignsTests {
x := NewInt(test.x) x := NewInt(test.x)
@ -420,7 +606,6 @@ func TestDivisionSigns(t *testing.T) {
} }
} }
func checkSetBytes(b []byte) bool { func checkSetBytes(b []byte) bool {
hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes()) hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes())
hex2 := hex.EncodeToString(b) hex2 := hex.EncodeToString(b)
@ -436,27 +621,23 @@ func checkSetBytes(b []byte) bool {
return hex1 == hex2 return hex1 == hex2
} }
func TestSetBytes(t *testing.T) { func TestSetBytes(t *testing.T) {
if err := quick.Check(checkSetBytes, nil); err != nil { if err := quick.Check(checkSetBytes, nil); err != nil {
t.Error(err) t.Error(err)
} }
} }
func checkBytes(b []byte) bool { func checkBytes(b []byte) bool {
b2 := new(Int).SetBytes(b).Bytes() b2 := new(Int).SetBytes(b).Bytes()
return bytes.Compare(b, b2) == 0 return bytes.Compare(b, b2) == 0
} }
func TestBytes(t *testing.T) { func TestBytes(t *testing.T) {
if err := quick.Check(checkSetBytes, nil); err != nil { if err := quick.Check(checkSetBytes, nil); err != nil {
t.Error(err) t.Error(err)
} }
} }
func checkQuo(x, y []byte) bool { func checkQuo(x, y []byte) bool {
u := new(Int).SetBytes(x) u := new(Int).SetBytes(x)
v := new(Int).SetBytes(y) v := new(Int).SetBytes(y)
@ -479,7 +660,6 @@ func checkQuo(x, y []byte) bool {
return uprime.Cmp(u) == 0 return uprime.Cmp(u) == 0
} }
var quoTests = []struct { var quoTests = []struct {
x, y string x, y string
q, r string q, r string
@ -498,7 +678,6 @@ var quoTests = []struct {
}, },
} }
func TestQuo(t *testing.T) { func TestQuo(t *testing.T) {
if err := quick.Check(checkQuo, nil); err != nil { if err := quick.Check(checkQuo, nil); err != nil {
t.Error(err) t.Error(err)
@ -519,7 +698,6 @@ func TestQuo(t *testing.T) {
} }
} }
func TestQuoStepD6(t *testing.T) { func TestQuoStepD6(t *testing.T) {
// See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises // See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises
// a code path which only triggers 1 in 10^{-19} cases. // a code path which only triggers 1 in 10^{-19} cases.
@ -539,7 +717,6 @@ func TestQuoStepD6(t *testing.T) {
} }
} }
var bitLenTests = []struct { var bitLenTests = []struct {
in string in string
out int out int
@ -558,7 +735,6 @@ var bitLenTests = []struct {
{"-0x4000000000000000000000", 87}, {"-0x4000000000000000000000", 87},
} }
func TestBitLen(t *testing.T) { func TestBitLen(t *testing.T) {
for i, test := range bitLenTests { for i, test := range bitLenTests {
x, ok := new(Int).SetString(test.in, 0) x, ok := new(Int).SetString(test.in, 0)
@ -573,7 +749,6 @@ func TestBitLen(t *testing.T) {
} }
} }
var expTests = []struct { var expTests = []struct {
x, y, m string x, y, m string
out string out string
@ -598,7 +773,6 @@ var expTests = []struct {
}, },
} }
func TestExp(t *testing.T) { func TestExp(t *testing.T) {
for i, test := range expTests { for i, test := range expTests {
x, ok1 := new(Int).SetString(test.x, 0) x, ok1 := new(Int).SetString(test.x, 0)
@ -629,7 +803,6 @@ func TestExp(t *testing.T) {
} }
} }
func checkGcd(aBytes, bBytes []byte) bool { func checkGcd(aBytes, bBytes []byte) bool {
a := new(Int).SetBytes(aBytes) a := new(Int).SetBytes(aBytes)
b := new(Int).SetBytes(bBytes) b := new(Int).SetBytes(bBytes)
@ -646,7 +819,6 @@ func checkGcd(aBytes, bBytes []byte) bool {
return x.Cmp(d) == 0 return x.Cmp(d) == 0
} }
var gcdTests = []struct { var gcdTests = []struct {
a, b int64 a, b int64
d, x, y int64 d, x, y int64
@ -654,7 +826,6 @@ var gcdTests = []struct {
{120, 23, 1, -9, 47}, {120, 23, 1, -9, 47},
} }
func TestGcd(t *testing.T) { func TestGcd(t *testing.T) {
for i, test := range gcdTests { for i, test := range gcdTests {
a := NewInt(test.a) a := NewInt(test.a)
@ -680,7 +851,6 @@ func TestGcd(t *testing.T) {
quick.Check(checkGcd, nil) quick.Check(checkGcd, nil)
} }
var primes = []string{ var primes = []string{
"2", "2",
"3", "3",
@ -706,7 +876,6 @@ var primes = []string{
"203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", "203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123",
} }
var composites = []string{ var composites = []string{
"21284175091214687912771199898307297748211672914763848041968395774954376176754", "21284175091214687912771199898307297748211672914763848041968395774954376176754",
"6084766654921918907427900243509372380954290099172559290432744450051395395951", "6084766654921918907427900243509372380954290099172559290432744450051395395951",
@ -714,7 +883,6 @@ var composites = []string{
"82793403787388584738507275144194252681", "82793403787388584738507275144194252681",
} }
func TestProbablyPrime(t *testing.T) { func TestProbablyPrime(t *testing.T) {
nreps := 20 nreps := 20
if testing.Short() { if testing.Short() {
@ -738,14 +906,12 @@ func TestProbablyPrime(t *testing.T) {
} }
} }
type intShiftTest struct { type intShiftTest struct {
in string in string
shift uint shift uint
out string out string
} }
var rshTests = []intShiftTest{ var rshTests = []intShiftTest{
{"0", 0, "0"}, {"0", 0, "0"},
{"-0", 0, "0"}, {"-0", 0, "0"},
@ -773,7 +939,6 @@ var rshTests = []intShiftTest{
{"340282366920938463463374607431768211456", 128, "1"}, {"340282366920938463463374607431768211456", 128, "1"},
} }
func TestRsh(t *testing.T) { func TestRsh(t *testing.T) {
for i, test := range rshTests { for i, test := range rshTests {
in, _ := new(Int).SetString(test.in, 10) in, _ := new(Int).SetString(test.in, 10)
@ -789,7 +954,6 @@ func TestRsh(t *testing.T) {
} }
} }
func TestRshSelf(t *testing.T) { func TestRshSelf(t *testing.T) {
for i, test := range rshTests { for i, test := range rshTests {
z, _ := new(Int).SetString(test.in, 10) z, _ := new(Int).SetString(test.in, 10)
@ -805,7 +969,6 @@ func TestRshSelf(t *testing.T) {
} }
} }
var lshTests = []intShiftTest{ var lshTests = []intShiftTest{
{"0", 0, "0"}, {"0", 0, "0"},
{"0", 1, "0"}, {"0", 1, "0"},
@ -828,7 +991,6 @@ var lshTests = []intShiftTest{
{"1", 128, "340282366920938463463374607431768211456"}, {"1", 128, "340282366920938463463374607431768211456"},
} }
func TestLsh(t *testing.T) { func TestLsh(t *testing.T) {
for i, test := range lshTests { for i, test := range lshTests {
in, _ := new(Int).SetString(test.in, 10) in, _ := new(Int).SetString(test.in, 10)
@ -844,7 +1006,6 @@ func TestLsh(t *testing.T) {
} }
} }
func TestLshSelf(t *testing.T) { func TestLshSelf(t *testing.T) {
for i, test := range lshTests { for i, test := range lshTests {
z, _ := new(Int).SetString(test.in, 10) z, _ := new(Int).SetString(test.in, 10)
@ -860,7 +1021,6 @@ func TestLshSelf(t *testing.T) {
} }
} }
func TestLshRsh(t *testing.T) { func TestLshRsh(t *testing.T) {
for i, test := range rshTests { for i, test := range rshTests {
in, _ := new(Int).SetString(test.in, 10) in, _ := new(Int).SetString(test.in, 10)
@ -888,7 +1048,6 @@ func TestLshRsh(t *testing.T) {
} }
} }
var int64Tests = []int64{ var int64Tests = []int64{
0, 0,
1, 1,
@ -902,7 +1061,6 @@ var int64Tests = []int64{
-9223372036854775808, -9223372036854775808,
} }
func TestInt64(t *testing.T) { func TestInt64(t *testing.T) {
for i, testVal := range int64Tests { for i, testVal := range int64Tests {
in := NewInt(testVal) in := NewInt(testVal)
@ -914,7 +1072,6 @@ func TestInt64(t *testing.T) {
} }
} }
var bitwiseTests = []struct { var bitwiseTests = []struct {
x, y string x, y string
and, or, xor, andNot string and, or, xor, andNot string
@ -958,7 +1115,6 @@ var bitwiseTests = []struct {
}, },
} }
type bitFun func(z, x, y *Int) *Int type bitFun func(z, x, y *Int) *Int
func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) { func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
@ -971,7 +1127,6 @@ func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
} }
} }
func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) { func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
self := new(Int) self := new(Int)
self.Set(x) self.Set(x)
@ -984,6 +1139,142 @@ func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
} }
} }
func altBit(x *Int, i int) uint {
z := new(Int).Rsh(x, uint(i))
z = z.And(z, NewInt(1))
if z.Cmp(new(Int)) != 0 {
return 1
}
return 0
}
func altSetBit(z *Int, x *Int, i int, b uint) *Int {
one := NewInt(1)
m := one.Lsh(one, uint(i))
switch b {
case 1:
return z.Or(x, m)
case 0:
return z.AndNot(x, m)
}
panic("set bit is not 0 or 1")
}
func testBitset(t *testing.T, x *Int) {
n := x.BitLen()
z := new(Int).Set(x)
z1 := new(Int).Set(x)
for i := 0; i < n+10; i++ {
old := z.Bit(i)
old1 := altBit(z1, i)
if old != old1 {
t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1)
}
z := new(Int).SetBit(z, i, 1)
z1 := altSetBit(new(Int), z1, i, 1)
if z.Bit(i) == 0 {
t.Errorf("bitset: bit %d of %s got 0 want 1", i, x)
}
if z.Cmp(z1) != 0 {
t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1)
}
z.SetBit(z, i, 0)
altSetBit(z1, z1, i, 0)
if z.Bit(i) != 0 {
t.Errorf("bitset: bit %d of %s got 1 want 0", i, x)
}
if z.Cmp(z1) != 0 {
t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1)
}
altSetBit(z1, z1, i, old)
z.SetBit(z, i, old)
if z.Cmp(z1) != 0 {
t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1)
}
}
if z.Cmp(x) != 0 {
t.Errorf("bitset: got %s want %s", z, x)
}
}
var bitsetTests = []struct {
x string
i int
b uint
}{
{"0", 0, 0},
{"0", 200, 0},
{"1", 0, 1},
{"1", 1, 0},
{"-1", 0, 1},
{"-1", 200, 1},
{"0x2000000000000000000000000000", 108, 0},
{"0x2000000000000000000000000000", 109, 1},
{"0x2000000000000000000000000000", 110, 0},
{"-0x2000000000000000000000000001", 108, 1},
{"-0x2000000000000000000000000001", 109, 0},
{"-0x2000000000000000000000000001", 110, 1},
}
func TestBitSet(t *testing.T) {
for _, test := range bitwiseTests {
x := new(Int)
x.SetString(test.x, 0)
testBitset(t, x)
x = new(Int)
x.SetString(test.y, 0)
testBitset(t, x)
}
for i, test := range bitsetTests {
x := new(Int)
x.SetString(test.x, 0)
b := x.Bit(test.i)
if b != test.b {
t.Errorf("#%d want %v got %v", i, test.b, b)
}
}
}
func BenchmarkBitset(b *testing.B) {
z := new(Int)
z.SetBit(z, 512, 1)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
z.SetBit(z, i&512, 1)
}
}
func BenchmarkBitsetNeg(b *testing.B) {
z := NewInt(-1)
z.SetBit(z, 512, 0)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
z.SetBit(z, i&512, 0)
}
}
func BenchmarkBitsetOrig(b *testing.B) {
z := new(Int)
altSetBit(z, z, 512, 1)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
altSetBit(z, z, i&512, 1)
}
}
func BenchmarkBitsetNegOrig(b *testing.B) {
z := NewInt(-1)
altSetBit(z, z, 512, 0)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
altSetBit(z, z, i&512, 0)
}
}
func TestBitwise(t *testing.T) { func TestBitwise(t *testing.T) {
x := new(Int) x := new(Int)
@ -1003,7 +1294,6 @@ func TestBitwise(t *testing.T) {
} }
} }
var notTests = []struct { var notTests = []struct {
in string in string
out string out string
@ -1037,7 +1327,6 @@ func TestNot(t *testing.T) {
} }
} }
var modInverseTests = []struct { var modInverseTests = []struct {
element string element string
prime string prime string
@ -1062,7 +1351,7 @@ func TestModInverse(t *testing.T) {
} }
} }
// used by TestIntGobEncoding and TestRatGobEncoding
var gobEncodingTests = []string{ var gobEncodingTests = []string{
"0", "0",
"1", "1",
@ -1073,7 +1362,7 @@ var gobEncodingTests = []string{
"298472983472983471903246121093472394872319615612417471234712061", "298472983472983471903246121093472394872319615612417471234712061",
} }
func TestGobEncoding(t *testing.T) { func TestIntGobEncoding(t *testing.T) {
var medium bytes.Buffer var medium bytes.Buffer
enc := gob.NewEncoder(&medium) enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium) dec := gob.NewDecoder(&medium)
@ -1081,7 +1370,8 @@ func TestGobEncoding(t *testing.T) {
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
medium.Reset() // empty buffer for each test case (in case of failures) medium.Reset() // empty buffer for each test case (in case of failures)
stest := test stest := test
if j == 0 { if j != 0 {
// negative numbers
stest = "-" + test stest = "-" + test
} }
var tx Int var tx Int

View File

@ -18,7 +18,11 @@ package big
// These are the building blocks for the operations on signed integers // These are the building blocks for the operations on signed integers
// and rationals. // and rationals.
import "rand" import (
"io"
"os"
"rand"
)
// An unsigned integer x of the form // An unsigned integer x of the form
// //
@ -40,14 +44,12 @@ var (
natTen = nat{10} natTen = nat{10}
) )
func (z nat) clear() { func (z nat) clear() {
for i := range z { for i := range z {
z[i] = 0 z[i] = 0
} }
} }
func (z nat) norm() nat { func (z nat) norm() nat {
i := len(z) i := len(z)
for i > 0 && z[i-1] == 0 { for i > 0 && z[i-1] == 0 {
@ -56,7 +58,6 @@ func (z nat) norm() nat {
return z[0:i] return z[0:i]
} }
func (z nat) make(n int) nat { func (z nat) make(n int) nat {
if n <= cap(z) { if n <= cap(z) {
return z[0:n] // reuse z return z[0:n] // reuse z
@ -67,7 +68,6 @@ func (z nat) make(n int) nat {
return make(nat, n, n+e) return make(nat, n, n+e)
} }
func (z nat) setWord(x Word) nat { func (z nat) setWord(x Word) nat {
if x == 0 { if x == 0 {
return z.make(0) return z.make(0)
@ -77,7 +77,6 @@ func (z nat) setWord(x Word) nat {
return z return z
} }
func (z nat) setUint64(x uint64) nat { func (z nat) setUint64(x uint64) nat {
// single-digit values // single-digit values
if w := Word(x); uint64(w) == x { if w := Word(x); uint64(w) == x {
@ -100,14 +99,12 @@ func (z nat) setUint64(x uint64) nat {
return z return z
} }
func (z nat) set(x nat) nat { func (z nat) set(x nat) nat {
z = z.make(len(x)) z = z.make(len(x))
copy(z, x) copy(z, x)
return z return z
} }
func (z nat) add(x, y nat) nat { func (z nat) add(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -134,7 +131,6 @@ func (z nat) add(x, y nat) nat {
return z.norm() return z.norm()
} }
func (z nat) sub(x, y nat) nat { func (z nat) sub(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -163,7 +159,6 @@ func (z nat) sub(x, y nat) nat {
return z.norm() return z.norm()
} }
func (x nat) cmp(y nat) (r int) { func (x nat) cmp(y nat) (r int) {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -191,7 +186,6 @@ func (x nat) cmp(y nat) (r int) {
return return
} }
func (z nat) mulAddWW(x nat, y, r Word) nat { func (z nat) mulAddWW(x nat, y, r Word) nat {
m := len(x) m := len(x)
if m == 0 || y == 0 { if m == 0 || y == 0 {
@ -205,7 +199,6 @@ func (z nat) mulAddWW(x nat, y, r Word) nat {
return z.norm() return z.norm()
} }
// basicMul multiplies x and y and leaves the result in z. // basicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)]. // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func basicMul(z, x, y nat) { func basicMul(z, x, y nat) {
@ -217,7 +210,6 @@ func basicMul(z, x, y nat) {
} }
} }
// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
// Factored out for readability - do not use outside karatsuba. // Factored out for readability - do not use outside karatsuba.
func karatsubaAdd(z, x nat, n int) { func karatsubaAdd(z, x nat, n int) {
@ -226,7 +218,6 @@ func karatsubaAdd(z, x nat, n int) {
} }
} }
// Like karatsubaAdd, but does subtract. // Like karatsubaAdd, but does subtract.
func karatsubaSub(z, x nat, n int) { func karatsubaSub(z, x nat, n int) {
if c := subVV(z[0:n], z, x); c != 0 { if c := subVV(z[0:n], z, x); c != 0 {
@ -234,7 +225,6 @@ func karatsubaSub(z, x nat, n int) {
} }
} }
// Operands that are shorter than karatsubaThreshold are multiplied using // Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm // "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used. // is used.
@ -339,13 +329,11 @@ func karatsuba(z, x, y nat) {
} }
} }
// alias returns true if x and y share the same base array. // alias returns true if x and y share the same base array.
func alias(x, y nat) bool { func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
} }
// addAt implements z += x*(1<<(_W*i)); z must be long enough. // addAt implements z += x*(1<<(_W*i)); z must be long enough.
// (we don't use nat.add because we need z to stay the same // (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition) // slice, and we don't need to normalize z after each addition)
@ -360,7 +348,6 @@ func addAt(z, x nat, i int) {
} }
} }
func max(x, y int) int { func max(x, y int) int {
if x > y { if x > y {
return x return x
@ -368,7 +355,6 @@ func max(x, y int) int {
return y return y
} }
// karatsubaLen computes an approximation to the maximum k <= n such that // karatsubaLen computes an approximation to the maximum k <= n such that
// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
// result is the largest number that can be divided repeatedly by 2 before // result is the largest number that can be divided repeatedly by 2 before
@ -382,7 +368,6 @@ func karatsubaLen(n int) int {
return n << i return n << i
} }
func (z nat) mul(x, y nat) nat { func (z nat) mul(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -450,7 +435,6 @@ func (z nat) mul(x, y nat) nat {
return z.norm() return z.norm()
} }
// mulRange computes the product of all the unsigned integers in the // mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1. // range [a, b] inclusively. If a > b (empty range), the result is 1.
func (z nat) mulRange(a, b uint64) nat { func (z nat) mulRange(a, b uint64) nat {
@ -469,7 +453,6 @@ func (z nat) mulRange(a, b uint64) nat {
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
} }
// q = (x-r)/y, with 0 <= r < y // q = (x-r)/y, with 0 <= r < y
func (z nat) divW(x nat, y Word) (q nat, r Word) { func (z nat) divW(x nat, y Word) (q nat, r Word) {
m := len(x) m := len(x)
@ -490,7 +473,6 @@ func (z nat) divW(x nat, y Word) (q nat, r Word) {
return return
} }
func (z nat) div(z2, u, v nat) (q, r nat) { func (z nat) div(z2, u, v nat) (q, r nat) {
if len(v) == 0 { if len(v) == 0 {
panic("division by zero") panic("division by zero")
@ -518,7 +500,6 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
return return
} }
// q = (uIn-r)/v, with 0 <= r < y // q = (uIn-r)/v, with 0 <= r < y
// Uses z as storage for q, and u as storage for r if possible. // Uses z as storage for q, and u as storage for r if possible.
// See Knuth, Volume 2, section 4.3.1, Algorithm D. // See Knuth, Volume 2, section 4.3.1, Algorithm D.
@ -545,9 +526,14 @@ func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
u.clear() u.clear()
// D1. // D1.
shift := Word(leadingZeros(v[n-1])) shift := leadingZeros(v[n-1])
shlVW(v, v, shift) if shift > 0 {
u[len(uIn)] = shlVW(u[0:len(uIn)], uIn, shift) // do not modify v, it may be used by another goroutine simultaneously
v1 := make(nat, n)
shlVU(v1, v, shift)
v = v1
}
u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
// D2. // D2.
for j := m; j >= 0; j-- { for j := m; j >= 0; j-- {
@ -586,14 +572,12 @@ func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
} }
q = q.norm() q = q.norm()
shrVW(u, u, shift) shrVU(u, u, shift)
shrVW(v, v, shift)
r = u.norm() r = u.norm()
return q, r return q, r
} }
// Length of x in bits. x must be normalized. // Length of x in bits. x must be normalized.
func (x nat) bitLen() int { func (x nat) bitLen() int {
if i := len(x) - 1; i >= 0 { if i := len(x) - 1; i >= 0 {
@ -602,103 +586,253 @@ func (x nat) bitLen() int {
return 0 return 0
} }
// MaxBase is the largest number base accepted for string conversions.
const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1
func hexValue(ch byte) int {
var d byte func hexValue(ch int) Word {
d := MaxBase + 1 // illegal base
switch { switch {
case '0' <= ch && ch <= '9': case '0' <= ch && ch <= '9':
d = ch - '0' d = ch - '0'
case 'a' <= ch && ch <= 'f': case 'a' <= ch && ch <= 'z':
d = ch - 'a' + 10 d = ch - 'a' + 10
case 'A' <= ch && ch <= 'F': case 'A' <= ch && ch <= 'Z':
d = ch - 'A' + 10 d = ch - 'A' + 10
default:
return -1
} }
return int(d) return Word(d)
} }
// scan sets z to the natural number corresponding to the longest possible prefix
// read from r representing an unsigned integer in a given conversion base.
// It returns z, the actual conversion base used, and an error, if any. In the
// error case, the value of z is undefined. The syntax follows the syntax of
// unsigned integer literals in Go.
//
// The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
//
func (z nat) scan(r io.RuneScanner, base int) (nat, int, os.Error) {
// reject illegal bases
if base < 0 || base == 1 || MaxBase < base {
return z, 0, os.NewError("illegal number base")
}
// one char look-ahead
ch, _, err := r.ReadRune()
if err != nil {
return z, 0, err
}
// scan returns the natural number corresponding to the
// longest possible prefix of s representing a natural number in a
// given conversion base, the actual conversion base used, and the
// prefix length. The syntax of natural numbers follows the syntax
// of unsigned integer literals in Go.
//
// If the base argument is 0, the string prefix determines the actual
// conversion base. A prefix of ``0x'' or ``0X'' selects base 16; the
// ``0'' prefix selects base 8, and a ``0b'' or ``0B'' prefix selects
// base 2. Otherwise the selected base is 10.
//
func (z nat) scan(s string, base int) (nat, int, int) {
// determine base if necessary // determine base if necessary
i, n := 0, len(s) b := Word(base)
if base == 0 { if base == 0 {
base = 10 b = 10
if n > 0 && s[0] == '0' { if ch == '0' {
base, i = 8, 1 switch ch, _, err = r.ReadRune(); err {
if n > 1 { case nil:
switch s[1] { b = 8
switch ch {
case 'x', 'X': case 'x', 'X':
base, i = 16, 2 b = 16
case 'b', 'B': case 'b', 'B':
base, i = 2, 2 b = 2
}
if b == 2 || b == 16 {
if ch, _, err = r.ReadRune(); err != nil {
return z, 0, err
}
}
case os.EOF:
return z, 10, nil
default:
return z, 10, err
}
}
}
// convert string
// - group as many digits d as possible together into a "super-digit" dd with "super-base" bb
// - only when bb does not fit into a word anymore, do a full number mulAddWW using bb and dd
z = z.make(0)
bb := Word(1)
dd := Word(0)
for max := _M / b; ; {
d := hexValue(ch)
if d >= b {
r.UnreadRune() // ch does not belong to number anymore
break
}
if bb <= max {
bb *= b
dd = dd*b + d
} else {
// bb * b would overflow
z = z.mulAddWW(z, bb, dd)
bb = b
dd = d
}
if ch, _, err = r.ReadRune(); err != nil {
if err != os.EOF {
return z, int(b), err
}
break
}
}
switch {
case bb > 1:
// there was at least one mantissa digit
z = z.mulAddWW(z, bb, dd)
case base == 0 && b == 8:
// there was only the octal prefix 0 (possibly followed by digits > 7);
// return base 10, not 8
return z, 10, nil
case base != 0 || b != 8:
// there was neither a mantissa digit nor the octal prefix 0
return z, int(b), os.NewError("syntax error scanning number")
}
return z.norm(), int(b), nil
}
// Character sets for string conversion.
const (
lowercaseDigits = "0123456789abcdefghijklmnopqrstuvwxyz"
uppercaseDigits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
// decimalString returns a decimal representation of x.
// It calls x.string with the charset "0123456789".
func (x nat) decimalString() string {
return x.string(lowercaseDigits[0:10])
}
// string converts x to a string using digits from a charset; a digit with
// value d is represented by charset[d]. The conversion base is determined
// by len(charset), which must be >= 2.
func (x nat) string(charset string) string {
b := Word(len(charset))
// special cases
switch {
case b < 2 || b > 256:
panic("illegal base")
case len(x) == 0:
return string(charset[0])
}
// allocate buffer for conversion
i := x.bitLen()/log2(b) + 1 // +1: round up
s := make([]byte, i)
// special case: power of two bases can avoid divisions completely
if b == b&-b {
// shift is base-b digit size in bits
shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2
mask := Word(1)<<shift - 1
w := x[0]
nbits := uint(_W) // number of unprocessed bits in w
// convert less-significant words
for k := 1; k < len(x); k++ {
// convert full digits
for nbits >= shift {
i--
s[i] = charset[w&mask]
w >>= shift
nbits -= shift
}
// convert any partial leading digit and advance to next word
if nbits == 0 {
// no partial digit remaining, just advance
w = x[k]
nbits = _W
} else {
// partial digit in current (k-1) and next (k) word
w |= x[k] << nbits
i--
s[i] = charset[w&mask]
// advance
w = x[k] >> (shift - nbits)
nbits = _W - (shift - nbits)
}
}
// convert digits of most-significant word (omit leading zeros)
for nbits >= 0 && w != 0 {
i--
s[i] = charset[w&mask]
w >>= shift
nbits -= shift
}
return string(s[i:])
}
// general case: extract groups of digits by multiprecision division
// maximize ndigits where b**ndigits < 2^_W; bb (big base) is b**ndigits
bb := Word(1)
ndigits := 0
for max := Word(_M / b); bb <= max; bb *= b {
ndigits++
}
// preserve x, create local copy for use in repeated divisions
q := nat(nil).set(x)
var r Word
// convert
if b == 10 { // hard-coding for 10 here speeds this up by 1.25x
for len(q) > 0 {
// extract least significant, base bb "digit"
q, r = q.divW(q, bb) // N.B. >82% of time is here. Optimize divW
if len(q) == 0 {
// skip leading zeros in most-significant group of digits
for j := 0; j < ndigits && r != 0; j++ {
i--
s[i] = charset[r%10]
r /= 10
}
} else {
for j := 0; j < ndigits; j++ {
i--
s[i] = charset[r%10]
r /= 10
}
}
}
} else {
for len(q) > 0 {
// extract least significant group of digits
q, r = q.divW(q, bb) // N.B. >82% of time is here. Optimize divW
if len(q) == 0 {
// skip leading zeros in most-significant group of digits
for j := 0; j < ndigits && r != 0; j++ {
i--
s[i] = charset[r%b]
r /= b
}
} else {
for j := 0; j < ndigits; j++ {
i--
s[i] = charset[r%b]
r /= b
} }
} }
} }
} }
// reject illegal bases or strings consisting only of prefix
if base < 2 || 16 < base || (base != 8 && i >= n) {
return z, 0, 0
}
// convert string
z = z.make(0)
for ; i < n; i++ {
d := hexValue(s[i])
if 0 <= d && d < base {
z = z.mulAddWW(z, Word(base), Word(d))
} else {
break
}
}
return z.norm(), base, i
}
// string converts x to a string for a given base, with 2 <= base <= 16.
// TODO(gri) in the style of the other routines, perhaps this should take
// a []byte buffer and return it
func (x nat) string(base int) string {
if base < 2 || 16 < base {
panic("illegal base")
}
if len(x) == 0 {
return "0"
}
// allocate buffer for conversion
i := x.bitLen()/log2(Word(base)) + 1 // +1: round up
s := make([]byte, i)
// don't destroy x
q := nat(nil).set(x)
// convert
for len(q) > 0 {
i--
var r Word
q, r = q.divW(q, Word(base))
s[i] = "0123456789abcdef"[r]
}
return string(s[i:]) return string(s[i:])
} }
const deBruijn32 = 0x077CB531 const deBruijn32 = 0x077CB531
var deBruijn32Lookup = []byte{ var deBruijn32Lookup = []byte{
@ -721,7 +855,7 @@ var deBruijn64Lookup = []byte{
func trailingZeroBits(x Word) int { func trailingZeroBits(x Word) int {
// x & -x leaves only the right-most bit set in the word. Let k be the // x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two // index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multipling by a power of two is equivalent to // to the power of k. Multiplying by a power of two is equivalent to
// left shifting, in this case by k bits. The de Bruijn constant is // left shifting, in this case by k bits. The de Bruijn constant is
// such that all six bit, consecutive substrings are distinct. // such that all six bit, consecutive substrings are distinct.
// Therefore, if we have a left shifted version of this constant we can // Therefore, if we have a left shifted version of this constant we can
@ -739,7 +873,6 @@ func trailingZeroBits(x Word) int {
return 0 return 0
} }
// z = x << s // z = x << s
func (z nat) shl(x nat, s uint) nat { func (z nat) shl(x nat, s uint) nat {
m := len(x) m := len(x)
@ -750,13 +883,12 @@ func (z nat) shl(x nat, s uint) nat {
n := m + int(s/_W) n := m + int(s/_W)
z = z.make(n + 1) z = z.make(n + 1)
z[n] = shlVW(z[n-m:n], x, Word(s%_W)) z[n] = shlVU(z[n-m:n], x, s%_W)
z[0 : n-m].clear() z[0 : n-m].clear()
return z.norm() return z.norm()
} }
// z = x >> s // z = x >> s
func (z nat) shr(x nat, s uint) nat { func (z nat) shr(x nat, s uint) nat {
m := len(x) m := len(x)
@ -767,11 +899,45 @@ func (z nat) shr(x nat, s uint) nat {
// n > 0 // n > 0
z = z.make(n) z = z.make(n)
shrVW(z, x[m-n:], Word(s%_W)) shrVU(z, x[m-n:], s%_W)
return z.norm() return z.norm()
} }
func (z nat) setBit(x nat, i uint, b uint) nat {
j := int(i / _W)
m := Word(1) << (i % _W)
n := len(x)
switch b {
case 0:
z = z.make(n)
copy(z, x)
if j >= n {
// no need to grow
return z
}
z[j] &^= m
return z.norm()
case 1:
if j >= n {
n = j + 1
}
z = z.make(n)
copy(z, x)
z[j] |= m
// no need to normalize
return z
}
panic("set bit is not 0 or 1")
}
func (z nat) bit(i uint) uint {
j := int(i / _W)
if j >= len(z) {
return 0
}
return uint(z[j] >> (i % _W) & 1)
}
func (z nat) and(x, y nat) nat { func (z nat) and(x, y nat) nat {
m := len(x) m := len(x)
@ -789,7 +955,6 @@ func (z nat) and(x, y nat) nat {
return z.norm() return z.norm()
} }
func (z nat) andNot(x, y nat) nat { func (z nat) andNot(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -807,7 +972,6 @@ func (z nat) andNot(x, y nat) nat {
return z.norm() return z.norm()
} }
func (z nat) or(x, y nat) nat { func (z nat) or(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -827,7 +991,6 @@ func (z nat) or(x, y nat) nat {
return z.norm() return z.norm()
} }
func (z nat) xor(x, y nat) nat { func (z nat) xor(x, y nat) nat {
m := len(x) m := len(x)
n := len(y) n := len(y)
@ -847,10 +1010,10 @@ func (z nat) xor(x, y nat) nat {
return z.norm() return z.norm()
} }
// greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } func greaterThan(x1, x2, y1, y2 Word) bool {
return x1 > y1 || x1 == y1 && x2 > y2
}
// modW returns x % d. // modW returns x % d.
func (x nat) modW(d Word) (r Word) { func (x nat) modW(d Word) (r Word) {
@ -860,30 +1023,29 @@ func (x nat) modW(d Word) (r Word) {
return divWVW(q, 0, x, d) return divWVW(q, 0, x, d)
} }
// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0.
// powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. func (x nat) powersOfTwoDecompose() (q nat, k int) {
func (n nat) powersOfTwoDecompose() (q nat, k Word) { if len(x) == 0 {
if len(n) == 0 { return x, 0
return n, 0
} }
zeroWords := 0 // One of the words must be non-zero by definition,
for n[zeroWords] == 0 { // so this loop will terminate with i < len(x), and
zeroWords++ // i is the number of 0 words.
i := 0
for x[i] == 0 {
i++
} }
// One of the words must be non-zero by invariant, therefore n := trailingZeroBits(x[i]) // x[i] != 0
// zeroWords < len(n).
x := trailingZeroBits(n[zeroWords]) q = make(nat, len(x)-i)
shrVU(q, x[i:], uint(n))
q = q.make(len(n) - zeroWords)
shrVW(q, n[zeroWords:], Word(x))
q = q.norm() q = q.norm()
k = i*_W + n
k = Word(_W*zeroWords + x)
return return
} }
// random creates a random integer in [0..limit), using the space in z if // random creates a random integer in [0..limit), using the space in z if
// possible. n is the bit length of limit. // possible. n is the bit length of limit.
func (z nat) random(rand *rand.Rand, limit nat, n int) nat { func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
@ -914,7 +1076,6 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
return z.norm() return z.norm()
} }
// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
// reuses the storage of z if possible. // reuses the storage of z if possible.
func (z nat) expNN(x, y, m nat) nat { func (z nat) expNN(x, y, m nat) nat {
@ -983,7 +1144,6 @@ func (z nat) expNN(x, y, m nat) nat {
return z return z
} }
// probablyPrime performs reps Miller-Rabin tests to check whether n is prime. // probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
// If it returns true, n is prime with probability 1 - 1/4^reps. // If it returns true, n is prime with probability 1 - 1/4^reps.
// If it returns false, n is not prime. // If it returns false, n is not prime.
@ -1050,7 +1210,7 @@ NextRandom:
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue continue
} }
for j := Word(1); j < k; j++ { for j := 1; j < k; j++ {
y = y.mul(y, y) y = y.mul(y, y)
quotient, y = quotient.div(y, y, n) quotient, y = quotient.div(y, y, n)
if y.cmp(nm1) == 0 { if y.cmp(nm1) == 0 {
@ -1066,7 +1226,6 @@ NextRandom:
return true return true
} }
// bytes writes the value of z into buf using big-endian encoding. // bytes writes the value of z into buf using big-endian encoding.
// len(buf) must be >= len(z)*_S. The value of z is encoded in the // len(buf) must be >= len(z)*_S. The value of z is encoded in the
// slice buf[i:]. The number i of unused bytes at the beginning of // slice buf[i:]. The number i of unused bytes at the beginning of
@ -1088,7 +1247,6 @@ func (z nat) bytes(buf []byte) (i int) {
return return
} }
// setBytes interprets buf as the bytes of a big-endian unsigned // setBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z. // integer, sets z to that value, and returns z.
func (z nat) setBytes(buf []byte) nat { func (z nat) setBytes(buf []byte) nat {

View File

@ -4,7 +4,12 @@
package big package big
import "testing" import (
"fmt"
"os"
"strings"
"testing"
)
var cmpTests = []struct { var cmpTests = []struct {
x, y nat x, y nat
@ -26,7 +31,6 @@ var cmpTests = []struct {
{nat{34986, 41, 105, 1957}, nat{56, 7458, 104, 1957}, 1}, {nat{34986, 41, 105, 1957}, nat{56, 7458, 104, 1957}, 1},
} }
func TestCmp(t *testing.T) { func TestCmp(t *testing.T) {
for i, a := range cmpTests { for i, a := range cmpTests {
r := a.x.cmp(a.y) r := a.x.cmp(a.y)
@ -36,13 +40,11 @@ func TestCmp(t *testing.T) {
} }
} }
type funNN func(z, x, y nat) nat type funNN func(z, x, y nat) nat
type argNN struct { type argNN struct {
z, x, y nat z, x, y nat
} }
var sumNN = []argNN{ var sumNN = []argNN{
{}, {},
{nat{1}, nil, nat{1}}, {nat{1}, nil, nat{1}},
@ -52,7 +54,6 @@ var sumNN = []argNN{
{nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}}, {nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}},
} }
var prodNN = []argNN{ var prodNN = []argNN{
{}, {},
{nil, nil, nil}, {nil, nil, nil},
@ -64,7 +65,6 @@ var prodNN = []argNN{
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}}, {nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
} }
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
for _, a := range sumNN { for _, a := range sumNN {
z := nat(nil).set(a.z) z := nat(nil).set(a.z)
@ -74,7 +74,6 @@ func TestSet(t *testing.T) {
} }
} }
func testFunNN(t *testing.T, msg string, f funNN, a argNN) { func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
z := f(nil, a.x, a.y) z := f(nil, a.x, a.y)
if z.cmp(a.z) != 0 { if z.cmp(a.z) != 0 {
@ -82,7 +81,6 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
} }
} }
func TestFunNN(t *testing.T) { func TestFunNN(t *testing.T) {
for _, a := range sumNN { for _, a := range sumNN {
arg := a arg := a
@ -107,7 +105,6 @@ func TestFunNN(t *testing.T) {
} }
} }
var mulRangesN = []struct { var mulRangesN = []struct {
a, b uint64 a, b uint64
prod string prod string
@ -130,17 +127,15 @@ var mulRangesN = []struct {
}, },
} }
func TestMulRangeN(t *testing.T) { func TestMulRangeN(t *testing.T) {
for i, r := range mulRangesN { for i, r := range mulRangesN {
prod := nat(nil).mulRange(r.a, r.b).string(10) prod := nat(nil).mulRange(r.a, r.b).decimalString()
if prod != r.prod { if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod) t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
} }
} }
} }
var mulArg, mulTmp nat var mulArg, mulTmp nat
func init() { func init() {
@ -151,7 +146,6 @@ func init() {
} }
} }
func benchmarkMulLoad() { func benchmarkMulLoad() {
for j := 1; j <= 10; j++ { for j := 1; j <= 10; j++ {
x := mulArg[0 : j*100] x := mulArg[0 : j*100]
@ -159,46 +153,376 @@ func benchmarkMulLoad() {
} }
} }
func BenchmarkMul(b *testing.B) { func BenchmarkMul(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
benchmarkMulLoad() benchmarkMulLoad()
} }
} }
func toString(x nat, charset string) string {
base := len(charset)
var tab = []struct { // special cases
x nat switch {
b int case base < 2:
s string panic("illegal base")
}{ case len(x) == 0:
{nil, 10, "0"}, return string(charset[0])
{nat{1}, 10, "1"}, }
{nat{10}, 10, "10"},
{nat{1234567890}, 10, "1234567890"}, // allocate buffer for conversion
i := x.bitLen()/log2(Word(base)) + 1 // +1: round up
s := make([]byte, i)
// don't destroy x
q := nat(nil).set(x)
// convert
for len(q) > 0 {
i--
var r Word
q, r = q.divW(q, Word(base))
s[i] = charset[r]
}
return string(s[i:])
} }
var strTests = []struct {
x nat // nat value to be converted
c string // conversion charset
s string // expected result
}{
{nil, "01", "0"},
{nat{1}, "01", "1"},
{nat{0xc5}, "01", "11000101"},
{nat{03271}, lowercaseDigits[0:8], "3271"},
{nat{10}, lowercaseDigits[0:10], "10"},
{nat{1234567890}, uppercaseDigits[0:10], "1234567890"},
{nat{0xdeadbeef}, lowercaseDigits[0:16], "deadbeef"},
{nat{0xdeadbeef}, uppercaseDigits[0:16], "DEADBEEF"},
{nat{0x229be7}, lowercaseDigits[0:17], "1a2b3c"},
{nat{0x309663e6}, uppercaseDigits[0:32], "O9COV6"},
}
func TestString(t *testing.T) { func TestString(t *testing.T) {
for _, a := range tab { for _, a := range strTests {
s := a.x.string(a.b) s := a.x.string(a.c)
if s != a.s { if s != a.s {
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
} }
x, b, n := nat(nil).scan(a.s, a.b) x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
if x.cmp(a.x) != 0 { if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
} }
if b != a.b { if b != len(a.c) {
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.b) t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, len(a.c))
} }
if n != len(a.s) { if err != nil {
t.Errorf("scan%+v\n\tgot n = %d; want %d", a, n, len(a.s)) t.Errorf("scan%+v\n\tgot error = %s", a, err)
} }
} }
} }
var natScanTests = []struct {
s string // string to be scanned
base int // input base
x nat // expected nat
b int // expected base
ok bool // expected success
next int // next character (or 0, if at EOF)
}{
// error: illegal base
{base: -1},
{base: 1},
{base: 37},
// error: no mantissa
{},
{s: "?"},
{base: 10},
{base: 36},
{s: "?", base: 10},
{s: "0x"},
{s: "345", base: 2},
// no errors
{"0", 0, nil, 10, true, 0},
{"0", 10, nil, 10, true, 0},
{"0", 36, nil, 36, true, 0},
{"1", 0, nat{1}, 10, true, 0},
{"1", 10, nat{1}, 10, true, 0},
{"0 ", 0, nil, 10, true, ' '},
{"08", 0, nil, 10, true, '8'},
{"018", 0, nat{1}, 8, true, '8'},
{"0b1", 0, nat{1}, 2, true, 0},
{"0b11000101", 0, nat{0xc5}, 2, true, 0},
{"03271", 0, nat{03271}, 8, true, 0},
{"10ab", 0, nat{10}, 10, true, 'a'},
{"1234567890", 0, nat{1234567890}, 10, true, 0},
{"xyz", 36, nat{(33*36+34)*36 + 35}, 36, true, 0},
{"xyz?", 36, nat{(33*36+34)*36 + 35}, 36, true, '?'},
{"0x", 16, nil, 16, true, 'x'},
{"0xdeadbeef", 0, nat{0xdeadbeef}, 16, true, 0},
{"0XDEADBEEF", 0, nat{0xdeadbeef}, 16, true, 0},
}
func TestScanBase(t *testing.T) {
for _, a := range natScanTests {
r := strings.NewReader(a.s)
x, b, err := nat(nil).scan(r, a.base)
if err == nil && !a.ok {
t.Errorf("scan%+v\n\texpected error", a)
}
if err != nil {
if a.ok {
t.Errorf("scan%+v\n\tgot error = %s", a, err)
}
continue
}
if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
}
if b != a.b {
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.base)
}
next, _, err := r.ReadRune()
if err == os.EOF {
next = 0
err = nil
}
if err == nil && next != a.next {
t.Errorf("scan%+v\n\tgot next = %q; want %q", a, next, a.next)
}
}
}
var pi = "3" +
"14159265358979323846264338327950288419716939937510582097494459230781640628620899862803482534211706798214808651" +
"32823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461" +
"28475648233786783165271201909145648566923460348610454326648213393607260249141273724587006606315588174881520920" +
"96282925409171536436789259036001133053054882046652138414695194151160943305727036575959195309218611738193261179" +
"31051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798" +
"60943702770539217176293176752384674818467669405132000568127145263560827785771342757789609173637178721468440901" +
"22495343014654958537105079227968925892354201995611212902196086403441815981362977477130996051870721134999999837" +
"29780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083" +
"81420617177669147303598253490428755468731159562863882353787593751957781857780532171226806613001927876611195909" +
"21642019893809525720106548586327886593615338182796823030195203530185296899577362259941389124972177528347913151" +
"55748572424541506959508295331168617278558890750983817546374649393192550604009277016711390098488240128583616035" +
"63707660104710181942955596198946767837449448255379774726847104047534646208046684259069491293313677028989152104" +
"75216205696602405803815019351125338243003558764024749647326391419927260426992279678235478163600934172164121992" +
"45863150302861829745557067498385054945885869269956909272107975093029553211653449872027559602364806654991198818" +
"34797753566369807426542527862551818417574672890977772793800081647060016145249192173217214772350141441973568548" +
"16136115735255213347574184946843852332390739414333454776241686251898356948556209921922218427255025425688767179" +
"04946016534668049886272327917860857843838279679766814541009538837863609506800642251252051173929848960841284886" +
"26945604241965285022210661186306744278622039194945047123713786960956364371917287467764657573962413890865832645" +
"99581339047802759009946576407895126946839835259570982582262052248940772671947826848260147699090264013639443745" +
"53050682034962524517493996514314298091906592509372216964615157098583874105978859597729754989301617539284681382" +
"68683868942774155991855925245953959431049972524680845987273644695848653836736222626099124608051243884390451244" +
"13654976278079771569143599770012961608944169486855584840635342207222582848864815845602850601684273945226746767" +
"88952521385225499546667278239864565961163548862305774564980355936345681743241125150760694794510965960940252288" +
"79710893145669136867228748940560101503308617928680920874760917824938589009714909675985261365549781893129784821" +
"68299894872265880485756401427047755513237964145152374623436454285844479526586782105114135473573952311342716610" +
"21359695362314429524849371871101457654035902799344037420073105785390621983874478084784896833214457138687519435" +
"06430218453191048481005370614680674919278191197939952061419663428754440643745123718192179998391015919561814675" +
"14269123974894090718649423196156794520809514655022523160388193014209376213785595663893778708303906979207734672" +
"21825625996615014215030680384477345492026054146659252014974428507325186660021324340881907104863317346496514539" +
"05796268561005508106658796998163574736384052571459102897064140110971206280439039759515677157700420337869936007" +
"23055876317635942187312514712053292819182618612586732157919841484882916447060957527069572209175671167229109816" +
"90915280173506712748583222871835209353965725121083579151369882091444210067510334671103141267111369908658516398" +
"31501970165151168517143765761835155650884909989859982387345528331635507647918535893226185489632132933089857064" +
"20467525907091548141654985946163718027098199430992448895757128289059232332609729971208443357326548938239119325" +
"97463667305836041428138830320382490375898524374417029132765618093773444030707469211201913020330380197621101100" +
"44929321516084244485963766983895228684783123552658213144957685726243344189303968642624341077322697802807318915" +
"44110104468232527162010526522721116603966655730925471105578537634668206531098965269186205647693125705863566201" +
"85581007293606598764861179104533488503461136576867532494416680396265797877185560845529654126654085306143444318" +
"58676975145661406800700237877659134401712749470420562230538994561314071127000407854733269939081454664645880797" +
"27082668306343285878569830523580893306575740679545716377525420211495576158140025012622859413021647155097925923" +
"09907965473761255176567513575178296664547791745011299614890304639947132962107340437518957359614589019389713111" +
"79042978285647503203198691514028708085990480109412147221317947647772622414254854540332157185306142288137585043" +
"06332175182979866223717215916077166925474873898665494945011465406284336639379003976926567214638530673609657120" +
"91807638327166416274888800786925602902284721040317211860820419000422966171196377921337575114959501566049631862" +
"94726547364252308177036751590673502350728354056704038674351362222477158915049530984448933309634087807693259939" +
"78054193414473774418426312986080998886874132604721569516239658645730216315981931951673538129741677294786724229" +
"24654366800980676928238280689964004824354037014163149658979409243237896907069779422362508221688957383798623001" +
"59377647165122893578601588161755782973523344604281512627203734314653197777416031990665541876397929334419521541" +
"34189948544473456738316249934191318148092777710386387734317720754565453220777092120190516609628049092636019759" +
"88281613323166636528619326686336062735676303544776280350450777235547105859548702790814356240145171806246436267" +
"94561275318134078330336254232783944975382437205835311477119926063813346776879695970309833913077109870408591337"
// Test case for BenchmarkScanPi.
func TestScanPi(t *testing.T) {
var x nat
z, _, err := x.scan(strings.NewReader(pi), 10)
if err != nil {
t.Errorf("scanning pi: %s", err)
}
if s := z.decimalString(); s != pi {
t.Errorf("scanning pi: got %s", s)
}
}
func BenchmarkScanPi(b *testing.B) {
for i := 0; i < b.N; i++ {
var x nat
x.scan(strings.NewReader(pi), 10)
}
}
const (
// 314**271
// base 2: 2249 digits
// base 8: 751 digits
// base 10: 678 digits
// base 16: 563 digits
shortBase = 314
shortExponent = 271
// 3141**2178
// base 2: 31577 digits
// base 8: 10527 digits
// base 10: 9507 digits
// base 16: 7895 digits
mediumBase = 3141
mediumExponent = 2718
// 3141**2178
// base 2: 406078 digits
// base 8: 135360 digits
// base 10: 122243 digits
// base 16: 101521 digits
longBase = 31415
longExponent = 27182
)
func BenchmarkScanShort2(b *testing.B) {
ScanHelper(b, 2, shortBase, shortExponent)
}
func BenchmarkScanShort8(b *testing.B) {
ScanHelper(b, 8, shortBase, shortExponent)
}
func BenchmarkScanSort10(b *testing.B) {
ScanHelper(b, 10, shortBase, shortExponent)
}
func BenchmarkScanShort16(b *testing.B) {
ScanHelper(b, 16, shortBase, shortExponent)
}
func BenchmarkScanMedium2(b *testing.B) {
ScanHelper(b, 2, mediumBase, mediumExponent)
}
func BenchmarkScanMedium8(b *testing.B) {
ScanHelper(b, 8, mediumBase, mediumExponent)
}
func BenchmarkScanMedium10(b *testing.B) {
ScanHelper(b, 10, mediumBase, mediumExponent)
}
func BenchmarkScanMedium16(b *testing.B) {
ScanHelper(b, 16, mediumBase, mediumExponent)
}
func BenchmarkScanLong2(b *testing.B) {
ScanHelper(b, 2, longBase, longExponent)
}
func BenchmarkScanLong8(b *testing.B) {
ScanHelper(b, 8, longBase, longExponent)
}
func BenchmarkScanLong10(b *testing.B) {
ScanHelper(b, 10, longBase, longExponent)
}
func BenchmarkScanLong16(b *testing.B) {
ScanHelper(b, 16, longBase, longExponent)
}
func ScanHelper(b *testing.B, base int, xv, yv Word) {
b.StopTimer()
var x, y, z nat
x = x.setWord(xv)
y = y.setWord(yv)
z = z.expNN(x, y, nil)
var s string
s = z.string(lowercaseDigits[0:base])
if t := toString(z, lowercaseDigits[0:base]); t != s {
panic(fmt.Sprintf("scanning: got %s; want %s", s, t))
}
b.StartTimer()
for i := 0; i < b.N; i++ {
x.scan(strings.NewReader(s), base)
}
}
func BenchmarkStringShort2(b *testing.B) {
StringHelper(b, 2, shortBase, shortExponent)
}
func BenchmarkStringShort8(b *testing.B) {
StringHelper(b, 8, shortBase, shortExponent)
}
func BenchmarkStringShort10(b *testing.B) {
StringHelper(b, 10, shortBase, shortExponent)
}
func BenchmarkStringShort16(b *testing.B) {
StringHelper(b, 16, shortBase, shortExponent)
}
func BenchmarkStringMedium2(b *testing.B) {
StringHelper(b, 2, mediumBase, mediumExponent)
}
func BenchmarkStringMedium8(b *testing.B) {
StringHelper(b, 8, mediumBase, mediumExponent)
}
func BenchmarkStringMedium10(b *testing.B) {
StringHelper(b, 10, mediumBase, mediumExponent)
}
func BenchmarkStringMedium16(b *testing.B) {
StringHelper(b, 16, mediumBase, mediumExponent)
}
func BenchmarkStringLong2(b *testing.B) {
StringHelper(b, 2, longBase, longExponent)
}
func BenchmarkStringLong8(b *testing.B) {
StringHelper(b, 8, longBase, longExponent)
}
func BenchmarkStringLong10(b *testing.B) {
StringHelper(b, 10, longBase, longExponent)
}
func BenchmarkStringLong16(b *testing.B) {
StringHelper(b, 16, longBase, longExponent)
}
func StringHelper(b *testing.B, base int, xv, yv Word) {
b.StopTimer()
var x, y, z nat
x = x.setWord(xv)
y = y.setWord(yv)
z = z.expNN(x, y, nil)
b.StartTimer()
for i := 0; i < b.N; i++ {
z.string(lowercaseDigits[0:base])
}
}
func TestLeadingZeros(t *testing.T) { func TestLeadingZeros(t *testing.T) {
var x Word = _B >> 1 var x Word = _B >> 1
@ -210,14 +534,12 @@ func TestLeadingZeros(t *testing.T) {
} }
} }
type shiftTest struct { type shiftTest struct {
in nat in nat
shift uint shift uint
out nat out nat
} }
var leftShiftTests = []shiftTest{ var leftShiftTests = []shiftTest{
{nil, 0, nil}, {nil, 0, nil},
{nil, 1, nil}, {nil, 1, nil},
@ -227,7 +549,6 @@ var leftShiftTests = []shiftTest{
{nat{1 << (_W - 1), 0}, 1, nat{0, 1}}, {nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
} }
func TestShiftLeft(t *testing.T) { func TestShiftLeft(t *testing.T) {
for i, test := range leftShiftTests { for i, test := range leftShiftTests {
var z nat var z nat
@ -241,7 +562,6 @@ func TestShiftLeft(t *testing.T) {
} }
} }
var rightShiftTests = []shiftTest{ var rightShiftTests = []shiftTest{
{nil, 0, nil}, {nil, 0, nil},
{nil, 1, nil}, {nil, 1, nil},
@ -252,7 +572,6 @@ var rightShiftTests = []shiftTest{
{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}}, {nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
} }
func TestShiftRight(t *testing.T) { func TestShiftRight(t *testing.T) {
for i, test := range rightShiftTests { for i, test := range rightShiftTests {
var z nat var z nat
@ -266,24 +585,20 @@ func TestShiftRight(t *testing.T) {
} }
} }
type modWTest struct { type modWTest struct {
in string in string
dividend string dividend string
out string out string
} }
var modWTests32 = []modWTest{ var modWTests32 = []modWTest{
{"23492635982634928349238759823742", "252341", "220170"}, {"23492635982634928349238759823742", "252341", "220170"},
} }
var modWTests64 = []modWTest{ var modWTests64 = []modWTest{
{"6527895462947293856291561095690465243862946", "524326975699234", "375066989628668"}, {"6527895462947293856291561095690465243862946", "524326975699234", "375066989628668"},
} }
func runModWTests(t *testing.T, tests []modWTest) { func runModWTests(t *testing.T, tests []modWTest) {
for i, test := range tests { for i, test := range tests {
in, _ := new(Int).SetString(test.in, 10) in, _ := new(Int).SetString(test.in, 10)
@ -297,7 +612,6 @@ func runModWTests(t *testing.T, tests []modWTest) {
} }
} }
func TestModW(t *testing.T) { func TestModW(t *testing.T) {
if _W >= 32 { if _W >= 32 {
runModWTests(t, modWTests32) runModWTests(t, modWTests32)
@ -307,7 +621,6 @@ func TestModW(t *testing.T) {
} }
} }
func TestTrailingZeroBits(t *testing.T) { func TestTrailingZeroBits(t *testing.T) {
var x Word var x Word
x-- x--
@ -319,7 +632,6 @@ func TestTrailingZeroBits(t *testing.T) {
} }
} }
var expNNTests = []struct { var expNNTests = []struct {
x, y, m string x, y, m string
out string out string
@ -337,17 +649,16 @@ var expNNTests = []struct {
}, },
} }
func TestExpNN(t *testing.T) { func TestExpNN(t *testing.T) {
for i, test := range expNNTests { for i, test := range expNNTests {
x, _, _ := nat(nil).scan(test.x, 0) x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
y, _, _ := nat(nil).scan(test.y, 0) y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
out, _, _ := nat(nil).scan(test.out, 0) out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
var m nat var m nat
if len(test.m) > 0 { if len(test.m) > 0 {
m, _, _ = nat(nil).scan(test.m, 0) m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
} }
z := nat(nil).expNN(x, y, m) z := nat(nil).expNN(x, y, m)

View File

@ -6,7 +6,12 @@
package big package big
import "strings" import (
"encoding/binary"
"fmt"
"os"
"strings"
)
// A Rat represents a quotient a/b of arbitrary precision. The zero value for // A Rat represents a quotient a/b of arbitrary precision. The zero value for
// a Rat, 0/0, is not a legal Rat. // a Rat, 0/0, is not a legal Rat.
@ -15,13 +20,11 @@ type Rat struct {
b nat b nat
} }
// NewRat creates a new Rat with numerator a and denominator b. // NewRat creates a new Rat with numerator a and denominator b.
func NewRat(a, b int64) *Rat { func NewRat(a, b int64) *Rat {
return new(Rat).SetFrac64(a, b) return new(Rat).SetFrac64(a, b)
} }
// SetFrac sets z to a/b and returns z. // SetFrac sets z to a/b and returns z.
func (z *Rat) SetFrac(a, b *Int) *Rat { func (z *Rat) SetFrac(a, b *Int) *Rat {
z.a.Set(a) z.a.Set(a)
@ -30,7 +33,6 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
return z.norm() return z.norm()
} }
// SetFrac64 sets z to a/b and returns z. // SetFrac64 sets z to a/b and returns z.
func (z *Rat) SetFrac64(a, b int64) *Rat { func (z *Rat) SetFrac64(a, b int64) *Rat {
z.a.SetInt64(a) z.a.SetInt64(a)
@ -42,7 +44,6 @@ func (z *Rat) SetFrac64(a, b int64) *Rat {
return z.norm() return z.norm()
} }
// SetInt sets z to x (by making a copy of x) and returns z. // SetInt sets z to x (by making a copy of x) and returns z.
func (z *Rat) SetInt(x *Int) *Rat { func (z *Rat) SetInt(x *Int) *Rat {
z.a.Set(x) z.a.Set(x)
@ -50,7 +51,6 @@ func (z *Rat) SetInt(x *Int) *Rat {
return z return z
} }
// SetInt64 sets z to x and returns z. // SetInt64 sets z to x and returns z.
func (z *Rat) SetInt64(x int64) *Rat { func (z *Rat) SetInt64(x int64) *Rat {
z.a.SetInt64(x) z.a.SetInt64(x)
@ -58,7 +58,6 @@ func (z *Rat) SetInt64(x int64) *Rat {
return z return z
} }
// Sign returns: // Sign returns:
// //
// -1 if x < 0 // -1 if x < 0
@ -69,13 +68,11 @@ func (x *Rat) Sign() int {
return x.a.Sign() return x.a.Sign()
} }
// IsInt returns true if the denominator of x is 1. // IsInt returns true if the denominator of x is 1.
func (x *Rat) IsInt() bool { func (x *Rat) IsInt() bool {
return len(x.b) == 1 && x.b[0] == 1 return len(x.b) == 1 && x.b[0] == 1
} }
// Num returns the numerator of z; it may be <= 0. // Num returns the numerator of z; it may be <= 0.
// The result is a reference to z's numerator; it // The result is a reference to z's numerator; it
// may change if a new value is assigned to z. // may change if a new value is assigned to z.
@ -83,15 +80,13 @@ func (z *Rat) Num() *Int {
return &z.a return &z.a
} }
// Denom returns the denominator of z; it is always > 0.
// Demom returns the denominator of z; it is always > 0.
// The result is a reference to z's denominator; it // The result is a reference to z's denominator; it
// may change if a new value is assigned to z. // may change if a new value is assigned to z.
func (z *Rat) Denom() *Int { func (z *Rat) Denom() *Int {
return &Int{false, z.b} return &Int{false, z.b}
} }
func gcd(x, y nat) nat { func gcd(x, y nat) nat {
// Euclidean algorithm. // Euclidean algorithm.
var a, b nat var a, b nat
@ -106,7 +101,6 @@ func gcd(x, y nat) nat {
return a return a
} }
func (z *Rat) norm() *Rat { func (z *Rat) norm() *Rat {
f := gcd(z.a.abs, z.b) f := gcd(z.a.abs, z.b)
if len(z.a.abs) == 0 { if len(z.a.abs) == 0 {
@ -122,7 +116,6 @@ func (z *Rat) norm() *Rat {
return z return z
} }
func mulNat(x *Int, y nat) *Int { func mulNat(x *Int, y nat) *Int {
var z Int var z Int
z.abs = z.abs.mul(x.abs, y) z.abs = z.abs.mul(x.abs, y)
@ -130,7 +123,6 @@ func mulNat(x *Int, y nat) *Int {
return &z return &z
} }
// Cmp compares x and y and returns: // Cmp compares x and y and returns:
// //
// -1 if x < y // -1 if x < y
@ -141,7 +133,6 @@ func (x *Rat) Cmp(y *Rat) (r int) {
return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b)) return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b))
} }
// Abs sets z to |x| (the absolute value of x) and returns z. // Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Rat) Abs(x *Rat) *Rat { func (z *Rat) Abs(x *Rat) *Rat {
z.a.Abs(&x.a) z.a.Abs(&x.a)
@ -149,7 +140,6 @@ func (z *Rat) Abs(x *Rat) *Rat {
return z return z
} }
// Add sets z to the sum x+y and returns z. // Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat { func (z *Rat) Add(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b) a1 := mulNat(&x.a, y.b)
@ -159,7 +149,6 @@ func (z *Rat) Add(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Sub sets z to the difference x-y and returns z. // Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat { func (z *Rat) Sub(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b) a1 := mulNat(&x.a, y.b)
@ -169,7 +158,6 @@ func (z *Rat) Sub(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Mul sets z to the product x*y and returns z. // Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat { func (z *Rat) Mul(x, y *Rat) *Rat {
z.a.Mul(&x.a, &y.a) z.a.Mul(&x.a, &y.a)
@ -177,7 +165,6 @@ func (z *Rat) Mul(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Quo sets z to the quotient x/y and returns z. // Quo sets z to the quotient x/y and returns z.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
func (z *Rat) Quo(x, y *Rat) *Rat { func (z *Rat) Quo(x, y *Rat) *Rat {
@ -192,7 +179,6 @@ func (z *Rat) Quo(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Neg sets z to -x (by making a copy of x if necessary) and returns z. // Neg sets z to -x (by making a copy of x if necessary) and returns z.
func (z *Rat) Neg(x *Rat) *Rat { func (z *Rat) Neg(x *Rat) *Rat {
z.a.Neg(&x.a) z.a.Neg(&x.a)
@ -200,7 +186,6 @@ func (z *Rat) Neg(x *Rat) *Rat {
return z return z
} }
// Set sets z to x (by making a copy of x if necessary) and returns z. // Set sets z to x (by making a copy of x if necessary) and returns z.
func (z *Rat) Set(x *Rat) *Rat { func (z *Rat) Set(x *Rat) *Rat {
z.a.Set(&x.a) z.a.Set(&x.a)
@ -208,6 +193,25 @@ func (z *Rat) Set(x *Rat) *Rat {
return z return z
} }
func ratTok(ch int) bool {
return strings.IndexRune("+-/0123456789.eE", ch) >= 0
}
// Scan is a support routine for fmt.Scanner. It accepts the formats
// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent.
func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error {
tok, err := s.Token(true, ratTok)
if err != nil {
return err
}
if strings.IndexRune("efgEFGv", ch) < 0 {
return os.NewError("Rat.Scan: invalid verb")
}
if _, ok := z.SetString(string(tok)); !ok {
return os.NewError("Rat.Scan: invalid syntax")
}
return nil
}
// SetString sets z to the value of s and returns z and a boolean indicating // SetString sets z to the value of s and returns z and a boolean indicating
// success. s can be given as a fraction "a/b" or as a floating-point number // success. s can be given as a fraction "a/b" or as a floating-point number
@ -225,8 +229,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, false return z, false
} }
s = s[sep+1:] s = s[sep+1:]
var n int var err os.Error
if z.b, _, n = z.b.scan(s, 10); n != len(s) { if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil {
return z, false return z, false
} }
return z.norm(), true return z.norm(), true
@ -267,13 +271,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, true return z, true
} }
// String returns a string representation of z in the form "a/b" (even if b == 1). // String returns a string representation of z in the form "a/b" (even if b == 1).
func (z *Rat) String() string { func (z *Rat) String() string {
return z.a.String() + "/" + z.b.string(10) return z.a.String() + "/" + z.b.decimalString()
} }
// RatString returns a string representation of z in the form "a/b" if b != 1, // RatString returns a string representation of z in the form "a/b" if b != 1,
// and in the form "a" if b == 1. // and in the form "a" if b == 1.
func (z *Rat) RatString() string { func (z *Rat) RatString() string {
@ -283,12 +285,15 @@ func (z *Rat) RatString() string {
return z.String() return z.String()
} }
// FloatString returns a string representation of z in decimal form with prec // FloatString returns a string representation of z in decimal form with prec
// digits of precision after the decimal point and the last digit rounded. // digits of precision after the decimal point and the last digit rounded.
func (z *Rat) FloatString(prec int) string { func (z *Rat) FloatString(prec int) string {
if z.IsInt() { if z.IsInt() {
return z.a.String() s := z.a.String()
if prec > 0 {
s += "." + strings.Repeat("0", prec)
}
return s
} }
q, r := nat{}.div(nat{}, z.a.abs, z.b) q, r := nat{}.div(nat{}, z.a.abs, z.b)
@ -311,16 +316,56 @@ func (z *Rat) FloatString(prec int) string {
} }
} }
s := q.string(10) s := q.decimalString()
if z.a.neg { if z.a.neg {
s = "-" + s s = "-" + s
} }
if prec > 0 { if prec > 0 {
rs := r.string(10) rs := r.decimalString()
leadingZeros := prec - len(rs) leadingZeros := prec - len(rs)
s += "." + strings.Repeat("0", leadingZeros) + rs s += "." + strings.Repeat("0", leadingZeros) + rs
} }
return s return s
} }
// Gob codec version. Permits backward-compatible changes to the encoding.
const ratGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (z *Rat) GobEncode() ([]byte, os.Error) {
buf := make([]byte, 1+4+(len(z.a.abs)+len(z.b))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
i := z.b.bytes(buf)
j := z.a.abs.bytes(buf[0:i])
n := i - j
if int(uint32(n)) != n {
// this should never happen
return nil, os.NewError("Rat.GobEncode: numerator too large")
}
binary.BigEndian.PutUint32(buf[j-4:j], uint32(n))
j -= 1 + 4
b := ratGobVersion << 1 // make space for sign bit
if z.a.neg {
b |= 1
}
buf[j] = b
return buf[j:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Rat) GobDecode(buf []byte) os.Error {
if len(buf) == 0 {
return os.NewError("Rat.GobDecode: no data")
}
b := buf[0]
if b>>1 != ratGobVersion {
return os.NewError(fmt.Sprintf("Rat.GobDecode: encoding version %d not supported", b>>1))
}
const j = 1 + 4
i := j + binary.BigEndian.Uint32(buf[j-4:j])
z.a.neg = b&1 != 0
z.a.abs = z.a.abs.setBytes(buf[j:i])
z.b = z.b.setBytes(buf[i:])
return nil
}

View File

@ -4,8 +4,12 @@
package big package big
import "testing" import (
"bytes"
"fmt"
"gob"
"testing"
)
var setStringTests = []struct { var setStringTests = []struct {
in, out string in, out string
@ -52,6 +56,27 @@ func TestRatSetString(t *testing.T) {
} }
} }
func TestRatScan(t *testing.T) {
var buf bytes.Buffer
for i, test := range setStringTests {
x := new(Rat)
buf.Reset()
buf.WriteString(test.in)
_, err := fmt.Fscanf(&buf, "%v", x)
if err == nil != test.ok {
if test.ok {
t.Errorf("#%d error: %s", i, err.String())
} else {
t.Errorf("#%d expected error", i)
}
continue
}
if err == nil && x.RatString() != test.out {
t.Errorf("#%d got %s want %s", i, x.RatString(), test.out)
}
}
}
var floatStringTests = []struct { var floatStringTests = []struct {
in string in string
@ -59,12 +84,13 @@ var floatStringTests = []struct {
out string out string
}{ }{
{"0", 0, "0"}, {"0", 0, "0"},
{"0", 4, "0"}, {"0", 4, "0.0000"},
{"1", 0, "1"}, {"1", 0, "1"},
{"1", 2, "1"}, {"1", 2, "1.00"},
{"-1", 0, "-1"}, {"-1", 0, "-1"},
{".25", 2, "0.25"}, {".25", 2, "0.25"},
{".25", 1, "0.3"}, {".25", 1, "0.3"},
{".25", 3, "0.250"},
{"-1/3", 3, "-0.333"}, {"-1/3", 3, "-0.333"},
{"-2/3", 4, "-0.6667"}, {"-2/3", 4, "-0.6667"},
{"0.96", 1, "1.0"}, {"0.96", 1, "1.0"},
@ -84,7 +110,6 @@ func TestFloatString(t *testing.T) {
} }
} }
func TestRatSign(t *testing.T) { func TestRatSign(t *testing.T) {
zero := NewRat(0, 1) zero := NewRat(0, 1)
for _, a := range setStringTests { for _, a := range setStringTests {
@ -98,7 +123,6 @@ func TestRatSign(t *testing.T) {
} }
} }
var ratCmpTests = []struct { var ratCmpTests = []struct {
rat1, rat2 string rat1, rat2 string
out int out int
@ -126,7 +150,6 @@ func TestRatCmp(t *testing.T) {
} }
} }
func TestIsInt(t *testing.T) { func TestIsInt(t *testing.T) {
one := NewInt(1) one := NewInt(1)
for _, a := range setStringTests { for _, a := range setStringTests {
@ -140,7 +163,6 @@ func TestIsInt(t *testing.T) {
} }
} }
func TestRatAbs(t *testing.T) { func TestRatAbs(t *testing.T) {
zero := NewRat(0, 1) zero := NewRat(0, 1)
for _, a := range setStringTests { for _, a := range setStringTests {
@ -158,7 +180,6 @@ func TestRatAbs(t *testing.T) {
} }
} }
type ratBinFun func(z, x, y *Rat) *Rat type ratBinFun func(z, x, y *Rat) *Rat
type ratBinArg struct { type ratBinArg struct {
x, y, z string x, y, z string
@ -175,7 +196,6 @@ func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) {
} }
} }
var ratBinTests = []struct { var ratBinTests = []struct {
x, y string x, y string
sum, prod string sum, prod string
@ -232,7 +252,6 @@ func TestRatBin(t *testing.T) {
} }
} }
func TestIssue820(t *testing.T) { func TestIssue820(t *testing.T) {
x := NewRat(3, 1) x := NewRat(3, 1)
y := NewRat(2, 1) y := NewRat(2, 1)
@ -258,7 +277,6 @@ func TestIssue820(t *testing.T) {
} }
} }
var setFrac64Tests = []struct { var setFrac64Tests = []struct {
a, b int64 a, b int64
out string out string
@ -280,3 +298,35 @@ func TestRatSetFrac64Rat(t *testing.T) {
} }
} }
} }
func TestRatGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
for i, test := range gobEncodingTests {
for j := 0; j < 4; j++ {
medium.Reset() // empty buffer for each test case (in case of failures)
stest := test
if j&1 != 0 {
// negative numbers
stest = "-" + test
}
if j%2 != 0 {
// fractions
stest = stest + "." + test
}
var tx Rat
tx.SetString(stest)
if err := enc.Encode(&tx); err != nil {
t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
}
var rx Rat
if err := dec.Decode(&rx); err != nil {
t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
}
if rx.Cmp(&tx) != 0 {
t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
}
}
}
}

View File

@ -15,16 +15,17 @@ import (
"utf8" "utf8"
) )
const ( const (
defaultBufSize = 4096 defaultBufSize = 4096
) )
// Errors introduced by this package. // Errors introduced by this package.
type Error struct { type Error struct {
os.ErrorString ErrorString string
} }
func (err *Error) String() string { return err.ErrorString }
var ( var (
ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"} ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"}
ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"} ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"}
@ -40,7 +41,6 @@ func (b BufSizeError) String() string {
return "bufio: bad buffer size " + strconv.Itoa(int(b)) return "bufio: bad buffer size " + strconv.Itoa(int(b))
} }
// Buffered input. // Buffered input.
// Reader implements buffering for an io.Reader object. // Reader implements buffering for an io.Reader object.
@ -101,6 +101,12 @@ func (b *Reader) fill() {
} }
} }
func (b *Reader) readErr() os.Error {
err := b.err
b.err = nil
return err
}
// Peek returns the next n bytes without advancing the reader. The bytes stop // Peek returns the next n bytes without advancing the reader. The bytes stop
// being valid at the next read call. If Peek returns fewer than n bytes, it // being valid at the next read call. If Peek returns fewer than n bytes, it
// also returns an error explaining why the read is short. The error is // also returns an error explaining why the read is short. The error is
@ -119,7 +125,7 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
if m > n { if m > n {
m = n m = n
} }
err := b.err err := b.readErr()
if m < n && err == nil { if m < n && err == nil {
err = ErrBufferFull err = ErrBufferFull
} }
@ -134,11 +140,11 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
func (b *Reader) Read(p []byte) (n int, err os.Error) { func (b *Reader) Read(p []byte) (n int, err os.Error) {
n = len(p) n = len(p)
if n == 0 { if n == 0 {
return 0, b.err return 0, b.readErr()
} }
if b.w == b.r { if b.w == b.r {
if b.err != nil { if b.err != nil {
return 0, b.err return 0, b.readErr()
} }
if len(p) >= len(b.buf) { if len(p) >= len(b.buf) {
// Large read, empty buffer. // Large read, empty buffer.
@ -148,11 +154,11 @@ func (b *Reader) Read(p []byte) (n int, err os.Error) {
b.lastByte = int(p[n-1]) b.lastByte = int(p[n-1])
b.lastRuneSize = -1 b.lastRuneSize = -1
} }
return n, b.err return n, b.readErr()
} }
b.fill() b.fill()
if b.w == b.r { if b.w == b.r {
return 0, b.err return 0, b.readErr()
} }
} }
@ -172,7 +178,7 @@ func (b *Reader) ReadByte() (c byte, err os.Error) {
b.lastRuneSize = -1 b.lastRuneSize = -1
for b.w == b.r { for b.w == b.r {
if b.err != nil { if b.err != nil {
return 0, b.err return 0, b.readErr()
} }
b.fill() b.fill()
} }
@ -208,7 +214,7 @@ func (b *Reader) ReadRune() (rune int, size int, err os.Error) {
} }
b.lastRuneSize = -1 b.lastRuneSize = -1
if b.r == b.w { if b.r == b.w {
return 0, 0, b.err return 0, 0, b.readErr()
} }
rune, size = int(b.buf[b.r]), 1 rune, size = int(b.buf[b.r]), 1
if rune >= 0x80 { if rune >= 0x80 {
@ -260,7 +266,7 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) {
if b.err != nil { if b.err != nil {
line := b.buf[b.r:b.w] line := b.buf[b.r:b.w]
b.r = b.w b.r = b.w
return line, b.err return line, b.readErr()
} }
n := b.Buffered() n := b.Buffered()
@ -367,7 +373,6 @@ func (b *Reader) ReadString(delim byte) (line string, err os.Error) {
return string(bytes), e return string(bytes), e
} }
// buffered output // buffered output
// Writer implements buffering for an io.Writer object. // Writer implements buffering for an io.Writer object.

View File

@ -53,11 +53,12 @@ func readBytes(buf *Reader) string {
if e == os.EOF { if e == os.EOF {
break break
} }
if e != nil { if e == nil {
b[nb] = c
nb++
} else if e != iotest.ErrTimeout {
panic("Data: " + e.String()) panic("Data: " + e.String())
} }
b[nb] = c
nb++
} }
return string(b[0:nb]) return string(b[0:nb])
} }
@ -75,7 +76,6 @@ func TestReaderSimple(t *testing.T) {
} }
} }
type readMaker struct { type readMaker struct {
name string name string
fn func(io.Reader) io.Reader fn func(io.Reader) io.Reader
@ -86,6 +86,7 @@ var readMakers = []readMaker{
{"byte", iotest.OneByteReader}, {"byte", iotest.OneByteReader},
{"half", iotest.HalfReader}, {"half", iotest.HalfReader},
{"data+err", iotest.DataErrReader}, {"data+err", iotest.DataErrReader},
{"timeout", iotest.TimeoutReader},
} }
// Call ReadString (which ends up calling everything else) // Call ReadString (which ends up calling everything else)
@ -97,7 +98,7 @@ func readLines(b *Reader) string {
if e == os.EOF { if e == os.EOF {
break break
} }
if e != nil { if e != nil && e != iotest.ErrTimeout {
panic("GetLines: " + e.String()) panic("GetLines: " + e.String())
} }
s += s1 s += s1

135
libgo/go/builtin/builtin.go Normal file
View File

@ -0,0 +1,135 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package builtin provides documentation for Go's built-in functions.
The functions documented here are not actually in package builtin
but their descriptions here allow godoc to present documentation
for the language's special functions.
*/
package builtin
// Type is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
type Type int
// IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc.
type IntegerType int
// FloatType is here for the purposes of documentation only. It is a stand-in
// for either float type: float32 or float64.
type FloatType int
// ComplexType is here for the purposes of documentation only. It is a
// stand-in for either complex type: complex64 or complex128.
type ComplexType int
// The append built-in function appends elements to the end of a slice. If
// it has sufficient capacity, the destination is resliced to accommodate the
// new elements. If it does not, a new underlying array will be allocated.
// Append returns the updated slice. It is therefore necessary to store the
// result of append, often in the variable holding the slice itself:
// slice = append(slice, elem1, elem2)
// slice = append(slice, anotherSlice...)
func append(slice []Type, elems ...Type) []Type
// The copy built-in function copies elements from a source slice into a
// destination slice. (As a special case, it also will copy bytes from a
// string to a slice of bytes.) The source and destination may overlap. Copy
// returns the number of elements copied, which will be the minimum of
// len(src) and len(dst).
func copy(dst, src []Type) int
// The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil).
// Slice, or map: the number of elements in v; if v is nil, len(v) is zero.
// String: the number of bytes in v.
// Channel: the number of elements queued (unread) in the channel buffer;
// if v is nil, len(v) is zero.
func len(v Type) int
// The cap built-in function returns the capacity of v, according to its type:
// Array: the number of elements in v (same as len(v)).
// Pointer to array: the number of elements in *v (same as len(v)).
// Slice: the maximum length the slice can reach when resliced;
// if v is nil, cap(v) is zero.
// Channel: the channel buffer capacity, in units of elements;
// if v is nil, cap(v) is zero.
func cap(v Type) int
// The make built-in function allocates and initializes an object of type
// slice, map, or chan (only). Like new, the first argument is a type, not a
// value. Unlike new, make's return type is the same as the type of its
// argument, not a pointer to it. The specification of the result depends on
// the type:
// Slice: The size specifies the length. The capacity of the slice is
// equal to its length. A second integer argument may be provided to
// specify a different capacity; it must be no smaller than the
// length, so make([]int, 0, 10) allocates a slice of length 0 and
// capacity 10.
// Map: An initial allocation is made according to the size but the
// resulting map has length 0. The size may be omitted, in which case
// a small starting size is allocated.
// Channel: The channel's buffer is initialized with the specified
// buffer capacity. If zero, or the size is omitted, the channel is
// unbuffered.
func make(Type, size IntegerType) Type
// The new built-in function allocates memory. The first argument is a type,
// not a value, and the value returned is a pointer to a newly
// allocated zero value of that type.
func new(Type) *Type
// The complex built-in function constructs a complex value from two
// floating-point values. The real and imaginary parts must be of the same
// size, either float32 or float64 (or assignable to them), and the return
// value will be the corresponding complex type (complex64 for float32,
// complex128 for float64).
func complex(r, i FloatType) ComplexType
// The real built-in function returns the real part of the complex number c.
// The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType
// The imaginary built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to
// the type of c.
func imag(c ComplexType) FloatType
// The close built-in function closes a channel, which must be either
// bidirectional or send-only. It should be executed only by the sender,
// never the receiver, and has the effect of shutting down the channel after
// the last sent value is received. After the last value has been received
// from a closed channel c, any receive from c will succeed without
// blocking, returning the zero value for the channel element. The form
// x, ok := <-c
// will also set ok to false for a closed channel.
func close(c chan<- Type)
// The panic built-in function stops normal execution of the current
// goroutine. When a function F calls panic, normal execution of F stops
// immediately. Any functions whose execution was deferred by F are run in
// the usual way, and then F returns to its caller. To the caller G, the
// invocation of F then behaves like a call to panic, terminating G's
// execution and running any deferred functions. This continues until all
// functions in the executing goroutine have stopped, in reverse order. At
// that point, the program is terminated and the error condition is reported,
// including the value of the argument to panic. This termination sequence
// is called panicking and can be controlled by the built-in function
// recover.
func panic(v interface{})
// The recover built-in function allows a program to manage behavior of a
// panicking goroutine. Executing a call to recover inside a deferred
// function (but not any function called by it) stops the panicking sequence
// by restoring normal execution and retrieves the error value passed to the
// call of panic. If recover is called outside the deferred function it will
// not stop a panicking sequence. In this case, or when the goroutine is not
// panicking, or if the argument supplied to panic was nil, recover returns
// nil. Thus the return value from recover reports whether the goroutine is
// panicking.
func recover() interface{}

View File

@ -280,7 +280,7 @@ func (b *Buffer) ReadRune() (r int, size int, err os.Error) {
// from any read operation.) // from any read operation.)
func (b *Buffer) UnreadRune() os.Error { func (b *Buffer) UnreadRune() os.Error {
if b.lastRead != opReadRune { if b.lastRead != opReadRune {
return os.ErrorString("bytes.Buffer: UnreadRune: previous operation was not ReadRune") return os.NewError("bytes.Buffer: UnreadRune: previous operation was not ReadRune")
} }
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off > 0 { if b.off > 0 {
@ -295,7 +295,7 @@ func (b *Buffer) UnreadRune() os.Error {
// returns an error. // returns an error.
func (b *Buffer) UnreadByte() os.Error { func (b *Buffer) UnreadByte() os.Error {
if b.lastRead != opReadRune && b.lastRead != opRead { if b.lastRead != opReadRune && b.lastRead != opRead {
return os.ErrorString("bytes.Buffer: UnreadByte: previous operation was not a read") return os.NewError("bytes.Buffer: UnreadByte: previous operation was not a read")
} }
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off > 0 { if b.off > 0 {

View File

@ -12,7 +12,6 @@ import (
"utf8" "utf8"
) )
const N = 10000 // make this bigger for a larger (and slower) test const N = 10000 // make this bigger for a larger (and slower) test
var data string // test data for write tests var data string // test data for write tests
var bytes []byte // test data; same as data but as a slice. var bytes []byte // test data; same as data but as a slice.
@ -47,7 +46,6 @@ func check(t *testing.T, testname string, buf *Buffer, s string) {
} }
} }
// Fill buf through n writes of string fus. // Fill buf through n writes of string fus.
// The initial contents of buf corresponds to the string s; // The initial contents of buf corresponds to the string s;
// the result is the final contents of buf returned as a string. // the result is the final contents of buf returned as a string.
@ -67,7 +65,6 @@ func fillString(t *testing.T, testname string, buf *Buffer, s string, n int, fus
return s return s
} }
// Fill buf through n writes of byte slice fub. // Fill buf through n writes of byte slice fub.
// The initial contents of buf corresponds to the string s; // The initial contents of buf corresponds to the string s;
// the result is the final contents of buf returned as a string. // the result is the final contents of buf returned as a string.
@ -87,19 +84,16 @@ func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub
return s return s
} }
func TestNewBuffer(t *testing.T) { func TestNewBuffer(t *testing.T) {
buf := NewBuffer(bytes) buf := NewBuffer(bytes)
check(t, "NewBuffer", buf, data) check(t, "NewBuffer", buf, data)
} }
func TestNewBufferString(t *testing.T) { func TestNewBufferString(t *testing.T) {
buf := NewBufferString(data) buf := NewBufferString(data)
check(t, "NewBufferString", buf, data) check(t, "NewBufferString", buf, data)
} }
// Empty buf through repeated reads into fub. // Empty buf through repeated reads into fub.
// The initial contents of buf corresponds to the string s. // The initial contents of buf corresponds to the string s.
func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) { func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
@ -120,7 +114,6 @@ func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
check(t, testname+" (empty 4)", buf, "") check(t, testname+" (empty 4)", buf, "")
} }
func TestBasicOperations(t *testing.T) { func TestBasicOperations(t *testing.T) {
var buf Buffer var buf Buffer
@ -175,7 +168,6 @@ func TestBasicOperations(t *testing.T) {
} }
} }
func TestLargeStringWrites(t *testing.T) { func TestLargeStringWrites(t *testing.T) {
var buf Buffer var buf Buffer
limit := 30 limit := 30
@ -189,7 +181,6 @@ func TestLargeStringWrites(t *testing.T) {
check(t, "TestLargeStringWrites (3)", &buf, "") check(t, "TestLargeStringWrites (3)", &buf, "")
} }
func TestLargeByteWrites(t *testing.T) { func TestLargeByteWrites(t *testing.T) {
var buf Buffer var buf Buffer
limit := 30 limit := 30
@ -203,7 +194,6 @@ func TestLargeByteWrites(t *testing.T) {
check(t, "TestLargeByteWrites (3)", &buf, "") check(t, "TestLargeByteWrites (3)", &buf, "")
} }
func TestLargeStringReads(t *testing.T) { func TestLargeStringReads(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
@ -213,7 +203,6 @@ func TestLargeStringReads(t *testing.T) {
check(t, "TestLargeStringReads (3)", &buf, "") check(t, "TestLargeStringReads (3)", &buf, "")
} }
func TestLargeByteReads(t *testing.T) { func TestLargeByteReads(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
@ -223,7 +212,6 @@ func TestLargeByteReads(t *testing.T) {
check(t, "TestLargeByteReads (3)", &buf, "") check(t, "TestLargeByteReads (3)", &buf, "")
} }
func TestMixedReadsAndWrites(t *testing.T) { func TestMixedReadsAndWrites(t *testing.T) {
var buf Buffer var buf Buffer
s := "" s := ""
@ -243,7 +231,6 @@ func TestMixedReadsAndWrites(t *testing.T) {
empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len())) empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len()))
} }
func TestNil(t *testing.T) { func TestNil(t *testing.T) {
var b *Buffer var b *Buffer
if b.String() != "<nil>" { if b.String() != "<nil>" {
@ -251,7 +238,6 @@ func TestNil(t *testing.T) {
} }
} }
func TestReadFrom(t *testing.T) { func TestReadFrom(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
@ -262,7 +248,6 @@ func TestReadFrom(t *testing.T) {
} }
} }
func TestWriteTo(t *testing.T) { func TestWriteTo(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
@ -273,7 +258,6 @@ func TestWriteTo(t *testing.T) {
} }
} }
func TestRuneIO(t *testing.T) { func TestRuneIO(t *testing.T) {
const NRune = 1000 const NRune = 1000
// Built a test array while we write the data // Built a test array while we write the data
@ -323,7 +307,6 @@ func TestRuneIO(t *testing.T) {
} }
} }
func TestNext(t *testing.T) { func TestNext(t *testing.T) {
b := []byte{0, 1, 2, 3, 4} b := []byte{0, 1, 2, 3, 4}
tmp := make([]byte, 5) tmp := make([]byte, 5)

View File

@ -212,24 +212,38 @@ func genSplit(s, sep []byte, sepSave, n int) [][]byte {
return a[0 : na+1] return a[0 : na+1]
} }
// Split slices s into subslices separated by sep and returns a slice of // SplitN slices s into subslices separated by sep and returns a slice of
// the subslices between those separators.
// If sep is empty, SplitN splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func SplitN(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
// SplitAfterN slices s into subslices after each instance of sep and
// returns a slice of those subslices.
// If sep is empty, SplitAfterN splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func SplitAfterN(s, sep []byte, n int) [][]byte {
return genSplit(s, sep, len(sep), n)
}
// Split slices s into all subslices separated by sep and returns a slice of
// the subslices between those separators. // the subslices between those separators.
// If sep is empty, Split splits after each UTF-8 sequence. // If sep is empty, Split splits after each UTF-8 sequence.
// The count determines the number of subslices to return: // It is equivalent to SplitN with a count of -1.
// n > 0: at most n subslices; the last subslice will be the unsplit remainder. func Split(s, sep []byte) [][]byte { return genSplit(s, sep, 0, -1) }
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func Split(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
// SplitAfter slices s into subslices after each instance of sep and // SplitAfter slices s into all subslices after each instance of sep and
// returns a slice of those subslices. // returns a slice of those subslices.
// If sep is empty, Split splits after each UTF-8 sequence. // If sep is empty, SplitAfter splits after each UTF-8 sequence.
// The count determines the number of subslices to return: // It is equivalent to SplitAfterN with a count of -1.
// n > 0: at most n subslices; the last subslice will be the unsplit remainder. func SplitAfter(s, sep []byte) [][]byte {
// n == 0: the result is nil (zero subslices) return genSplit(s, sep, len(sep), -1)
// n < 0: all subslices
func SplitAfter(s, sep []byte, n int) [][]byte {
return genSplit(s, sep, len(sep), n)
} }
// Fields splits the array s around each instance of one or more consecutive white space // Fields splits the array s around each instance of one or more consecutive white space
@ -384,7 +398,6 @@ func ToTitleSpecial(_case unicode.SpecialCase, s []byte) []byte {
return Map(func(r int) int { return _case.ToTitle(r) }, s) return Map(func(r int) int { return _case.ToTitle(r) }, s)
} }
// isSeparator reports whether the rune could mark a word boundary. // isSeparator reports whether the rune could mark a word boundary.
// TODO: update when package unicode captures more of the properties. // TODO: update when package unicode captures more of the properties.
func isSeparator(rune int) bool { func isSeparator(rune int) bool {

View File

@ -6,6 +6,7 @@ package bytes_test
import ( import (
. "bytes" . "bytes"
"reflect"
"testing" "testing"
"unicode" "unicode"
"utf8" "utf8"
@ -315,7 +316,7 @@ var explodetests = []ExplodeTest{
func TestExplode(t *testing.T) { func TestExplode(t *testing.T) {
for _, tt := range explodetests { for _, tt := range explodetests {
a := Split([]byte(tt.s), nil, tt.n) a := SplitN([]byte(tt.s), nil, tt.n)
result := arrayOfString(a) result := arrayOfString(a)
if !eq(result, tt.a) { if !eq(result, tt.a) {
t.Errorf(`Explode("%s", %d) = %v; want %v`, tt.s, tt.n, result, tt.a) t.Errorf(`Explode("%s", %d) = %v; want %v`, tt.s, tt.n, result, tt.a)
@ -328,7 +329,6 @@ func TestExplode(t *testing.T) {
} }
} }
type SplitTest struct { type SplitTest struct {
s string s string
sep string sep string
@ -354,7 +354,7 @@ var splittests = []SplitTest{
func TestSplit(t *testing.T) { func TestSplit(t *testing.T) {
for _, tt := range splittests { for _, tt := range splittests {
a := Split([]byte(tt.s), []byte(tt.sep), tt.n) a := SplitN([]byte(tt.s), []byte(tt.sep), tt.n)
result := arrayOfString(a) result := arrayOfString(a)
if !eq(result, tt.a) { if !eq(result, tt.a) {
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a) t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
@ -367,6 +367,12 @@ func TestSplit(t *testing.T) {
if string(s) != tt.s { if string(s) != tt.s {
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s) t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
} }
if tt.n < 0 {
b := Split([]byte(tt.s), []byte(tt.sep))
if !reflect.DeepEqual(a, b) {
t.Errorf("Split disagrees withSplitN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
}
}
} }
} }
@ -388,7 +394,7 @@ var splitaftertests = []SplitTest{
func TestSplitAfter(t *testing.T) { func TestSplitAfter(t *testing.T) {
for _, tt := range splitaftertests { for _, tt := range splitaftertests {
a := SplitAfter([]byte(tt.s), []byte(tt.sep), tt.n) a := SplitAfterN([]byte(tt.s), []byte(tt.sep), tt.n)
result := arrayOfString(a) result := arrayOfString(a)
if !eq(result, tt.a) { if !eq(result, tt.a) {
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a) t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
@ -398,6 +404,12 @@ func TestSplitAfter(t *testing.T) {
if string(s) != tt.s { if string(s) != tt.s {
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s) t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
} }
if tt.n < 0 {
b := SplitAfter([]byte(tt.s), []byte(tt.sep))
if !reflect.DeepEqual(a, b) {
t.Errorf("SplitAfter disagrees withSplitAfterN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
}
}
} }
} }
@ -649,7 +661,6 @@ func TestRunes(t *testing.T) {
} }
} }
type TrimTest struct { type TrimTest struct {
f func([]byte, string) []byte f func([]byte, string) []byte
in, cutset, out string in, cutset, out string

View File

@ -284,7 +284,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
repeat := 0 repeat := 0
repeat_power := 0 repeat_power := 0
// The `C' array (used by the inverse BWT) needs to be zero initialised. // The `C' array (used by the inverse BWT) needs to be zero initialized.
for i := range bz2.c { for i := range bz2.c {
bz2.c[i] = 0 bz2.c[i] = 0
} }
@ -330,7 +330,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
if int(v) == numSymbols-1 { if int(v) == numSymbols-1 {
// This is the EOF symbol. Because it's always at the // This is the EOF symbol. Because it's always at the
// end of the move-to-front list, and nevers gets moved // end of the move-to-front list, and never gets moved
// to the front, it has this unique value. // to the front, it has this unique value.
break break
} }

View File

@ -68,7 +68,7 @@ func newHuffmanTree(lengths []uint8) (huffmanTree, os.Error) {
// each symbol (consider reflecting a tree down the middle, for // each symbol (consider reflecting a tree down the middle, for
// example). Since the code length assignments determine the // example). Since the code length assignments determine the
// efficiency of the tree, each of these trees is equally good. In // efficiency of the tree, each of these trees is equally good. In
// order to minimise the amount of information needed to build a tree // order to minimize the amount of information needed to build a tree
// bzip2 uses a canonical tree so that it can be reconstructed given // bzip2 uses a canonical tree so that it can be reconstructed given
// only the code length assignments. // only the code length assignments.

View File

@ -11,16 +11,18 @@ import (
) )
const ( const (
NoCompression = 0 NoCompression = 0
BestSpeed = 1 BestSpeed = 1
fastCompression = 3 fastCompression = 3
BestCompression = 9 BestCompression = 9
DefaultCompression = -1 DefaultCompression = -1
logMaxOffsetSize = 15 // Standard DEFLATE logWindowSize = 15
wideLogMaxOffsetSize = 22 // Wide DEFLATE windowSize = 1 << logWindowSize
minMatchLength = 3 // The smallest match that the compressor looks for windowMask = windowSize - 1
maxMatchLength = 258 // The longest match for the compressor logMaxOffsetSize = 15 // Standard DEFLATE
minOffsetSize = 1 // The shortest offset that makes any sence minMatchLength = 3 // The smallest match that the compressor looks for
maxMatchLength = 258 // The longest match for the compressor
minOffsetSize = 1 // The shortest offset that makes any sence
// The maximum number of tokens we put into a single flat block, just too // The maximum number of tokens we put into a single flat block, just too
// stop things from getting too large. // stop things from getting too large.
@ -32,22 +34,6 @@ const (
hashShift = (hashBits + minMatchLength - 1) / minMatchLength hashShift = (hashBits + minMatchLength - 1) / minMatchLength
) )
type syncPipeReader struct {
*io.PipeReader
closeChan chan bool
}
func (sr *syncPipeReader) CloseWithError(err os.Error) os.Error {
retErr := sr.PipeReader.CloseWithError(err)
sr.closeChan <- true // finish writer close
return retErr
}
type syncPipeWriter struct {
*io.PipeWriter
closeChan chan bool
}
type compressionLevel struct { type compressionLevel struct {
good, lazy, nice, chain, fastSkipHashing int good, lazy, nice, chain, fastSkipHashing int
} }
@ -68,105 +54,73 @@ var levels = []compressionLevel{
{32, 258, 258, 4096, math.MaxInt32}, {32, 258, 258, 4096, math.MaxInt32},
} }
func (sw *syncPipeWriter) Close() os.Error {
err := sw.PipeWriter.Close()
<-sw.closeChan // wait for reader close
return err
}
func syncPipe() (*syncPipeReader, *syncPipeWriter) {
r, w := io.Pipe()
sr := &syncPipeReader{r, make(chan bool, 1)}
sw := &syncPipeWriter{w, sr.closeChan}
return sr, sw
}
type compressor struct { type compressor struct {
level int compressionLevel
logWindowSize uint
w *huffmanBitWriter
r io.Reader
// (1 << logWindowSize) - 1.
windowMask int
eof bool // has eof been reached on input? w *huffmanBitWriter
sync bool // writer wants to flush
syncChan chan os.Error
// compression algorithm
fill func(*compressor, []byte) int // copy data to window
step func(*compressor) // process window
sync bool // requesting flush
// Input hash chains
// hashHead[hashValue] contains the largest inputIndex with the specified hash value // hashHead[hashValue] contains the largest inputIndex with the specified hash value
hashHead []int
// If hashHead[hashValue] is within the current window, then // If hashHead[hashValue] is within the current window, then
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index // hashPrev[hashHead[hashValue] & windowMask] contains the previous index
// with the same hash value. // with the same hash value.
hashPrev []int chainHead int
hashHead []int
hashPrev []int
// If we find a match of length >= niceMatch, then we don't bother searching // input window: unprocessed data is window[index:windowEnd]
// any further. index int
niceMatch int window []byte
windowEnd int
blockStart int // window index where current tokens start
byteAvailable bool // if true, still need to process window[index-1].
// If we find a match of length >= goodMatch, we only do a half-hearted // queued output tokens: tokens[:ti]
// effort at doing lazy matching starting at the next character tokens []token
goodMatch int ti int
// The maximum number of chains we look at when finding a match // deflate state
maxChainLength int length int
offset int
// The sliding window we use for matching hash int
window []byte maxInsertIndex int
err os.Error
// The index just past the last valid character
windowEnd int
// index in "window" at which current block starts
blockStart int
} }
func (d *compressor) flush() os.Error { func (d *compressor) fillDeflate(b []byte) int {
d.w.flush() if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
return d.w.err // shift the window by windowSize
} copy(d.window, d.window[windowSize:2*windowSize])
d.index -= windowSize
func (d *compressor) fillWindow(index int) (int, os.Error) { d.windowEnd -= windowSize
if d.sync { if d.blockStart >= windowSize {
return index, nil d.blockStart -= windowSize
}
wSize := d.windowMask + 1
if index >= wSize+wSize-(minMatchLength+maxMatchLength) {
// shift the window by wSize
copy(d.window, d.window[wSize:2*wSize])
index -= wSize
d.windowEnd -= wSize
if d.blockStart >= wSize {
d.blockStart -= wSize
} else { } else {
d.blockStart = math.MaxInt32 d.blockStart = math.MaxInt32
} }
for i, h := range d.hashHead { for i, h := range d.hashHead {
v := h - wSize v := h - windowSize
if v < -1 { if v < -1 {
v = -1 v = -1
} }
d.hashHead[i] = v d.hashHead[i] = v
} }
for i, h := range d.hashPrev { for i, h := range d.hashPrev {
v := -h - wSize v := -h - windowSize
if v < -1 { if v < -1 {
v = -1 v = -1
} }
d.hashPrev[i] = v d.hashPrev[i] = v
} }
} }
count, err := d.r.Read(d.window[d.windowEnd:]) n := copy(d.window[d.windowEnd:], b)
d.windowEnd += count d.windowEnd += n
if count == 0 && err == nil { return n
d.sync = true
}
if err == os.EOF {
d.eof = true
err = nil
}
return index, err
} }
func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error { func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error {
@ -194,21 +148,21 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
// We quit when we get a match that's at least nice long // We quit when we get a match that's at least nice long
nice := len(win) - pos nice := len(win) - pos
if d.niceMatch < nice { if d.nice < nice {
nice = d.niceMatch nice = d.nice
} }
// If we've got a match that's good enough, only look in 1/4 the chain. // If we've got a match that's good enough, only look in 1/4 the chain.
tries := d.maxChainLength tries := d.chain
length = prevLength length = prevLength
if length >= d.goodMatch { if length >= d.good {
tries >>= 2 tries >>= 2
} }
w0 := win[pos] w0 := win[pos]
w1 := win[pos+1] w1 := win[pos+1]
wEnd := win[pos+length] wEnd := win[pos+length]
minIndex := pos - (d.windowMask + 1) minIndex := pos - windowSize
for i := prevHead; tries > 0; tries-- { for i := prevHead; tries > 0; tries-- {
if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] { if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] {
@ -233,7 +187,7 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
// hashPrev[i & windowMask] has already been overwritten, so stop now. // hashPrev[i & windowMask] has already been overwritten, so stop now.
break break
} }
if i = d.hashPrev[i&d.windowMask]; i < minIndex || i < 0 { if i = d.hashPrev[i&windowMask]; i < minIndex || i < 0 {
break break
} }
} }
@ -248,234 +202,224 @@ func (d *compressor) writeStoredBlock(buf []byte) os.Error {
return d.w.err return d.w.err
} }
func (d *compressor) storedDeflate() os.Error { func (d *compressor) initDeflate() {
buf := make([]byte, maxStoreBlockSize) d.hashHead = make([]int, hashSize)
for { d.hashPrev = make([]int, windowSize)
n, err := d.r.Read(buf) d.window = make([]byte, 2*windowSize)
if n == 0 && err == nil { fillInts(d.hashHead, -1)
d.sync = true d.tokens = make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
} d.length = minMatchLength - 1
if n > 0 || d.sync { d.offset = 0
if err := d.writeStoredBlock(buf[0:n]); err != nil { d.byteAvailable = false
return err d.index = 0
} d.ti = 0
if d.sync { d.hash = 0
d.syncChan <- nil d.chainHead = -1
d.sync = false
}
}
if err != nil {
if err == os.EOF {
break
}
return err
}
}
return nil
} }
func (d *compressor) doDeflate() (err os.Error) { func (d *compressor) deflate() {
// init if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
d.windowMask = 1<<d.logWindowSize - 1
d.hashHead = make([]int, hashSize)
d.hashPrev = make([]int, 1<<d.logWindowSize)
d.window = make([]byte, 2<<d.logWindowSize)
fillInts(d.hashHead, -1)
tokens := make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
l := levels[d.level]
d.goodMatch = l.good
d.niceMatch = l.nice
d.maxChainLength = l.chain
lazyMatch := l.lazy
length := minMatchLength - 1
offset := 0
byteAvailable := false
isFastDeflate := l.fastSkipHashing != 0
index := 0
// run
if index, err = d.fillWindow(index); err != nil {
return return
} }
maxOffset := d.windowMask + 1 // (1 << logWindowSize);
// only need to change when you refill the window
windowEnd := d.windowEnd
maxInsertIndex := windowEnd - (minMatchLength - 1)
ti := 0
hash := int(0) d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
if index < maxInsertIndex { if d.index < d.maxInsertIndex {
hash = int(d.window[index])<<hashShift + int(d.window[index+1]) d.hash = int(d.window[d.index])<<hashShift + int(d.window[d.index+1])
} }
chainHead := -1
Loop: Loop:
for { for {
if index > windowEnd { if d.index > d.windowEnd {
panic("index > windowEnd") panic("index > windowEnd")
} }
lookahead := windowEnd - index lookahead := d.windowEnd - d.index
if lookahead < minMatchLength+maxMatchLength { if lookahead < minMatchLength+maxMatchLength {
if index, err = d.fillWindow(index); err != nil { if !d.sync {
return break Loop
} }
windowEnd = d.windowEnd if d.index > d.windowEnd {
if index > windowEnd {
panic("index > windowEnd") panic("index > windowEnd")
} }
maxInsertIndex = windowEnd - (minMatchLength - 1)
lookahead = windowEnd - index
if lookahead == 0 { if lookahead == 0 {
// Flush current output block if any. // Flush current output block if any.
if byteAvailable { if d.byteAvailable {
// There is still one pending token that needs to be flushed // There is still one pending token that needs to be flushed
tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF) d.tokens[d.ti] = literalToken(uint32(d.window[d.index-1]))
ti++ d.ti++
byteAvailable = false d.byteAvailable = false
} }
if ti > 0 { if d.ti > 0 {
if err = d.writeBlock(tokens[0:ti], index, false); err != nil { if d.err = d.writeBlock(d.tokens[0:d.ti], d.index, false); d.err != nil {
return return
} }
ti = 0 d.ti = 0
}
if d.sync {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.syncChan <- d.w.err
d.sync = false
}
// If this was only a sync (not at EOF) keep going.
if !d.eof {
continue
} }
break Loop break Loop
} }
} }
if index < maxInsertIndex { if d.index < d.maxInsertIndex {
// Update the hash // Update the hash
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
chainHead = d.hashHead[hash] d.chainHead = d.hashHead[d.hash]
d.hashPrev[index&d.windowMask] = chainHead d.hashPrev[d.index&windowMask] = d.chainHead
d.hashHead[hash] = index d.hashHead[d.hash] = d.index
} }
prevLength := length prevLength := d.length
prevOffset := offset prevOffset := d.offset
length = minMatchLength - 1 d.length = minMatchLength - 1
offset = 0 d.offset = 0
minIndex := index - maxOffset minIndex := d.index - windowSize
if minIndex < 0 { if minIndex < 0 {
minIndex = 0 minIndex = 0
} }
if chainHead >= minIndex && if d.chainHead >= minIndex &&
(isFastDeflate && lookahead > minMatchLength-1 || (d.fastSkipHashing != 0 && lookahead > minMatchLength-1 ||
!isFastDeflate && lookahead > prevLength && prevLength < lazyMatch) { d.fastSkipHashing == 0 && lookahead > prevLength && prevLength < d.lazy) {
if newLength, newOffset, ok := d.findMatch(index, chainHead, minMatchLength-1, lookahead); ok { if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead, minMatchLength-1, lookahead); ok {
length = newLength d.length = newLength
offset = newOffset d.offset = newOffset
} }
} }
if isFastDeflate && length >= minMatchLength || if d.fastSkipHashing != 0 && d.length >= minMatchLength ||
!isFastDeflate && prevLength >= minMatchLength && length <= prevLength { d.fastSkipHashing == 0 && prevLength >= minMatchLength && d.length <= prevLength {
// There was a match at the previous step, and the current match is // There was a match at the previous step, and the current match is
// not better. Output the previous match. // not better. Output the previous match.
if isFastDeflate { if d.fastSkipHashing != 0 {
tokens[ti] = matchToken(uint32(length-minMatchLength), uint32(offset-minOffsetSize)) d.tokens[d.ti] = matchToken(uint32(d.length-minMatchLength), uint32(d.offset-minOffsetSize))
} else { } else {
tokens[ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize)) d.tokens[d.ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
} }
ti++ d.ti++
// Insert in the hash table all strings up to the end of the match. // Insert in the hash table all strings up to the end of the match.
// index and index-1 are already inserted. If there is not enough // index and index-1 are already inserted. If there is not enough
// lookahead, the last two strings are not inserted into the hash // lookahead, the last two strings are not inserted into the hash
// table. // table.
if length <= l.fastSkipHashing { if d.length <= d.fastSkipHashing {
var newIndex int var newIndex int
if isFastDeflate { if d.fastSkipHashing != 0 {
newIndex = index + length newIndex = d.index + d.length
} else { } else {
newIndex = prevLength - 1 newIndex = prevLength - 1
} }
for index++; index < newIndex; index++ { for d.index++; d.index < newIndex; d.index++ {
if index < maxInsertIndex { if d.index < d.maxInsertIndex {
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
// Get previous value with the same hash. // Get previous value with the same hash.
// Our chain should point to the previous value. // Our chain should point to the previous value.
d.hashPrev[index&d.windowMask] = d.hashHead[hash] d.hashPrev[d.index&windowMask] = d.hashHead[d.hash]
// Set the head of the hash chain to us. // Set the head of the hash chain to us.
d.hashHead[hash] = index d.hashHead[d.hash] = d.index
} }
} }
if !isFastDeflate { if d.fastSkipHashing == 0 {
byteAvailable = false d.byteAvailable = false
length = minMatchLength - 1 d.length = minMatchLength - 1
} }
} else { } else {
// For matches this long, we don't bother inserting each individual // For matches this long, we don't bother inserting each individual
// item into the table. // item into the table.
index += length d.index += d.length
hash = (int(d.window[index])<<hashShift + int(d.window[index+1])) d.hash = (int(d.window[d.index])<<hashShift + int(d.window[d.index+1]))
} }
if ti == maxFlateBlockTokens { if d.ti == maxFlateBlockTokens {
// The block includes the current character // The block includes the current character
if err = d.writeBlock(tokens, index, false); err != nil { if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
return return
} }
ti = 0 d.ti = 0
} }
} else { } else {
if isFastDeflate || byteAvailable { if d.fastSkipHashing != 0 || d.byteAvailable {
i := index - 1 i := d.index - 1
if isFastDeflate { if d.fastSkipHashing != 0 {
i = index i = d.index
} }
tokens[ti] = literalToken(uint32(d.window[i]) & 0xFF) d.tokens[d.ti] = literalToken(uint32(d.window[i]))
ti++ d.ti++
if ti == maxFlateBlockTokens { if d.ti == maxFlateBlockTokens {
if err = d.writeBlock(tokens, i+1, false); err != nil { if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
return return
} }
ti = 0 d.ti = 0
} }
} }
index++ d.index++
if !isFastDeflate { if d.fastSkipHashing == 0 {
byteAvailable = true d.byteAvailable = true
} }
} }
} }
return
} }
func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) { func (d *compressor) fillStore(b []byte) int {
d.r = r n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) store() {
if d.windowEnd > 0 {
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
}
d.windowEnd = 0
}
func (d *compressor) write(b []byte) (n int, err os.Error) {
n = len(b)
b = b[d.fill(d, b):]
for len(b) > 0 {
d.step(d)
b = b[d.fill(d, b):]
}
return n, d.err
}
func (d *compressor) syncFlush() os.Error {
d.sync = true
d.step(d)
if d.err == nil {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.err = d.w.err
}
d.sync = false
return d.err
}
func (d *compressor) init(w io.Writer, level int) (err os.Error) {
d.w = newHuffmanBitWriter(w) d.w = newHuffmanBitWriter(w)
d.level = level
d.logWindowSize = logWindowSize
switch { switch {
case level == NoCompression: case level == NoCompression:
err = d.storedDeflate() d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).store
case level == DefaultCompression: case level == DefaultCompression:
d.level = 6 level = 6
fallthrough fallthrough
case 1 <= level && level <= 9: case 1 <= level && level <= 9:
err = d.doDeflate() d.compressionLevel = levels[level]
d.initDeflate()
d.fill = (*compressor).fillDeflate
d.step = (*compressor).deflate
default: default:
return WrongValueError{"level", 0, 9, int32(level)} return WrongValueError{"level", 0, 9, int32(level)}
} }
return nil
}
if d.sync { func (d *compressor) close() os.Error {
d.syncChan <- err d.sync = true
d.sync = false d.step(d)
} if d.err != nil {
if err != nil { return d.err
return err
} }
if d.w.writeStoredHeader(0, true); d.w.err != nil { if d.w.writeStoredHeader(0, true); d.w.err != nil {
return d.w.err return d.w.err
} }
return d.flush() d.w.flush()
return d.w.err
} }
// NewWriter returns a new Writer compressing // NewWriter returns a new Writer compressing
@ -486,14 +430,9 @@ func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize
// compression; it only adds the necessary DEFLATE framing. // compression; it only adds the necessary DEFLATE framing.
func NewWriter(w io.Writer, level int) *Writer { func NewWriter(w io.Writer, level int) *Writer {
const logWindowSize = logMaxOffsetSize const logWindowSize = logMaxOffsetSize
var d compressor var dw Writer
d.syncChan = make(chan os.Error, 1) dw.d.init(w, level)
pr, pw := syncPipe() return &dw
go func() {
err := d.compress(pr, w, level, logWindowSize)
pr.CloseWithError(err)
}()
return &Writer{pw, &d}
} }
// NewWriterDict is like NewWriter but initializes the new // NewWriterDict is like NewWriter but initializes the new
@ -526,18 +465,13 @@ func (w *dictWriter) Write(b []byte) (n int, err os.Error) {
// A Writer takes data written to it and writes the compressed // A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter). // form of that data to an underlying writer (see NewWriter).
type Writer struct { type Writer struct {
w *syncPipeWriter d compressor
d *compressor
} }
// Write writes data to w, which will eventually write the // Write writes data to w, which will eventually write the
// compressed form of data to its underlying writer. // compressed form of data to its underlying writer.
func (w *Writer) Write(data []byte) (n int, err os.Error) { func (w *Writer) Write(data []byte) (n int, err os.Error) {
if len(data) == 0 { return w.d.write(data)
// no point, and nil interferes with sync
return
}
return w.w.Write(data)
} }
// Flush flushes any pending compressed data to the underlying writer. // Flush flushes any pending compressed data to the underlying writer.
@ -550,18 +484,10 @@ func (w *Writer) Write(data []byte) (n int, err os.Error) {
func (w *Writer) Flush() os.Error { func (w *Writer) Flush() os.Error {
// For more about flushing: // For more about flushing:
// http://www.bolet.org/~pornin/deflate-flush.html // http://www.bolet.org/~pornin/deflate-flush.html
if w.d.sync { return w.d.syncFlush()
panic("compress/flate: double Flush")
}
_, err := w.w.Write(nil)
err1 := <-w.d.syncChan
if err == nil {
err = err1
}
return err
} }
// Close flushes and closes the writer. // Close flushes and closes the writer.
func (w *Writer) Close() os.Error { func (w *Writer) Close() os.Error {
return w.w.Close() return w.d.close()
} }

View File

@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{
&deflateInflateTest{[]byte{0x11, 0x12}}, &deflateInflateTest{[]byte{0x11, 0x12}},
&deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}}, &deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
&deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}}, &deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
&deflateInflateTest{getLargeDataChunk()}, &deflateInflateTest{largeDataChunk()},
} }
var reverseBitsTests = []*reverseBitsTest{ var reverseBitsTests = []*reverseBitsTest{
@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{
&reverseBitsTest{29, 5, 23}, &reverseBitsTest{29, 5, 23},
} }
func getLargeDataChunk() []byte { func largeDataChunk() []byte {
result := make([]byte, 100000) result := make([]byte, 100000)
for i := range result { for i := range result {
result[i] = byte(int64(i) * int64(i) & 0xFF) result[i] = byte(i * i & 0xFF)
} }
return result return result
} }
func TestDeflate(t *testing.T) { func TestDeflate(t *testing.T) {
for _, h := range deflateTests { for _, h := range deflateTests {
buffer := bytes.NewBuffer(nil) var buf bytes.Buffer
w := NewWriter(buffer, h.level) w := NewWriter(&buf, h.level)
w.Write(h.in) w.Write(h.in)
w.Close() w.Close()
if bytes.Compare(buffer.Bytes(), h.out) != 0 { if !bytes.Equal(buf.Bytes(), h.out) {
t.Errorf("buffer is wrong; level = %v, buffer.Bytes() = %v, expected output = %v", t.Errorf("Deflate(%d, %x) = %x, want %x", h.level, h.in, buf.Bytes(), h.out)
h.level, buffer.Bytes(), h.out)
} }
} }
} }
@ -226,7 +225,6 @@ func testSync(t *testing.T, level int, input []byte, name string) {
} }
} }
func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error { func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error {
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)
w := NewWriter(buffer, level) w := NewWriter(buffer, level)

View File

@ -15,9 +15,6 @@ const (
// The largest offset code. // The largest offset code.
offsetCodeCount = 30 offsetCodeCount = 30
// The largest offset code in the extensions.
extendedOffsetCodeCount = 42
// The special code used to mark the end of a block. // The special code used to mark the end of a block.
endBlockMarker = 256 endBlockMarker = 256
@ -100,11 +97,11 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{ return &huffmanBitWriter{
w: w, w: w,
literalFreq: make([]int32, maxLit), literalFreq: make([]int32, maxLit),
offsetFreq: make([]int32, extendedOffsetCodeCount), offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxLit+extendedOffsetCodeCount+1), codegen: make([]uint8, maxLit+offsetCodeCount+1),
codegenFreq: make([]int32, codegenCodeCount), codegenFreq: make([]int32, codegenCodeCount),
literalEncoding: newHuffmanEncoder(maxLit), literalEncoding: newHuffmanEncoder(maxLit),
offsetEncoding: newHuffmanEncoder(extendedOffsetCodeCount), offsetEncoding: newHuffmanEncoder(offsetCodeCount),
codegenEncoding: newHuffmanEncoder(codegenCodeCount), codegenEncoding: newHuffmanEncoder(codegenCodeCount),
} }
} }
@ -185,7 +182,7 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
_, w.err = w.w.Write(bytes) _, w.err = w.w.Write(bytes)
} }
// RFC 1951 3.2.7 specifies a special run-length encoding for specifiying // RFC 1951 3.2.7 specifies a special run-length encoding for specifying
// the literal and offset lengths arrays (which are concatenated into a single // the literal and offset lengths arrays (which are concatenated into a single
// array). This method generates that run-length encoding. // array). This method generates that run-length encoding.
// //
@ -279,7 +276,7 @@ func (w *huffmanBitWriter) writeCode(code *huffmanEncoder, literal uint32) {
// //
// numLiterals The number of literals specified in codegen // numLiterals The number of literals specified in codegen
// numOffsets The number of offsets specified in codegen // numOffsets The number of offsets specified in codegen
// numCodegens Tne number of codegens used in codegen // numCodegens The number of codegens used in codegen
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) { func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
if w.err != nil { if w.err != nil {
return return
@ -290,13 +287,7 @@ func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, n
} }
w.writeBits(firstBits, 3) w.writeBits(firstBits, 3)
w.writeBits(int32(numLiterals-257), 5) w.writeBits(int32(numLiterals-257), 5)
if numOffsets > offsetCodeCount { w.writeBits(int32(numOffsets-1), 5)
// Extended version of decompressor
w.writeBits(int32(offsetCodeCount+((numOffsets-(1+offsetCodeCount))>>3)), 5)
w.writeBits(int32((numOffsets-(1+offsetCodeCount))&0x7), 3)
} else {
w.writeBits(int32(numOffsets-1), 5)
}
w.writeBits(int32(numCodegens-4), 4) w.writeBits(int32(numCodegens-4), 4)
for i := 0; i < numCodegens; i++ { for i := 0; i < numCodegens; i++ {
@ -368,24 +359,17 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
tokens = tokens[0 : n+1] tokens = tokens[0 : n+1]
tokens[n] = endBlockMarker tokens[n] = endBlockMarker
totalLength := -1 // Subtract 1 for endBlock.
for _, t := range tokens { for _, t := range tokens {
switch t.typ() { switch t.typ() {
case literalType: case literalType:
w.literalFreq[t.literal()]++ w.literalFreq[t.literal()]++
totalLength++
break
case matchType: case matchType:
length := t.length() length := t.length()
offset := t.offset() offset := t.offset()
totalLength += int(length + 3)
w.literalFreq[lengthCodesStart+lengthCode(length)]++ w.literalFreq[lengthCodesStart+lengthCode(length)]++
w.offsetFreq[offsetCode(offset)]++ w.offsetFreq[offsetCode(offset)]++
break
} }
} }
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
// get the number of literals // get the number of literals
numLiterals := len(w.literalFreq) numLiterals := len(w.literalFreq)
@ -394,15 +378,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
} }
// get the number of offsets // get the number of offsets
numOffsets := len(w.offsetFreq) numOffsets := len(w.offsetFreq)
for numOffsets > 1 && w.offsetFreq[numOffsets-1] == 0 { for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
numOffsets-- numOffsets--
} }
if numOffsets == 0 {
// We haven't found a single match. If we want to go with the dynamic encoding,
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
w.offsetFreq[0] = 1
numOffsets = 1
}
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
storedBytes := 0 storedBytes := 0
if input != nil { if input != nil {
storedBytes = len(input) storedBytes = len(input)
} }
var extraBits int64 var extraBits int64
var storedSize int64 var storedSize int64 = math.MaxInt64
if storedBytes <= maxStoreBlockSize && input != nil { if storedBytes <= maxStoreBlockSize && input != nil {
storedSize = int64((storedBytes + 5) * 8) storedSize = int64((storedBytes + 5) * 8)
// We only bother calculating the costs of the extra bits required by // We only bother calculating the costs of the extra bits required by
@ -417,34 +411,29 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
// First four offset codes have extra size = 0. // First four offset codes have extra size = 0.
extraBits += int64(w.offsetFreq[offsetCode]) * int64(offsetExtraBits[offsetCode]) extraBits += int64(w.offsetFreq[offsetCode]) * int64(offsetExtraBits[offsetCode])
} }
} else {
storedSize = math.MaxInt32
} }
// Figure out which generates smaller code, fixed Huffman, dynamic // Figure out smallest code.
// Huffman, or just storing the data. // Fixed Huffman baseline.
var fixedSize int64 = math.MaxInt64 var size = int64(3) +
if numOffsets <= offsetCodeCount { fixedLiteralEncoding.bitLength(w.literalFreq) +
fixedSize = int64(3) + fixedOffsetEncoding.bitLength(w.offsetFreq) +
fixedLiteralEncoding.bitLength(w.literalFreq) + extraBits
fixedOffsetEncoding.bitLength(w.offsetFreq) + var literalEncoding = fixedLiteralEncoding
extraBits var offsetEncoding = fixedOffsetEncoding
}
// Dynamic Huffman?
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode // Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding. // the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets) w.generateCodegen(numLiterals, numOffsets)
w.codegenEncoding.generate(w.codegenFreq, 7) w.codegenEncoding.generate(w.codegenFreq, 7)
numCodegens := len(w.codegenFreq) numCodegens = len(w.codegenFreq)
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
numCodegens-- numCodegens--
} }
extensionSummand := 0
if numOffsets > offsetCodeCount {
extensionSummand = 3
}
dynamicHeader := int64(3+5+5+4+(3*numCodegens)) + dynamicHeader := int64(3+5+5+4+(3*numCodegens)) +
// Following line is an extension.
int64(extensionSummand) +
w.codegenEncoding.bitLength(w.codegenFreq) + w.codegenEncoding.bitLength(w.codegenFreq) +
int64(extraBits) + int64(extraBits) +
int64(w.codegenFreq[16]*2) + int64(w.codegenFreq[16]*2) +
@ -454,26 +443,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
w.literalEncoding.bitLength(w.literalFreq) + w.literalEncoding.bitLength(w.literalFreq) +
w.offsetEncoding.bitLength(w.offsetFreq) w.offsetEncoding.bitLength(w.offsetFreq)
if storedSize < fixedSize && storedSize < dynamicSize { if dynamicSize < size {
w.writeStoredHeader(storedBytes, eof) size = dynamicSize
w.writeBytes(input[0:storedBytes])
return
}
var literalEncoding *huffmanEncoder
var offsetEncoding *huffmanEncoder
if fixedSize <= dynamicSize {
w.writeFixedHeader(eof)
literalEncoding = fixedLiteralEncoding
offsetEncoding = fixedOffsetEncoding
} else {
// Write the header.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
literalEncoding = w.literalEncoding literalEncoding = w.literalEncoding
offsetEncoding = w.offsetEncoding offsetEncoding = w.offsetEncoding
} }
// Write the tokens. // Stored bytes?
if storedSize < size {
w.writeStoredHeader(storedBytes, eof)
w.writeBytes(input[0:storedBytes])
return
}
// Huffman.
if literalEncoding == fixedLiteralEncoding {
w.writeFixedHeader(eof)
} else {
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
}
for _, t := range tokens { for _, t := range tokens {
switch t.typ() { switch t.typ() {
case literalType: case literalType:

View File

@ -363,7 +363,12 @@ func (s literalNodeSorter) Less(i, j int) bool {
func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] } func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] }
func sortByFreq(a []literalNode) { func sortByFreq(a []literalNode) {
s := &literalNodeSorter{a, func(i, j int) bool { return a[i].freq < a[j].freq }} s := &literalNodeSorter{a, func(i, j int) bool {
if a[i].freq == a[j].freq {
return a[i].literal < a[j].literal
}
return a[i].freq < a[j].freq
}}
sort.Sort(s) sort.Sort(s)
} }

View File

@ -77,8 +77,6 @@ type huffmanDecoder struct {
// Initialize Huffman decoding tables from array of code lengths. // Initialize Huffman decoding tables from array of code lengths.
func (h *huffmanDecoder) init(bits []int) bool { func (h *huffmanDecoder) init(bits []int) bool {
// TODO(rsc): Return false sometimes.
// Count number of codes of each length, // Count number of codes of each length,
// compute min and max length. // compute min and max length.
var count [maxCodeLen + 1]int var count [maxCodeLen + 1]int
@ -197,9 +195,8 @@ type Reader interface {
// Decompress state. // Decompress state.
type decompressor struct { type decompressor struct {
// Input/output sources. // Input source.
r Reader r Reader
w io.Writer
roffset int64 roffset int64
woffset int64 woffset int64
@ -222,38 +219,79 @@ type decompressor struct {
// Temporary buffer (avoids repeated allocation). // Temporary buffer (avoids repeated allocation).
buf [4]byte buf [4]byte
// Next step in the decompression,
// and decompression state.
step func(*decompressor)
final bool
err os.Error
toRead []byte
hl, hd *huffmanDecoder
copyLen int
copyDist int
} }
func (f *decompressor) inflate() (err os.Error) { func (f *decompressor) nextBlock() {
final := false if f.final {
for err == nil && !final { if f.hw != f.hp {
for f.nb < 1+2 { f.flush((*decompressor).nextBlock)
if err = f.moreBits(); err != nil { return
return
}
} }
final = f.b&1 == 1 f.err = os.EOF
f.b >>= 1 return
typ := f.b & 3 }
f.b >>= 2 for f.nb < 1+2 {
f.nb -= 1 + 2 if f.err = f.moreBits(); f.err != nil {
switch typ { return
case 0:
err = f.dataBlock()
case 1:
// compressed, fixed Huffman tables
err = f.decodeBlock(&fixedHuffmanDecoder, nil)
case 2:
// compressed, dynamic Huffman tables
if err = f.readHuffman(); err == nil {
err = f.decodeBlock(&f.h1, &f.h2)
}
default:
// 3 is reserved.
err = CorruptInputError(f.roffset)
} }
} }
return f.final = f.b&1 == 1
f.b >>= 1
typ := f.b & 3
f.b >>= 2
f.nb -= 1 + 2
switch typ {
case 0:
f.dataBlock()
case 1:
// compressed, fixed Huffman tables
f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2:
// compressed, dynamic Huffman tables
if f.err = f.readHuffman(); f.err != nil {
break
}
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default:
// 3 is reserved.
f.err = CorruptInputError(f.roffset)
}
}
func (f *decompressor) Read(b []byte) (int, os.Error) {
for {
if len(f.toRead) > 0 {
n := copy(b, f.toRead)
f.toRead = f.toRead[n:]
return n, nil
}
if f.err != nil {
return 0, f.err
}
f.step(f)
}
panic("unreachable")
}
func (f *decompressor) Close() os.Error {
if f.err == os.EOF {
return nil
}
return f.err
} }
// RFC 1951 section 3.2.7. // RFC 1951 section 3.2.7.
@ -358,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error {
// hl and hd are the Huffman states for the lit/length values // hl and hd are the Huffman states for the lit/length values
// and the distance values, respectively. If hd == nil, using the // and the distance values, respectively. If hd == nil, using the
// fixed distance encoding associated with fixed Huffman blocks. // fixed distance encoding associated with fixed Huffman blocks.
func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { func (f *decompressor) huffmanBlock() {
for { for {
v, err := f.huffSym(hl) v, err := f.huffSym(f.hl)
if err != nil { if err != nil {
return err f.err = err
return
} }
var n uint // number of bits extra var n uint // number of bits extra
var length int var length int
@ -371,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hist[f.hp] = byte(v) f.hist[f.hp] = byte(v)
f.hp++ f.hp++
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { // After the flush, continue this loop.
return err f.flush((*decompressor).huffmanBlock)
} return
} }
continue continue
case v == 256: case v == 256:
return nil // Done with huffman block; read next block.
f.step = (*decompressor).nextBlock
return
// otherwise, reference to older data // otherwise, reference to older data
case v < 265: case v < 265:
length = v - (257 - 3) length = v - (257 - 3)
@ -404,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
if n > 0 { if n > 0 {
for f.nb < n { for f.nb < n {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
length += int(f.b & uint32(1<<n-1)) length += int(f.b & uint32(1<<n-1))
@ -413,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
} }
var dist int var dist int
if hd == nil { if f.hd == nil {
for f.nb < 5 { for f.nb < 5 {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
dist = int(reverseByte[(f.b&0x1F)<<3]) dist = int(reverseByte[(f.b&0x1F)<<3])
f.b >>= 5 f.b >>= 5
f.nb -= 5 f.nb -= 5
} else { } else {
if dist, err = f.huffSym(hd); err != nil { if dist, err = f.huffSym(f.hd); err != nil {
return err f.err = err
return
} }
} }
@ -432,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
case dist < 4: case dist < 4:
dist++ dist++
case dist >= 30: case dist >= 30:
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
default: default:
nb := uint(dist-2) >> 1 nb := uint(dist-2) >> 1
// have 1 bit in bottom of dist, need nb more. // have 1 bit in bottom of dist, need nb more.
extra := (dist & 1) << nb extra := (dist & 1) << nb
for f.nb < nb { for f.nb < nb {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
extra |= int(f.b & uint32(1<<nb-1)) extra |= int(f.b & uint32(1<<nb-1))
@ -450,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
// Copy history[-dist:-dist+length] into output. // Copy history[-dist:-dist+length] into output.
if dist > len(f.hist) { if dist > len(f.hist) {
return InternalError("bad history distance") f.err = InternalError("bad history distance")
return
} }
// No check on length; encoding can be prescient. // No check on length; encoding can be prescient.
if !f.hfull && dist > f.hp { if !f.hfull && dist > f.hp {
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
} }
p := f.hp - dist p := f.hp - dist
@ -467,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hp++ f.hp++
p++ p++
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { // After flush continue copying out of history.
return err f.copyLen = length - (i + 1)
} f.copyDist = dist
f.flush((*decompressor).copyHuff)
return
} }
if p == len(f.hist) { if p == len(f.hist) {
p = 0 p = 0
@ -479,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
panic("unreached") panic("unreached")
} }
func (f *decompressor) copyHuff() {
length := f.copyLen
dist := f.copyDist
p := f.hp - dist
if p < 0 {
p += len(f.hist)
}
for i := 0; i < length; i++ {
f.hist[f.hp] = f.hist[p]
f.hp++
p++
if f.hp == len(f.hist) {
f.copyLen = length - (i + 1)
f.flush((*decompressor).copyHuff)
return
}
if p == len(f.hist) {
p = 0
}
}
// Continue processing Huffman block.
f.huffmanBlock()
}
// Copy a single uncompressed data block from input to output. // Copy a single uncompressed data block from input to output.
func (f *decompressor) dataBlock() os.Error { func (f *decompressor) dataBlock() {
// Uncompressed. // Uncompressed.
// Discard current half-byte. // Discard current half-byte.
f.nb = 0 f.nb = 0
@ -490,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error {
nr, err := io.ReadFull(f.r, f.buf[0:4]) nr, err := io.ReadFull(f.r, f.buf[0:4])
f.roffset += int64(nr) f.roffset += int64(nr)
if err != nil { if err != nil {
return &ReadError{f.roffset, err} f.err = &ReadError{f.roffset, err}
return
} }
n := int(f.buf[0]) | int(f.buf[1])<<8 n := int(f.buf[0]) | int(f.buf[1])<<8
nn := int(f.buf[2]) | int(f.buf[3])<<8 nn := int(f.buf[2]) | int(f.buf[3])<<8
if uint16(nn) != uint16(^n) { if uint16(nn) != uint16(^n) {
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
} }
if n == 0 { if n == 0 {
// 0-length block means sync // 0-length block means sync
return f.flush() f.flush((*decompressor).nextBlock)
return
} }
// Read len bytes into history, f.copyLen = n
// writing as history fills. f.copyData()
}
func (f *decompressor) copyData() {
// Read f.dataLen bytes into history,
// pausing for reads as history fills.
n := f.copyLen
for n > 0 { for n > 0 {
m := len(f.hist) - f.hp m := len(f.hist) - f.hp
if m > n { if m > n {
@ -513,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error {
m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m]) m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m])
f.roffset += int64(m) f.roffset += int64(m)
if err != nil { if err != nil {
return &ReadError{f.roffset, err} f.err = &ReadError{f.roffset, err}
return
} }
n -= m n -= m
f.hp += m f.hp += m
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { f.copyLen = n
return err f.flush((*decompressor).copyData)
} return
} }
} }
return nil f.step = (*decompressor).nextBlock
} }
func (f *decompressor) setDict(dict []byte) { func (f *decompressor) setDict(dict []byte) {
@ -579,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
} }
// Flush any buffered output to the underlying writer. // Flush any buffered output to the underlying writer.
func (f *decompressor) flush() os.Error { func (f *decompressor) flush(step func(*decompressor)) {
if f.hw == f.hp { f.toRead = f.hist[f.hw:f.hp]
return nil
}
n, err := f.w.Write(f.hist[f.hw:f.hp])
if n != f.hp-f.hw && err == nil {
err = io.ErrShortWrite
}
if err != nil {
return &WriteError{f.woffset, err}
}
f.woffset += int64(f.hp - f.hw) f.woffset += int64(f.hp - f.hw)
f.hw = f.hp f.hw = f.hp
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
@ -597,7 +673,7 @@ func (f *decompressor) flush() os.Error {
f.hw = 0 f.hw = 0
f.hfull = true f.hfull = true
} }
return nil f.step = step
} }
func makeReader(r io.Reader) Reader { func makeReader(r io.Reader) Reader {
@ -607,30 +683,15 @@ func makeReader(r io.Reader) Reader {
return bufio.NewReader(r) return bufio.NewReader(r)
} }
// decompress reads DEFLATE-compressed data from r and writes
// the uncompressed data to w.
func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
f.r = makeReader(r)
f.w = w
f.woffset = 0
if err := f.inflate(); err != nil {
return err
}
if err := f.flush(); err != nil {
return err
}
return nil
}
// NewReader returns a new ReadCloser that can be used // NewReader returns a new ReadCloser that can be used
// to read the uncompressed version of r. It is the caller's // to read the uncompressed version of r. It is the caller's
// responsibility to call Close on the ReadCloser when // responsibility to call Close on the ReadCloser when
// finished reading. // finished reading.
func NewReader(r io.Reader) io.ReadCloser { func NewReader(r io.Reader) io.ReadCloser {
var f decompressor var f decompressor
pr, pw := io.Pipe() f.r = makeReader(r)
go func() { pw.CloseWithError(f.decompress(r, pw)) }() f.step = (*decompressor).nextBlock
return pr return &f
} }
// NewReaderDict is like NewReader but initializes the reader // NewReaderDict is like NewReader but initializes the reader
@ -641,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser {
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var f decompressor var f decompressor
f.setDict(dict) f.setDict(dict)
pr, pw := io.Pipe() f.r = makeReader(r)
go func() { pw.CloseWithError(f.decompress(r, pw)) }() f.step = (*decompressor).nextBlock
return pr return &f
} }

View File

@ -36,8 +36,8 @@ func makeReader(r io.Reader) flate.Reader {
return bufio.NewReader(r) return bufio.NewReader(r)
} }
var HeaderError os.Error = os.ErrorString("invalid gzip header") var HeaderError = os.NewError("invalid gzip header")
var ChecksumError os.Error = os.ErrorString("gzip checksum error") var ChecksumError = os.NewError("gzip checksum error")
// The gzip file stores a header giving metadata about the compressed file. // The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the Compressor and Decompressor structs. // That header is exposed as the fields of the Compressor and Decompressor structs.

View File

@ -11,7 +11,7 @@ import (
) )
// pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the // pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the
// writer end and ifunc at the reader end. // writer end and cfunc at the reader end.
func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) { func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) {
piper, pipew := io.Pipe() piper, pipew := io.Pipe()
defer piper.Close() defer piper.Close()

View File

@ -32,13 +32,49 @@ const (
MSB MSB
) )
const (
maxWidth = 12
decoderInvalidCode = 0xffff
flushBuffer = 1 << maxWidth
)
// decoder is the state from which the readXxx method converts a byte // decoder is the state from which the readXxx method converts a byte
// stream into a code stream. // stream into a code stream.
type decoder struct { type decoder struct {
r io.ByteReader r io.ByteReader
bits uint32 bits uint32
nBits uint nBits uint
width uint width uint
read func(*decoder) (uint16, os.Error) // readLSB or readMSB
litWidth int // width in bits of literal codes
err os.Error
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
// overflow is the code at which hi overflows the code width.
// last is the most recently seen code, or decoderInvalidCode.
clear, eof, hi, overflow, last uint16
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// output is the temporary output buffer.
// Literal codes are accumulated from the start of the buffer.
// Non-literal codes decode to a sequence of suffixes that are first
// written right-to-left from the end of the buffer before being copied
// to the start of the buffer.
// It is flushed when it contains >= 1<<maxWidth bytes,
// so that there is always room to decode an entire code.
output [2 * 1 << maxWidth]byte
o int // write index into output
toRead []byte // bytes to return from Read
} }
// readLSB returns the next code for "Least Significant Bits first" data. // readLSB returns the next code for "Least Significant Bits first" data.
@ -73,119 +109,113 @@ func (d *decoder) readMSB() (uint16, os.Error) {
return code, nil return code, nil
} }
// decode decompresses bytes from r and writes them to pw. func (d *decoder) Read(b []byte) (int, os.Error) {
// read specifies how to decode bytes into codes. for {
// litWidth is the width in bits of literal codes. if len(d.toRead) > 0 {
func decode(r io.Reader, read func(*decoder) (uint16, os.Error), litWidth int, pw *io.PipeWriter) { n := copy(b, d.toRead)
br, ok := r.(io.ByteReader) d.toRead = d.toRead[n:]
if !ok { return n, nil
br = bufio.NewReader(r) }
if d.err != nil {
return 0, d.err
}
d.decode()
} }
pw.CloseWithError(decode1(pw, br, read, uint(litWidth))) panic("unreachable")
} }
func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os.Error), litWidth uint) os.Error { // decode decompresses bytes from r and leaves them in d.toRead.
const ( // read specifies how to decode bytes into codes.
maxWidth = 12 // litWidth is the width in bits of literal codes.
invalidCode = 0xffff func (d *decoder) decode() {
)
d := decoder{r, 0, 0, 1 + litWidth}
w := bufio.NewWriter(pw)
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
clear := uint16(1) << litWidth
eof, hi := clear+1, clear+1
// overflow is the code at which hi overflows the code width.
overflow := uint16(1) << d.width
var (
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// buf is a scratch buffer for reconstituting the bytes that a code expands to.
// Code suffixes are written right-to-left from the end of the buffer.
buf [1 << maxWidth]byte
)
// Loop over the code stream, converting codes into decompressed bytes. // Loop over the code stream, converting codes into decompressed bytes.
last := uint16(invalidCode)
for { for {
code, err := read(&d) code, err := d.read(d)
if err != nil { if err != nil {
if err == os.EOF { if err == os.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
return err d.err = err
return
} }
switch { switch {
case code < clear: case code < d.clear:
// We have a literal code. // We have a literal code.
if err := w.WriteByte(uint8(code)); err != nil { d.output[d.o] = uint8(code)
return err d.o++
} if d.last != decoderInvalidCode {
if last != invalidCode {
// Save what the hi code expands to. // Save what the hi code expands to.
suffix[hi] = uint8(code) d.suffix[d.hi] = uint8(code)
prefix[hi] = last d.prefix[d.hi] = d.last
} }
case code == clear: case code == d.clear:
d.width = 1 + litWidth d.width = 1 + uint(d.litWidth)
hi = eof d.hi = d.eof
overflow = 1 << d.width d.overflow = 1 << d.width
last = invalidCode d.last = decoderInvalidCode
continue continue
case code == eof: case code == d.eof:
return w.Flush() d.flush()
case code <= hi: d.err = os.EOF
c, i := code, len(buf)-1 return
if code == hi { case code <= d.hi:
c, i := code, len(d.output)-1
if code == d.hi {
// code == hi is a special case which expands to the last expansion // code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk // followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code. // the prefix chain until we find a literal code.
c = last c = d.last
for c >= clear { for c >= d.clear {
c = prefix[c] c = d.prefix[c]
} }
buf[i] = uint8(c) d.output[i] = uint8(c)
i-- i--
c = last c = d.last
} }
// Copy the suffix chain into buf and then write that to w. // Copy the suffix chain into output and then write that to w.
for c >= clear { for c >= d.clear {
buf[i] = suffix[c] d.output[i] = d.suffix[c]
i-- i--
c = prefix[c] c = d.prefix[c]
} }
buf[i] = uint8(c) d.output[i] = uint8(c)
if _, err := w.Write(buf[i:]); err != nil { d.o += copy(d.output[d.o:], d.output[i:])
return err if d.last != decoderInvalidCode {
}
if last != invalidCode {
// Save what the hi code expands to. // Save what the hi code expands to.
suffix[hi] = uint8(c) d.suffix[d.hi] = uint8(c)
prefix[hi] = last d.prefix[d.hi] = d.last
} }
default: default:
return os.NewError("lzw: invalid code") d.err = os.NewError("lzw: invalid code")
return
} }
last, hi = code, hi+1 d.last, d.hi = code, d.hi+1
if hi >= overflow { if d.hi >= d.overflow {
if d.width == maxWidth { if d.width == maxWidth {
last = invalidCode d.last = decoderInvalidCode
continue } else {
d.width++
d.overflow <<= 1
} }
d.width++ }
overflow <<= 1 if d.o >= flushBuffer {
d.flush()
return
} }
} }
panic("unreachable") panic("unreachable")
} }
func (d *decoder) flush() {
d.toRead = d.output[:d.o]
d.o = 0
}
func (d *decoder) Close() os.Error {
d.err = os.EINVAL // in case any Reads come along
return nil
}
// NewReader creates a new io.ReadCloser that satisfies reads by decompressing // NewReader creates a new io.ReadCloser that satisfies reads by decompressing
// the data read from r. // the data read from r.
// It is the caller's responsibility to call Close on the ReadCloser when // It is the caller's responsibility to call Close on the ReadCloser when
@ -193,21 +223,31 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os
// The number of bits to use for literal codes, litWidth, must be in the // The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. // range [2,8] and is typically 8.
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser { func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
pr, pw := io.Pipe() d := new(decoder)
var read func(*decoder) (uint16, os.Error)
switch order { switch order {
case LSB: case LSB:
read = (*decoder).readLSB d.read = (*decoder).readLSB
case MSB: case MSB:
read = (*decoder).readMSB d.read = (*decoder).readMSB
default: default:
pw.CloseWithError(os.NewError("lzw: unknown order")) d.err = os.NewError("lzw: unknown order")
return pr return d
} }
if litWidth < 2 || 8 < litWidth { if litWidth < 2 || 8 < litWidth {
pw.CloseWithError(fmt.Errorf("lzw: litWidth %d out of range", litWidth)) d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return pr return d
} }
go decode(r, read, litWidth, pw) if br, ok := r.(io.ByteReader); ok {
return pr d.r = br
} else {
d.r = bufio.NewReader(r)
}
d.litWidth = litWidth
d.width = 1 + uint(litWidth)
d.clear = uint16(1) << uint(litWidth)
d.eof, d.hi = d.clear+1, d.clear+1
d.overflow = uint16(1) << d.width
d.last = decoderInvalidCode
return d
} }

View File

@ -84,7 +84,7 @@ var lzwTests = []lzwTest{
func TestReader(t *testing.T) { func TestReader(t *testing.T) {
b := bytes.NewBuffer(nil) b := bytes.NewBuffer(nil)
for _, tt := range lzwTests { for _, tt := range lzwTests {
d := strings.Split(tt.desc, ";", -1) d := strings.Split(tt.desc, ";")
var order Order var order Order
switch d[1] { switch d[1] {
case "LSB": case "LSB":

View File

@ -77,13 +77,13 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1) t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return return
} }
if len(b0) != len(b1) { if len(b1) != len(b0) {
t.Errorf("%s (order=%d litWidth=%d): length mismatch %d versus %d", fn, order, litWidth, len(b0), len(b1)) t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
return return
} }
for i := 0; i < len(b0); i++ { for i := 0; i < len(b0); i++ {
if b0[i] != b1[i] { if b1[i] != b0[i] {
t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, order, litWidth, i, b0[i], b1[i]) t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
return return
} }
} }

View File

@ -34,9 +34,9 @@ import (
const zlibDeflate = 8 const zlibDeflate = 8
var ChecksumError os.Error = os.ErrorString("zlib checksum error") var ChecksumError = os.NewError("zlib checksum error")
var HeaderError os.Error = os.ErrorString("invalid zlib header") var HeaderError = os.NewError("invalid zlib header")
var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary") var DictionaryError = os.NewError("invalid zlib dictionary")
type reader struct { type reader struct {
r flate.Reader r flate.Reader

View File

@ -89,7 +89,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) {
} }
} }
z.w = w z.w = w
z.compressor = flate.NewWriter(w, level) z.compressor = flate.NewWriterDict(w, level, dict)
z.digest = adler32.New() z.digest = adler32.New()
return z, nil return z, nil
} }

View File

@ -5,6 +5,8 @@
package zlib package zlib
import ( import (
"bytes"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
@ -16,15 +18,13 @@ var filenames = []string{
"../testdata/pi.txt", "../testdata/pi.txt",
} }
var data = []string{
"test a reasonable sized string that can be compressed",
}
// Tests that compressing and then decompressing the given file at the given compression level and dictionary // Tests that compressing and then decompressing the given file at the given compression level and dictionary
// yields equivalent bytes to the original file. // yields equivalent bytes to the original file.
func testFileLevelDict(t *testing.T, fn string, level int, d string) { func testFileLevelDict(t *testing.T, fn string, level int, d string) {
// Read dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Read the file, as golden output. // Read the file, as golden output.
golden, err := os.Open(fn) golden, err := os.Open(fn)
if err != nil { if err != nil {
@ -32,17 +32,25 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
return return
} }
defer golden.Close() defer golden.Close()
b0, err0 := ioutil.ReadAll(golden)
// Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end. if err0 != nil {
raw, err := os.Open(fn) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
} }
testLevelDict(t, fn, b0, level, d)
}
func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
// Make dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Push data through a pipe that compresses at the write end, and decompresses at the read end.
piper, pipew := io.Pipe() piper, pipew := io.Pipe()
defer piper.Close() defer piper.Close()
go func() { go func() {
defer raw.Close()
defer pipew.Close() defer pipew.Close()
zlibw, err := NewWriterDict(pipew, level, dict) zlibw, err := NewWriterDict(pipew, level, dict)
if err != nil { if err != nil {
@ -50,25 +58,14 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
return return
} }
defer zlibw.Close() defer zlibw.Close()
var b [1024]byte _, err = zlibw.Write(b0)
for { if err == os.EPIPE {
n, err0 := raw.Read(b[0:]) // Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
if err0 != nil && err0 != os.EOF { return
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) }
return if err != nil {
} t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
_, err1 := zlibw.Write(b[0:n]) return
if err1 == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
return
}
if err1 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return
}
if err0 == os.EOF {
break
}
} }
}() }()
zlibr, err := NewReaderDict(piper, dict) zlibr, err := NewReaderDict(piper, dict)
@ -78,13 +75,8 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
} }
defer zlibr.Close() defer zlibr.Close()
// Compare the two. // Compare the decompressed data.
b0, err0 := ioutil.ReadAll(golden)
b1, err1 := ioutil.ReadAll(zlibr) b1, err1 := ioutil.ReadAll(zlibr)
if err0 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return
}
if err1 != nil { if err1 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return return
@ -102,6 +94,18 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
} }
func TestWriter(t *testing.T) { func TestWriter(t *testing.T) {
for i, s := range data {
b := []byte(s)
tag := fmt.Sprintf("#%d", i)
testLevelDict(t, tag, b, DefaultCompression, "")
testLevelDict(t, tag, b, NoCompression, "")
for level := BestSpeed; level <= BestCompression; level++ {
testLevelDict(t, tag, b, level, "")
}
}
}
func TestWriterBig(t *testing.T) {
for _, fn := range filenames { for _, fn := range filenames {
testFileLevelDict(t, fn, DefaultCompression, "") testFileLevelDict(t, fn, DefaultCompression, "")
testFileLevelDict(t, fn, NoCompression, "") testFileLevelDict(t, fn, NoCompression, "")
@ -121,3 +125,20 @@ func TestWriterDict(t *testing.T) {
} }
} }
} }
func TestWriterDictIsUsed(t *testing.T) {
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
buf := bytes.NewBuffer(nil)
compressor, err := NewWriterDict(buf, BestCompression, input)
if err != nil {
t.Errorf("error in NewWriterDict: %s", err)
return
}
compressor.Write(input)
compressor.Close()
const expectedMaxSize = 25
output := buf.Bytes()
if len(output) > expectedMaxSize {
t.Errorf("result too large (got %d, want <= %d bytes). Is the dictionary being used?", len(output), expectedMaxSize)
}
}

View File

@ -21,8 +21,7 @@ type Interface interface {
Pop() interface{} Pop() interface{}
} }
// A heap must be initialized before any of the heap operations
// A heaper must be initialized before any of the heap operations
// can be used. Init is idempotent with respect to the heap invariants // can be used. Init is idempotent with respect to the heap invariants
// and may be called whenever the heap invariants may have been invalidated. // and may be called whenever the heap invariants may have been invalidated.
// Its complexity is O(n) where n = h.Len(). // Its complexity is O(n) where n = h.Len().
@ -35,7 +34,6 @@ func Init(h Interface) {
} }
} }
// Push pushes the element x onto the heap. The complexity is // Push pushes the element x onto the heap. The complexity is
// O(log(n)) where n = h.Len(). // O(log(n)) where n = h.Len().
// //
@ -44,7 +42,6 @@ func Push(h Interface, x interface{}) {
up(h, h.Len()-1) up(h, h.Len()-1)
} }
// Pop removes the minimum element (according to Less) from the heap // Pop removes the minimum element (according to Less) from the heap
// and returns it. The complexity is O(log(n)) where n = h.Len(). // and returns it. The complexity is O(log(n)) where n = h.Len().
// Same as Remove(h, 0). // Same as Remove(h, 0).
@ -56,7 +53,6 @@ func Pop(h Interface) interface{} {
return h.Pop() return h.Pop()
} }
// Remove removes the element at index i from the heap. // Remove removes the element at index i from the heap.
// The complexity is O(log(n)) where n = h.Len(). // The complexity is O(log(n)) where n = h.Len().
// //
@ -70,7 +66,6 @@ func Remove(h Interface, i int) interface{} {
return h.Pop() return h.Pop()
} }
func up(h Interface, j int) { func up(h Interface, j int) {
for { for {
i := (j - 1) / 2 // parent i := (j - 1) / 2 // parent
@ -82,7 +77,6 @@ func up(h Interface, j int) {
} }
} }
func down(h Interface, i, n int) { func down(h Interface, i, n int) {
for { for {
j1 := 2*i + 1 j1 := 2*i + 1

View File

@ -10,17 +10,14 @@ import (
. "container/heap" . "container/heap"
) )
type myHeap struct { type myHeap struct {
// A vector.Vector implements sort.Interface except for Less, // A vector.Vector implements sort.Interface except for Less,
// and it implements Push and Pop as required for heap.Interface. // and it implements Push and Pop as required for heap.Interface.
vector.Vector vector.Vector
} }
func (h *myHeap) Less(i, j int) bool { return h.At(i).(int) < h.At(j).(int) } func (h *myHeap) Less(i, j int) bool { return h.At(i).(int) < h.At(j).(int) }
func (h *myHeap) verify(t *testing.T, i int) { func (h *myHeap) verify(t *testing.T, i int) {
n := h.Len() n := h.Len()
j1 := 2*i + 1 j1 := 2*i + 1
@ -41,7 +38,6 @@ func (h *myHeap) verify(t *testing.T, i int) {
} }
} }
func TestInit0(t *testing.T) { func TestInit0(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 20; i > 0; i-- { for i := 20; i > 0; i-- {
@ -59,7 +55,6 @@ func TestInit0(t *testing.T) {
} }
} }
func TestInit1(t *testing.T) { func TestInit1(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 20; i > 0; i-- { for i := 20; i > 0; i-- {
@ -77,7 +72,6 @@ func TestInit1(t *testing.T) {
} }
} }
func Test(t *testing.T) { func Test(t *testing.T) {
h := new(myHeap) h := new(myHeap)
h.verify(t, 0) h.verify(t, 0)
@ -105,7 +99,6 @@ func Test(t *testing.T) {
} }
} }
func TestRemove0(t *testing.T) { func TestRemove0(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -123,7 +116,6 @@ func TestRemove0(t *testing.T) {
} }
} }
func TestRemove1(t *testing.T) { func TestRemove1(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -140,7 +132,6 @@ func TestRemove1(t *testing.T) {
} }
} }
func TestRemove2(t *testing.T) { func TestRemove2(t *testing.T) {
N := 10 N := 10

View File

@ -16,14 +16,12 @@ type Ring struct {
Value interface{} // for use by client; untouched by this library Value interface{} // for use by client; untouched by this library
} }
func (r *Ring) init() *Ring { func (r *Ring) init() *Ring {
r.next = r r.next = r
r.prev = r r.prev = r
return r return r
} }
// Next returns the next ring element. r must not be empty. // Next returns the next ring element. r must not be empty.
func (r *Ring) Next() *Ring { func (r *Ring) Next() *Ring {
if r.next == nil { if r.next == nil {
@ -32,7 +30,6 @@ func (r *Ring) Next() *Ring {
return r.next return r.next
} }
// Prev returns the previous ring element. r must not be empty. // Prev returns the previous ring element. r must not be empty.
func (r *Ring) Prev() *Ring { func (r *Ring) Prev() *Ring {
if r.next == nil { if r.next == nil {
@ -41,7 +38,6 @@ func (r *Ring) Prev() *Ring {
return r.prev return r.prev
} }
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0) // Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
// in the ring and returns that ring element. r must not be empty. // in the ring and returns that ring element. r must not be empty.
// //
@ -62,7 +58,6 @@ func (r *Ring) Move(n int) *Ring {
return r return r
} }
// New creates a ring of n elements. // New creates a ring of n elements.
func New(n int) *Ring { func New(n int) *Ring {
if n <= 0 { if n <= 0 {
@ -79,7 +74,6 @@ func New(n int) *Ring {
return r return r
} }
// Link connects ring r with with ring s such that r.Next() // Link connects ring r with with ring s such that r.Next()
// becomes s and returns the original value for r.Next(). // becomes s and returns the original value for r.Next().
// r must not be empty. // r must not be empty.
@ -110,7 +104,6 @@ func (r *Ring) Link(s *Ring) *Ring {
return n return n
} }
// Unlink removes n % r.Len() elements from the ring r, starting // Unlink removes n % r.Len() elements from the ring r, starting
// at r.Next(). If n % r.Len() == 0, r remains unchanged. // at r.Next(). If n % r.Len() == 0, r remains unchanged.
// The result is the removed subring. r must not be empty. // The result is the removed subring. r must not be empty.
@ -122,7 +115,6 @@ func (r *Ring) Unlink(n int) *Ring {
return r.Link(r.Move(n + 1)) return r.Link(r.Move(n + 1))
} }
// Len computes the number of elements in ring r. // Len computes the number of elements in ring r.
// It executes in time proportional to the number of elements. // It executes in time proportional to the number of elements.
// //
@ -137,7 +129,6 @@ func (r *Ring) Len() int {
return n return n
} }
// Do calls function f on each element of the ring, in forward order. // Do calls function f on each element of the ring, in forward order.
// The behavior of Do is undefined if f changes *r. // The behavior of Do is undefined if f changes *r.
func (r *Ring) Do(f func(interface{})) { func (r *Ring) Do(f func(interface{})) {

View File

@ -9,7 +9,6 @@ import (
"testing" "testing"
) )
// For debugging - keep around. // For debugging - keep around.
func dump(r *Ring) { func dump(r *Ring) {
if r == nil { if r == nil {
@ -24,7 +23,6 @@ func dump(r *Ring) {
fmt.Println() fmt.Println()
} }
func verify(t *testing.T, r *Ring, N int, sum int) { func verify(t *testing.T, r *Ring, N int, sum int) {
// Len // Len
n := r.Len() n := r.Len()
@ -96,7 +94,6 @@ func verify(t *testing.T, r *Ring, N int, sum int) {
} }
} }
func TestCornerCases(t *testing.T) { func TestCornerCases(t *testing.T) {
var ( var (
r0 *Ring r0 *Ring
@ -118,7 +115,6 @@ func TestCornerCases(t *testing.T) {
verify(t, &r1, 1, 0) verify(t, &r1, 1, 0)
} }
func makeN(n int) *Ring { func makeN(n int) *Ring {
r := New(n) r := New(n)
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
@ -130,7 +126,6 @@ func makeN(n int) *Ring {
func sumN(n int) int { return (n*n + n) / 2 } func sumN(n int) int { return (n*n + n) / 2 }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
r := New(i) r := New(i)
@ -142,7 +137,6 @@ func TestNew(t *testing.T) {
} }
} }
func TestLink1(t *testing.T) { func TestLink1(t *testing.T) {
r1a := makeN(1) r1a := makeN(1)
var r1b Ring var r1b Ring
@ -163,7 +157,6 @@ func TestLink1(t *testing.T) {
verify(t, r2b, 1, 0) verify(t, r2b, 1, 0)
} }
func TestLink2(t *testing.T) { func TestLink2(t *testing.T) {
var r0 *Ring var r0 *Ring
r1a := &Ring{Value: 42} r1a := &Ring{Value: 42}
@ -183,7 +176,6 @@ func TestLink2(t *testing.T) {
verify(t, r10, 12, sumN(10)+42+77) verify(t, r10, 12, sumN(10)+42+77)
} }
func TestLink3(t *testing.T) { func TestLink3(t *testing.T) {
var r Ring var r Ring
n := 1 n := 1
@ -193,7 +185,6 @@ func TestLink3(t *testing.T) {
} }
} }
func TestUnlink(t *testing.T) { func TestUnlink(t *testing.T) {
r10 := makeN(10) r10 := makeN(10)
s10 := r10.Move(6) s10 := r10.Move(6)
@ -215,7 +206,6 @@ func TestUnlink(t *testing.T) {
verify(t, r10, 9, sum10-2) verify(t, r10, 9, sum10-2)
} }
func TestLinkUnlink(t *testing.T) { func TestLinkUnlink(t *testing.T) {
for i := 1; i < 4; i++ { for i := 1; i < 4; i++ {
ri := New(i) ri := New(i)

View File

@ -6,29 +6,24 @@
// Vectors grow and shrink dynamically as necessary. // Vectors grow and shrink dynamically as necessary.
package vector package vector
// Vector is a container for numbered sequences of elements of type interface{}. // Vector is a container for numbered sequences of elements of type interface{}.
// A vector's length and capacity adjusts automatically as necessary. // A vector's length and capacity adjusts automatically as necessary.
// The zero value for Vector is an empty vector ready to use. // The zero value for Vector is an empty vector ready to use.
type Vector []interface{} type Vector []interface{}
// IntVector is a container for numbered sequences of elements of type int. // IntVector is a container for numbered sequences of elements of type int.
// A vector's length and capacity adjusts automatically as necessary. // A vector's length and capacity adjusts automatically as necessary.
// The zero value for IntVector is an empty vector ready to use. // The zero value for IntVector is an empty vector ready to use.
type IntVector []int type IntVector []int
// StringVector is a container for numbered sequences of elements of type string. // StringVector is a container for numbered sequences of elements of type string.
// A vector's length and capacity adjusts automatically as necessary. // A vector's length and capacity adjusts automatically as necessary.
// The zero value for StringVector is an empty vector ready to use. // The zero value for StringVector is an empty vector ready to use.
type StringVector []string type StringVector []string
// Initial underlying array size // Initial underlying array size
const initialSize = 8 const initialSize = 8
// Partial sort.Interface support // Partial sort.Interface support
// LessInterface provides partial support of the sort.Interface. // LessInterface provides partial support of the sort.Interface.
@ -36,16 +31,13 @@ type LessInterface interface {
Less(y interface{}) bool Less(y interface{}) bool
} }
// Less returns a boolean denoting whether the i'th element is less than the j'th element. // Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *Vector) Less(i, j int) bool { return (*p)[i].(LessInterface).Less((*p)[j]) } func (p *Vector) Less(i, j int) bool { return (*p)[i].(LessInterface).Less((*p)[j]) }
// sort.Interface support // sort.Interface support
// Less returns a boolean denoting whether the i'th element is less than the j'th element. // Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *IntVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] } func (p *IntVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }
// Less returns a boolean denoting whether the i'th element is less than the j'th element. // Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *StringVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] } func (p *StringVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }

View File

@ -7,7 +7,6 @@
package vector package vector
func (p *IntVector) realloc(length, capacity int) (b []int) { func (p *IntVector) realloc(length, capacity int) (b []int) {
if capacity < initialSize { if capacity < initialSize {
capacity = initialSize capacity = initialSize
@ -21,7 +20,6 @@ func (p *IntVector) realloc(length, capacity int) (b []int) {
return return
} }
// Insert n elements at position i. // Insert n elements at position i.
func (p *IntVector) Expand(i, n int) { func (p *IntVector) Expand(i, n int) {
a := *p a := *p
@ -51,11 +49,9 @@ func (p *IntVector) Expand(i, n int) {
*p = a *p = a
} }
// Insert n elements at the end of a vector. // Insert n elements at the end of a vector.
func (p *IntVector) Extend(n int) { p.Expand(len(*p), n) } func (p *IntVector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector. // Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards // If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length, // trailing elements. If the new length is longer than the current length,
@ -80,30 +76,24 @@ func (p *IntVector) Resize(length, capacity int) *IntVector {
return p return p
} }
// Len returns the number of elements in the vector. // Len returns the number of elements in the vector.
// Same as len(*p). // Same as len(*p).
func (p *IntVector) Len() int { return len(*p) } func (p *IntVector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the // Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing. // maximum length the vector can grow without resizing.
// Same as cap(*p). // Same as cap(*p).
func (p *IntVector) Cap() int { return cap(*p) } func (p *IntVector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector. // At returns the i'th element of the vector.
func (p *IntVector) At(i int) int { return (*p)[i] } func (p *IntVector) At(i int) int { return (*p)[i] }
// Set sets the i'th element of the vector to value x. // Set sets the i'th element of the vector to value x.
func (p *IntVector) Set(i int, x int) { (*p)[i] = x } func (p *IntVector) Set(i int, x int) { (*p)[i] = x }
// Last returns the element in the vector of highest index. // Last returns the element in the vector of highest index.
func (p *IntVector) Last() int { return (*p)[len(*p)-1] } func (p *IntVector) Last() int { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it. // Copy makes a copy of the vector and returns it.
func (p *IntVector) Copy() IntVector { func (p *IntVector) Copy() IntVector {
arr := make(IntVector, len(*p)) arr := make(IntVector, len(*p))
@ -111,7 +101,6 @@ func (p *IntVector) Copy() IntVector {
return arr return arr
} }
// Insert inserts into the vector an element of value x before // Insert inserts into the vector an element of value x before
// the current element at index i. // the current element at index i.
func (p *IntVector) Insert(i int, x int) { func (p *IntVector) Insert(i int, x int) {
@ -119,7 +108,6 @@ func (p *IntVector) Insert(i int, x int) {
(*p)[i] = x (*p)[i] = x
} }
// Delete deletes the i'th element of the vector. The gap is closed so the old // Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards. // element at index i+1 has index i afterwards.
func (p *IntVector) Delete(i int) { func (p *IntVector) Delete(i int) {
@ -132,7 +120,6 @@ func (p *IntVector) Delete(i int) {
*p = a[0 : n-1] *p = a[0 : n-1]
} }
// InsertVector inserts into the vector the contents of the vector // InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion. // x such that the 0th element of x appears at index i after insertion.
func (p *IntVector) InsertVector(i int, x *IntVector) { func (p *IntVector) InsertVector(i int, x *IntVector) {
@ -142,7 +129,6 @@ func (p *IntVector) InsertVector(i int, x *IntVector) {
copy((*p)[i:i+len(b)], b) copy((*p)[i:i+len(b)], b)
} }
// Cut deletes elements i through j-1, inclusive. // Cut deletes elements i through j-1, inclusive.
func (p *IntVector) Cut(i, j int) { func (p *IntVector) Cut(i, j int) {
a := *p a := *p
@ -158,7 +144,6 @@ func (p *IntVector) Cut(i, j int) {
*p = a[0:m] *p = a[0:m]
} }
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j]. // Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged. // The elements are copied. The original vector is unchanged.
func (p *IntVector) Slice(i, j int) *IntVector { func (p *IntVector) Slice(i, j int) *IntVector {
@ -168,13 +153,11 @@ func (p *IntVector) Slice(i, j int) *IntVector {
return &s return &s
} }
// Convenience wrappers // Convenience wrappers
// Push appends x to the end of the vector. // Push appends x to the end of the vector.
func (p *IntVector) Push(x int) { p.Insert(len(*p), x) } func (p *IntVector) Push(x int) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector. // Pop deletes the last element of the vector.
func (p *IntVector) Pop() int { func (p *IntVector) Pop() int {
a := *p a := *p
@ -187,18 +170,15 @@ func (p *IntVector) Pop() int {
return x return x
} }
// AppendVector appends the entire vector x to the end of this vector. // AppendVector appends the entire vector x to the end of this vector.
func (p *IntVector) AppendVector(x *IntVector) { p.InsertVector(len(*p), x) } func (p *IntVector) AppendVector(x *IntVector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j. // Swap exchanges the elements at indexes i and j.
func (p *IntVector) Swap(i, j int) { func (p *IntVector) Swap(i, j int) {
a := *p a := *p
a[i], a[j] = a[j], a[i] a[i], a[j] = a[j], a[i]
} }
// Do calls function f for each element of the vector, in order. // Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p. // The behavior of Do is undefined if f changes *p.
func (p *IntVector) Do(f func(elem int)) { func (p *IntVector) Do(f func(elem int)) {

View File

@ -9,7 +9,6 @@ package vector
import "testing" import "testing"
func TestIntZeroLen(t *testing.T) { func TestIntZeroLen(t *testing.T) {
a := new(IntVector) a := new(IntVector)
if a.Len() != 0 { if a.Len() != 0 {
@ -27,7 +26,6 @@ func TestIntZeroLen(t *testing.T) {
} }
} }
func TestIntResize(t *testing.T) { func TestIntResize(t *testing.T) {
var a IntVector var a IntVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
@ -40,7 +38,6 @@ func TestIntResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100) checkSize(t, a.Resize(11, 100), 11, 100)
} }
func TestIntResize2(t *testing.T) { func TestIntResize2(t *testing.T) {
var a IntVector var a IntVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
@ -62,7 +59,6 @@ func TestIntResize2(t *testing.T) {
} }
} }
func checkIntZero(t *testing.T, a *IntVector, i int) { func checkIntZero(t *testing.T, a *IntVector, i int) {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if a.At(j) == intzero { if a.At(j) == intzero {
@ -82,7 +78,6 @@ func checkIntZero(t *testing.T, a *IntVector, i int) {
} }
} }
func TestIntTrailingElements(t *testing.T) { func TestIntTrailingElements(t *testing.T) {
var a IntVector var a IntVector
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -95,7 +90,6 @@ func TestIntTrailingElements(t *testing.T) {
checkIntZero(t, &a, 5) checkIntZero(t, &a, 5)
} }
func TestIntAccess(t *testing.T) { func TestIntAccess(t *testing.T) {
const n = 100 const n = 100
var a IntVector var a IntVector
@ -120,7 +114,6 @@ func TestIntAccess(t *testing.T) {
} }
} }
func TestIntInsertDeleteClear(t *testing.T) { func TestIntInsertDeleteClear(t *testing.T) {
const n = 100 const n = 100
var a IntVector var a IntVector
@ -207,7 +200,6 @@ func TestIntInsertDeleteClear(t *testing.T) {
} }
} }
func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) { func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
for k := i; k < j; k++ { for k := i; k < j; k++ {
if elem2IntValue(x.At(k)) != int2IntValue(elt) { if elem2IntValue(x.At(k)) != int2IntValue(elt) {
@ -223,7 +215,6 @@ func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
} }
} }
func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) { func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
n := a + b + c n := a + b + c
if x.Len() != n { if x.Len() != n {
@ -237,7 +228,6 @@ func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
verify_sliceInt(t, x, 0, a+b, n) verify_sliceInt(t, x, 0, a+b, n)
} }
func make_vectorInt(elt, len int) *IntVector { func make_vectorInt(elt, len int) *IntVector {
x := new(IntVector).Resize(len, 0) x := new(IntVector).Resize(len, 0)
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
@ -246,7 +236,6 @@ func make_vectorInt(elt, len int) *IntVector {
return x return x
} }
func TestIntInsertVector(t *testing.T) { func TestIntInsertVector(t *testing.T) {
// 1 // 1
a := make_vectorInt(0, 0) a := make_vectorInt(0, 0)
@ -270,7 +259,6 @@ func TestIntInsertVector(t *testing.T) {
verify_patternInt(t, a, 8, 1000, 2) verify_patternInt(t, a, 8, 1000, 2)
} }
func TestIntDo(t *testing.T) { func TestIntDo(t *testing.T) {
const n = 25 const n = 25
const salt = 17 const salt = 17
@ -325,7 +313,6 @@ func TestIntDo(t *testing.T) {
} }
func TestIntVectorCopy(t *testing.T) { func TestIntVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector // verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10 const Len = 10

View File

@ -4,7 +4,6 @@
package vector package vector
import ( import (
"fmt" "fmt"
"sort" "sort"
@ -17,28 +16,23 @@ var (
strzero string strzero string
) )
func int2Value(x int) int { return x } func int2Value(x int) int { return x }
func int2IntValue(x int) int { return x } func int2IntValue(x int) int { return x }
func int2StrValue(x int) string { return string(x) } func int2StrValue(x int) string { return string(x) }
func elem2Value(x interface{}) int { return x.(int) } func elem2Value(x interface{}) int { return x.(int) }
func elem2IntValue(x int) int { return x } func elem2IntValue(x int) int { return x }
func elem2StrValue(x string) string { return x } func elem2StrValue(x string) string { return x }
func intf2Value(x interface{}) int { return x.(int) } func intf2Value(x interface{}) int { return x.(int) }
func intf2IntValue(x interface{}) int { return x.(int) } func intf2IntValue(x interface{}) int { return x.(int) }
func intf2StrValue(x interface{}) string { return x.(string) } func intf2StrValue(x interface{}) string { return x.(string) }
type VectorInterface interface { type VectorInterface interface {
Len() int Len() int
Cap() int Cap() int
} }
func checkSize(t *testing.T, v VectorInterface, len, cap int) { func checkSize(t *testing.T, v VectorInterface, len, cap int) {
if v.Len() != len { if v.Len() != len {
t.Errorf("%T expected len = %d; found %d", v, len, v.Len()) t.Errorf("%T expected len = %d; found %d", v, len, v.Len())
@ -48,10 +42,8 @@ func checkSize(t *testing.T, v VectorInterface, len, cap int) {
} }
} }
func val(i int) int { return i*991 - 1234 } func val(i int) int { return i*991 - 1234 }
func TestSorting(t *testing.T) { func TestSorting(t *testing.T) {
const n = 100 const n = 100
@ -72,5 +64,4 @@ func TestSorting(t *testing.T) {
} }
} }
func tname(x interface{}) string { return fmt.Sprintf("%T: ", x) } func tname(x interface{}) string { return fmt.Sprintf("%T: ", x) }

View File

@ -11,10 +11,8 @@ import (
"testing" "testing"
) )
const memTestN = 1000000 const memTestN = 1000000
func s(n uint64) string { func s(n uint64) string {
str := fmt.Sprintf("%d", n) str := fmt.Sprintf("%d", n)
lens := len(str) lens := len(str)
@ -31,7 +29,6 @@ func s(n uint64) string {
return strings.Join(a, " ") return strings.Join(a, " ")
} }
func TestVectorNums(t *testing.T) { func TestVectorNums(t *testing.T) {
if testing.Short() { if testing.Short() {
return return
@ -52,7 +49,6 @@ func TestVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN) t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
} }
func TestIntVectorNums(t *testing.T) { func TestIntVectorNums(t *testing.T) {
if testing.Short() { if testing.Short() {
return return
@ -73,7 +69,6 @@ func TestIntVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN) t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
} }
func TestStringVectorNums(t *testing.T) { func TestStringVectorNums(t *testing.T) {
if testing.Short() { if testing.Short() {
return return
@ -94,7 +89,6 @@ func TestStringVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN) t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
} }
func BenchmarkVectorNums(b *testing.B) { func BenchmarkVectorNums(b *testing.B) {
c := int(0) c := int(0)
var v Vector var v Vector
@ -106,7 +100,6 @@ func BenchmarkVectorNums(b *testing.B) {
} }
} }
func BenchmarkIntVectorNums(b *testing.B) { func BenchmarkIntVectorNums(b *testing.B) {
c := int(0) c := int(0)
var v IntVector var v IntVector
@ -118,7 +111,6 @@ func BenchmarkIntVectorNums(b *testing.B) {
} }
} }
func BenchmarkStringVectorNums(b *testing.B) { func BenchmarkStringVectorNums(b *testing.B) {
c := "" c := ""
var v StringVector var v StringVector

View File

@ -7,7 +7,6 @@
package vector package vector
func (p *StringVector) realloc(length, capacity int) (b []string) { func (p *StringVector) realloc(length, capacity int) (b []string) {
if capacity < initialSize { if capacity < initialSize {
capacity = initialSize capacity = initialSize
@ -21,7 +20,6 @@ func (p *StringVector) realloc(length, capacity int) (b []string) {
return return
} }
// Insert n elements at position i. // Insert n elements at position i.
func (p *StringVector) Expand(i, n int) { func (p *StringVector) Expand(i, n int) {
a := *p a := *p
@ -51,11 +49,9 @@ func (p *StringVector) Expand(i, n int) {
*p = a *p = a
} }
// Insert n elements at the end of a vector. // Insert n elements at the end of a vector.
func (p *StringVector) Extend(n int) { p.Expand(len(*p), n) } func (p *StringVector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector. // Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards // If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length, // trailing elements. If the new length is longer than the current length,
@ -80,30 +76,24 @@ func (p *StringVector) Resize(length, capacity int) *StringVector {
return p return p
} }
// Len returns the number of elements in the vector. // Len returns the number of elements in the vector.
// Same as len(*p). // Same as len(*p).
func (p *StringVector) Len() int { return len(*p) } func (p *StringVector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the // Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing. // maximum length the vector can grow without resizing.
// Same as cap(*p). // Same as cap(*p).
func (p *StringVector) Cap() int { return cap(*p) } func (p *StringVector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector. // At returns the i'th element of the vector.
func (p *StringVector) At(i int) string { return (*p)[i] } func (p *StringVector) At(i int) string { return (*p)[i] }
// Set sets the i'th element of the vector to value x. // Set sets the i'th element of the vector to value x.
func (p *StringVector) Set(i int, x string) { (*p)[i] = x } func (p *StringVector) Set(i int, x string) { (*p)[i] = x }
// Last returns the element in the vector of highest index. // Last returns the element in the vector of highest index.
func (p *StringVector) Last() string { return (*p)[len(*p)-1] } func (p *StringVector) Last() string { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it. // Copy makes a copy of the vector and returns it.
func (p *StringVector) Copy() StringVector { func (p *StringVector) Copy() StringVector {
arr := make(StringVector, len(*p)) arr := make(StringVector, len(*p))
@ -111,7 +101,6 @@ func (p *StringVector) Copy() StringVector {
return arr return arr
} }
// Insert inserts into the vector an element of value x before // Insert inserts into the vector an element of value x before
// the current element at index i. // the current element at index i.
func (p *StringVector) Insert(i int, x string) { func (p *StringVector) Insert(i int, x string) {
@ -119,7 +108,6 @@ func (p *StringVector) Insert(i int, x string) {
(*p)[i] = x (*p)[i] = x
} }
// Delete deletes the i'th element of the vector. The gap is closed so the old // Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards. // element at index i+1 has index i afterwards.
func (p *StringVector) Delete(i int) { func (p *StringVector) Delete(i int) {
@ -132,7 +120,6 @@ func (p *StringVector) Delete(i int) {
*p = a[0 : n-1] *p = a[0 : n-1]
} }
// InsertVector inserts into the vector the contents of the vector // InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion. // x such that the 0th element of x appears at index i after insertion.
func (p *StringVector) InsertVector(i int, x *StringVector) { func (p *StringVector) InsertVector(i int, x *StringVector) {
@ -142,7 +129,6 @@ func (p *StringVector) InsertVector(i int, x *StringVector) {
copy((*p)[i:i+len(b)], b) copy((*p)[i:i+len(b)], b)
} }
// Cut deletes elements i through j-1, inclusive. // Cut deletes elements i through j-1, inclusive.
func (p *StringVector) Cut(i, j int) { func (p *StringVector) Cut(i, j int) {
a := *p a := *p
@ -158,7 +144,6 @@ func (p *StringVector) Cut(i, j int) {
*p = a[0:m] *p = a[0:m]
} }
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j]. // Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged. // The elements are copied. The original vector is unchanged.
func (p *StringVector) Slice(i, j int) *StringVector { func (p *StringVector) Slice(i, j int) *StringVector {
@ -168,13 +153,11 @@ func (p *StringVector) Slice(i, j int) *StringVector {
return &s return &s
} }
// Convenience wrappers // Convenience wrappers
// Push appends x to the end of the vector. // Push appends x to the end of the vector.
func (p *StringVector) Push(x string) { p.Insert(len(*p), x) } func (p *StringVector) Push(x string) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector. // Pop deletes the last element of the vector.
func (p *StringVector) Pop() string { func (p *StringVector) Pop() string {
a := *p a := *p
@ -187,18 +170,15 @@ func (p *StringVector) Pop() string {
return x return x
} }
// AppendVector appends the entire vector x to the end of this vector. // AppendVector appends the entire vector x to the end of this vector.
func (p *StringVector) AppendVector(x *StringVector) { p.InsertVector(len(*p), x) } func (p *StringVector) AppendVector(x *StringVector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j. // Swap exchanges the elements at indexes i and j.
func (p *StringVector) Swap(i, j int) { func (p *StringVector) Swap(i, j int) {
a := *p a := *p
a[i], a[j] = a[j], a[i] a[i], a[j] = a[j], a[i]
} }
// Do calls function f for each element of the vector, in order. // Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p. // The behavior of Do is undefined if f changes *p.
func (p *StringVector) Do(f func(elem string)) { func (p *StringVector) Do(f func(elem string)) {

View File

@ -9,7 +9,6 @@ package vector
import "testing" import "testing"
func TestStrZeroLen(t *testing.T) { func TestStrZeroLen(t *testing.T) {
a := new(StringVector) a := new(StringVector)
if a.Len() != 0 { if a.Len() != 0 {
@ -27,7 +26,6 @@ func TestStrZeroLen(t *testing.T) {
} }
} }
func TestStrResize(t *testing.T) { func TestStrResize(t *testing.T) {
var a StringVector var a StringVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
@ -40,7 +38,6 @@ func TestStrResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100) checkSize(t, a.Resize(11, 100), 11, 100)
} }
func TestStrResize2(t *testing.T) { func TestStrResize2(t *testing.T) {
var a StringVector var a StringVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
@ -62,7 +59,6 @@ func TestStrResize2(t *testing.T) {
} }
} }
func checkStrZero(t *testing.T, a *StringVector, i int) { func checkStrZero(t *testing.T, a *StringVector, i int) {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if a.At(j) == strzero { if a.At(j) == strzero {
@ -82,7 +78,6 @@ func checkStrZero(t *testing.T, a *StringVector, i int) {
} }
} }
func TestStrTrailingElements(t *testing.T) { func TestStrTrailingElements(t *testing.T) {
var a StringVector var a StringVector
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -95,7 +90,6 @@ func TestStrTrailingElements(t *testing.T) {
checkStrZero(t, &a, 5) checkStrZero(t, &a, 5)
} }
func TestStrAccess(t *testing.T) { func TestStrAccess(t *testing.T) {
const n = 100 const n = 100
var a StringVector var a StringVector
@ -120,7 +114,6 @@ func TestStrAccess(t *testing.T) {
} }
} }
func TestStrInsertDeleteClear(t *testing.T) { func TestStrInsertDeleteClear(t *testing.T) {
const n = 100 const n = 100
var a StringVector var a StringVector
@ -207,7 +200,6 @@ func TestStrInsertDeleteClear(t *testing.T) {
} }
} }
func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) { func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
for k := i; k < j; k++ { for k := i; k < j; k++ {
if elem2StrValue(x.At(k)) != int2StrValue(elt) { if elem2StrValue(x.At(k)) != int2StrValue(elt) {
@ -223,7 +215,6 @@ func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
} }
} }
func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) { func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
n := a + b + c n := a + b + c
if x.Len() != n { if x.Len() != n {
@ -237,7 +228,6 @@ func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
verify_sliceStr(t, x, 0, a+b, n) verify_sliceStr(t, x, 0, a+b, n)
} }
func make_vectorStr(elt, len int) *StringVector { func make_vectorStr(elt, len int) *StringVector {
x := new(StringVector).Resize(len, 0) x := new(StringVector).Resize(len, 0)
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
@ -246,7 +236,6 @@ func make_vectorStr(elt, len int) *StringVector {
return x return x
} }
func TestStrInsertVector(t *testing.T) { func TestStrInsertVector(t *testing.T) {
// 1 // 1
a := make_vectorStr(0, 0) a := make_vectorStr(0, 0)
@ -270,7 +259,6 @@ func TestStrInsertVector(t *testing.T) {
verify_patternStr(t, a, 8, 1000, 2) verify_patternStr(t, a, 8, 1000, 2)
} }
func TestStrDo(t *testing.T) { func TestStrDo(t *testing.T) {
const n = 25 const n = 25
const salt = 17 const salt = 17
@ -325,7 +313,6 @@ func TestStrDo(t *testing.T) {
} }
func TestStrVectorCopy(t *testing.T) { func TestStrVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector // verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10 const Len = 10

View File

@ -7,7 +7,6 @@
package vector package vector
func (p *Vector) realloc(length, capacity int) (b []interface{}) { func (p *Vector) realloc(length, capacity int) (b []interface{}) {
if capacity < initialSize { if capacity < initialSize {
capacity = initialSize capacity = initialSize
@ -21,7 +20,6 @@ func (p *Vector) realloc(length, capacity int) (b []interface{}) {
return return
} }
// Insert n elements at position i. // Insert n elements at position i.
func (p *Vector) Expand(i, n int) { func (p *Vector) Expand(i, n int) {
a := *p a := *p
@ -51,11 +49,9 @@ func (p *Vector) Expand(i, n int) {
*p = a *p = a
} }
// Insert n elements at the end of a vector. // Insert n elements at the end of a vector.
func (p *Vector) Extend(n int) { p.Expand(len(*p), n) } func (p *Vector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector. // Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards // If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length, // trailing elements. If the new length is longer than the current length,
@ -80,30 +76,24 @@ func (p *Vector) Resize(length, capacity int) *Vector {
return p return p
} }
// Len returns the number of elements in the vector. // Len returns the number of elements in the vector.
// Same as len(*p). // Same as len(*p).
func (p *Vector) Len() int { return len(*p) } func (p *Vector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the // Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing. // maximum length the vector can grow without resizing.
// Same as cap(*p). // Same as cap(*p).
func (p *Vector) Cap() int { return cap(*p) } func (p *Vector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector. // At returns the i'th element of the vector.
func (p *Vector) At(i int) interface{} { return (*p)[i] } func (p *Vector) At(i int) interface{} { return (*p)[i] }
// Set sets the i'th element of the vector to value x. // Set sets the i'th element of the vector to value x.
func (p *Vector) Set(i int, x interface{}) { (*p)[i] = x } func (p *Vector) Set(i int, x interface{}) { (*p)[i] = x }
// Last returns the element in the vector of highest index. // Last returns the element in the vector of highest index.
func (p *Vector) Last() interface{} { return (*p)[len(*p)-1] } func (p *Vector) Last() interface{} { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it. // Copy makes a copy of the vector and returns it.
func (p *Vector) Copy() Vector { func (p *Vector) Copy() Vector {
arr := make(Vector, len(*p)) arr := make(Vector, len(*p))
@ -111,7 +101,6 @@ func (p *Vector) Copy() Vector {
return arr return arr
} }
// Insert inserts into the vector an element of value x before // Insert inserts into the vector an element of value x before
// the current element at index i. // the current element at index i.
func (p *Vector) Insert(i int, x interface{}) { func (p *Vector) Insert(i int, x interface{}) {
@ -119,7 +108,6 @@ func (p *Vector) Insert(i int, x interface{}) {
(*p)[i] = x (*p)[i] = x
} }
// Delete deletes the i'th element of the vector. The gap is closed so the old // Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards. // element at index i+1 has index i afterwards.
func (p *Vector) Delete(i int) { func (p *Vector) Delete(i int) {
@ -132,7 +120,6 @@ func (p *Vector) Delete(i int) {
*p = a[0 : n-1] *p = a[0 : n-1]
} }
// InsertVector inserts into the vector the contents of the vector // InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion. // x such that the 0th element of x appears at index i after insertion.
func (p *Vector) InsertVector(i int, x *Vector) { func (p *Vector) InsertVector(i int, x *Vector) {
@ -142,7 +129,6 @@ func (p *Vector) InsertVector(i int, x *Vector) {
copy((*p)[i:i+len(b)], b) copy((*p)[i:i+len(b)], b)
} }
// Cut deletes elements i through j-1, inclusive. // Cut deletes elements i through j-1, inclusive.
func (p *Vector) Cut(i, j int) { func (p *Vector) Cut(i, j int) {
a := *p a := *p
@ -158,7 +144,6 @@ func (p *Vector) Cut(i, j int) {
*p = a[0:m] *p = a[0:m]
} }
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j]. // Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged. // The elements are copied. The original vector is unchanged.
func (p *Vector) Slice(i, j int) *Vector { func (p *Vector) Slice(i, j int) *Vector {
@ -168,13 +153,11 @@ func (p *Vector) Slice(i, j int) *Vector {
return &s return &s
} }
// Convenience wrappers // Convenience wrappers
// Push appends x to the end of the vector. // Push appends x to the end of the vector.
func (p *Vector) Push(x interface{}) { p.Insert(len(*p), x) } func (p *Vector) Push(x interface{}) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector. // Pop deletes the last element of the vector.
func (p *Vector) Pop() interface{} { func (p *Vector) Pop() interface{} {
a := *p a := *p
@ -187,18 +170,15 @@ func (p *Vector) Pop() interface{} {
return x return x
} }
// AppendVector appends the entire vector x to the end of this vector. // AppendVector appends the entire vector x to the end of this vector.
func (p *Vector) AppendVector(x *Vector) { p.InsertVector(len(*p), x) } func (p *Vector) AppendVector(x *Vector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j. // Swap exchanges the elements at indexes i and j.
func (p *Vector) Swap(i, j int) { func (p *Vector) Swap(i, j int) {
a := *p a := *p
a[i], a[j] = a[j], a[i] a[i], a[j] = a[j], a[i]
} }
// Do calls function f for each element of the vector, in order. // Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p. // The behavior of Do is undefined if f changes *p.
func (p *Vector) Do(f func(elem interface{})) { func (p *Vector) Do(f func(elem interface{})) {

View File

@ -9,7 +9,6 @@ package vector
import "testing" import "testing"
func TestZeroLen(t *testing.T) { func TestZeroLen(t *testing.T) {
a := new(Vector) a := new(Vector)
if a.Len() != 0 { if a.Len() != 0 {
@ -27,7 +26,6 @@ func TestZeroLen(t *testing.T) {
} }
} }
func TestResize(t *testing.T) { func TestResize(t *testing.T) {
var a Vector var a Vector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
@ -40,7 +38,6 @@ func TestResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100) checkSize(t, a.Resize(11, 100), 11, 100)
} }
func TestResize2(t *testing.T) { func TestResize2(t *testing.T) {
var a Vector var a Vector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
@ -62,7 +59,6 @@ func TestResize2(t *testing.T) {
} }
} }
func checkZero(t *testing.T, a *Vector, i int) { func checkZero(t *testing.T, a *Vector, i int) {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if a.At(j) == zero { if a.At(j) == zero {
@ -82,7 +78,6 @@ func checkZero(t *testing.T, a *Vector, i int) {
} }
} }
func TestTrailingElements(t *testing.T) { func TestTrailingElements(t *testing.T) {
var a Vector var a Vector
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -95,7 +90,6 @@ func TestTrailingElements(t *testing.T) {
checkZero(t, &a, 5) checkZero(t, &a, 5)
} }
func TestAccess(t *testing.T) { func TestAccess(t *testing.T) {
const n = 100 const n = 100
var a Vector var a Vector
@ -120,7 +114,6 @@ func TestAccess(t *testing.T) {
} }
} }
func TestInsertDeleteClear(t *testing.T) { func TestInsertDeleteClear(t *testing.T) {
const n = 100 const n = 100
var a Vector var a Vector
@ -207,7 +200,6 @@ func TestInsertDeleteClear(t *testing.T) {
} }
} }
func verify_slice(t *testing.T, x *Vector, elt, i, j int) { func verify_slice(t *testing.T, x *Vector, elt, i, j int) {
for k := i; k < j; k++ { for k := i; k < j; k++ {
if elem2Value(x.At(k)) != int2Value(elt) { if elem2Value(x.At(k)) != int2Value(elt) {
@ -223,7 +215,6 @@ func verify_slice(t *testing.T, x *Vector, elt, i, j int) {
} }
} }
func verify_pattern(t *testing.T, x *Vector, a, b, c int) { func verify_pattern(t *testing.T, x *Vector, a, b, c int) {
n := a + b + c n := a + b + c
if x.Len() != n { if x.Len() != n {
@ -237,7 +228,6 @@ func verify_pattern(t *testing.T, x *Vector, a, b, c int) {
verify_slice(t, x, 0, a+b, n) verify_slice(t, x, 0, a+b, n)
} }
func make_vector(elt, len int) *Vector { func make_vector(elt, len int) *Vector {
x := new(Vector).Resize(len, 0) x := new(Vector).Resize(len, 0)
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
@ -246,7 +236,6 @@ func make_vector(elt, len int) *Vector {
return x return x
} }
func TestInsertVector(t *testing.T) { func TestInsertVector(t *testing.T) {
// 1 // 1
a := make_vector(0, 0) a := make_vector(0, 0)
@ -270,7 +259,6 @@ func TestInsertVector(t *testing.T) {
verify_pattern(t, a, 8, 1000, 2) verify_pattern(t, a, 8, 1000, 2)
} }
func TestDo(t *testing.T) { func TestDo(t *testing.T) {
const n = 25 const n = 25
const salt = 17 const salt = 17
@ -325,7 +313,6 @@ func TestDo(t *testing.T) {
} }
func TestVectorCopy(t *testing.T) { func TestVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector // verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10 const Len = 10

View File

@ -45,14 +45,14 @@ func NewCipher(key []byte) (*Cipher, os.Error) {
// BlockSize returns the AES block size, 16 bytes. // BlockSize returns the AES block size, 16 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Cipher interface in the
// package "crypto/block". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 16-byte buffer src using the key k // Encrypt encrypts the 16-byte buffer src using the key k
// and stores the result in dst. // and stores the result in dst.
// Note that for amounts of data larger than a block, // Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks; // it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/block/cbc.go). // instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.enc, dst, src) } func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.enc, dst, src) }
// Decrypt decrypts the 16-byte buffer src using the key k // Decrypt decrypts the 16-byte buffer src using the key k

View File

@ -42,14 +42,14 @@ func NewCipher(key []byte) (*Cipher, os.Error) {
// BlockSize returns the Blowfish block size, 8 bytes. // BlockSize returns the Blowfish block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Cipher interface in the
// package "crypto/block". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 8-byte buffer src using the key k // Encrypt encrypts the 8-byte buffer src using the key k
// and stores the result in dst. // and stores the result in dst.
// Note that for amounts of data larger than a block, // Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks; // it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/block/cbc.go). // instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) { func (c *Cipher) Encrypt(dst, src []byte) {
l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7]) r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])

View File

@ -20,7 +20,7 @@ type Cipher struct {
func NewCipher(key []byte) (c *Cipher, err os.Error) { func NewCipher(key []byte) (c *Cipher, err os.Error) {
if len(key) != KeySize { if len(key) != KeySize {
return nil, os.ErrorString("CAST5: keys must be 16 bytes") return nil, os.NewError("CAST5: keys must be 16 bytes")
} }
c = new(Cipher) c = new(Cipher)

View File

@ -80,9 +80,10 @@ type ocfbDecrypter struct {
// NewOCFBDecrypter returns a Stream which decrypts data with OpenPGP's cipher // NewOCFBDecrypter returns a Stream which decrypts data with OpenPGP's cipher
// feedback mode using the given Block. Prefix must be the first blockSize + 2 // feedback mode using the given Block. Prefix must be the first blockSize + 2
// bytes of the ciphertext, where blockSize is the Block's block size. If an // bytes of the ciphertext, where blockSize is the Block's block size. If an
// incorrect key is detected then nil is returned. Resync determines if the // incorrect key is detected then nil is returned. On successful exit,
// "resynchronization step" from RFC 4880, 13.9 step 7 is performed. Different // blockSize+2 bytes of decrypted data are written into prefix. Resync
// parts of OpenPGP vary on this point. // determines if the "resynchronization step" from RFC 4880, 13.9 step 7 is
// performed. Different parts of OpenPGP vary on this point.
func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Stream { func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Stream {
blockSize := block.BlockSize() blockSize := block.BlockSize()
if len(prefix) != blockSize+2 { if len(prefix) != blockSize+2 {
@ -118,6 +119,7 @@ func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Strea
x.fre[1] = prefix[blockSize+1] x.fre[1] = prefix[blockSize+1]
x.outUsed = 2 x.outUsed = 2
} }
copy(prefix, prefixCopy)
return x return x
} }

View File

@ -79,7 +79,7 @@ func GenerateParameters(params *Parameters, rand io.Reader, sizes ParameterSizes
L = 3072 L = 3072
N = 256 N = 256
default: default:
return os.ErrorString("crypto/dsa: invalid ParameterSizes") return os.NewError("crypto/dsa: invalid ParameterSizes")
} }
qBytes := make([]byte, N/8) qBytes := make([]byte, N/8)
@ -158,7 +158,7 @@ GeneratePrimes:
// PrivateKey must already be valid (see GenerateParameters). // PrivateKey must already be valid (see GenerateParameters).
func GenerateKey(priv *PrivateKey, rand io.Reader) os.Error { func GenerateKey(priv *PrivateKey, rand io.Reader) os.Error {
if priv.P == nil || priv.Q == nil || priv.G == nil { if priv.P == nil || priv.Q == nil || priv.G == nil {
return os.ErrorString("crypto/dsa: parameters not set up before generating key") return os.NewError("crypto/dsa: parameters not set up before generating key")
} }
x := new(big.Int) x := new(big.Int)

View File

@ -284,7 +284,7 @@ func (curve *Curve) Marshal(x, y *big.Int) []byte {
return ret return ret
} }
// Unmarshal converts a point, serialised by Marshal, into an x, y pair. On // Unmarshal converts a point, serialized by Marshal, into an x, y pair. On
// error, x = nil. // error, x = nil.
func (curve *Curve) Unmarshal(data []byte) (x, y *big.Int) { func (curve *Curve) Unmarshal(data []byte) (x, y *big.Int) {
byteLen := (curve.BitSize + 7) >> 3 byteLen := (curve.BitSize + 7) >> 3

View File

@ -321,8 +321,8 @@ func TestMarshal(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
serialised := p224.Marshal(x, y) serialized := p224.Marshal(x, y)
xx, yy := p224.Unmarshal(serialised) xx, yy := p224.Unmarshal(serialized)
if xx == nil { if xx == nil {
t.Error("failed to unmarshal") t.Error("failed to unmarshal")
return return

View File

@ -190,7 +190,7 @@ func TestHMAC(t *testing.T) {
continue continue
} }
// Repetive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum()) sum := fmt.Sprintf("%x", h.Sum())
if sum != tt.out { if sum != tt.out {

View File

@ -13,6 +13,7 @@ import (
"crypto/rsa" "crypto/rsa"
_ "crypto/sha1" _ "crypto/sha1"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"os" "os"
"time" "time"
) )
@ -32,21 +33,8 @@ const (
ocspUnauthorized = 5 ocspUnauthorized = 5
) )
type rdnSequence []relativeDistinguishedNameSET
type relativeDistinguishedNameSET []attributeTypeAndValue
type attributeTypeAndValue struct {
Type asn1.ObjectIdentifier
Value interface{}
}
type algorithmIdentifier struct {
Algorithm asn1.ObjectIdentifier
}
type certID struct { type certID struct {
HashAlgorithm algorithmIdentifier HashAlgorithm pkix.AlgorithmIdentifier
NameHash []byte NameHash []byte
IssuerKeyHash []byte IssuerKeyHash []byte
SerialNumber asn1.RawValue SerialNumber asn1.RawValue
@ -54,7 +42,7 @@ type certID struct {
type responseASN1 struct { type responseASN1 struct {
Status asn1.Enumerated Status asn1.Enumerated
Response responseBytes "explicit,tag:0" Response responseBytes `asn1:"explicit,tag:0"`
} }
type responseBytes struct { type responseBytes struct {
@ -64,32 +52,32 @@ type responseBytes struct {
type basicResponse struct { type basicResponse struct {
TBSResponseData responseData TBSResponseData responseData
SignatureAlgorithm algorithmIdentifier SignatureAlgorithm pkix.AlgorithmIdentifier
Signature asn1.BitString Signature asn1.BitString
Certificates []asn1.RawValue "explicit,tag:0,optional" Certificates []asn1.RawValue `asn1:"explicit,tag:0,optional"`
} }
type responseData struct { type responseData struct {
Raw asn1.RawContent Raw asn1.RawContent
Version int "optional,default:1,explicit,tag:0" Version int `asn1:"optional,default:1,explicit,tag:0"`
RequestorName rdnSequence "optional,explicit,tag:1" RequestorName pkix.RDNSequence `asn1:"optional,explicit,tag:1"`
KeyHash []byte "optional,explicit,tag:2" KeyHash []byte `asn1:"optional,explicit,tag:2"`
ProducedAt *time.Time ProducedAt *time.Time
Responses []singleResponse Responses []singleResponse
} }
type singleResponse struct { type singleResponse struct {
CertID certID CertID certID
Good asn1.Flag "explicit,tag:0,optional" Good asn1.Flag `asn1:"explicit,tag:0,optional"`
Revoked revokedInfo "explicit,tag:1,optional" Revoked revokedInfo `asn1:"explicit,tag:1,optional"`
Unknown asn1.Flag "explicit,tag:2,optional" Unknown asn1.Flag `asn1:"explicit,tag:2,optional"`
ThisUpdate *time.Time ThisUpdate *time.Time
NextUpdate *time.Time "explicit,tag:0,optional" NextUpdate *time.Time `asn1:"explicit,tag:0,optional"`
} }
type revokedInfo struct { type revokedInfo struct {
RevocationTime *time.Time RevocationTime *time.Time
Reason int "explicit,tag:0,optional" Reason int `asn1:"explicit,tag:0,optional"`
} }
// This is the exposed reflection of the internal OCSP structures. // This is the exposed reflection of the internal OCSP structures.

View File

@ -153,7 +153,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err os.Error) {
// Decode reads a PGP armored block from the given Reader. It will ignore // Decode reads a PGP armored block from the given Reader. It will ignore
// leading garbage. If it doesn't find a block, it will return nil, os.EOF. The // leading garbage. If it doesn't find a block, it will return nil, os.EOF. The
// given Reader is not usable after calling this function: an arbitary amount // given Reader is not usable after calling this function: an arbitrary amount
// of data may have been read past the end of the block. // of data may have been read past the end of the block.
func Decode(in io.Reader) (p *Block, err os.Error) { func Decode(in io.Reader) (p *Block, err os.Error) {
r, _ := bufio.NewReaderSize(in, 100) r, _ := bufio.NewReaderSize(in, 100)

View File

@ -30,7 +30,6 @@ func (r recordingHash) Size() int {
panic("shouldn't be called") panic("shouldn't be called")
} }
func testCanonicalText(t *testing.T, input, expected string) { func testCanonicalText(t *testing.T, input, expected string) {
r := recordingHash{bytes.NewBuffer(nil)} r := recordingHash{bytes.NewBuffer(nil)}
c := NewCanonicalTextHash(r) c := NewCanonicalTextHash(r)

View File

@ -0,0 +1,122 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package elgamal implements ElGamal encryption, suitable for OpenPGP,
// as specified in "A Public-Key Cryptosystem and a Signature Scheme Based on
// Discrete Logarithms," IEEE Transactions on Information Theory, v. IT-31,
// n. 4, 1985, pp. 469-472.
//
// This form of ElGamal embeds PKCS#1 v1.5 padding, which may make it
// unsuitable for other protocols. RSA should be used in preference in any
// case.
package elgamal
import (
"big"
"crypto/rand"
"crypto/subtle"
"io"
"os"
)
// PublicKey represents an ElGamal public key.
type PublicKey struct {
G, P, Y *big.Int
}
// PrivateKey represents an ElGamal private key.
type PrivateKey struct {
PublicKey
X *big.Int
}
// Encrypt encrypts the given message to the given public key. The result is a
// pair of integers. Errors can result from reading random, or because msg is
// too large to be encrypted to the public key.
func Encrypt(random io.Reader, pub *PublicKey, msg []byte) (c1, c2 *big.Int, err os.Error) {
pLen := (pub.P.BitLen() + 7) / 8
if len(msg) > pLen-11 {
err = os.NewError("elgamal: message too long")
return
}
// EM = 0x02 || PS || 0x00 || M
em := make([]byte, pLen-1)
em[0] = 2
ps, mm := em[1:len(em)-len(msg)-1], em[len(em)-len(msg):]
err = nonZeroRandomBytes(ps, random)
if err != nil {
return
}
em[len(em)-len(msg)-1] = 0
copy(mm, msg)
m := new(big.Int).SetBytes(em)
k, err := rand.Int(random, pub.P)
if err != nil {
return
}
c1 = new(big.Int).Exp(pub.G, k, pub.P)
s := new(big.Int).Exp(pub.Y, k, pub.P)
c2 = s.Mul(s, m)
c2.Mod(c2, pub.P)
return
}
// Decrypt takes two integers, resulting from an ElGamal encryption, and
// returns the plaintext of the message. An error can result only if the
// ciphertext is invalid. Users should keep in mind that this is a padding
// oracle and thus, if exposed to an adaptive chosen ciphertext attack, can
// be used to break the cryptosystem. See ``Chosen Ciphertext Attacks
// Against Protocols Based on the RSA Encryption Standard PKCS #1'', Daniel
// Bleichenbacher, Advances in Cryptology (Crypto '98),
func Decrypt(priv *PrivateKey, c1, c2 *big.Int) (msg []byte, err os.Error) {
s := new(big.Int).Exp(c1, priv.X, priv.P)
s.ModInverse(s, priv.P)
s.Mul(s, c2)
s.Mod(s, priv.P)
em := s.Bytes()
firstByteIsTwo := subtle.ConstantTimeByteEq(em[0], 2)
// The remainder of the plaintext must be a string of non-zero random
// octets, followed by a 0, followed by the message.
// lookingForIndex: 1 iff we are still looking for the zero.
// index: the offset of the first zero byte.
var lookingForIndex, index int
lookingForIndex = 1
for i := 1; i < len(em); i++ {
equals0 := subtle.ConstantTimeByteEq(em[i], 0)
index = subtle.ConstantTimeSelect(lookingForIndex&equals0, i, index)
lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
}
if firstByteIsTwo != 1 || lookingForIndex != 0 || index < 9 {
return nil, os.NewError("elgamal: decryption error")
}
return em[index+1:], nil
}
// nonZeroRandomBytes fills the given slice with non-zero random octets.
func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) {
_, err = io.ReadFull(rand, s)
if err != nil {
return
}
for i := 0; i < len(s); i++ {
for s[i] == 0 {
_, err = io.ReadFull(rand, s[i:i+1])
if err != nil {
return
}
}
}
return
}

View File

@ -0,0 +1,49 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elgamal
import (
"big"
"bytes"
"crypto/rand"
"testing"
)
// This is the 1024-bit MODP group from RFC 5114, section 2.1:
const primeHex = "B10B8F96A080E01DDE92DE5EAE5D54EC52C99FBCFB06A3C69A6A9DCA52D23B616073E28675A23D189838EF1E2EE652C013ECB4AEA906112324975C3CD49B83BFACCBDD7D90C4BD7098488E9C219A73724EFFD6FAE5644738FAA31A4FF55BCCC0A151AF5F0DC8B4BD45BF37DF365C1A65E68CFDA76D4DA708DF1FB2BC2E4A4371"
const generatorHex = "A4D1CBD5C3FD34126765A442EFB99905F8104DD258AC507FD6406CFF14266D31266FEA1E5C41564B777E690F5504F213160217B4B01B886A5E91547F9E2749F4D7FBD7D3B9A92EE1909D0D2263F80A76A6A24C087A091F531DBF0A0169B6A28AD662A4D18E73AFA32D779D5918D08BC8858F4DCEF97C2A24855E6EEB22B3B2E5"
func fromHex(hex string) *big.Int {
n, ok := new(big.Int).SetString(hex, 16)
if !ok {
panic("failed to parse hex number")
}
return n
}
func TestEncryptDecrypt(t *testing.T) {
priv := &PrivateKey{
PublicKey: PublicKey{
G: fromHex(generatorHex),
P: fromHex(primeHex),
},
X: fromHex("42"),
}
priv.Y = new(big.Int).Exp(priv.G, priv.X, priv.P)
message := []byte("hello world")
c1, c2, err := Encrypt(rand.Reader, &priv.PublicKey, message)
if err != nil {
t.Errorf("error encrypting: %s", err)
}
message2, err := Decrypt(priv, c1, c2)
if err != nil {
t.Errorf("error decrypting: %s", err)
}
if !bytes.Equal(message2, message) {
t.Errorf("decryption failed, got: %x, want: %x", message2, message)
}
}

View File

@ -5,11 +5,14 @@
package openpgp package openpgp
import ( import (
"crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/packet" "crypto/openpgp/packet"
"crypto/rsa"
"io" "io"
"os" "os"
"time"
) )
// PublicKeyType is the armor type for a PGP public key. // PublicKeyType is the armor type for a PGP public key.
@ -62,6 +65,78 @@ type KeyRing interface {
DecryptionKeys() []Key DecryptionKeys() []Key
} }
// primaryIdentity returns the Identity marked as primary or the first identity
// if none are so marked.
func (e *Entity) primaryIdentity() *Identity {
var firstIdentity *Identity
for _, ident := range e.Identities {
if firstIdentity == nil {
firstIdentity = ident
}
if ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
return ident
}
}
return firstIdentity
}
// encryptionKey returns the best candidate Key for encrypting a message to the
// given Entity.
func (e *Entity) encryptionKey() Key {
candidateSubkey := -1
for i, subkey := range e.Subkeys {
if subkey.Sig.FlagsValid && subkey.Sig.FlagEncryptCommunications && subkey.PublicKey.PubKeyAlgo.CanEncrypt() {
candidateSubkey = i
break
}
}
i := e.primaryIdentity()
if e.PrimaryKey.PubKeyAlgo.CanEncrypt() {
// If we don't have any candidate subkeys for encryption and
// the primary key doesn't have any usage metadata then we
// assume that the primary key is ok. Or, if the primary key is
// marked as ok to encrypt to, then we can obviously use it.
if candidateSubkey == -1 && !i.SelfSignature.FlagsValid || i.SelfSignature.FlagEncryptCommunications && i.SelfSignature.FlagsValid {
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}
}
}
if candidateSubkey != -1 {
subkey := e.Subkeys[candidateSubkey]
return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}
}
// This Entity appears to be signing only.
return Key{}
}
// signingKey return the best candidate Key for signing a message with this
// Entity.
func (e *Entity) signingKey() Key {
candidateSubkey := -1
for i, subkey := range e.Subkeys {
if subkey.Sig.FlagsValid && subkey.Sig.FlagSign && subkey.PublicKey.PubKeyAlgo.CanSign() {
candidateSubkey = i
break
}
}
i := e.primaryIdentity()
// If we have no candidate subkey then we assume that it's ok to sign
// with the primary key.
if candidateSubkey == -1 || i.SelfSignature.FlagsValid && i.SelfSignature.FlagSign {
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}
}
subkey := e.Subkeys[candidateSubkey]
return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}
}
// An EntityList contains one or more Entities. // An EntityList contains one or more Entities.
type EntityList []*Entity type EntityList []*Entity
@ -197,6 +272,10 @@ func readEntity(packets *packet.Reader) (*Entity, os.Error) {
} }
} }
if !e.PrimaryKey.PubKeyAlgo.CanSign() {
return nil, error.StructuralError("primary key cannot be used for signatures")
}
var current *Identity var current *Identity
EachPacket: EachPacket:
for { for {
@ -227,7 +306,7 @@ EachPacket:
return nil, error.StructuralError("user ID packet not followed by self-signature") return nil, error.StructuralError("user ID packet not followed by self-signature")
} }
if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId { if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil { if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) return nil, error.StructuralError("user ID self-signature invalid: " + err.String())
} }
@ -297,3 +376,170 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
e.Subkeys = append(e.Subkeys, subKey) e.Subkeys = append(e.Subkeys, subKey)
return nil return nil
} }
const defaultRSAKeyBits = 2048
// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
// single identity composed of the given full name, comment and email, any of
// which may be empty but must not contain any of "()<>\x00".
func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email string) (*Entity, os.Error) {
uid := packet.NewUserId(name, comment, email)
if uid == nil {
return nil, error.InvalidArgumentError("user id field contained invalid characters")
}
signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil {
return nil, err
}
encryptingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil {
return nil, err
}
t := uint32(currentTimeSecs)
e := &Entity{
PrimaryKey: packet.NewRSAPublicKey(t, &signingPriv.PublicKey, false /* not a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(t, signingPriv, false /* not a subkey */ ),
Identities: make(map[string]*Identity),
}
isPrimaryId := true
e.Identities[uid.Id] = &Identity{
Name: uid.Name,
UserId: uid,
SelfSignature: &packet.Signature{
CreationTime: t,
SigType: packet.SigTypePositiveCert,
PubKeyAlgo: packet.PubKeyAlgoRSA,
Hash: crypto.SHA256,
IsPrimaryId: &isPrimaryId,
FlagsValid: true,
FlagSign: true,
FlagCertify: true,
IssuerKeyId: &e.PrimaryKey.KeyId,
},
}
e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(t, &encryptingPriv.PublicKey, true /* is a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(t, encryptingPriv, true /* is a subkey */ ),
Sig: &packet.Signature{
CreationTime: t,
SigType: packet.SigTypeSubkeyBinding,
PubKeyAlgo: packet.PubKeyAlgoRSA,
Hash: crypto.SHA256,
FlagsValid: true,
FlagEncryptStorage: true,
FlagEncryptCommunications: true,
IssuerKeyId: &e.PrimaryKey.KeyId,
},
}
return e, nil
}
// SerializePrivate serializes an Entity, including private key material, to
// the given Writer. For now, it must only be used on an Entity returned from
// NewEntity.
func (e *Entity) SerializePrivate(w io.Writer) (err os.Error) {
err = e.PrivateKey.Serialize(w)
if err != nil {
return
}
for _, ident := range e.Identities {
err = ident.UserId.Serialize(w)
if err != nil {
return
}
err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
if err != nil {
return
}
err = ident.SelfSignature.Serialize(w)
if err != nil {
return
}
}
for _, subkey := range e.Subkeys {
err = subkey.PrivateKey.Serialize(w)
if err != nil {
return
}
err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey)
if err != nil {
return
}
err = subkey.Sig.Serialize(w)
if err != nil {
return
}
}
return nil
}
// Serialize writes the public part of the given Entity to w. (No private
// key material will be output).
func (e *Entity) Serialize(w io.Writer) os.Error {
err := e.PrimaryKey.Serialize(w)
if err != nil {
return err
}
for _, ident := range e.Identities {
err = ident.UserId.Serialize(w)
if err != nil {
return err
}
err = ident.SelfSignature.Serialize(w)
if err != nil {
return err
}
for _, sig := range ident.Signatures {
err = sig.Serialize(w)
if err != nil {
return err
}
}
}
for _, subkey := range e.Subkeys {
err = subkey.PublicKey.Serialize(w)
if err != nil {
return err
}
err = subkey.Sig.Serialize(w)
if err != nil {
return err
}
}
return nil
}
// SignIdentity adds a signature to e, from signer, attesting that identity is
// associated with e. The provided identity must already be an element of
// e.Identities and the private key of signer must have been decrypted if
// necessary.
func (e *Entity) SignIdentity(identity string, signer *Entity) os.Error {
if signer.PrivateKey == nil {
return error.InvalidArgumentError("signing Entity must have a private key")
}
if signer.PrivateKey.Encrypted {
return error.InvalidArgumentError("signing Entity's private key must be decrypted")
}
ident, ok := e.Identities[identity]
if !ok {
return error.InvalidArgumentError("given identity string not found in Entity")
}
sig := &packet.Signature{
SigType: packet.SigTypeGenericCert,
PubKeyAlgo: signer.PrivateKey.PubKeyAlgo,
Hash: crypto.SHA256,
CreationTime: uint32(time.Seconds()),
IssuerKeyId: &signer.PrivateKey.KeyId,
}
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil {
return err
}
ident.Signatures = append(ident.Signatures, sig)
return nil
}

View File

@ -5,6 +5,8 @@
package packet package packet
import ( import (
"big"
"crypto/openpgp/elgamal"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@ -14,14 +16,17 @@ import (
"strconv" "strconv"
) )
const encryptedKeyVersion = 3
// EncryptedKey represents a public-key encrypted session key. See RFC 4880, // EncryptedKey represents a public-key encrypted session key. See RFC 4880,
// section 5.1. // section 5.1.
type EncryptedKey struct { type EncryptedKey struct {
KeyId uint64 KeyId uint64
Algo PublicKeyAlgorithm Algo PublicKeyAlgorithm
Encrypted []byte
CipherFunc CipherFunction // only valid after a successful Decrypt CipherFunc CipherFunction // only valid after a successful Decrypt
Key []byte // only valid after a successful Decrypt Key []byte // only valid after a successful Decrypt
encryptedMPI1, encryptedMPI2 []byte
} }
func (e *EncryptedKey) parse(r io.Reader) (err os.Error) { func (e *EncryptedKey) parse(r io.Reader) (err os.Error) {
@ -30,37 +35,134 @@ func (e *EncryptedKey) parse(r io.Reader) (err os.Error) {
if err != nil { if err != nil {
return return
} }
if buf[0] != 3 { if buf[0] != encryptedKeyVersion {
return error.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0]))) return error.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
} }
e.KeyId = binary.BigEndian.Uint64(buf[1:9]) e.KeyId = binary.BigEndian.Uint64(buf[1:9])
e.Algo = PublicKeyAlgorithm(buf[9]) e.Algo = PublicKeyAlgorithm(buf[9])
if e.Algo == PubKeyAlgoRSA || e.Algo == PubKeyAlgoRSAEncryptOnly { switch e.Algo {
e.Encrypted, _, err = readMPI(r) case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
e.encryptedMPI1, _, err = readMPI(r)
case PubKeyAlgoElGamal:
e.encryptedMPI1, _, err = readMPI(r)
if err != nil {
return
}
e.encryptedMPI2, _, err = readMPI(r)
} }
_, err = consumeAll(r) _, err = consumeAll(r)
return return
} }
// DecryptRSA decrypts an RSA encrypted session key with the given private key. func checksumKeyMaterial(key []byte) uint16 {
func (e *EncryptedKey) DecryptRSA(priv *rsa.PrivateKey) (err os.Error) { var checksum uint16
if e.Algo != PubKeyAlgoRSA && e.Algo != PubKeyAlgoRSAEncryptOnly { for _, v := range key {
return error.InvalidArgumentError("EncryptedKey not RSA encrypted") checksum += uint16(v)
} }
b, err := rsa.DecryptPKCS1v15(rand.Reader, priv, e.Encrypted) return checksum
}
// Decrypt decrypts an encrypted session key with the given private key. The
// private key must have been decrypted first.
func (e *EncryptedKey) Decrypt(priv *PrivateKey) os.Error {
var err os.Error
var b []byte
// TODO(agl): use session key decryption routines here to avoid
// padding oracle attacks.
switch priv.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
b, err = rsa.DecryptPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), e.encryptedMPI1)
case PubKeyAlgoElGamal:
c1 := new(big.Int).SetBytes(e.encryptedMPI1)
c2 := new(big.Int).SetBytes(e.encryptedMPI2)
b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
default:
err = error.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
}
if err != nil { if err != nil {
return return err
} }
e.CipherFunc = CipherFunction(b[0]) e.CipherFunc = CipherFunction(b[0])
e.Key = b[1 : len(b)-2] e.Key = b[1 : len(b)-2]
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1]) expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
var checksum uint16 checksum := checksumKeyMaterial(e.Key)
for _, v := range e.Key {
checksum += uint16(v)
}
if checksum != expectedChecksum { if checksum != expectedChecksum {
return error.StructuralError("EncryptedKey checksum incorrect") return error.StructuralError("EncryptedKey checksum incorrect")
} }
return return nil
}
// SerializeEncryptedKey serializes an encrypted key packet to w that contains
// key, encrypted to pub.
func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFunc CipherFunction, key []byte) os.Error {
var buf [10]byte
buf[0] = encryptedKeyVersion
binary.BigEndian.PutUint64(buf[1:9], pub.KeyId)
buf[9] = byte(pub.PubKeyAlgo)
keyBlock := make([]byte, 1 /* cipher type */ +len(key)+2 /* checksum */ )
keyBlock[0] = byte(cipherFunc)
copy(keyBlock[1:], key)
checksum := checksumKeyMaterial(key)
keyBlock[1+len(key)] = byte(checksum >> 8)
keyBlock[1+len(key)+1] = byte(checksum)
switch pub.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
return serializeEncryptedKeyRSA(w, rand, buf, pub.PublicKey.(*rsa.PublicKey), keyBlock)
case PubKeyAlgoElGamal:
return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
return error.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
}
return error.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
}
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) os.Error {
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
if err != nil {
return error.InvalidArgumentError("RSA encryption failed: " + err.String())
}
packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil {
return err
}
_, err = w.Write(header[:])
if err != nil {
return err
}
return writeMPI(w, 8*uint16(len(cipherText)), cipherText)
}
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) os.Error {
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
if err != nil {
return error.InvalidArgumentError("ElGamal encryption failed: " + err.String())
}
packetLen := 10 /* header length */
packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8
packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil {
return err
}
_, err = w.Write(header[:])
if err != nil {
return err
}
err = writeBig(w, c1)
if err != nil {
return err
}
return writeBig(w, c2)
} }

View File

@ -6,6 +6,8 @@ package packet
import ( import (
"big" "big"
"bytes"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"fmt" "fmt"
"testing" "testing"
@ -19,7 +21,27 @@ func bigFromBase10(s string) *big.Int {
return b return b
} }
func TestEncryptedKey(t *testing.T) { var encryptedKeyPub = rsa.PublicKey{
E: 65537,
N: bigFromBase10("115804063926007623305902631768113868327816898845124614648849934718568541074358183759250136204762053879858102352159854352727097033322663029387610959884180306668628526686121021235757016368038585212410610742029286439607686208110250133174279811431933746643015923132833417396844716207301518956640020862630546868823"),
}
var encryptedKeyRSAPriv = &rsa.PrivateKey{
PublicKey: encryptedKeyPub,
D: bigFromBase10("32355588668219869544751561565313228297765464314098552250409557267371233892496951383426602439009993875125222579159850054973310859166139474359774543943714622292329487391199285040721944491839695981199720170366763547754915493640685849961780092241140181198779299712578774460837139360803883139311171713302987058393"),
}
var encryptedKeyPriv = &PrivateKey{
PublicKey: PublicKey{
PubKeyAlgo: PubKeyAlgoRSA,
},
PrivateKey: encryptedKeyRSAPriv,
}
func TestDecryptingEncryptedKey(t *testing.T) {
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8"
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
p, err := Read(readerFromHex(encryptedKeyHex)) p, err := Read(readerFromHex(encryptedKeyHex))
if err != nil { if err != nil {
t.Errorf("error from Read: %s", err) t.Errorf("error from Read: %s", err)
@ -36,19 +58,9 @@ func TestEncryptedKey(t *testing.T) {
return return
} }
pub := rsa.PublicKey{ err = ek.Decrypt(encryptedKeyPriv)
E: 65537,
N: bigFromBase10("115804063926007623305902631768113868327816898845124614648849934718568541074358183759250136204762053879858102352159854352727097033322663029387610959884180306668628526686121021235757016368038585212410610742029286439607686208110250133174279811431933746643015923132833417396844716207301518956640020862630546868823"),
}
priv := &rsa.PrivateKey{
PublicKey: pub,
D: bigFromBase10("32355588668219869544751561565313228297765464314098552250409557267371233892496951383426602439009993875125222579159850054973310859166139474359774543943714622292329487391199285040721944491839695981199720170366763547754915493640685849961780092241140181198779299712578774460837139360803883139311171713302987058393"),
}
err = ek.DecryptRSA(priv)
if err != nil { if err != nil {
t.Errorf("error from DecryptRSA: %s", err) t.Errorf("error from Decrypt: %s", err)
return return
} }
@ -63,5 +75,52 @@ func TestEncryptedKey(t *testing.T) {
} }
} }
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8" func TestEncryptingEncryptedKey(t *testing.T) {
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b" key := []byte{1, 2, 3, 4}
const expectedKeyHex = "01020304"
const keyId = 42
pub := &PublicKey{
PublicKey: &encryptedKeyPub,
KeyId: keyId,
PubKeyAlgo: PubKeyAlgoRSAEncryptOnly,
}
buf := new(bytes.Buffer)
err := SerializeEncryptedKey(buf, rand.Reader, pub, CipherAES128, key)
if err != nil {
t.Errorf("error writing encrypted key packet: %s", err)
}
p, err := Read(buf)
if err != nil {
t.Errorf("error from Read: %s", err)
return
}
ek, ok := p.(*EncryptedKey)
if !ok {
t.Errorf("didn't parse an EncryptedKey, got %#v", p)
return
}
if ek.KeyId != keyId || ek.Algo != PubKeyAlgoRSAEncryptOnly {
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
return
}
err = ek.Decrypt(encryptedKeyPriv)
if err != nil {
t.Errorf("error from Decrypt: %s", err)
return
}
if ek.CipherFunc != CipherAES128 {
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
return
}
keyHex := fmt.Sprintf("%x", ek.Key)
if keyHex != expectedKeyHex {
t.Errorf("bad key, got %s want %x", keyHex, expectedKeyHex)
}
}

View File

@ -51,3 +51,40 @@ func (l *LiteralData) parse(r io.Reader) (err os.Error) {
l.Body = r l.Body = r
return return
} }
// SerializeLiteral serializes a literal data packet to w and returns a
// WriteCloser to which the data itself can be written and which MUST be closed
// on completion. The fileName is truncated to 255 bytes.
func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err os.Error) {
var buf [4]byte
buf[0] = 't'
if isBinary {
buf[0] = 'b'
}
if len(fileName) > 255 {
fileName = fileName[:255]
}
buf[1] = byte(len(fileName))
inner, err := serializeStreamHeader(w, packetTypeLiteralData)
if err != nil {
return
}
_, err = inner.Write(buf[:2])
if err != nil {
return
}
_, err = inner.Write([]byte(fileName))
if err != nil {
return
}
binary.BigEndian.PutUint32(buf[:], time)
_, err = inner.Write(buf[:])
if err != nil {
return
}
plaintext = inner
return
}

View File

@ -24,6 +24,8 @@ type OnePassSignature struct {
IsLast bool IsLast bool
} }
const onePassSignatureVersion = 3
func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) { func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
var buf [13]byte var buf [13]byte
@ -31,7 +33,7 @@ func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
if err != nil { if err != nil {
return return
} }
if buf[0] != 3 { if buf[0] != onePassSignatureVersion {
err = error.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0]))) err = error.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
} }
@ -47,3 +49,26 @@ func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
ops.IsLast = buf[12] != 0 ops.IsLast = buf[12] != 0
return return
} }
// Serialize marshals the given OnePassSignature to w.
func (ops *OnePassSignature) Serialize(w io.Writer) os.Error {
var buf [13]byte
buf[0] = onePassSignatureVersion
buf[1] = uint8(ops.SigType)
var ok bool
buf[2], ok = s2k.HashToHashId(ops.Hash)
if !ok {
return error.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
}
buf[3] = uint8(ops.PubKeyAlgo)
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
if ops.IsLast {
buf[12] = 1
}
if err := serializeHeader(w, packetTypeOnePassSignature, len(buf)); err != nil {
return err
}
_, err := w.Write(buf[:])
return err
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package packet implements parsing and serialisation of OpenPGP packets, as // Package packet implements parsing and serialization of OpenPGP packets, as
// specified in RFC 4880. // specified in RFC 4880.
package packet package packet
@ -92,6 +92,46 @@ func (r *partialLengthReader) Read(p []byte) (n int, err os.Error) {
return return
} }
// partialLengthWriter writes a stream of data using OpenPGP partial lengths.
// See RFC 4880, section 4.2.2.4.
type partialLengthWriter struct {
w io.WriteCloser
lengthByte [1]byte
}
func (w *partialLengthWriter) Write(p []byte) (n int, err os.Error) {
for len(p) > 0 {
for power := uint(14); power < 32; power-- {
l := 1 << power
if len(p) >= l {
w.lengthByte[0] = 224 + uint8(power)
_, err = w.w.Write(w.lengthByte[:])
if err != nil {
return
}
var m int
m, err = w.w.Write(p[:l])
n += m
if err != nil {
return
}
p = p[l:]
break
}
}
}
return
}
func (w *partialLengthWriter) Close() os.Error {
w.lengthByte[0] = 0
_, err := w.w.Write(w.lengthByte[:])
if err != nil {
return err
}
return w.w.Close()
}
// A spanReader is an io.LimitReader, but it returns ErrUnexpectedEOF if the // A spanReader is an io.LimitReader, but it returns ErrUnexpectedEOF if the
// underlying Reader returns EOF before the limit has been reached. // underlying Reader returns EOF before the limit has been reached.
type spanReader struct { type spanReader struct {
@ -195,6 +235,20 @@ func serializeHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
return return
} }
// serializeStreamHeader writes an OpenPGP packet header to w where the
// length of the packet is unknown. It returns a io.WriteCloser which can be
// used to write the contents of the packet. See RFC 4880, section 4.2.
func serializeStreamHeader(w io.WriteCloser, ptype packetType) (out io.WriteCloser, err os.Error) {
var buf [1]byte
buf[0] = 0x80 | 0x40 | byte(ptype)
_, err = w.Write(buf[:])
if err != nil {
return
}
out = &partialLengthWriter{w: w}
return
}
// Packet represents an OpenPGP packet. Users are expected to try casting // Packet represents an OpenPGP packet. Users are expected to try casting
// instances of this interface to specific packet types. // instances of this interface to specific packet types.
type Packet interface { type Packet interface {
@ -301,12 +355,12 @@ type SignatureType uint8
const ( const (
SigTypeBinary SignatureType = 0 SigTypeBinary SignatureType = 0
SigTypeText = 1 SigTypeText = 1
SigTypeGenericCert = 0x10 SigTypeGenericCert = 0x10
SigTypePersonaCert = 0x11 SigTypePersonaCert = 0x11
SigTypeCasualCert = 0x12 SigTypeCasualCert = 0x12
SigTypePositiveCert = 0x13 SigTypePositiveCert = 0x13
SigTypeSubkeyBinding = 0x18 SigTypeSubkeyBinding = 0x18
) )
// PublicKeyAlgorithm represents the different public key system specified for // PublicKeyAlgorithm represents the different public key system specified for
@ -318,23 +372,43 @@ const (
PubKeyAlgoRSA PublicKeyAlgorithm = 1 PubKeyAlgoRSA PublicKeyAlgorithm = 1
PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2 PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2
PubKeyAlgoRSASignOnly PublicKeyAlgorithm = 3 PubKeyAlgoRSASignOnly PublicKeyAlgorithm = 3
PubKeyAlgoElgamal PublicKeyAlgorithm = 16 PubKeyAlgoElGamal PublicKeyAlgorithm = 16
PubKeyAlgoDSA PublicKeyAlgorithm = 17 PubKeyAlgoDSA PublicKeyAlgorithm = 17
) )
// CanEncrypt returns true if it's possible to encrypt a message to a public
// key of the given type.
func (pka PublicKeyAlgorithm) CanEncrypt() bool {
switch pka {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal:
return true
}
return false
}
// CanSign returns true if it's possible for a public key of the given type to
// sign a message.
func (pka PublicKeyAlgorithm) CanSign() bool {
switch pka {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
return true
}
return false
}
// CipherFunction represents the different block ciphers specified for OpenPGP. See // CipherFunction represents the different block ciphers specified for OpenPGP. See
// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-13 // http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-13
type CipherFunction uint8 type CipherFunction uint8
const ( const (
CipherCAST5 = 3 CipherCAST5 CipherFunction = 3
CipherAES128 = 7 CipherAES128 CipherFunction = 7
CipherAES192 = 8 CipherAES192 CipherFunction = 8
CipherAES256 = 9 CipherAES256 CipherFunction = 9
) )
// keySize returns the key size, in bytes, of cipher. // KeySize returns the key size, in bytes, of cipher.
func (cipher CipherFunction) keySize() int { func (cipher CipherFunction) KeySize() int {
switch cipher { switch cipher {
case CipherCAST5: case CipherCAST5:
return cast5.KeySize return cast5.KeySize
@ -386,6 +460,14 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
return return
} }
// mpiLength returns the length of the given *big.Int when serialized as an
// MPI.
func mpiLength(n *big.Int) (mpiLengthInBytes int) {
mpiLengthInBytes = 2 /* MPI length */
mpiLengthInBytes += (n.BitLen() + 7) / 8
return
}
// writeMPI serializes a big integer to w. // writeMPI serializes a big integer to w.
func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) { func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
_, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)}) _, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})

View File

@ -210,3 +210,47 @@ func TestSerializeHeader(t *testing.T) {
} }
} }
} }
func TestPartialLengths(t *testing.T) {
buf := bytes.NewBuffer(nil)
w := new(partialLengthWriter)
w.w = noOpCloser{buf}
const maxChunkSize = 64
var b [maxChunkSize]byte
var n uint8
for l := 1; l <= maxChunkSize; l++ {
for i := 0; i < l; i++ {
b[i] = n
n++
}
m, err := w.Write(b[:l])
if m != l {
t.Errorf("short write got: %d want: %d", m, l)
}
if err != nil {
t.Errorf("error from write: %s", err)
}
}
w.Close()
want := (maxChunkSize * (maxChunkSize + 1)) / 2
copyBuf := bytes.NewBuffer(nil)
r := &partialLengthReader{buf, 0, true}
m, err := io.Copy(copyBuf, r)
if m != int64(want) {
t.Errorf("short copy got: %d want: %d", m, want)
}
if err != nil {
t.Errorf("error from copy: %s", err)
}
copyBytes := copyBuf.Bytes()
for i := 0; i < want; i++ {
if copyBytes[i] != uint8(i) {
t.Errorf("bad pattern in copy at %d", i)
break
}
}
}

View File

@ -9,6 +9,7 @@ import (
"bytes" "bytes"
"crypto/cipher" "crypto/cipher"
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/elgamal"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rsa" "crypto/rsa"
@ -32,6 +33,13 @@ type PrivateKey struct {
iv []byte iv []byte
} }
func NewRSAPrivateKey(currentTimeSecs uint32, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewRSAPublicKey(currentTimeSecs, &priv.PublicKey, isSubkey)
pk.PrivateKey = priv
return pk
}
func (pk *PrivateKey) parse(r io.Reader) (err os.Error) { func (pk *PrivateKey) parse(r io.Reader) (err os.Error) {
err = (&pk.PublicKey).parse(r) err = (&pk.PublicKey).parse(r)
if err != nil { if err != nil {
@ -91,13 +99,90 @@ func (pk *PrivateKey) parse(r io.Reader) (err os.Error) {
return return
} }
func mod64kHash(d []byte) uint16 {
h := uint16(0)
for i := 0; i < len(d); i += 2 {
v := uint16(d[i]) << 8
if i+1 < len(d) {
v += uint16(d[i+1])
}
h += v
}
return h
}
func (pk *PrivateKey) Serialize(w io.Writer) (err os.Error) {
// TODO(agl): support encrypted private keys
buf := bytes.NewBuffer(nil)
err = pk.PublicKey.serializeWithoutHeaders(buf)
if err != nil {
return
}
buf.WriteByte(0 /* no encryption */ )
privateKeyBuf := bytes.NewBuffer(nil)
switch priv := pk.PrivateKey.(type) {
case *rsa.PrivateKey:
err = serializeRSAPrivateKey(privateKeyBuf, priv)
default:
err = error.InvalidArgumentError("non-RSA private key")
}
if err != nil {
return
}
ptype := packetTypePrivateKey
contents := buf.Bytes()
privateKeyBytes := privateKeyBuf.Bytes()
if pk.IsSubkey {
ptype = packetTypePrivateSubkey
}
err = serializeHeader(w, ptype, len(contents)+len(privateKeyBytes)+2)
if err != nil {
return
}
_, err = w.Write(contents)
if err != nil {
return
}
_, err = w.Write(privateKeyBytes)
if err != nil {
return
}
checksum := mod64kHash(privateKeyBytes)
var checksumBytes [2]byte
checksumBytes[0] = byte(checksum >> 8)
checksumBytes[1] = byte(checksum)
_, err = w.Write(checksumBytes[:])
return
}
func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) os.Error {
err := writeBig(w, priv.D)
if err != nil {
return err
}
err = writeBig(w, priv.Primes[1])
if err != nil {
return err
}
err = writeBig(w, priv.Primes[0])
if err != nil {
return err
}
return writeBig(w, priv.Precomputed.Qinv)
}
// Decrypt decrypts an encrypted private key using a passphrase. // Decrypt decrypts an encrypted private key using a passphrase.
func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error { func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error {
if !pk.Encrypted { if !pk.Encrypted {
return nil return nil
} }
key := make([]byte, pk.cipher.keySize()) key := make([]byte, pk.cipher.KeySize())
pk.s2k(key, passphrase) pk.s2k(key, passphrase)
block := pk.cipher.new(key) block := pk.cipher.new(key)
cfb := cipher.NewCFBDecrypter(block, pk.iv) cfb := cipher.NewCFBDecrypter(block, pk.iv)
@ -140,6 +225,8 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
return pk.parseRSAPrivateKey(data) return pk.parseRSAPrivateKey(data)
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
return pk.parseDSAPrivateKey(data) return pk.parseDSAPrivateKey(data)
case PubKeyAlgoElGamal:
return pk.parseElGamalPrivateKey(data)
} }
panic("impossible") panic("impossible")
} }
@ -193,3 +280,22 @@ func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err os.Error) {
return nil return nil
} }
func (pk *PrivateKey) parseElGamalPrivateKey(data []byte) (err os.Error) {
pub := pk.PublicKey.PublicKey.(*elgamal.PublicKey)
priv := new(elgamal.PrivateKey)
priv.PublicKey = *pub
buf := bytes.NewBuffer(data)
x, _, err := readMPI(buf)
if err != nil {
return
}
priv.X = new(big.Int).SetBytes(x)
pk.PrivateKey = priv
pk.Encrypted = false
pk.encryptedData = nil
return nil
}

View File

@ -8,30 +8,50 @@ import (
"testing" "testing"
) )
var privateKeyTests = []struct {
privateKeyHex string
creationTime uint32
}{
{
privKeyRSAHex,
0x4cc349a8,
},
{
privKeyElGamalHex,
0x4df9ee1a,
},
}
func TestPrivateKeyRead(t *testing.T) { func TestPrivateKeyRead(t *testing.T) {
packet, err := Read(readerFromHex(privKeyHex)) for i, test := range privateKeyTests {
if err != nil { packet, err := Read(readerFromHex(test.privateKeyHex))
t.Error(err) if err != nil {
return t.Errorf("#%d: failed to parse: %s", i, err)
} continue
}
privKey := packet.(*PrivateKey) privKey := packet.(*PrivateKey)
if !privKey.Encrypted { if !privKey.Encrypted {
t.Error("private key isn't encrypted") t.Errorf("#%d: private key isn't encrypted", i)
return continue
} }
err = privKey.Decrypt([]byte("testing")) err = privKey.Decrypt([]byte("testing"))
if err != nil { if err != nil {
t.Error(err) t.Errorf("#%d: failed to decrypt: %s", i, err)
return continue
} }
if privKey.CreationTime != 0x4cc349a8 || privKey.Encrypted { if privKey.CreationTime != test.creationTime || privKey.Encrypted {
t.Errorf("failed to parse, got: %#v", privKey) t.Errorf("#%d: bad result, got: %#v", i, privKey)
}
} }
} }
// Generated with `gpg --export-secret-keys "Test Key 2"` // Generated with `gpg --export-secret-keys "Test Key 2"`
const privKeyHex = "9501fe044cc349a8010400b70ca0010e98c090008d45d1ee8f9113bd5861fd57b88bacb7c68658747663f1e1a3b5a98f32fda6472373c024b97359cd2efc88ff60f77751adfbf6af5e615e6a1408cfad8bf0cea30b0d5f53aa27ad59089ba9b15b7ebc2777a25d7b436144027e3bcd203909f147d0e332b240cf63d3395f5dfe0df0a6c04e8655af7eacdf0011010001fe0303024a252e7d475fd445607de39a265472aa74a9320ba2dac395faa687e9e0336aeb7e9a7397e511b5afd9dc84557c80ac0f3d4d7bfec5ae16f20d41c8c84a04552a33870b930420e230e179564f6d19bb153145e76c33ae993886c388832b0fa042ddda7f133924f3854481533e0ede31d51278c0519b29abc3bf53da673e13e3e1214b52413d179d7f66deee35cac8eacb060f78379d70ef4af8607e68131ff529439668fc39c9ce6dfef8a5ac234d234802cbfb749a26107db26406213ae5c06d4673253a3cbee1fcbae58d6ab77e38d6e2c0e7c6317c48e054edadb5a40d0d48acb44643d998139a8a66bb820be1f3f80185bc777d14b5954b60effe2448a036d565c6bc0b915fcea518acdd20ab07bc1529f561c58cd044f723109b93f6fd99f876ff891d64306b5d08f48bab59f38695e9109c4dec34013ba3153488ce070268381ba923ee1eb77125b36afcb4347ec3478c8f2735b06ef17351d872e577fa95d0c397c88c71b59629a36aec" const privKeyRSAHex = "9501fe044cc349a8010400b70ca0010e98c090008d45d1ee8f9113bd5861fd57b88bacb7c68658747663f1e1a3b5a98f32fda6472373c024b97359cd2efc88ff60f77751adfbf6af5e615e6a1408cfad8bf0cea30b0d5f53aa27ad59089ba9b15b7ebc2777a25d7b436144027e3bcd203909f147d0e332b240cf63d3395f5dfe0df0a6c04e8655af7eacdf0011010001fe0303024a252e7d475fd445607de39a265472aa74a9320ba2dac395faa687e9e0336aeb7e9a7397e511b5afd9dc84557c80ac0f3d4d7bfec5ae16f20d41c8c84a04552a33870b930420e230e179564f6d19bb153145e76c33ae993886c388832b0fa042ddda7f133924f3854481533e0ede31d51278c0519b29abc3bf53da673e13e3e1214b52413d179d7f66deee35cac8eacb060f78379d70ef4af8607e68131ff529439668fc39c9ce6dfef8a5ac234d234802cbfb749a26107db26406213ae5c06d4673253a3cbee1fcbae58d6ab77e38d6e2c0e7c6317c48e054edadb5a40d0d48acb44643d998139a8a66bb820be1f3f80185bc777d14b5954b60effe2448a036d565c6bc0b915fcea518acdd20ab07bc1529f561c58cd044f723109b93f6fd99f876ff891d64306b5d08f48bab59f38695e9109c4dec34013ba3153488ce070268381ba923ee1eb77125b36afcb4347ec3478c8f2735b06ef17351d872e577fa95d0c397c88c71b59629a36aec"
// Generated by `gpg --export-secret-keys` followed by a manual extraction of
// the ElGamal subkey from the packets.
const privKeyElGamalHex = "9d0157044df9ee1a100400eb8e136a58ec39b582629cdadf830bc64e0a94ed8103ca8bb247b27b11b46d1d25297ef4bcc3071785ba0c0bedfe89eabc5287fcc0edf81ab5896c1c8e4b20d27d79813c7aede75320b33eaeeaa586edc00fd1036c10133e6ba0ff277245d0d59d04b2b3421b7244aca5f4a8d870c6f1c1fbff9e1c26699a860b9504f35ca1d700030503fd1ededd3b840795be6d9ccbe3c51ee42e2f39233c432b831ddd9c4e72b7025a819317e47bf94f9ee316d7273b05d5fcf2999c3a681f519b1234bbfa6d359b4752bd9c3f77d6b6456cde152464763414ca130f4e91d91041432f90620fec0e6d6b5116076c2985d5aeaae13be492b9b329efcaf7ee25120159a0a30cd976b42d7afe030302dae7eb80db744d4960c4df930d57e87fe81412eaace9f900e6c839817a614ddb75ba6603b9417c33ea7b6c93967dfa2bcff3fa3c74a5ce2c962db65b03aece14c96cbd0038fc"

View File

@ -7,6 +7,7 @@ package packet
import ( import (
"big" "big"
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/elgamal"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
@ -30,6 +31,28 @@ type PublicKey struct {
n, e, p, q, g, y parsedMPI n, e, p, q, g, y parsedMPI
} }
func fromBig(n *big.Int) parsedMPI {
return parsedMPI{
bytes: n.Bytes(),
bitLength: uint16(n.BitLen()),
}
}
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewRSAPublicKey(creationTimeSecs uint32, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
pk := &PublicKey{
CreationTime: creationTimeSecs,
PubKeyAlgo: PubKeyAlgoRSA,
PublicKey: pub,
IsSubkey: isSubkey,
n: fromBig(pub.N),
e: fromBig(big.NewInt(int64(pub.E))),
}
pk.setFingerPrintAndKeyId()
return pk
}
func (pk *PublicKey) parse(r io.Reader) (err os.Error) { func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
// RFC 4880, section 5.5.2 // RFC 4880, section 5.5.2
var buf [6]byte var buf [6]byte
@ -47,6 +70,8 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
err = pk.parseRSA(r) err = pk.parseRSA(r)
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
err = pk.parseDSA(r) err = pk.parseDSA(r)
case PubKeyAlgoElGamal:
err = pk.parseElGamal(r)
default: default:
err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
} }
@ -54,14 +79,17 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
return return
} }
pk.setFingerPrintAndKeyId()
return
}
func (pk *PublicKey) setFingerPrintAndKeyId() {
// RFC 4880, section 12.2 // RFC 4880, section 12.2
fingerPrint := sha1.New() fingerPrint := sha1.New()
pk.SerializeSignaturePrefix(fingerPrint) pk.SerializeSignaturePrefix(fingerPrint)
pk.Serialize(fingerPrint) pk.serializeWithoutHeaders(fingerPrint)
copy(pk.Fingerprint[:], fingerPrint.Sum()) copy(pk.Fingerprint[:], fingerPrint.Sum())
pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20]) pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20])
return
} }
// parseRSA parses RSA public key material from the given Reader. See RFC 4880, // parseRSA parses RSA public key material from the given Reader. See RFC 4880,
@ -92,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err os.Error) {
return return
} }
// parseRSA parses DSA public key material from the given Reader. See RFC 4880, // parseDSA parses DSA public key material from the given Reader. See RFC 4880,
// section 5.5.2. // section 5.5.2.
func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) { func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) {
pk.p.bytes, pk.p.bitLength, err = readMPI(r) pk.p.bytes, pk.p.bitLength, err = readMPI(r)
@ -121,6 +149,30 @@ func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) {
return return
} }
// parseElGamal parses ElGamal public key material from the given Reader. See
// RFC 4880, section 5.5.2.
func (pk *PublicKey) parseElGamal(r io.Reader) (err os.Error) {
pk.p.bytes, pk.p.bitLength, err = readMPI(r)
if err != nil {
return
}
pk.g.bytes, pk.g.bitLength, err = readMPI(r)
if err != nil {
return
}
pk.y.bytes, pk.y.bitLength, err = readMPI(r)
if err != nil {
return
}
elgamal := new(elgamal.PublicKey)
elgamal.P = new(big.Int).SetBytes(pk.p.bytes)
elgamal.G = new(big.Int).SetBytes(pk.g.bytes)
elgamal.Y = new(big.Int).SetBytes(pk.y.bytes)
pk.PublicKey = elgamal
return
}
// SerializeSignaturePrefix writes the prefix for this public key to the given Writer. // SerializeSignaturePrefix writes the prefix for this public key to the given Writer.
// The prefix is used when calculating a signature over this public key. See // The prefix is used when calculating a signature over this public key. See
// RFC 4880, section 5.2.4. // RFC 4880, section 5.2.4.
@ -135,6 +187,10 @@ func (pk *PublicKey) SerializeSignaturePrefix(h hash.Hash) {
pLength += 2 + uint16(len(pk.q.bytes)) pLength += 2 + uint16(len(pk.q.bytes))
pLength += 2 + uint16(len(pk.g.bytes)) pLength += 2 + uint16(len(pk.g.bytes))
pLength += 2 + uint16(len(pk.y.bytes)) pLength += 2 + uint16(len(pk.y.bytes))
case PubKeyAlgoElGamal:
pLength += 2 + uint16(len(pk.p.bytes))
pLength += 2 + uint16(len(pk.g.bytes))
pLength += 2 + uint16(len(pk.y.bytes))
default: default:
panic("unknown public key algorithm") panic("unknown public key algorithm")
} }
@ -143,9 +199,40 @@ func (pk *PublicKey) SerializeSignaturePrefix(h hash.Hash) {
return return
} }
// Serialize marshals the PublicKey to w in the form of an OpenPGP public key
// packet, not including the packet header.
func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) { func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) {
length := 6 // 6 byte header
switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
length += 2 + len(pk.n.bytes)
length += 2 + len(pk.e.bytes)
case PubKeyAlgoDSA:
length += 2 + len(pk.p.bytes)
length += 2 + len(pk.q.bytes)
length += 2 + len(pk.g.bytes)
length += 2 + len(pk.y.bytes)
case PubKeyAlgoElGamal:
length += 2 + len(pk.p.bytes)
length += 2 + len(pk.g.bytes)
length += 2 + len(pk.y.bytes)
default:
panic("unknown public key algorithm")
}
packetType := packetTypePublicKey
if pk.IsSubkey {
packetType = packetTypePublicSubkey
}
err = serializeHeader(w, packetType, length)
if err != nil {
return
}
return pk.serializeWithoutHeaders(w)
}
// serializeWithoutHeaders marshals the PublicKey to w in the form of an
// OpenPGP public key packet, not including the packet header.
func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err os.Error) {
var buf [6]byte var buf [6]byte
buf[0] = 4 buf[0] = 4
buf[1] = byte(pk.CreationTime >> 24) buf[1] = byte(pk.CreationTime >> 24)
@ -164,13 +251,15 @@ func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) {
return writeMPIs(w, pk.n, pk.e) return writeMPIs(w, pk.n, pk.e)
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
return writeMPIs(w, pk.p, pk.q, pk.g, pk.y) return writeMPIs(w, pk.p, pk.q, pk.g, pk.y)
case PubKeyAlgoElGamal:
return writeMPIs(w, pk.p, pk.g, pk.y)
} }
return error.InvalidArgumentError("bad public-key algorithm") return error.InvalidArgumentError("bad public-key algorithm")
} }
// CanSign returns true iff this public key can generate signatures // CanSign returns true iff this public key can generate signatures
func (pk *PublicKey) CanSign() bool { func (pk *PublicKey) CanSign() bool {
return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElgamal return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal
} }
// VerifySignature returns nil iff sig is a valid signature, made by this // VerifySignature returns nil iff sig is a valid signature, made by this
@ -194,14 +283,14 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
switch pk.PubKeyAlgo { switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey) rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature) err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
if err != nil { if err != nil {
return error.SignatureError("RSA verification failure") return error.SignatureError("RSA verification failure")
} }
return nil return nil
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey) dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
if !dsa.Verify(dsaPublicKey, hashBytes, sig.DSASigR, sig.DSASigS) { if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
return error.SignatureError("DSA verification failure") return error.SignatureError("DSA verification failure")
} }
return nil return nil
@ -211,34 +300,43 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
panic("unreachable") panic("unreachable")
} }
// VerifyKeySignature returns nil iff sig is a valid signature, make by this // keySignatureHash returns a Hash of the message that needs to be signed for
// public key, of the public key in signed. // pk to assert a subkey relationship to signed.
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err os.Error) { func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err os.Error) {
h := sig.Hash.New() h = sig.Hash.New()
if h == nil { if h == nil {
return error.UnsupportedError("hash function") return nil, error.UnsupportedError("hash function")
} }
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
pk.SerializeSignaturePrefix(h) pk.SerializeSignaturePrefix(h)
pk.Serialize(h) pk.serializeWithoutHeaders(h)
signed.SerializeSignaturePrefix(h) signed.SerializeSignaturePrefix(h)
signed.Serialize(h) signed.serializeWithoutHeaders(h)
return
}
// VerifyKeySignature returns nil iff sig is a valid signature, made by this
// public key, of signed.
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err os.Error) {
h, err := keySignatureHash(pk, signed, sig)
if err != nil {
return err
}
return pk.VerifySignature(h, sig) return pk.VerifySignature(h, sig)
} }
// VerifyUserIdSignature returns nil iff sig is a valid signature, make by this // userIdSignatureHash returns a Hash of the message that needs to be signed
// public key, of the given user id. // to assert that pk is a valid key for id.
func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Error) { func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err os.Error) {
h := sig.Hash.New() h = sig.Hash.New()
if h == nil { if h == nil {
return error.UnsupportedError("hash function") return nil, error.UnsupportedError("hash function")
} }
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
pk.SerializeSignaturePrefix(h) pk.SerializeSignaturePrefix(h)
pk.Serialize(h) pk.serializeWithoutHeaders(h)
var buf [5]byte var buf [5]byte
buf[0] = 0xb4 buf[0] = 0xb4
@ -249,6 +347,16 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er
h.Write(buf[:]) h.Write(buf[:])
h.Write([]byte(id)) h.Write([]byte(id))
return
}
// VerifyUserIdSignature returns nil iff sig is a valid signature, made by this
// public key, of id.
func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Error) {
h, err := userIdSignatureHash(id, pk, sig)
if err != nil {
return err
}
return pk.VerifySignature(h, sig) return pk.VerifySignature(h, sig)
} }
@ -272,7 +380,7 @@ type parsedMPI struct {
bitLength uint16 bitLength uint16
} }
// writeMPIs is a utility function for serialising several big integers to the // writeMPIs is a utility function for serializing several big integers to the
// given Writer. // given Writer.
func writeMPIs(w io.Writer, mpis ...parsedMPI) (err os.Error) { func writeMPIs(w io.Writer, mpis ...parsedMPI) (err os.Error) {
for _, mpi := range mpis { for _, mpi := range mpis {

View File

@ -28,12 +28,12 @@ func TestPublicKeyRead(t *testing.T) {
packet, err := Read(readerFromHex(test.hexData)) packet, err := Read(readerFromHex(test.hexData))
if err != nil { if err != nil {
t.Errorf("#%d: Read error: %s", i, err) t.Errorf("#%d: Read error: %s", i, err)
return continue
} }
pk, ok := packet.(*PublicKey) pk, ok := packet.(*PublicKey)
if !ok { if !ok {
t.Errorf("#%d: failed to parse, got: %#v", i, packet) t.Errorf("#%d: failed to parse, got: %#v", i, packet)
return continue
} }
if pk.PubKeyAlgo != test.pubKeyAlgo { if pk.PubKeyAlgo != test.pubKeyAlgo {
t.Errorf("#%d: bad public key algorithm got:%x want:%x", i, pk.PubKeyAlgo, test.pubKeyAlgo) t.Errorf("#%d: bad public key algorithm got:%x want:%x", i, pk.PubKeyAlgo, test.pubKeyAlgo)
@ -57,6 +57,38 @@ func TestPublicKeyRead(t *testing.T) {
} }
} }
func TestPublicKeySerialize(t *testing.T) {
for i, test := range pubKeyTests {
packet, err := Read(readerFromHex(test.hexData))
if err != nil {
t.Errorf("#%d: Read error: %s", i, err)
continue
}
pk, ok := packet.(*PublicKey)
if !ok {
t.Errorf("#%d: failed to parse, got: %#v", i, packet)
continue
}
serializeBuf := bytes.NewBuffer(nil)
err = pk.Serialize(serializeBuf)
if err != nil {
t.Errorf("#%d: failed to serialize: %s", i, err)
continue
}
packet, err = Read(serializeBuf)
if err != nil {
t.Errorf("#%d: Read error (from serialized data): %s", i, err)
continue
}
pk, ok = packet.(*PublicKey)
if !ok {
t.Errorf("#%d: failed to parse serialized data, got: %#v", i, packet)
continue
}
}
}
const rsaFingerprintHex = "5fb74b1d03b1e3cb31bc2f8aa34d7e18c20c31bb" const rsaFingerprintHex = "5fb74b1d03b1e3cb31bc2f8aa34d7e18c20c31bb"
const rsaPkDataHex = "988d044d3c5c10010400b1d13382944bd5aba23a4312968b5095d14f947f600eb478e14a6fcb16b0e0cac764884909c020bc495cfcc39a935387c661507bdb236a0612fb582cac3af9b29cc2c8c70090616c41b662f4da4c1201e195472eb7f4ae1ccbcbf9940fe21d985e379a5563dde5b9a23d35f1cfaa5790da3b79db26f23695107bfaca8e7b5bcd0011010001" const rsaPkDataHex = "988d044d3c5c10010400b1d13382944bd5aba23a4312968b5095d14f947f600eb478e14a6fcb16b0e0cac764884909c020bc495cfcc39a935387c661507bdb236a0612fb582cac3af9b29cc2c8c70090616c41b662f4da4c1201e195472eb7f4ae1ccbcbf9940fe21d985e379a5563dde5b9a23d35f1cfaa5790da3b79db26f23695107bfaca8e7b5bcd0011010001"

View File

@ -5,7 +5,6 @@
package packet package packet
import ( import (
"big"
"crypto" "crypto"
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/error" "crypto/openpgp/error"
@ -32,8 +31,11 @@ type Signature struct {
HashTag [2]byte HashTag [2]byte
CreationTime uint32 // Unix epoch time CreationTime uint32 // Unix epoch time
RSASignature []byte RSASignature parsedMPI
DSASigR, DSASigS *big.Int DSASigR, DSASigS parsedMPI
// rawSubpackets contains the unparsed subpackets, in order.
rawSubpackets []outputSubpacket
// The following are optional so are nil when not included in the // The following are optional so are nil when not included in the
// signature. // signature.
@ -128,14 +130,11 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
switch sig.PubKeyAlgo { switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature, _, err = readMPI(r) sig.RSASignature.bytes, sig.RSASignature.bitLength, err = readMPI(r)
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
var rBytes, sBytes []byte sig.DSASigR.bytes, sig.DSASigR.bitLength, err = readMPI(r)
rBytes, _, err = readMPI(r)
sig.DSASigR = new(big.Int).SetBytes(rBytes)
if err == nil { if err == nil {
sBytes, _, err = readMPI(r) sig.DSASigS.bytes, sig.DSASigS.bitLength, err = readMPI(r)
sig.DSASigS = new(big.Int).SetBytes(sBytes)
} }
default: default:
panic("unreachable") panic("unreachable")
@ -177,7 +176,11 @@ const (
// parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1. // parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1.
func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (rest []byte, err os.Error) { func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (rest []byte, err os.Error) {
// RFC 4880, section 5.2.3.1 // RFC 4880, section 5.2.3.1
var length uint32 var (
length uint32
packetType signatureSubpacketType
isCritical bool
)
switch { switch {
case subpacket[0] < 192: case subpacket[0] < 192:
length = uint32(subpacket[0]) length = uint32(subpacket[0])
@ -207,10 +210,11 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
err = error.StructuralError("zero length signature subpacket") err = error.StructuralError("zero length signature subpacket")
return return
} }
packetType := subpacket[0] & 0x7f packetType = signatureSubpacketType(subpacket[0] & 0x7f)
isCritial := subpacket[0]&0x80 == 0x80 isCritical = subpacket[0]&0x80 == 0x80
subpacket = subpacket[1:] subpacket = subpacket[1:]
switch signatureSubpacketType(packetType) { sig.rawSubpackets = append(sig.rawSubpackets, outputSubpacket{isHashed, packetType, isCritical, subpacket})
switch packetType {
case creationTimeSubpacket: case creationTimeSubpacket:
if !isHashed { if !isHashed {
err = error.StructuralError("signature creation time in non-hashed area") err = error.StructuralError("signature creation time in non-hashed area")
@ -309,7 +313,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
} }
default: default:
if isCritial { if isCritical {
err = error.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType))) err = error.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
return return
} }
@ -381,7 +385,6 @@ func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
// buildHashSuffix constructs the HashSuffix member of sig in preparation for signing. // buildHashSuffix constructs the HashSuffix member of sig in preparation for signing.
func (sig *Signature) buildHashSuffix() (err os.Error) { func (sig *Signature) buildHashSuffix() (err os.Error) {
sig.outSubpackets = sig.buildSubpackets()
hashedSubpacketsLen := subpacketsLength(sig.outSubpackets, true) hashedSubpacketsLen := subpacketsLength(sig.outSubpackets, true)
var ok bool var ok bool
@ -393,7 +396,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash) sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash)
if !ok { if !ok {
sig.HashSuffix = nil sig.HashSuffix = nil
return error.InvalidArgumentError("hash cannot be repesented in OpenPGP: " + strconv.Itoa(int(sig.Hash))) return error.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
} }
sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8) sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
sig.HashSuffix[5] = byte(hashedSubpacketsLen) sig.HashSuffix[5] = byte(hashedSubpacketsLen)
@ -420,45 +423,72 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err os.Error)
return return
} }
// SignRSA signs a message with an RSA private key. The hash, h, must contain // Sign signs a message with a private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function. // the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out. // On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) { func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err os.Error) {
sig.outSubpackets = sig.buildSubpackets()
digest, err := sig.signPrepareHash(h) digest, err := sig.signPrepareHash(h)
if err != nil { if err != nil {
return return
} }
sig.RSASignature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
switch priv.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest)
sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes))
case PubKeyAlgoDSA:
r, s, err := dsa.Sign(rand.Reader, priv.PrivateKey.(*dsa.PrivateKey), digest)
if err == nil {
sig.DSASigR.bytes = r.Bytes()
sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes))
sig.DSASigS.bytes = s.Bytes()
sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes))
}
default:
err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
}
return return
} }
// SignDSA signs a message with a DSA private key. The hash, h, must contain // SignUserId computes a signature from priv, asserting that pub is a valid
// the hash of the message to be signed and will be mutated by this function. // key for the identity id. On success, the signature is stored in sig. Call
// On success, the signature is stored in sig. Call Serialize to write it out. // Serialize to write it out.
func (sig *Signature) SignDSA(h hash.Hash, priv *dsa.PrivateKey) (err os.Error) { func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey) os.Error {
digest, err := sig.signPrepareHash(h) h, err := userIdSignatureHash(id, pub, sig)
if err != nil { if err != nil {
return return nil
} }
sig.DSASigR, sig.DSASigS, err = dsa.Sign(rand.Reader, priv, digest) return sig.Sign(h, priv)
return }
// SignKey computes a signature from priv, asserting that pub is a subkey. On
// success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey) os.Error {
h, err := keySignatureHash(&priv.PublicKey, pub, sig)
if err != nil {
return err
}
return sig.Sign(h, priv)
} }
// Serialize marshals sig to w. SignRSA or SignDSA must have been called first. // Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
func (sig *Signature) Serialize(w io.Writer) (err os.Error) { func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
if sig.RSASignature == nil && sig.DSASigR == nil { if len(sig.outSubpackets) == 0 {
sig.outSubpackets = sig.rawSubpackets
}
if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil {
return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize") return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
} }
sigLength := 0 sigLength := 0
switch sig.PubKeyAlgo { switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sigLength = len(sig.RSASignature) sigLength = 2 + len(sig.RSASignature.bytes)
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
sigLength = 2 /* MPI length */ sigLength = 2 + len(sig.DSASigR.bytes)
sigLength += (sig.DSASigR.BitLen() + 7) / 8 sigLength += 2 + len(sig.DSASigS.bytes)
sigLength += 2 /* MPI length */
sigLength += (sig.DSASigS.BitLen() + 7) / 8
default: default:
panic("impossible") panic("impossible")
} }
@ -466,7 +496,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false) unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
length := len(sig.HashSuffix) - 6 /* trailer not included */ + length := len(sig.HashSuffix) - 6 /* trailer not included */ +
2 /* length of unhashed subpackets */ + unhashedSubpacketsLen + 2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
2 /* hash tag */ + 2 /* length of signature MPI */ + sigLength 2 /* hash tag */ + sigLength
err = serializeHeader(w, packetTypeSignature, length) err = serializeHeader(w, packetTypeSignature, length)
if err != nil { if err != nil {
return return
@ -493,12 +523,9 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
switch sig.PubKeyAlgo { switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
err = writeMPI(w, 8*uint16(len(sig.RSASignature)), sig.RSASignature) err = writeMPIs(w, sig.RSASignature)
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
err = writeBig(w, sig.DSASigR) err = writeMPIs(w, sig.DSASigR, sig.DSASigS)
if err == nil {
err = writeBig(w, sig.DSASigS)
}
default: default:
panic("impossible") panic("impossible")
} }
@ -509,6 +536,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
type outputSubpacket struct { type outputSubpacket struct {
hashed bool // true if this subpacket is in the hashed area. hashed bool // true if this subpacket is in the hashed area.
subpacketType signatureSubpacketType subpacketType signatureSubpacketType
isCritical bool
contents []byte contents []byte
} }
@ -518,12 +546,12 @@ func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
creationTime[1] = byte(sig.CreationTime >> 16) creationTime[1] = byte(sig.CreationTime >> 16)
creationTime[2] = byte(sig.CreationTime >> 8) creationTime[2] = byte(sig.CreationTime >> 8)
creationTime[3] = byte(sig.CreationTime) creationTime[3] = byte(sig.CreationTime)
subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, creationTime}) subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, false, creationTime})
if sig.IssuerKeyId != nil { if sig.IssuerKeyId != nil {
keyId := make([]byte, 8) keyId := make([]byte, 8)
binary.BigEndian.PutUint64(keyId, *sig.IssuerKeyId) binary.BigEndian.PutUint64(keyId, *sig.IssuerKeyId)
subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, keyId}) subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId})
} }
return return

View File

@ -12,9 +12,7 @@ import (
) )
func TestSignatureRead(t *testing.T) { func TestSignatureRead(t *testing.T) {
signatureData, _ := hex.DecodeString(signatureDataHex) packet, err := Read(readerFromHex(signatureDataHex))
buf := bytes.NewBuffer(signatureData)
packet, err := Read(buf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -25,4 +23,20 @@ func TestSignatureRead(t *testing.T) {
} }
} }
const signatureDataHex = "89011c04000102000605024cb45112000a0910ab105c91af38fb158f8d07ff5596ea368c5efe015bed6e78348c0f033c931d5f2ce5db54ce7f2a7e4b4ad64db758d65a7a71773edeab7ba2a9e0908e6a94a1175edd86c1d843279f045b021a6971a72702fcbd650efc393c5474d5b59a15f96d2eaad4c4c426797e0dcca2803ef41c6ff234d403eec38f31d610c344c06f2401c262f0993b2e66cad8a81ebc4322c723e0d4ba09fe917e8777658307ad8329adacba821420741009dfe87f007759f0982275d028a392c6ed983a0d846f890b36148c7358bdb8a516007fac760261ecd06076813831a36d0459075d1befa245ae7f7fb103d92ca759e9498fe60ef8078a39a3beda510deea251ea9f0a7f0df6ef42060f20780360686f3e400e" func TestSignatureReserialize(t *testing.T) {
packet, _ := Read(readerFromHex(signatureDataHex))
sig := packet.(*Signature)
out := new(bytes.Buffer)
err := sig.Serialize(out)
if err != nil {
t.Errorf("error reserializing: %s", err)
return
}
expected, _ := hex.DecodeString(signatureDataHex)
if !bytes.Equal(expected, out.Bytes()) {
t.Errorf("output doesn't match input (got vs expected):\n%s\n%s", hex.Dump(out.Bytes()), hex.Dump(expected))
}
}
const signatureDataHex = "c2c05c04000102000605024cb45112000a0910ab105c91af38fb158f8d07ff5596ea368c5efe015bed6e78348c0f033c931d5f2ce5db54ce7f2a7e4b4ad64db758d65a7a71773edeab7ba2a9e0908e6a94a1175edd86c1d843279f045b021a6971a72702fcbd650efc393c5474d5b59a15f96d2eaad4c4c426797e0dcca2803ef41c6ff234d403eec38f31d610c344c06f2401c262f0993b2e66cad8a81ebc4322c723e0d4ba09fe917e8777658307ad8329adacba821420741009dfe87f007759f0982275d028a392c6ed983a0d846f890b36148c7358bdb8a516007fac760261ecd06076813831a36d0459075d1befa245ae7f7fb103d92ca759e9498fe60ef8078a39a3beda510deea251ea9f0a7f0df6ef42060f20780360686f3e400e"

View File

@ -5,6 +5,7 @@
package packet package packet
import ( import (
"bytes"
"crypto/cipher" "crypto/cipher"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
@ -27,6 +28,8 @@ type SymmetricKeyEncrypted struct {
encryptedKey []byte encryptedKey []byte
} }
const symmetricKeyEncryptedVersion = 4
func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) { func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) {
// RFC 4880, section 5.3. // RFC 4880, section 5.3.
var buf [2]byte var buf [2]byte
@ -34,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) {
if err != nil { if err != nil {
return return
} }
if buf[0] != 4 { if buf[0] != symmetricKeyEncryptedVersion {
return error.UnsupportedError("SymmetricKeyEncrypted version") return error.UnsupportedError("SymmetricKeyEncrypted version")
} }
ske.CipherFunc = CipherFunction(buf[1]) ske.CipherFunc = CipherFunction(buf[1])
if ske.CipherFunc.keySize() == 0 { if ske.CipherFunc.KeySize() == 0 {
return error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1]))) return error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
} }
@ -75,7 +78,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) os.Error {
return nil return nil
} }
key := make([]byte, ske.CipherFunc.keySize()) key := make([]byte, ske.CipherFunc.KeySize())
ske.s2k(key, passphrase) ske.s2k(key, passphrase)
if len(ske.encryptedKey) == 0 { if len(ske.encryptedKey) == 0 {
@ -100,3 +103,60 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) os.Error {
ske.Encrypted = false ske.Encrypted = false
return nil return nil
} }
// SerializeSymmetricKeyEncrypted serializes a symmetric key packet to w. The
// packet contains a random session key, encrypted by a key derived from the
// given passphrase. The session key is returned and must be passed to
// SerializeSymmetricallyEncrypted.
func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err os.Error) {
keySize := cipherFunc.KeySize()
if keySize == 0 {
return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
}
s2kBuf := new(bytes.Buffer)
keyEncryptingKey := make([]byte, keySize)
// s2k.Serialize salts and stretches the passphrase, and writes the
// resulting key to keyEncryptingKey and the s2k descriptor to s2kBuf.
err = s2k.Serialize(s2kBuf, keyEncryptingKey, rand, passphrase)
if err != nil {
return
}
s2kBytes := s2kBuf.Bytes()
packetLength := 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize
err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength)
if err != nil {
return
}
var buf [2]byte
buf[0] = symmetricKeyEncryptedVersion
buf[1] = byte(cipherFunc)
_, err = w.Write(buf[:])
if err != nil {
return
}
_, err = w.Write(s2kBytes)
if err != nil {
return
}
sessionKey := make([]byte, keySize)
_, err = io.ReadFull(rand, sessionKey)
if err != nil {
return
}
iv := make([]byte, cipherFunc.blockSize())
c := cipher.NewCFBEncrypter(cipherFunc.new(keyEncryptingKey), iv)
encryptedCipherAndKey := make([]byte, keySize+1)
c.XORKeyStream(encryptedCipherAndKey, buf[1:])
c.XORKeyStream(encryptedCipherAndKey[1:], sessionKey)
_, err = w.Write(encryptedCipherAndKey)
if err != nil {
return
}
key = sessionKey
return
}

View File

@ -6,6 +6,7 @@ package packet
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/hex" "encoding/hex"
"io/ioutil" "io/ioutil"
"os" "os"
@ -60,3 +61,41 @@ func TestSymmetricKeyEncrypted(t *testing.T) {
const symmetricallyEncryptedHex = "8c0d04030302371a0b38d884f02060c91cf97c9973b8e58e028e9501708ccfe618fb92afef7fa2d80ddadd93cf" const symmetricallyEncryptedHex = "8c0d04030302371a0b38d884f02060c91cf97c9973b8e58e028e9501708ccfe618fb92afef7fa2d80ddadd93cf"
const symmetricallyEncryptedContentsHex = "cb1062004d14c4df636f6e74656e74732e0a" const symmetricallyEncryptedContentsHex = "cb1062004d14c4df636f6e74656e74732e0a"
func TestSerializeSymmetricKeyEncrypted(t *testing.T) {
buf := bytes.NewBuffer(nil)
passphrase := []byte("testing")
cipherFunc := CipherAES128
key, err := SerializeSymmetricKeyEncrypted(buf, rand.Reader, passphrase, cipherFunc)
if err != nil {
t.Errorf("failed to serialize: %s", err)
return
}
p, err := Read(buf)
if err != nil {
t.Errorf("failed to reparse: %s", err)
return
}
ske, ok := p.(*SymmetricKeyEncrypted)
if !ok {
t.Errorf("parsed a different packet type: %#v", p)
return
}
if !ske.Encrypted {
t.Errorf("SKE not encrypted but should be")
}
if ske.CipherFunc != cipherFunc {
t.Errorf("SKE cipher function is %d (expected %d)", ske.CipherFunc, cipherFunc)
}
err = ske.Decrypt(passphrase)
if err != nil {
t.Errorf("failed to decrypt reparsed SKE: %s", err)
return
}
if !bytes.Equal(key, ske.Key) {
t.Errorf("keys don't match after Decrpyt: %x (original) vs %x (parsed)", key, ske.Key)
}
}

View File

@ -7,6 +7,7 @@ package packet
import ( import (
"crypto/cipher" "crypto/cipher"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/subtle" "crypto/subtle"
"hash" "hash"
@ -24,6 +25,8 @@ type SymmetricallyEncrypted struct {
prefix []byte prefix []byte
} }
const symmetricallyEncryptedVersion = 1
func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error { func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
if se.MDC { if se.MDC {
// See RFC 4880, section 5.13. // See RFC 4880, section 5.13.
@ -32,7 +35,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
if err != nil { if err != nil {
return err return err
} }
if buf[0] != 1 { if buf[0] != symmetricallyEncryptedVersion {
return error.UnsupportedError("unknown SymmetricallyEncrypted version") return error.UnsupportedError("unknown SymmetricallyEncrypted version")
} }
} }
@ -44,7 +47,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
// packet can be read. An incorrect key can, with high probability, be detected // packet can be read. An incorrect key can, with high probability, be detected
// immediately and this will result in a KeyIncorrect error being returned. // immediately and this will result in a KeyIncorrect error being returned.
func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, os.Error) { func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, os.Error) {
keySize := c.keySize() keySize := c.KeySize()
if keySize == 0 { if keySize == 0 {
return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c))) return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
} }
@ -174,6 +177,9 @@ func (ser *seMDCReader) Read(buf []byte) (n int, err os.Error) {
return return
} }
// This is a new-format packet tag byte for a type 19 (MDC) packet.
const mdcPacketTagByte = byte(0x80) | 0x40 | 19
func (ser *seMDCReader) Close() os.Error { func (ser *seMDCReader) Close() os.Error {
if ser.error { if ser.error {
return error.SignatureError("error during reading") return error.SignatureError("error during reading")
@ -191,16 +197,95 @@ func (ser *seMDCReader) Close() os.Error {
} }
} }
// This is a new-format packet tag byte for a type 19 (MDC) packet.
const mdcPacketTagByte = byte(0x80) | 0x40 | 19
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size { if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
return error.SignatureError("MDC packet not found") return error.SignatureError("MDC packet not found")
} }
ser.h.Write(ser.trailer[:2]) ser.h.Write(ser.trailer[:2])
final := ser.h.Sum() final := ser.h.Sum()
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) == 1 { if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
return error.SignatureError("hash mismatch") return error.SignatureError("hash mismatch")
} }
return nil return nil
} }
// An seMDCWriter writes through to an io.WriteCloser while maintains a running
// hash of the data written. On close, it emits an MDC packet containing the
// running hash.
type seMDCWriter struct {
w io.WriteCloser
h hash.Hash
}
func (w *seMDCWriter) Write(buf []byte) (n int, err os.Error) {
w.h.Write(buf)
return w.w.Write(buf)
}
func (w *seMDCWriter) Close() (err os.Error) {
var buf [mdcTrailerSize]byte
buf[0] = mdcPacketTagByte
buf[1] = sha1.Size
w.h.Write(buf[:2])
digest := w.h.Sum()
copy(buf[2:], digest)
_, err = w.w.Write(buf[:])
if err != nil {
return
}
return w.w.Close()
}
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
type noOpCloser struct {
w io.Writer
}
func (c noOpCloser) Write(data []byte) (n int, err os.Error) {
return c.w.Write(data)
}
func (c noOpCloser) Close() os.Error {
return nil
}
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
// to w and returns a WriteCloser to which the to-be-encrypted packets can be
// written.
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err os.Error) {
if c.KeySize() != len(key) {
return nil, error.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
}
writeCloser := noOpCloser{w}
ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
if err != nil {
return
}
_, err = ciphertext.Write([]byte{symmetricallyEncryptedVersion})
if err != nil {
return
}
block := c.new(key)
blockSize := block.BlockSize()
iv := make([]byte, blockSize)
_, err = rand.Reader.Read(iv)
if err != nil {
return
}
s, prefix := cipher.NewOCFBEncrypter(block, iv, cipher.OCFBNoResync)
_, err = ciphertext.Write(prefix)
if err != nil {
return
}
plaintext := cipher.StreamWriter{S: s, W: ciphertext}
h := sha1.New()
h.Write(iv)
h.Write(iv[blockSize-2:])
contents = &seMDCWriter{w: plaintext, h: h}
return
}

View File

@ -9,6 +9,7 @@ import (
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
@ -76,3 +77,48 @@ func testMDCReader(t *testing.T) {
} }
const mdcPlaintextHex = "a302789c3b2d93c4e0eb9aba22283539b3203335af44a134afb800c849cb4c4de10200aff40b45d31432c80cb384299a0655966d6939dfdeed1dddf980" const mdcPlaintextHex = "a302789c3b2d93c4e0eb9aba22283539b3203335af44a134afb800c849cb4c4de10200aff40b45d31432c80cb384299a0655966d6939dfdeed1dddf980"
func TestSerialize(t *testing.T) {
buf := bytes.NewBuffer(nil)
c := CipherAES128
key := make([]byte, c.KeySize())
w, err := SerializeSymmetricallyEncrypted(buf, c, key)
if err != nil {
t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
return
}
contents := []byte("hello world\n")
w.Write(contents)
w.Close()
p, err := Read(buf)
if err != nil {
t.Errorf("error from Read: %s", err)
return
}
se, ok := p.(*SymmetricallyEncrypted)
if !ok {
t.Errorf("didn't read a *SymmetricallyEncrypted")
return
}
r, err := se.Decrypt(c, key)
if err != nil {
t.Errorf("error from Decrypt: %s", err)
return
}
contentsCopy := bytes.NewBuffer(nil)
_, err = io.Copy(contentsCopy, r)
if err != nil {
t.Errorf("error from io.Copy: %s", err)
return
}
if !bytes.Equal(contentsCopy.Bytes(), contents) {
t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
}
}

Some files were not shown because too many files have changed in this diff Show More