]> gcc.gnu.org Git - gcc.git/commitdiff
libgo: Update to weekly.2012-01-15.
authorIan Lance Taylor <ian@gcc.gnu.org>
Wed, 25 Jan 2012 20:56:26 +0000 (20:56 +0000)
committerIan Lance Taylor <ian@gcc.gnu.org>
Wed, 25 Jan 2012 20:56:26 +0000 (20:56 +0000)
From-SVN: r183539

192 files changed:
libgo/MERGE
libgo/Makefile.am
libgo/Makefile.in
libgo/config.h.in
libgo/configure
libgo/configure.ac
libgo/go/bytes/buffer.go
libgo/go/bytes/buffer_test.go
libgo/go/crypto/openpgp/armor/armor.go
libgo/go/crypto/openpgp/errors/errors.go [moved from libgo/go/crypto/openpgp/error/error.go with 94% similarity]
libgo/go/crypto/openpgp/keys.go
libgo/go/crypto/openpgp/packet/compressed.go
libgo/go/crypto/openpgp/packet/encrypted_key.go
libgo/go/crypto/openpgp/packet/one_pass_signature.go
libgo/go/crypto/openpgp/packet/packet.go
libgo/go/crypto/openpgp/packet/packet_test.go
libgo/go/crypto/openpgp/packet/private_key.go
libgo/go/crypto/openpgp/packet/public_key.go
libgo/go/crypto/openpgp/packet/reader.go
libgo/go/crypto/openpgp/packet/signature.go
libgo/go/crypto/openpgp/packet/symmetric_key_encrypted.go
libgo/go/crypto/openpgp/packet/symmetrically_encrypted.go
libgo/go/crypto/openpgp/packet/symmetrically_encrypted_test.go
libgo/go/crypto/openpgp/read.go
libgo/go/crypto/openpgp/read_test.go
libgo/go/crypto/openpgp/s2k/s2k.go
libgo/go/crypto/openpgp/write.go
libgo/go/crypto/openpgp/write_test.go
libgo/go/crypto/tls/common.go
libgo/go/crypto/tls/generate_cert.go
libgo/go/crypto/tls/handshake_client.go
libgo/go/crypto/tls/handshake_messages.go
libgo/go/crypto/tls/handshake_server.go
libgo/go/crypto/tls/handshake_server_test.go
libgo/go/crypto/tls/tls.go
libgo/go/crypto/x509/cert_pool.go
libgo/go/debug/gosym/pclntab_test.go
libgo/go/encoding/asn1/asn1.go
libgo/go/encoding/asn1/asn1_test.go
libgo/go/encoding/asn1/marshal_test.go
libgo/go/encoding/gob/decode.go
libgo/go/encoding/gob/encoder_test.go
libgo/go/encoding/json/decode.go
libgo/go/encoding/json/decode_test.go
libgo/go/encoding/json/encode.go
libgo/go/encoding/json/encode_test.go
libgo/go/encoding/xml/atom_test.go
libgo/go/encoding/xml/embed_test.go [deleted file]
libgo/go/encoding/xml/marshal.go
libgo/go/encoding/xml/marshal_test.go
libgo/go/encoding/xml/read.go
libgo/go/encoding/xml/read_test.go
libgo/go/encoding/xml/typeinfo.go [new file with mode: 0644]
libgo/go/encoding/xml/xml_test.go
libgo/go/exp/norm/input.go
libgo/go/exp/norm/normalize.go
libgo/go/exp/norm/normalize_test.go
libgo/go/exp/norm/readwriter.go
libgo/go/exp/proxy/direct.go [new file with mode: 0644]
libgo/go/exp/proxy/per_host.go [new file with mode: 0644]
libgo/go/exp/proxy/per_host_test.go [new file with mode: 0644]
libgo/go/exp/proxy/proxy.go [new file with mode: 0644]
libgo/go/exp/proxy/proxy_test.go [new file with mode: 0644]
libgo/go/exp/proxy/socks5.go [new file with mode: 0644]
libgo/go/exp/sql/convert_test.go
libgo/go/exp/sql/driver/driver.go
libgo/go/exp/sql/driver/types.go
libgo/go/exp/sql/driver/types_test.go
libgo/go/exp/sql/fakedb_test.go
libgo/go/exp/sql/sql.go
libgo/go/exp/sql/sql_test.go
libgo/go/exp/ssh/client.go
libgo/go/exp/ssh/doc.go
libgo/go/exp/ssh/server_shell.go [deleted file]
libgo/go/exp/ssh/server_shell_test.go [deleted file]
libgo/go/exp/ssh/server_terminal.go [new file with mode: 0644]
libgo/go/exp/ssh/session_test.go
libgo/go/exp/ssh/transport.go
libgo/go/exp/terminal/terminal.go
libgo/go/exp/terminal/terminal_test.go
libgo/go/exp/types/check_test.go
libgo/go/exp/types/universe.go
libgo/go/flag/flag.go
libgo/go/flag/flag_test.go
libgo/go/fmt/doc.go
libgo/go/fmt/fmt_test.go
libgo/go/go/ast/ast.go
libgo/go/go/ast/filter.go
libgo/go/go/ast/print.go
libgo/go/go/ast/print_test.go
libgo/go/go/ast/scope.go
libgo/go/go/build/build.go
libgo/go/go/build/dir.go
libgo/go/go/build/path.go
libgo/go/go/doc/comment.go
libgo/go/go/doc/doc.go
libgo/go/go/doc/doc_test.go [new file with mode: 0644]
libgo/go/go/doc/example.go
libgo/go/go/doc/exports.go
libgo/go/go/doc/filter.go
libgo/go/go/doc/headscan.go
libgo/go/go/doc/reader.go [new file with mode: 0644]
libgo/go/go/parser/interface.go
libgo/go/go/parser/parser.go
libgo/go/go/parser/parser_test.go
libgo/go/go/printer/printer.go
libgo/go/go/printer/printer_test.go
libgo/go/go/scanner/scanner.go
libgo/go/go/scanner/scanner_test.go
libgo/go/go/token/token.go
libgo/go/html/foreign.go
libgo/go/html/node.go
libgo/go/html/parse.go
libgo/go/html/parse_test.go
libgo/go/html/render.go
libgo/go/html/template/escape_test.go
libgo/go/html/token.go
libgo/go/image/names.go
libgo/go/log/syslog/syslog.go
libgo/go/math/all_test.go
libgo/go/math/big/nat.go
libgo/go/net/file_test.go
libgo/go/net/http/cookie.go
libgo/go/net/http/request.go
libgo/go/net/interface.go
libgo/go/net/iprawsock_plan9.go
libgo/go/net/iprawsock_posix.go
libgo/go/net/ipsock_posix.go
libgo/go/net/multicast_test.go
libgo/go/net/rpc/server_test.go
libgo/go/net/server_test.go
libgo/go/net/sock.go
libgo/go/net/sock_bsd.go
libgo/go/net/sock_linux.go
libgo/go/net/sock_windows.go
libgo/go/net/sockopt.go [new file with mode: 0644]
libgo/go/net/sockopt_bsd.go [new file with mode: 0644]
libgo/go/net/sockopt_linux.go [new file with mode: 0644]
libgo/go/net/sockopt_windows.go [new file with mode: 0644]
libgo/go/net/sockoptip.go [new file with mode: 0644]
libgo/go/net/sockoptip_bsd.go [new file with mode: 0644]
libgo/go/net/sockoptip_darwin.go [new file with mode: 0644]
libgo/go/net/sockoptip_freebsd.go [new file with mode: 0644]
libgo/go/net/sockoptip_linux.go [new file with mode: 0644]
libgo/go/net/sockoptip_openbsd.go [new file with mode: 0644]
libgo/go/net/sockoptip_windows.go [new file with mode: 0644]
libgo/go/net/tcpsock_posix.go
libgo/go/net/textproto/reader.go
libgo/go/net/textproto/reader_test.go
libgo/go/net/udpsock_posix.go
libgo/go/net/unicast_test.go [new file with mode: 0644]
libgo/go/net/unixsock_posix.go
libgo/go/os/env_test.go
libgo/go/os/os_test.go
libgo/go/os/os_unix_test.go
libgo/go/os/path_test.go
libgo/go/os/stat_openbsd.go
libgo/go/os/types.go
libgo/go/runtime/debug.go
libgo/go/runtime/extern.go
libgo/go/sort/sort.go
libgo/go/strconv/extfloat.go
libgo/go/strconv/fp_test.go
libgo/go/strconv/ftoa.go
libgo/go/strconv/ftoa_test.go
libgo/go/strconv/quote.go
libgo/go/strconv/quote_test.go
libgo/go/syscall/env_unix.go
libgo/go/syscall/exec_bsd.go [new file with mode: 0644]
libgo/go/syscall/exec_linux.go [new file with mode: 0644]
libgo/go/syscall/exec_unix.go
libgo/go/syscall/socket.go
libgo/go/syscall/syscall_unix.go
libgo/go/testing/benchmark.go
libgo/go/testing/testing.go
libgo/go/testing/wrapper.go [deleted file]
libgo/go/text/template/doc.go
libgo/go/text/template/exec.go
libgo/go/text/template/exec_test.go
libgo/go/time/format.go
libgo/go/time/sleep.go
libgo/go/time/tick.go
libgo/go/time/time_test.go
libgo/mksysinfo.sh
libgo/runtime/malloc.goc
libgo/runtime/malloc.h
libgo/runtime/mgc0.c
libgo/runtime/runtime.c
libgo/runtime/runtime.h
libgo/runtime/runtime1.goc
libgo/runtime/thread-linux.c
libgo/testsuite/gotest

index 96fb7f66498e0cd28d80c714a48bbd242a2c2302..b72962fecbebc3d7d24359010e5b586e916bb56b 100644 (file)
@@ -1,4 +1,4 @@
-4a8268927758
+354b17404643
 
 The first line of this file holds the Mercurial revision number of the
 last merge done from the master library sources.
index 348a1cae8d2c4b570654372909ce34b08c32d9a1..770a849e7446c979d4b573e21398735444699f3d 100644 (file)
@@ -188,7 +188,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp
 toolexeclibgocryptoopenpgp_DATA = \
        crypto/openpgp/armor.gox \
        crypto/openpgp/elgamal.gox \
-       crypto/openpgp/error.gox \
+       crypto/openpgp/errors.gox \
        crypto/openpgp/packet.gox \
        crypto/openpgp/s2k.gox
 
@@ -235,6 +235,7 @@ toolexeclibgoexp_DATA = \
        exp/ebnf.gox \
        $(exp_inotify_gox) \
        exp/norm.gox \
+       exp/proxy.gox \
        exp/spdy.gox \
        exp/sql.gox \
        exp/ssh.gox \
@@ -669,17 +670,25 @@ endif # !LIBGO_IS_RTEMS
 if LIBGO_IS_LINUX
 go_net_cgo_file = go/net/cgo_linux.go
 go_net_sock_file = go/net/sock_linux.go
+go_net_sockopt_file = go/net/sockopt_linux.go
+go_net_sockoptip_file = go/net/sockoptip_linux.go
 else
 if LIBGO_IS_IRIX
 go_net_cgo_file = go/net/cgo_linux.go
 go_net_sock_file = go/net/sock_linux.go
+go_net_sockopt_file = go/net/sockopt_linux.go
+go_net_sockoptip_file = go/net/sockoptip_linux.go
 else
 if LIBGO_IS_SOLARIS
 go_net_cgo_file = go/net/cgo_linux.go
 go_net_sock_file = go/net/sock_linux.go
+go_net_sockopt_file = go/net/sockopt_linux.go
+go_net_sockoptip_file = go/net/sockoptip_linux.go
 else
 go_net_cgo_file = go/net/cgo_bsd.go
 go_net_sock_file = go/net/sock_bsd.go
+go_net_sockopt_file = go/net/sockopt_bsd.go
+go_net_sockoptip_file = go/net/sockoptip_bsd.go
 endif
 endif
 endif
@@ -728,6 +737,10 @@ go_net_files = \
        $(go_net_sendfile_file) \
        go/net/sock.go \
        $(go_net_sock_file) \
+       go/net/sockopt.go \
+       $(go_net_sockopt_file) \
+       go/net/sockoptip.go \
+       $(go_net_sockoptip_file) \
        go/net/tcpsock.go \
        go/net/tcpsock_posix.go \
        go/net/udpsock.go \
@@ -890,8 +903,7 @@ go_syslog_c_files = \
 go_testing_files = \
        go/testing/benchmark.go \
        go/testing/example.go \
-       go/testing/testing.go \
-       go/testing/wrapper.go
+       go/testing/testing.go
 
 go_time_files = \
        go/time/format.go \
@@ -1061,8 +1073,8 @@ go_crypto_openpgp_armor_files = \
        go/crypto/openpgp/armor/encode.go
 go_crypto_openpgp_elgamal_files = \
        go/crypto/openpgp/elgamal/elgamal.go
-go_crypto_openpgp_error_files = \
-       go/crypto/openpgp/error/error.go
+go_crypto_openpgp_errors_files = \
+       go/crypto/openpgp/errors/errors.go
 go_crypto_openpgp_packet_files = \
        go/crypto/openpgp/packet/compressed.go \
        go/crypto/openpgp/packet/encrypted_key.go \
@@ -1142,6 +1154,7 @@ go_encoding_pem_files = \
 go_encoding_xml_files = \
        go/encoding/xml/marshal.go \
        go/encoding/xml/read.go \
+       go/encoding/xml/typeinfo.go \
        go/encoding/xml/xml.go
 
 go_exp_ebnf_files = \
@@ -1157,6 +1170,11 @@ go_exp_norm_files = \
        go/exp/norm/readwriter.go \
        go/exp/norm/tables.go \
        go/exp/norm/trie.go
+go_exp_proxy_files = \
+       go/exp/proxy/direct.go \
+       go/exp/proxy/per_host.go \
+       go/exp/proxy/proxy.go \
+       go/exp/proxy/socks5.go
 go_exp_spdy_files = \
        go/exp/spdy/read.go \
        go/exp/spdy/types.go \
@@ -1173,7 +1191,7 @@ go_exp_ssh_files = \
        go/exp/ssh/doc.go \
        go/exp/ssh/messages.go \
        go/exp/ssh/server.go \
-       go/exp/ssh/server_shell.go \
+       go/exp/ssh/server_terminal.go \
        go/exp/ssh/session.go \
        go/exp/ssh/tcpip.go \
        go/exp/ssh/transport.go
@@ -1210,7 +1228,8 @@ go_go_doc_files = \
        go/go/doc/doc.go \
        go/go/doc/example.go \
        go/go/doc/exports.go \
-       go/go/doc/filter.go
+       go/go/doc/filter.go \
+       go/go/doc/reader.go
 go_go_parser_files = \
        go/go/parser/interface.go \
        go/go/parser/parser.go
@@ -1461,8 +1480,15 @@ endif
 # Define ForkExec and Exec.
 if LIBGO_IS_RTEMS
 syscall_exec_file = go/syscall/exec_stubs.go
+syscall_exec_os_file =
+else
+if LIBGO_IS_LINUX
+syscall_exec_file = go/syscall/exec_unix.go
+syscall_exec_os_file = go/syscall/exec_linux.go
 else
 syscall_exec_file = go/syscall/exec_unix.go
+syscall_exec_os_file = go/syscall/exec_bsd.go
+endif
 endif
 
 # Define Wait4.
@@ -1573,6 +1599,7 @@ go_base_syscall_files = \
        go/syscall/syscall.go \
        $(syscall_syscall_file) \
        $(syscall_exec_file) \
+       $(syscall_exec_os_file) \
        $(syscall_wait_file) \
        $(syscall_sleep_file) \
        $(syscall_errstr_file) \
@@ -1720,7 +1747,7 @@ libgo_go_objs = \
        crypto/xtea.lo \
        crypto/openpgp/armor.lo \
        crypto/openpgp/elgamal.lo \
-       crypto/openpgp/error.lo \
+       crypto/openpgp/errors.lo \
        crypto/openpgp/packet.lo \
        crypto/openpgp/s2k.lo \
        crypto/x509/pkix.lo \
@@ -1743,6 +1770,7 @@ libgo_go_objs = \
        encoding/xml.lo \
        exp/ebnf.lo \
        exp/norm.lo \
+       exp/proxy.lo \
        exp/spdy.lo \
        exp/sql.lo \
        exp/ssh.lo \
@@ -2578,15 +2606,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS)
        @$(CHECK)
 .PHONY: crypto/openpgp/elgamal/check
 
-@go_include@ crypto/openpgp/error.lo.dep
-crypto/openpgp/error.lo.dep: $(go_crypto_openpgp_error_files)
+@go_include@ crypto/openpgp/errors.lo.dep
+crypto/openpgp/errors.lo.dep: $(go_crypto_openpgp_errors_files)
        $(BUILDDEPS)
-crypto/openpgp/error.lo: $(go_crypto_openpgp_error_files)
+crypto/openpgp/errors.lo: $(go_crypto_openpgp_errors_files)
        $(BUILDPACKAGE)
-crypto/openpgp/error/check: $(CHECK_DEPS)
-       @$(MKDIR_P) crypto/openpgp/error
+crypto/openpgp/errors/check: $(CHECK_DEPS)
+       @$(MKDIR_P) crypto/openpgp/errors
        @$(CHECK)
-.PHONY: crypto/openpgp/error/check
+.PHONY: crypto/openpgp/errors/check
 
 @go_include@ crypto/openpgp/packet.lo.dep
 crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files)
@@ -2808,6 +2836,16 @@ exp/norm/check: $(CHECK_DEPS)
        @$(CHECK)
 .PHONY: exp/norm/check
 
+@go_include@ exp/proxy.lo.dep
+exp/proxy.lo.dep: $(go_exp_proxy_files)
+       $(BUILDDEPS)
+exp/proxy.lo: $(go_exp_proxy_files)
+       $(BUILDPACKAGE)
+exp/proxy/check: $(CHECK_DEPS)
+       @$(MKDIR_P) exp/proxy
+       @$(CHECK)
+.PHONY: exp/proxy/check
+
 @go_include@ exp/spdy.lo.dep
 exp/spdy.lo.dep: $(go_exp_spdy_files)
        $(BUILDDEPS)
@@ -3622,7 +3660,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo
        $(BUILDGOX)
 crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo
        $(BUILDGOX)
-crypto/openpgp/error.gox: crypto/openpgp/error.lo
+crypto/openpgp/errors.gox: crypto/openpgp/errors.lo
        $(BUILDGOX)
 crypto/openpgp/packet.gox: crypto/openpgp/packet.lo
        $(BUILDGOX)
@@ -3674,6 +3712,8 @@ exp/inotify.gox: exp/inotify.lo
        $(BUILDGOX)
 exp/norm.gox: exp/norm.lo
        $(BUILDGOX)
+exp/proxy.gox: exp/proxy.lo
+       $(BUILDGOX)
 exp/spdy.gox: exp/spdy.lo
        $(BUILDGOX)
 exp/sql.gox: exp/sql.lo
@@ -3920,6 +3960,7 @@ TEST_PACKAGES = \
        exp/ebnf/check \
        $(exp_inotify_check) \
        exp/norm/check \
+       exp/proxy/check \
        exp/spdy/check \
        exp/sql/check \
        exp/ssh/check \
index 6bf18475628d9d50994ab43f89a6e7f26c0addd3..b82bf422e5566a43df4efe8cc2873ef5e4d0a830 100644 (file)
@@ -153,33 +153,34 @@ am__DEPENDENCIES_2 = bufio/bufio.lo bytes/bytes.lo bytes/index.lo \
        crypto/sha256.lo crypto/sha512.lo crypto/subtle.lo \
        crypto/tls.lo crypto/twofish.lo crypto/x509.lo crypto/xtea.lo \
        crypto/openpgp/armor.lo crypto/openpgp/elgamal.lo \
-       crypto/openpgp/error.lo crypto/openpgp/packet.lo \
+       crypto/openpgp/errors.lo crypto/openpgp/packet.lo \
        crypto/openpgp/s2k.lo crypto/x509/pkix.lo debug/dwarf.lo \
        debug/elf.lo debug/gosym.lo debug/macho.lo debug/pe.lo \
        encoding/ascii85.lo encoding/asn1.lo encoding/base32.lo \
        encoding/base64.lo encoding/binary.lo encoding/csv.lo \
        encoding/git85.lo encoding/gob.lo encoding/hex.lo \
        encoding/json.lo encoding/pem.lo encoding/xml.lo exp/ebnf.lo \
-       exp/norm.lo exp/spdy.lo exp/sql.lo exp/ssh.lo exp/terminal.lo \
-       exp/types.lo exp/sql/driver.lo html/template.lo go/ast.lo \
-       go/build.lo go/doc.lo go/parser.lo go/printer.lo go/scanner.lo \
-       go/token.lo hash/adler32.lo hash/crc32.lo hash/crc64.lo \
-       hash/fnv.lo net/http/cgi.lo net/http/fcgi.lo \
-       net/http/httptest.lo net/http/httputil.lo net/http/pprof.lo \
-       image/bmp.lo image/color.lo image/draw.lo image/gif.lo \
-       image/jpeg.lo image/png.lo image/tiff.lo index/suffixarray.lo \
-       io/ioutil.lo log/syslog.lo log/syslog/syslog_c.lo math/big.lo \
-       math/cmplx.lo math/rand.lo mime/mime.lo mime/multipart.lo \
-       net/dict.lo net/http.lo net/mail.lo net/rpc.lo net/smtp.lo \
-       net/textproto.lo net/url.lo old/netchan.lo old/regexp.lo \
-       old/template.lo $(am__DEPENDENCIES_1) os/user.lo os/signal.lo \
-       path/filepath.lo regexp/syntax.lo net/rpc/jsonrpc.lo \
-       runtime/debug.lo runtime/pprof.lo sync/atomic.lo \
-       sync/atomic_c.lo syscall/syscall.lo syscall/errno.lo \
-       syscall/wait.lo text/scanner.lo text/tabwriter.lo \
-       text/template.lo text/template/parse.lo testing/testing.lo \
-       testing/iotest.lo testing/quick.lo testing/script.lo \
-       unicode/utf16.lo unicode/utf8.lo
+       exp/norm.lo exp/proxy.lo exp/spdy.lo exp/sql.lo exp/ssh.lo \
+       exp/terminal.lo exp/types.lo exp/sql/driver.lo \
+       html/template.lo go/ast.lo go/build.lo go/doc.lo go/parser.lo \
+       go/printer.lo go/scanner.lo go/token.lo hash/adler32.lo \
+       hash/crc32.lo hash/crc64.lo hash/fnv.lo net/http/cgi.lo \
+       net/http/fcgi.lo net/http/httptest.lo net/http/httputil.lo \
+       net/http/pprof.lo image/bmp.lo image/color.lo image/draw.lo \
+       image/gif.lo image/jpeg.lo image/png.lo image/tiff.lo \
+       index/suffixarray.lo io/ioutil.lo log/syslog.lo \
+       log/syslog/syslog_c.lo math/big.lo math/cmplx.lo math/rand.lo \
+       mime/mime.lo mime/multipart.lo net/dict.lo net/http.lo \
+       net/mail.lo net/rpc.lo net/smtp.lo net/textproto.lo net/url.lo \
+       old/netchan.lo old/regexp.lo old/template.lo \
+       $(am__DEPENDENCIES_1) os/user.lo os/signal.lo path/filepath.lo \
+       regexp/syntax.lo net/rpc/jsonrpc.lo runtime/debug.lo \
+       runtime/pprof.lo sync/atomic.lo sync/atomic_c.lo \
+       syscall/syscall.lo syscall/errno.lo syscall/wait.lo \
+       text/scanner.lo text/tabwriter.lo text/template.lo \
+       text/template/parse.lo testing/testing.lo testing/iotest.lo \
+       testing/quick.lo testing/script.lo unicode/utf16.lo \
+       unicode/utf8.lo
 libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \
        $(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \
        $(am__DEPENDENCIES_1)
@@ -652,7 +653,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp
 toolexeclibgocryptoopenpgp_DATA = \
        crypto/openpgp/armor.gox \
        crypto/openpgp/elgamal.gox \
-       crypto/openpgp/error.gox \
+       crypto/openpgp/errors.gox \
        crypto/openpgp/packet.gox \
        crypto/openpgp/s2k.gox
 
@@ -692,6 +693,7 @@ toolexeclibgoexp_DATA = \
        exp/ebnf.gox \
        $(exp_inotify_gox) \
        exp/norm.gox \
+       exp/proxy.gox \
        exp/spdy.gox \
        exp/sql.gox \
        exp/ssh.gox \
@@ -1049,6 +1051,14 @@ go_mime_files = \
 @LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sock_file = go/net/sock_linux.go
 @LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sock_file = go/net/sock_linux.go
 @LIBGO_IS_LINUX_TRUE@go_net_sock_file = go/net/sock_linux.go
+@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_FALSE@go_net_sockopt_file = go/net/sockopt_bsd.go
+@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sockopt_file = go/net/sockopt_linux.go
+@LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sockopt_file = go/net/sockopt_linux.go
+@LIBGO_IS_LINUX_TRUE@go_net_sockopt_file = go/net/sockopt_linux.go
+@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_FALSE@go_net_sockoptip_file = go/net/sockoptip_bsd.go
+@LIBGO_IS_IRIX_FALSE@@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_SOLARIS_TRUE@go_net_sockoptip_file = go/net/sockoptip_linux.go
+@LIBGO_IS_IRIX_TRUE@@LIBGO_IS_LINUX_FALSE@go_net_sockoptip_file = go/net/sockoptip_linux.go
+@LIBGO_IS_LINUX_TRUE@go_net_sockoptip_file = go/net/sockoptip_linux.go
 @LIBGO_IS_LINUX_FALSE@go_net_sendfile_file = go/net/sendfile_stub.go
 @LIBGO_IS_LINUX_TRUE@go_net_sendfile_file = go/net/sendfile_linux.go
 @LIBGO_IS_LINUX_FALSE@@LIBGO_IS_NETBSD_FALSE@go_net_interface_file = go/net/interface_stub.go
@@ -1082,6 +1092,10 @@ go_net_files = \
        $(go_net_sendfile_file) \
        go/net/sock.go \
        $(go_net_sock_file) \
+       go/net/sockopt.go \
+       $(go_net_sockopt_file) \
+       go/net/sockoptip.go \
+       $(go_net_sockoptip_file) \
        go/net/tcpsock.go \
        go/net/tcpsock_posix.go \
        go/net/udpsock.go \
@@ -1197,8 +1211,7 @@ go_syslog_c_files = \
 go_testing_files = \
        go/testing/benchmark.go \
        go/testing/example.go \
-       go/testing/testing.go \
-       go/testing/wrapper.go
+       go/testing/testing.go
 
 go_time_files = \
        go/time/format.go \
@@ -1394,8 +1407,8 @@ go_crypto_openpgp_armor_files = \
 go_crypto_openpgp_elgamal_files = \
        go/crypto/openpgp/elgamal/elgamal.go
 
-go_crypto_openpgp_error_files = \
-       go/crypto/openpgp/error/error.go
+go_crypto_openpgp_errors_files = \
+       go/crypto/openpgp/errors/errors.go
 
 go_crypto_openpgp_packet_files = \
        go/crypto/openpgp/packet/compressed.go \
@@ -1492,6 +1505,7 @@ go_encoding_pem_files = \
 go_encoding_xml_files = \
        go/encoding/xml/marshal.go \
        go/encoding/xml/read.go \
+       go/encoding/xml/typeinfo.go \
        go/encoding/xml/xml.go
 
 go_exp_ebnf_files = \
@@ -1510,6 +1524,12 @@ go_exp_norm_files = \
        go/exp/norm/tables.go \
        go/exp/norm/trie.go
 
+go_exp_proxy_files = \
+       go/exp/proxy/direct.go \
+       go/exp/proxy/per_host.go \
+       go/exp/proxy/proxy.go \
+       go/exp/proxy/socks5.go
+
 go_exp_spdy_files = \
        go/exp/spdy/read.go \
        go/exp/spdy/types.go \
@@ -1528,7 +1548,7 @@ go_exp_ssh_files = \
        go/exp/ssh/doc.go \
        go/exp/ssh/messages.go \
        go/exp/ssh/server.go \
-       go/exp/ssh/server_shell.go \
+       go/exp/ssh/server_terminal.go \
        go/exp/ssh/session.go \
        go/exp/ssh/tcpip.go \
        go/exp/ssh/transport.go
@@ -1569,7 +1589,8 @@ go_go_doc_files = \
        go/go/doc/doc.go \
        go/go/doc/example.go \
        go/go/doc/exports.go \
-       go/go/doc/filter.go
+       go/go/doc/filter.go \
+       go/go/doc/reader.go
 
 go_go_parser_files = \
        go/go/parser/interface.go \
@@ -1840,10 +1861,14 @@ go_unicode_utf8_files = \
 
 # Define Syscall and Syscall6.
 @LIBGO_IS_RTEMS_TRUE@syscall_syscall_file = go/syscall/syscall_stubs.go
-@LIBGO_IS_RTEMS_FALSE@syscall_exec_file = go/syscall/exec_unix.go
+@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_file = go/syscall/exec_unix.go
+@LIBGO_IS_LINUX_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_file = go/syscall/exec_unix.go
 
 # Define ForkExec and Exec.
 @LIBGO_IS_RTEMS_TRUE@syscall_exec_file = go/syscall/exec_stubs.go
+@LIBGO_IS_LINUX_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_os_file = go/syscall/exec_bsd.go
+@LIBGO_IS_LINUX_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_exec_os_file = go/syscall/exec_linux.go
+@LIBGO_IS_RTEMS_TRUE@syscall_exec_os_file = 
 @HAVE_WAIT4_FALSE@@LIBGO_IS_RTEMS_FALSE@syscall_wait_file = go/syscall/libcall_waitpid.go
 @HAVE_WAIT4_TRUE@@LIBGO_IS_RTEMS_FALSE@syscall_wait_file = go/syscall/libcall_wait4.go
 
@@ -1901,6 +1926,7 @@ go_base_syscall_files = \
        go/syscall/syscall.go \
        $(syscall_syscall_file) \
        $(syscall_exec_file) \
+       $(syscall_exec_os_file) \
        $(syscall_wait_file) \
        $(syscall_sleep_file) \
        $(syscall_errstr_file) \
@@ -1995,7 +2021,7 @@ libgo_go_objs = \
        crypto/xtea.lo \
        crypto/openpgp/armor.lo \
        crypto/openpgp/elgamal.lo \
-       crypto/openpgp/error.lo \
+       crypto/openpgp/errors.lo \
        crypto/openpgp/packet.lo \
        crypto/openpgp/s2k.lo \
        crypto/x509/pkix.lo \
@@ -2018,6 +2044,7 @@ libgo_go_objs = \
        encoding/xml.lo \
        exp/ebnf.lo \
        exp/norm.lo \
+       exp/proxy.lo \
        exp/spdy.lo \
        exp/sql.lo \
        exp/ssh.lo \
@@ -2286,6 +2313,7 @@ TEST_PACKAGES = \
        exp/ebnf/check \
        $(exp_inotify_check) \
        exp/norm/check \
+       exp/proxy/check \
        exp/spdy/check \
        exp/sql/check \
        exp/ssh/check \
@@ -5162,15 +5190,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS)
        @$(CHECK)
 .PHONY: crypto/openpgp/elgamal/check
 
-@go_include@ crypto/openpgp/error.lo.dep
-crypto/openpgp/error.lo.dep: $(go_crypto_openpgp_error_files)
+@go_include@ crypto/openpgp/errors.lo.dep
+crypto/openpgp/errors.lo.dep: $(go_crypto_openpgp_errors_files)
        $(BUILDDEPS)
-crypto/openpgp/error.lo: $(go_crypto_openpgp_error_files)
+crypto/openpgp/errors.lo: $(go_crypto_openpgp_errors_files)
        $(BUILDPACKAGE)
-crypto/openpgp/error/check: $(CHECK_DEPS)
-       @$(MKDIR_P) crypto/openpgp/error
+crypto/openpgp/errors/check: $(CHECK_DEPS)
+       @$(MKDIR_P) crypto/openpgp/errors
        @$(CHECK)
-.PHONY: crypto/openpgp/error/check
+.PHONY: crypto/openpgp/errors/check
 
 @go_include@ crypto/openpgp/packet.lo.dep
 crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files)
@@ -5392,6 +5420,16 @@ exp/norm/check: $(CHECK_DEPS)
        @$(CHECK)
 .PHONY: exp/norm/check
 
+@go_include@ exp/proxy.lo.dep
+exp/proxy.lo.dep: $(go_exp_proxy_files)
+       $(BUILDDEPS)
+exp/proxy.lo: $(go_exp_proxy_files)
+       $(BUILDPACKAGE)
+exp/proxy/check: $(CHECK_DEPS)
+       @$(MKDIR_P) exp/proxy
+       @$(CHECK)
+.PHONY: exp/proxy/check
+
 @go_include@ exp/spdy.lo.dep
 exp/spdy.lo.dep: $(go_exp_spdy_files)
        $(BUILDDEPS)
@@ -6201,7 +6239,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo
        $(BUILDGOX)
 crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo
        $(BUILDGOX)
-crypto/openpgp/error.gox: crypto/openpgp/error.lo
+crypto/openpgp/errors.gox: crypto/openpgp/errors.lo
        $(BUILDGOX)
 crypto/openpgp/packet.gox: crypto/openpgp/packet.lo
        $(BUILDGOX)
@@ -6253,6 +6291,8 @@ exp/inotify.gox: exp/inotify.lo
        $(BUILDGOX)
 exp/norm.gox: exp/norm.lo
        $(BUILDGOX)
+exp/proxy.gox: exp/proxy.lo
+       $(BUILDGOX)
 exp/spdy.gox: exp/spdy.lo
        $(BUILDGOX)
 exp/sql.gox: exp/sql.lo
index f30af59816811f9f6e577b5c982e57e1b018f222..e4a2569a0ae7626b85361656d2e84bde231a4439 100644 (file)
@@ -74,6 +74,9 @@
 /* Define to 1 if you have the <sys/mman.h> header file. */
 #undef HAVE_SYS_MMAN_H
 
+/* Define to 1 if you have the <sys/prctl.h> header file. */
+#undef HAVE_SYS_PRCTL_H
+
 /* Define to 1 if you have the <sys/ptrace.h> header file. */
 #undef HAVE_SYS_PTRACE_H
 
index 8c8fe38bc8e92dfafbf4e9ee66ec57f21d2708d8..5ebed8017001e83d32835d2990637f1d9fd97576 100755 (executable)
@@ -14505,7 +14505,7 @@ no)
   ;;
 esac
 
-for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h
+for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h
 do :
   as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
 ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
index cd6b1a9ac82463a8e69d6bccb9ed975c7d78485f..9795332d9a081f1c7c66d3de5e23d90f09bad5a5 100644 (file)
@@ -451,7 +451,7 @@ no)
   ;;
 esac
 
-AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h)
+AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h)
 
 AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
 [#ifdef HAVE_SYS_SOCKET_H
index e66ac026e5beb8232a71dae4c10d6077220dc74a..77757af1d804fc501693ff416ac92c65ce425284 100644 (file)
@@ -97,8 +97,7 @@ func (b *Buffer) grow(n int) int {
 func (b *Buffer) Write(p []byte) (n int, err error) {
        b.lastRead = opInvalid
        m := b.grow(len(p))
-       copy(b.buf[m:], p)
-       return len(p), nil
+       return copy(b.buf[m:], p), nil
 }
 
 // WriteString appends the contents of s to the buffer.  The return
@@ -200,13 +199,16 @@ func (b *Buffer) WriteRune(r rune) (n int, err error) {
 
 // Read reads the next len(p) bytes from the buffer or until the buffer
 // is drained.  The return value n is the number of bytes read.  If the
-// buffer has no data to return, err is io.EOF even if len(p) is zero;
+// buffer has no data to return, err is io.EOF (unless len(p) is zero);
 // otherwise it is nil.
 func (b *Buffer) Read(p []byte) (n int, err error) {
        b.lastRead = opInvalid
        if b.off >= len(b.buf) {
                // Buffer is empty, reset to recover space.
                b.Truncate(0)
+               if len(p) == 0 {
+                       return
+               }
                return 0, io.EOF
        }
        n = copy(p, b.buf[b.off:])
index adb93302a541ebcea1cc9ea6b7f9df1e055e11cb..d0af11f104b9d1706a70121c2523bf946cf5ef22 100644 (file)
@@ -373,3 +373,16 @@ func TestReadBytes(t *testing.T) {
                }
        }
 }
+
+// Was a bug: used to give EOF reading empty slice at EOF.
+func TestReadEmptyAtEOF(t *testing.T) {
+       b := new(Buffer)
+       slice := make([]byte, 0)
+       n, err := b.Read(slice)
+       if err != nil {
+               t.Errorf("read error: %v", err)
+       }
+       if n != 0 {
+               t.Errorf("wrong count; got %d want 0", n)
+       }
+}
index 3bbb5dc351a148ff9d81bedf073c94cd1d6de512..96957ab1b483e31210583776abbb4e116db185e3 100644 (file)
@@ -9,7 +9,7 @@ package armor
 import (
        "bufio"
        "bytes"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "encoding/base64"
        "io"
 )
@@ -35,7 +35,7 @@ type Block struct {
        oReader openpgpReader
 }
 
-var ArmorCorrupt error = error_.StructuralError("armor invalid")
+var ArmorCorrupt error = errors.StructuralError("armor invalid")
 
 const crc24Init = 0xb704ce
 const crc24Poly = 0x1864cfb
similarity index 94%
rename from libgo/go/crypto/openpgp/error/error.go
rename to libgo/go/crypto/openpgp/errors/errors.go
index ceeb05419488b2274625d222ba7f151c28a8dc62..c434b764c9b2a621814be42ef537ce8aa5a44b63 100644 (file)
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package error contains common error types for the OpenPGP packages.
-package error
+// Package errors contains common error types for the OpenPGP packages.
+package errors
 
 import (
        "strconv"
index 74e7d239e0882b2fc1ed1797c2b0befe173e5f42..624a5ea8a769c6481458ef0a2baec21a5874db2f 100644 (file)
@@ -7,8 +7,9 @@ package openpgp
 import (
        "crypto"
        "crypto/openpgp/armor"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/packet"
+       "crypto/rand"
        "crypto/rsa"
        "io"
        "time"
@@ -181,13 +182,13 @@ func (el EntityList) DecryptionKeys() (keys []Key) {
 func ReadArmoredKeyRing(r io.Reader) (EntityList, error) {
        block, err := armor.Decode(r)
        if err == io.EOF {
-               return nil, error_.InvalidArgumentError("no armored data found")
+               return nil, errors.InvalidArgumentError("no armored data found")
        }
        if err != nil {
                return nil, err
        }
        if block.Type != PublicKeyType && block.Type != PrivateKeyType {
-               return nil, error_.InvalidArgumentError("expected public or private key block, got: " + block.Type)
+               return nil, errors.InvalidArgumentError("expected public or private key block, got: " + block.Type)
        }
 
        return ReadKeyRing(block.Body)
@@ -203,7 +204,7 @@ func ReadKeyRing(r io.Reader) (el EntityList, err error) {
                var e *Entity
                e, err = readEntity(packets)
                if err != nil {
-                       if _, ok := err.(error_.UnsupportedError); ok {
+                       if _, ok := err.(errors.UnsupportedError); ok {
                                lastUnsupportedError = err
                                err = readToNextPublicKey(packets)
                        }
@@ -235,7 +236,7 @@ func readToNextPublicKey(packets *packet.Reader) (err error) {
                if err == io.EOF {
                        return
                } else if err != nil {
-                       if _, ok := err.(error_.UnsupportedError); ok {
+                       if _, ok := err.(errors.UnsupportedError); ok {
                                err = nil
                                continue
                        }
@@ -266,14 +267,14 @@ func readEntity(packets *packet.Reader) (*Entity, error) {
        if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok {
                if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
                        packets.Unread(p)
-                       return nil, error_.StructuralError("first packet was not a public/private key")
+                       return nil, errors.StructuralError("first packet was not a public/private key")
                } else {
                        e.PrimaryKey = &e.PrivateKey.PublicKey
                }
        }
 
        if !e.PrimaryKey.PubKeyAlgo.CanSign() {
-               return nil, error_.StructuralError("primary key cannot be used for signatures")
+               return nil, errors.StructuralError("primary key cannot be used for signatures")
        }
 
        var current *Identity
@@ -303,12 +304,12 @@ EachPacket:
 
                                sig, ok := p.(*packet.Signature)
                                if !ok {
-                                       return nil, error_.StructuralError("user ID packet not followed by self-signature")
+                                       return nil, errors.StructuralError("user ID packet not followed by self-signature")
                                }
 
                                if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
                                        if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
-                                               return nil, error_.StructuralError("user ID self-signature invalid: " + err.Error())
+                                               return nil, errors.StructuralError("user ID self-signature invalid: " + err.Error())
                                        }
                                        current.SelfSignature = sig
                                        break
@@ -317,7 +318,7 @@ EachPacket:
                        }
                case *packet.Signature:
                        if current == nil {
-                               return nil, error_.StructuralError("signature packet found before user id packet")
+                               return nil, errors.StructuralError("signature packet found before user id packet")
                        }
                        current.Signatures = append(current.Signatures, pkt)
                case *packet.PrivateKey:
@@ -344,7 +345,7 @@ EachPacket:
        }
 
        if len(e.Identities) == 0 {
-               return nil, error_.StructuralError("entity without any identities")
+               return nil, errors.StructuralError("entity without any identities")
        }
 
        return e, nil
@@ -359,19 +360,19 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
                return io.ErrUnexpectedEOF
        }
        if err != nil {
-               return error_.StructuralError("subkey signature invalid: " + err.Error())
+               return errors.StructuralError("subkey signature invalid: " + err.Error())
        }
        var ok bool
        subKey.Sig, ok = p.(*packet.Signature)
        if !ok {
-               return error_.StructuralError("subkey packet not followed by signature")
+               return errors.StructuralError("subkey packet not followed by signature")
        }
        if subKey.Sig.SigType != packet.SigTypeSubkeyBinding {
-               return error_.StructuralError("subkey signature with wrong type")
+               return errors.StructuralError("subkey signature with wrong type")
        }
        err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig)
        if err != nil {
-               return error_.StructuralError("subkey signature invalid: " + err.Error())
+               return errors.StructuralError("subkey signature invalid: " + err.Error())
        }
        e.Subkeys = append(e.Subkeys, subKey)
        return nil
@@ -385,7 +386,7 @@ const defaultRSAKeyBits = 2048
 func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) {
        uid := packet.NewUserId(name, comment, email)
        if uid == nil {
-               return nil, error_.InvalidArgumentError("user id field contained invalid characters")
+               return nil, errors.InvalidArgumentError("user id field contained invalid characters")
        }
        signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
        if err != nil {
@@ -397,8 +398,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
        }
 
        e := &Entity{
-               PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey, false /* not a subkey */ ),
-               PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv, false /* not a subkey */ ),
+               PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey),
+               PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv),
                Identities: make(map[string]*Identity),
        }
        isPrimaryId := true
@@ -420,8 +421,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
 
        e.Subkeys = make([]Subkey, 1)
        e.Subkeys[0] = Subkey{
-               PublicKey:  packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey, true /* is a subkey */ ),
-               PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv, true /* is a subkey */ ),
+               PublicKey:  packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey),
+               PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv),
                Sig: &packet.Signature{
                        CreationTime:              currentTime,
                        SigType:                   packet.SigTypeSubkeyBinding,
@@ -433,6 +434,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
                        IssuerKeyId:               &e.PrimaryKey.KeyId,
                },
        }
+       e.Subkeys[0].PublicKey.IsSubkey = true
+       e.Subkeys[0].PrivateKey.IsSubkey = true
 
        return e, nil
 }
@@ -450,7 +453,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
                if err != nil {
                        return
                }
-               err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
+               err = ident.SelfSignature.SignUserId(rand.Reader, ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
                if err != nil {
                        return
                }
@@ -464,7 +467,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
                if err != nil {
                        return
                }
-               err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey)
+               err = subkey.Sig.SignKey(rand.Reader, subkey.PublicKey, e.PrivateKey)
                if err != nil {
                        return
                }
@@ -518,14 +521,14 @@ func (e *Entity) Serialize(w io.Writer) error {
 // necessary.
 func (e *Entity) SignIdentity(identity string, signer *Entity) error {
        if signer.PrivateKey == nil {
-               return error_.InvalidArgumentError("signing Entity must have a private key")
+               return errors.InvalidArgumentError("signing Entity must have a private key")
        }
        if signer.PrivateKey.Encrypted {
-               return error_.InvalidArgumentError("signing Entity's private key must be decrypted")
+               return errors.InvalidArgumentError("signing Entity's private key must be decrypted")
        }
        ident, ok := e.Identities[identity]
        if !ok {
-               return error_.InvalidArgumentError("given identity string not found in Entity")
+               return errors.InvalidArgumentError("given identity string not found in Entity")
        }
 
        sig := &packet.Signature{
@@ -535,7 +538,7 @@ func (e *Entity) SignIdentity(identity string, signer *Entity) error {
                CreationTime: time.Now(),
                IssuerKeyId:  &signer.PrivateKey.KeyId,
        }
-       if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil {
+       if err := sig.SignKey(rand.Reader, e.PrimaryKey, signer.PrivateKey); err != nil {
                return err
        }
        ident.Signatures = append(ident.Signatures, sig)
index f80d798cfe6bfa16ffcb460ee3ae52d046caf062..36736e34a0edcd7a153fc8c38a2000a77ca5cd47 100644 (file)
@@ -7,7 +7,7 @@ package packet
 import (
        "compress/flate"
        "compress/zlib"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "io"
        "strconv"
 )
@@ -31,7 +31,7 @@ func (c *Compressed) parse(r io.Reader) error {
        case 2:
                c.Body, err = zlib.NewReader(r)
        default:
-               err = error_.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
+               err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
        }
 
        return err
index b24fa3a3fd3f0f05dcf4a98a2f51be2b82c6d743..479a643935ed1aae47333ae08616c0447ac99f5f 100644 (file)
@@ -6,7 +6,7 @@ package packet
 
 import (
        "crypto/openpgp/elgamal"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/rand"
        "crypto/rsa"
        "encoding/binary"
@@ -35,7 +35,7 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
                return
        }
        if buf[0] != encryptedKeyVersion {
-               return error_.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
+               return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
        }
        e.KeyId = binary.BigEndian.Uint64(buf[1:9])
        e.Algo = PublicKeyAlgorithm(buf[9])
@@ -77,7 +77,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
                c2 := new(big.Int).SetBytes(e.encryptedMPI2)
                b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
        default:
-               err = error_.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
+               err = errors.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
        }
 
        if err != nil {
@@ -89,7 +89,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
        expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
        checksum := checksumKeyMaterial(e.Key)
        if checksum != expectedChecksum {
-               return error_.StructuralError("EncryptedKey checksum incorrect")
+               return errors.StructuralError("EncryptedKey checksum incorrect")
        }
 
        return nil
@@ -116,16 +116,16 @@ func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFu
        case PubKeyAlgoElGamal:
                return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
        case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
-               return error_.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
+               return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
        }
 
-       return error_.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
+       return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
 }
 
 func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error {
        cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
        if err != nil {
-               return error_.InvalidArgumentError("RSA encryption failed: " + err.Error())
+               return errors.InvalidArgumentError("RSA encryption failed: " + err.Error())
        }
 
        packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
@@ -144,7 +144,7 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub
 func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error {
        c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
        if err != nil {
-               return error_.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
+               return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
        }
 
        packetLen := 10 /* header length */
index 13e6aa5aff8d5fa974daf98559d4bd9ff3cc7fc0..822cfe9b8f64ac559336d4a2771fb7078ffce96f 100644 (file)
@@ -6,7 +6,7 @@ package packet
 
 import (
        "crypto"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/s2k"
        "encoding/binary"
        "io"
@@ -33,13 +33,13 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) {
                return
        }
        if buf[0] != onePassSignatureVersion {
-               err = error_.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
+               err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
        }
 
        var ok bool
        ops.Hash, ok = s2k.HashIdToHash(buf[2])
        if !ok {
-               return error_.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
+               return errors.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
        }
 
        ops.SigType = SignatureType(buf[1])
@@ -57,7 +57,7 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error {
        var ok bool
        buf[2], ok = s2k.HashToHashId(ops.Hash)
        if !ok {
-               return error_.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
+               return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
        }
        buf[3] = uint8(ops.PubKeyAlgo)
        binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
index 778df15c0bd17bd39c509c4abb4410b560982b6e..f7c1964fd4ca795226019f359e996d22ce2f20f8 100644 (file)
@@ -10,7 +10,7 @@ import (
        "crypto/aes"
        "crypto/cast5"
        "crypto/cipher"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "io"
        "math/big"
 )
@@ -162,7 +162,7 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
                return
        }
        if buf[0]&0x80 == 0 {
-               err = error_.StructuralError("tag byte does not have MSB set")
+               err = errors.StructuralError("tag byte does not have MSB set")
                return
        }
        if buf[0]&0x40 == 0 {
@@ -337,7 +337,7 @@ func Read(r io.Reader) (p Packet, err error) {
                se.MDC = true
                p = se
        default:
-               err = error_.UnknownPacketTypeError(tag)
+               err = errors.UnknownPacketTypeError(tag)
        }
        if p != nil {
                err = p.parse(contents)
index 53266413c86c0f2ce28d53d5ce1cec6c0d0a2a98..e4b86914192c7201f8e117bc13398e2e352741e2 100644 (file)
@@ -6,7 +6,7 @@ package packet
 
 import (
        "bytes"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "encoding/hex"
        "fmt"
        "io"
@@ -152,7 +152,7 @@ func TestReadHeader(t *testing.T) {
        for i, test := range readHeaderTests {
                tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
                if test.structuralError {
-                       if _, ok := err.(error_.StructuralError); ok {
+                       if _, ok := err.(errors.StructuralError); ok {
                                continue
                        }
                        t.Errorf("%d: expected StructuralError, got:%s", i, err)
index d67e968861758a183fa60784b05ff4cb2a18f87e..5a90d0625fada59b0f237eab48508431b1c8421f 100644 (file)
@@ -9,7 +9,7 @@ import (
        "crypto/cipher"
        "crypto/dsa"
        "crypto/openpgp/elgamal"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/s2k"
        "crypto/rsa"
        "crypto/sha1"
@@ -28,14 +28,21 @@ type PrivateKey struct {
        encryptedData []byte
        cipher        CipherFunction
        s2k           func(out, in []byte)
-       PrivateKey    interface{} // An *rsa.PrivateKey.
+       PrivateKey    interface{} // An *rsa.PrivateKey or *dsa.PrivateKey.
        sha1Checksum  bool
        iv            []byte
 }
 
-func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
+func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey) *PrivateKey {
        pk := new(PrivateKey)
-       pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey, isSubkey)
+       pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey)
+       pk.PrivateKey = priv
+       return pk
+}
+
+func NewDSAPrivateKey(currentTime time.Time, priv *dsa.PrivateKey) *PrivateKey {
+       pk := new(PrivateKey)
+       pk.PublicKey = *NewDSAPublicKey(currentTime, &priv.PublicKey)
        pk.PrivateKey = priv
        return pk
 }
@@ -72,13 +79,13 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
                        pk.sha1Checksum = true
                }
        default:
-               return error_.UnsupportedError("deprecated s2k function in private key")
+               return errors.UnsupportedError("deprecated s2k function in private key")
        }
 
        if pk.Encrypted {
                blockSize := pk.cipher.blockSize()
                if blockSize == 0 {
-                       return error_.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
+                       return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
                }
                pk.iv = make([]byte, blockSize)
                _, err = readFull(r, pk.iv)
@@ -121,8 +128,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
        switch priv := pk.PrivateKey.(type) {
        case *rsa.PrivateKey:
                err = serializeRSAPrivateKey(privateKeyBuf, priv)
+       case *dsa.PrivateKey:
+               err = serializeDSAPrivateKey(privateKeyBuf, priv)
        default:
-               err = error_.InvalidArgumentError("non-RSA private key")
+               err = errors.InvalidArgumentError("unknown private key type")
        }
        if err != nil {
                return
@@ -172,6 +181,10 @@ func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) error {
        return writeBig(w, priv.Precomputed.Qinv)
 }
 
+func serializeDSAPrivateKey(w io.Writer, priv *dsa.PrivateKey) error {
+       return writeBig(w, priv.X)
+}
+
 // Decrypt decrypts an encrypted private key using a passphrase.
 func (pk *PrivateKey) Decrypt(passphrase []byte) error {
        if !pk.Encrypted {
@@ -188,18 +201,18 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
 
        if pk.sha1Checksum {
                if len(data) < sha1.Size {
-                       return error_.StructuralError("truncated private key data")
+                       return errors.StructuralError("truncated private key data")
                }
                h := sha1.New()
                h.Write(data[:len(data)-sha1.Size])
                sum := h.Sum(nil)
                if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
-                       return error_.StructuralError("private key checksum failure")
+                       return errors.StructuralError("private key checksum failure")
                }
                data = data[:len(data)-sha1.Size]
        } else {
                if len(data) < 2 {
-                       return error_.StructuralError("truncated private key data")
+                       return errors.StructuralError("truncated private key data")
                }
                var sum uint16
                for i := 0; i < len(data)-2; i++ {
@@ -207,7 +220,7 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
                }
                if data[len(data)-2] != uint8(sum>>8) ||
                        data[len(data)-1] != uint8(sum) {
-                       return error_.StructuralError("private key checksum failure")
+                       return errors.StructuralError("private key checksum failure")
                }
                data = data[:len(data)-2]
        }
index 9aa30e0c15f0c85ce67ca9e8bba6c30a2ae7ba8c..ba178b519ebdaa15c0b709c65ad7d92d66961818 100644 (file)
@@ -7,7 +7,7 @@ package packet
 import (
        "crypto/dsa"
        "crypto/openpgp/elgamal"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/rsa"
        "crypto/sha1"
        "encoding/binary"
@@ -39,12 +39,11 @@ func fromBig(n *big.Int) parsedMPI {
 }
 
 // NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
-func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
+func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey) *PublicKey {
        pk := &PublicKey{
                CreationTime: creationTime,
                PubKeyAlgo:   PubKeyAlgoRSA,
                PublicKey:    pub,
-               IsSubkey:     isSubkey,
                n:            fromBig(pub.N),
                e:            fromBig(big.NewInt(int64(pub.E))),
        }
@@ -53,6 +52,22 @@ func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool)
        return pk
 }
 
+// NewDSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
+func NewDSAPublicKey(creationTime time.Time, pub *dsa.PublicKey) *PublicKey {
+       pk := &PublicKey{
+               CreationTime: creationTime,
+               PubKeyAlgo:   PubKeyAlgoDSA,
+               PublicKey:    pub,
+               p:            fromBig(pub.P),
+               q:            fromBig(pub.Q),
+               g:            fromBig(pub.G),
+               y:            fromBig(pub.Y),
+       }
+
+       pk.setFingerPrintAndKeyId()
+       return pk
+}
+
 func (pk *PublicKey) parse(r io.Reader) (err error) {
        // RFC 4880, section 5.5.2
        var buf [6]byte
@@ -61,7 +76,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
                return
        }
        if buf[0] != 4 {
-               return error_.UnsupportedError("public key version")
+               return errors.UnsupportedError("public key version")
        }
        pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
        pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
@@ -73,7 +88,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
        case PubKeyAlgoElGamal:
                err = pk.parseElGamal(r)
        default:
-               err = error_.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
+               err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
        }
        if err != nil {
                return
@@ -105,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err error) {
        }
 
        if len(pk.e.bytes) > 3 {
-               err = error_.UnsupportedError("large public exponent")
+               err = errors.UnsupportedError("large public exponent")
                return
        }
        rsa := &rsa.PublicKey{
@@ -255,7 +270,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
        case PubKeyAlgoElGamal:
                return writeMPIs(w, pk.p, pk.g, pk.y)
        }
-       return error_.InvalidArgumentError("bad public-key algorithm")
+       return errors.InvalidArgumentError("bad public-key algorithm")
 }
 
 // CanSign returns true iff this public key can generate signatures
@@ -267,18 +282,18 @@ func (pk *PublicKey) CanSign() bool {
 // public key, of the data hashed into signed. signed is mutated by this call.
 func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) {
        if !pk.CanSign() {
-               return error_.InvalidArgumentError("public key cannot generate signatures")
+               return errors.InvalidArgumentError("public key cannot generate signatures")
        }
 
        signed.Write(sig.HashSuffix)
        hashBytes := signed.Sum(nil)
 
        if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
-               return error_.SignatureError("hash tag doesn't match")
+               return errors.SignatureError("hash tag doesn't match")
        }
 
        if pk.PubKeyAlgo != sig.PubKeyAlgo {
-               return error_.InvalidArgumentError("public key and signature use different algorithms")
+               return errors.InvalidArgumentError("public key and signature use different algorithms")
        }
 
        switch pk.PubKeyAlgo {
@@ -286,13 +301,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
                rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
                err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
                if err != nil {
-                       return error_.SignatureError("RSA verification failure")
+                       return errors.SignatureError("RSA verification failure")
                }
                return nil
        case PubKeyAlgoDSA:
                dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
+               // Need to truncate hashBytes to match FIPS 186-3 section 4.6.
+               subgroupSize := (dsaPublicKey.Q.BitLen() + 7) / 8
+               if len(hashBytes) > subgroupSize {
+                       hashBytes = hashBytes[:subgroupSize]
+               }
                if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
-                       return error_.SignatureError("DSA verification failure")
+                       return errors.SignatureError("DSA verification failure")
                }
                return nil
        default:
@@ -306,7 +326,7 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
 func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err error) {
        h = sig.Hash.New()
        if h == nil {
-               return nil, error_.UnsupportedError("hash function")
+               return nil, errors.UnsupportedError("hash function")
        }
 
        // RFC 4880, section 5.2.4
@@ -332,7 +352,7 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err
 func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err error) {
        h = sig.Hash.New()
        if h == nil {
-               return nil, error_.UnsupportedError("hash function")
+               return nil, errors.UnsupportedError("hash function")
        }
 
        // RFC 4880, section 5.2.4
index e3d733cb02192e24c194c9e18a8a9ecb9b390825..1a3e8e231338528d728ceff26301d422331c6486 100644 (file)
@@ -5,7 +5,7 @@
 package packet
 
 import (
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "io"
 )
 
@@ -34,7 +34,7 @@ func (r *Reader) Next() (p Packet, err error) {
                        r.readers = r.readers[:len(r.readers)-1]
                        continue
                }
-               if _, ok := err.(error_.UnknownPacketTypeError); !ok {
+               if _, ok := err.(errors.UnknownPacketTypeError); !ok {
                        return nil, err
                }
        }
index 1cdc1ee0f0c798e9dc62dc3dfec5e94b3669071e..c3ffb3a6fb9e50397bb35133de60edc87170dd7e 100644 (file)
@@ -7,9 +7,8 @@ package packet
 import (
        "crypto"
        "crypto/dsa"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/s2k"
-       "crypto/rand"
        "crypto/rsa"
        "encoding/binary"
        "hash"
@@ -61,7 +60,7 @@ func (sig *Signature) parse(r io.Reader) (err error) {
                return
        }
        if buf[0] != 4 {
-               err = error_.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0])))
+               err = errors.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0])))
                return
        }
 
@@ -74,14 +73,14 @@ func (sig *Signature) parse(r io.Reader) (err error) {
        switch sig.PubKeyAlgo {
        case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
        default:
-               err = error_.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
+               err = errors.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
                return
        }
 
        var ok bool
        sig.Hash, ok = s2k.HashIdToHash(buf[2])
        if !ok {
-               return error_.UnsupportedError("hash function " + strconv.Itoa(int(buf[2])))
+               return errors.UnsupportedError("hash function " + strconv.Itoa(int(buf[2])))
        }
 
        hashedSubpacketsLength := int(buf[3])<<8 | int(buf[4])
@@ -153,7 +152,7 @@ func parseSignatureSubpackets(sig *Signature, subpackets []byte, isHashed bool)
        }
 
        if sig.CreationTime.IsZero() {
-               err = error_.StructuralError("no creation time in signature")
+               err = errors.StructuralError("no creation time in signature")
        }
 
        return
@@ -164,7 +163,7 @@ type signatureSubpacketType uint8
 const (
        creationTimeSubpacket        signatureSubpacketType = 2
        signatureExpirationSubpacket signatureSubpacketType = 3
-       keyExpirySubpacket           signatureSubpacketType = 9
+       keyExpirationSubpacket       signatureSubpacketType = 9
        prefSymmetricAlgosSubpacket  signatureSubpacketType = 11
        issuerSubpacket              signatureSubpacketType = 16
        prefHashAlgosSubpacket       signatureSubpacketType = 21
@@ -207,7 +206,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
        rest = subpacket[length:]
        subpacket = subpacket[:length]
        if len(subpacket) == 0 {
-               err = error_.StructuralError("zero length signature subpacket")
+               err = errors.StructuralError("zero length signature subpacket")
                return
        }
        packetType = signatureSubpacketType(subpacket[0] & 0x7f)
@@ -217,37 +216,33 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
        switch packetType {
        case creationTimeSubpacket:
                if !isHashed {
-                       err = error_.StructuralError("signature creation time in non-hashed area")
+                       err = errors.StructuralError("signature creation time in non-hashed area")
                        return
                }
                if len(subpacket) != 4 {
-                       err = error_.StructuralError("signature creation time not four bytes")
+                       err = errors.StructuralError("signature creation time not four bytes")
                        return
                }
                t := binary.BigEndian.Uint32(subpacket)
-               if t == 0 {
-                       sig.CreationTime = time.Time{}
-               } else {
-                       sig.CreationTime = time.Unix(int64(t), 0)
-               }
+               sig.CreationTime = time.Unix(int64(t), 0)
        case signatureExpirationSubpacket:
                // Signature expiration time, section 5.2.3.10
                if !isHashed {
                        return
                }
                if len(subpacket) != 4 {
-                       err = error_.StructuralError("expiration subpacket with bad length")
+                       err = errors.StructuralError("expiration subpacket with bad length")
                        return
                }
                sig.SigLifetimeSecs = new(uint32)
                *sig.SigLifetimeSecs = binary.BigEndian.Uint32(subpacket)
-       case keyExpirySubpacket:
+       case keyExpirationSubpacket:
                // Key expiration time, section 5.2.3.6
                if !isHashed {
                        return
                }
                if len(subpacket) != 4 {
-                       err = error_.StructuralError("key expiration subpacket with bad length")
+                       err = errors.StructuralError("key expiration subpacket with bad length")
                        return
                }
                sig.KeyLifetimeSecs = new(uint32)
@@ -262,7 +257,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
        case issuerSubpacket:
                // Issuer, section 5.2.3.5
                if len(subpacket) != 8 {
-                       err = error_.StructuralError("issuer subpacket with bad length")
+                       err = errors.StructuralError("issuer subpacket with bad length")
                        return
                }
                sig.IssuerKeyId = new(uint64)
@@ -287,7 +282,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
                        return
                }
                if len(subpacket) != 1 {
-                       err = error_.StructuralError("primary user id subpacket with bad length")
+                       err = errors.StructuralError("primary user id subpacket with bad length")
                        return
                }
                sig.IsPrimaryId = new(bool)
@@ -300,7 +295,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
                        return
                }
                if len(subpacket) == 0 {
-                       err = error_.StructuralError("empty key flags subpacket")
+                       err = errors.StructuralError("empty key flags subpacket")
                        return
                }
                sig.FlagsValid = true
@@ -319,14 +314,14 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
 
        default:
                if isCritical {
-                       err = error_.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
+                       err = errors.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
                        return
                }
        }
        return
 
 Truncated:
-       err = error_.StructuralError("signature subpacket truncated")
+       err = errors.StructuralError("signature subpacket truncated")
        return
 }
 
@@ -401,7 +396,7 @@ func (sig *Signature) buildHashSuffix() (err error) {
        sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash)
        if !ok {
                sig.HashSuffix = nil
-               return error_.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
+               return errors.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
        }
        sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
        sig.HashSuffix[5] = byte(hashedSubpacketsLen)
@@ -431,7 +426,7 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err error) {
 // Sign signs a message with a private key. The hash, h, must contain
 // the hash of the message to be signed and will be mutated by this function.
 // On success, the signature is stored in sig. Call Serialize to write it out.
-func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
+func (sig *Signature) Sign(rand io.Reader, h hash.Hash, priv *PrivateKey) (err error) {
        sig.outSubpackets = sig.buildSubpackets()
        digest, err := sig.signPrepareHash(h)
        if err != nil {
@@ -440,10 +435,17 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
 
        switch priv.PubKeyAlgo {
        case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
-               sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest)
+               sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest)
                sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes))
        case PubKeyAlgoDSA:
-               r, s, err := dsa.Sign(rand.Reader, priv.PrivateKey.(*dsa.PrivateKey), digest)
+               dsaPriv := priv.PrivateKey.(*dsa.PrivateKey)
+
+               // Need to truncate hashBytes to match FIPS 186-3 section 4.6.
+               subgroupSize := (dsaPriv.Q.BitLen() + 7) / 8
+               if len(digest) > subgroupSize {
+                       digest = digest[:subgroupSize]
+               }
+               r, s, err := dsa.Sign(rand, dsaPriv, digest)
                if err == nil {
                        sig.DSASigR.bytes = r.Bytes()
                        sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes))
@@ -451,7 +453,7 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
                        sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes))
                }
        default:
-               err = error_.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
+               err = errors.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
        }
 
        return
@@ -460,22 +462,22 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err error) {
 // SignUserId computes a signature from priv, asserting that pub is a valid
 // key for the identity id.  On success, the signature is stored in sig. Call
 // Serialize to write it out.
-func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey) error {
+func (sig *Signature) SignUserId(rand io.Reader, id string, pub *PublicKey, priv *PrivateKey) error {
        h, err := userIdSignatureHash(id, pub, sig)
        if err != nil {
                return nil
        }
-       return sig.Sign(h, priv)
+       return sig.Sign(rand, h, priv)
 }
 
 // SignKey computes a signature from priv, asserting that pub is a subkey.  On
 // success, the signature is stored in sig. Call Serialize to write it out.
-func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey) error {
+func (sig *Signature) SignKey(rand io.Reader, pub *PublicKey, priv *PrivateKey) error {
        h, err := keySignatureHash(&priv.PublicKey, pub, sig)
        if err != nil {
                return err
        }
-       return sig.Sign(h, priv)
+       return sig.Sign(rand, h, priv)
 }
 
 // Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
@@ -484,7 +486,7 @@ func (sig *Signature) Serialize(w io.Writer) (err error) {
                sig.outSubpackets = sig.rawSubpackets
        }
        if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil {
-               return error_.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
+               return errors.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
        }
 
        sigLength := 0
@@ -556,5 +558,54 @@ func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
                subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId})
        }
 
+       if sig.SigLifetimeSecs != nil && *sig.SigLifetimeSecs != 0 {
+               sigLifetime := make([]byte, 4)
+               binary.BigEndian.PutUint32(sigLifetime, *sig.SigLifetimeSecs)
+               subpackets = append(subpackets, outputSubpacket{true, signatureExpirationSubpacket, true, sigLifetime})
+       }
+
+       // Key flags may only appear in self-signatures or certification signatures.
+
+       if sig.FlagsValid {
+               var flags byte
+               if sig.FlagCertify {
+                       flags |= 1
+               }
+               if sig.FlagSign {
+                       flags |= 2
+               }
+               if sig.FlagEncryptCommunications {
+                       flags |= 4
+               }
+               if sig.FlagEncryptStorage {
+                       flags |= 8
+               }
+               subpackets = append(subpackets, outputSubpacket{true, keyFlagsSubpacket, false, []byte{flags}})
+       }
+
+       // The following subpackets may only appear in self-signatures
+
+       if sig.KeyLifetimeSecs != nil && *sig.KeyLifetimeSecs != 0 {
+               keyLifetime := make([]byte, 4)
+               binary.BigEndian.PutUint32(keyLifetime, *sig.KeyLifetimeSecs)
+               subpackets = append(subpackets, outputSubpacket{true, keyExpirationSubpacket, true, keyLifetime})
+       }
+
+       if sig.IsPrimaryId != nil && *sig.IsPrimaryId {
+               subpackets = append(subpackets, outputSubpacket{true, primaryUserIdSubpacket, false, []byte{1}})
+       }
+
+       if len(sig.PreferredSymmetric) > 0 {
+               subpackets = append(subpackets, outputSubpacket{true, prefSymmetricAlgosSubpacket, false, sig.PreferredSymmetric})
+       }
+
+       if len(sig.PreferredHash) > 0 {
+               subpackets = append(subpackets, outputSubpacket{true, prefHashAlgosSubpacket, false, sig.PreferredHash})
+       }
+
+       if len(sig.PreferredCompression) > 0 {
+               subpackets = append(subpackets, outputSubpacket{true, prefCompressionSubpacket, false, sig.PreferredCompression})
+       }
+
        return
 }
index 76d5151379a9684c1935e492bef0e17b53219583..94e0705040112f95339a8f57ef1d7b33172f9061 100644 (file)
@@ -7,7 +7,7 @@ package packet
 import (
        "bytes"
        "crypto/cipher"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/s2k"
        "io"
        "strconv"
@@ -37,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
                return
        }
        if buf[0] != symmetricKeyEncryptedVersion {
-               return error_.UnsupportedError("SymmetricKeyEncrypted version")
+               return errors.UnsupportedError("SymmetricKeyEncrypted version")
        }
        ske.CipherFunc = CipherFunction(buf[1])
 
        if ske.CipherFunc.KeySize() == 0 {
-               return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
+               return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
        }
 
        ske.s2k, err = s2k.Parse(r)
@@ -60,7 +60,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
        err = nil
        if n != 0 {
                if n == maxSessionKeySizeInBytes {
-                       return error_.UnsupportedError("oversized encrypted session key")
+                       return errors.UnsupportedError("oversized encrypted session key")
                }
                ske.encryptedKey = encryptedKey[:n]
        }
@@ -89,13 +89,13 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
                c.XORKeyStream(ske.encryptedKey, ske.encryptedKey)
                ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
                if ske.CipherFunc.blockSize() == 0 {
-                       return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
+                       return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
                }
                ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
                ske.Key = ske.encryptedKey[1:]
                if len(ske.Key)%ske.CipherFunc.blockSize() != 0 {
                        ske.Key = nil
-                       return error_.StructuralError("length of decrypted key not a multiple of block size")
+                       return errors.StructuralError("length of decrypted key not a multiple of block size")
                }
        }
 
@@ -110,7 +110,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
 func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err error) {
        keySize := cipherFunc.KeySize()
        if keySize == 0 {
-               return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
+               return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
        }
 
        s2kBuf := new(bytes.Buffer)
index dff776e3eb2f0ceb9dfd881c0f11df1558af42cc..e99a23b9fb205ecc8f50152708b9b80286fe002b 100644 (file)
@@ -6,8 +6,7 @@ package packet
 
 import (
        "crypto/cipher"
-       error_ "crypto/openpgp/error"
-       "crypto/rand"
+       "crypto/openpgp/errors"
        "crypto/sha1"
        "crypto/subtle"
        "hash"
@@ -35,7 +34,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
                        return err
                }
                if buf[0] != symmetricallyEncryptedVersion {
-                       return error_.UnsupportedError("unknown SymmetricallyEncrypted version")
+                       return errors.UnsupportedError("unknown SymmetricallyEncrypted version")
                }
        }
        se.contents = r
@@ -48,10 +47,10 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
 func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) {
        keySize := c.KeySize()
        if keySize == 0 {
-               return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
+               return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
        }
        if len(key) != keySize {
-               return nil, error_.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
+               return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
        }
 
        if se.prefix == nil {
@@ -61,7 +60,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
                        return nil, err
                }
        } else if len(se.prefix) != c.blockSize()+2 {
-               return nil, error_.InvalidArgumentError("can't try ciphers with different block lengths")
+               return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
        }
 
        ocfbResync := cipher.OCFBResync
@@ -72,7 +71,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
 
        s := cipher.NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
        if s == nil {
-               return nil, error_.KeyIncorrectError
+               return nil, errors.KeyIncorrectError
        }
 
        plaintext := cipher.StreamReader{S: s, R: se.contents}
@@ -181,7 +180,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19
 
 func (ser *seMDCReader) Close() error {
        if ser.error {
-               return error_.SignatureError("error during reading")
+               return errors.SignatureError("error during reading")
        }
 
        for !ser.eof {
@@ -192,18 +191,18 @@ func (ser *seMDCReader) Close() error {
                        break
                }
                if err != nil {
-                       return error_.SignatureError("error during reading")
+                       return errors.SignatureError("error during reading")
                }
        }
 
        if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
-               return error_.SignatureError("MDC packet not found")
+               return errors.SignatureError("MDC packet not found")
        }
        ser.h.Write(ser.trailer[:2])
 
        final := ser.h.Sum(nil)
        if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
-               return error_.SignatureError("hash mismatch")
+               return errors.SignatureError("hash mismatch")
        }
        return nil
 }
@@ -253,9 +252,9 @@ func (c noOpCloser) Close() error {
 // SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
 // to w and returns a WriteCloser to which the to-be-encrypted packets can be
 // written.
-func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err error) {
+func SerializeSymmetricallyEncrypted(w io.Writer, rand io.Reader, c CipherFunction, key []byte) (contents io.WriteCloser, err error) {
        if c.KeySize() != len(key) {
-               return nil, error_.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
+               return nil, errors.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
        }
        writeCloser := noOpCloser{w}
        ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
@@ -271,7 +270,7 @@ func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte)
        block := c.new(key)
        blockSize := block.BlockSize()
        iv := make([]byte, blockSize)
-       _, err = rand.Reader.Read(iv)
+       _, err = rand.Read(iv)
        if err != nil {
                return
        }
index 8eee9713983e38bca19daf10498bfe92cb369f07..f7d133d0bbeeb6fbe467c77b7cd4a344fc87afb5 100644 (file)
@@ -6,7 +6,8 @@ package packet
 
 import (
        "bytes"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
+       "crypto/rand"
        "crypto/sha1"
        "encoding/hex"
        "io"
@@ -70,7 +71,7 @@ func testMDCReader(t *testing.T) {
        err = mdcReader.Close()
        if err == nil {
                t.Error("corruption: no error")
-       } else if _, ok := err.(*error_.SignatureError); !ok {
+       } else if _, ok := err.(*errors.SignatureError); !ok {
                t.Errorf("corruption: expected SignatureError, got: %s", err)
        }
 }
@@ -82,7 +83,7 @@ func TestSerialize(t *testing.T) {
        c := CipherAES128
        key := make([]byte, c.KeySize())
 
-       w, err := SerializeSymmetricallyEncrypted(buf, c, key)
+       w, err := SerializeSymmetricallyEncrypted(buf, rand.Reader, c, key)
        if err != nil {
                t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
                return
index 76fb1ead9f01a0d22792d306dc3049c6ba348723..1d2343470412309278305b952cbec34b818d836f 100644 (file)
@@ -8,7 +8,7 @@ package openpgp
 import (
        "crypto"
        "crypto/openpgp/armor"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/packet"
        _ "crypto/sha256"
        "hash"
@@ -27,7 +27,7 @@ func readArmored(r io.Reader, expectedType string) (body io.Reader, err error) {
        }
 
        if block.Type != expectedType {
-               return nil, error_.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
+               return nil, errors.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
        }
 
        return block.Body, nil
@@ -130,7 +130,7 @@ ParsePackets:
                case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature:
                        // This message isn't encrypted.
                        if len(symKeys) != 0 || len(pubKeys) != 0 {
-                               return nil, error_.StructuralError("key material not followed by encrypted message")
+                               return nil, errors.StructuralError("key material not followed by encrypted message")
                        }
                        packets.Unread(p)
                        return readSignedMessage(packets, nil, keyring)
@@ -161,7 +161,7 @@ FindKey:
                                        continue
                                }
                                decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key)
-                               if err != nil && err != error_.KeyIncorrectError {
+                               if err != nil && err != errors.KeyIncorrectError {
                                        return nil, err
                                }
                                if decrypted != nil {
@@ -179,11 +179,11 @@ FindKey:
                }
 
                if len(candidates) == 0 && len(symKeys) == 0 {
-                       return nil, error_.KeyIncorrectError
+                       return nil, errors.KeyIncorrectError
                }
 
                if prompt == nil {
-                       return nil, error_.KeyIncorrectError
+                       return nil, errors.KeyIncorrectError
                }
 
                passphrase, err := prompt(candidates, len(symKeys) != 0)
@@ -197,7 +197,7 @@ FindKey:
                                err = s.Decrypt(passphrase)
                                if err == nil && !s.Encrypted {
                                        decrypted, err = se.Decrypt(s.CipherFunc, s.Key)
-                                       if err != nil && err != error_.KeyIncorrectError {
+                                       if err != nil && err != errors.KeyIncorrectError {
                                                return nil, err
                                        }
                                        if decrypted != nil {
@@ -237,7 +237,7 @@ FindLiteralData:
                        packets.Push(p.Body)
                case *packet.OnePassSignature:
                        if !p.IsLast {
-                               return nil, error_.UnsupportedError("nested signatures")
+                               return nil, errors.UnsupportedError("nested signatures")
                        }
 
                        h, wrappedHash, err = hashForSignature(p.Hash, p.SigType)
@@ -281,7 +281,7 @@ FindLiteralData:
 func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) {
        h := hashId.New()
        if h == nil {
-               return nil, nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
+               return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
        }
 
        switch sigType {
@@ -291,7 +291,7 @@ func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Ha
                return h, NewCanonicalTextHash(h), nil
        }
 
-       return nil, nil, error_.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
+       return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
 }
 
 // checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF
@@ -333,7 +333,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) {
 
                var ok bool
                if scr.md.Signature, ok = p.(*packet.Signature); !ok {
-                       scr.md.SignatureError = error_.StructuralError("LiteralData not followed by Signature")
+                       scr.md.SignatureError = errors.StructuralError("LiteralData not followed by Signature")
                        return
                }
 
@@ -363,16 +363,16 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
 
        sig, ok := p.(*packet.Signature)
        if !ok {
-               return nil, error_.StructuralError("non signature packet found")
+               return nil, errors.StructuralError("non signature packet found")
        }
 
        if sig.IssuerKeyId == nil {
-               return nil, error_.StructuralError("signature doesn't have an issuer")
+               return nil, errors.StructuralError("signature doesn't have an issuer")
        }
 
        keys := keyring.KeysById(*sig.IssuerKeyId)
        if len(keys) == 0 {
-               return nil, error_.UnknownIssuerError
+               return nil, errors.UnknownIssuerError
        }
 
        h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType)
@@ -399,7 +399,7 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
                return
        }
 
-       return nil, error_.UnknownIssuerError
+       return nil, errors.UnknownIssuerError
 }
 
 // CheckArmoredDetachedSignature performs the same actions as
index e8a6bf5992e820ff0838316623aededa0415f82a..d1ecad3817958cff067ffc8ff60771c2da05ec78 100644 (file)
@@ -6,7 +6,8 @@ package openpgp
 
 import (
        "bytes"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
+       _ "crypto/sha512"
        "encoding/hex"
        "io"
        "io/ioutil"
@@ -77,6 +78,15 @@ func TestReadDSAKey(t *testing.T) {
        }
 }
 
+func TestDSAHashTruncatation(t *testing.T) {
+       // dsaKeyWithSHA512 was generated with GnuPG and --cert-digest-algo
+       // SHA512 in order to require DSA hash truncation to verify correctly.
+       _, err := ReadKeyRing(readerFromHex(dsaKeyWithSHA512))
+       if err != nil {
+               t.Error(err)
+       }
+}
+
 func TestGetKeyById(t *testing.T) {
        kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex))
 
@@ -151,18 +161,18 @@ func TestSignedEncryptedMessage(t *testing.T) {
                prompt := func(keys []Key, symmetric bool) ([]byte, error) {
                        if symmetric {
                                t.Errorf("prompt: message was marked as symmetrically encrypted")
-                               return nil, error_.KeyIncorrectError
+                               return nil, errors.KeyIncorrectError
                        }
 
                        if len(keys) == 0 {
                                t.Error("prompt: no keys requested")
-                               return nil, error_.KeyIncorrectError
+                               return nil, errors.KeyIncorrectError
                        }
 
                        err := keys[0].PrivateKey.Decrypt([]byte("passphrase"))
                        if err != nil {
                                t.Errorf("prompt: error decrypting key: %s", err)
-                               return nil, error_.KeyIncorrectError
+                               return nil, errors.KeyIncorrectError
                        }
 
                        return nil, nil
@@ -286,7 +296,7 @@ func TestReadingArmoredPrivateKey(t *testing.T) {
 
 func TestNoArmoredData(t *testing.T) {
        _, err := ReadArmoredKeyRing(bytes.NewBufferString("foo"))
-       if _, ok := err.(error_.InvalidArgumentError); !ok {
+       if _, ok := err.(errors.InvalidArgumentError); !ok {
                t.Errorf("error was not an InvalidArgumentError: %s", err)
        }
 }
@@ -358,3 +368,5 @@ AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL
 VrM0m72/jnpKo04=
 =zNCn
 -----END PGP PRIVATE KEY BLOCK-----`
+
+const dsaKeyWithSHA512 = `9901a2044f04b07f110400db244efecc7316553ee08d179972aab87bb1214de7692593fcf5b6feb1c80fba268722dd464748539b85b81d574cd2d7ad0ca2444de4d849b8756bad7768c486c83a824f9bba4af773d11742bdfb4ac3b89ef8cc9452d4aad31a37e4b630d33927bff68e879284a1672659b8b298222fc68f370f3e24dccacc4a862442b9438b00a0ea444a24088dc23e26df7daf8f43cba3bffc4fe703fe3d6cd7fdca199d54ed8ae501c30e3ec7871ea9cdd4cf63cfe6fc82281d70a5b8bb493f922cd99fba5f088935596af087c8d818d5ec4d0b9afa7f070b3d7c1dd32a84fca08d8280b4890c8da1dde334de8e3cad8450eed2a4a4fcc2db7b8e5528b869a74a7f0189e11ef097ef1253582348de072bb07a9fa8ab838e993cef0ee203ff49298723e2d1f549b00559f886cd417a41692ce58d0ac1307dc71d85a8af21b0cf6eaa14baf2922d3a70389bedf17cc514ba0febbd107675a372fe84b90162a9e88b14d4b1c6be855b96b33fb198c46f058568817780435b6936167ebb3724b680f32bf27382ada2e37a879b3d9de2abe0c3f399350afd1ad438883f4791e2e3b4184453412068617368207472756e636174696f6e207465737488620413110a002205024f04b07f021b03060b090807030206150802090a0b0416020301021e01021780000a0910ef20e0cefca131581318009e2bf3bf047a44d75a9bacd00161ee04d435522397009a03a60d51bd8a568c6c021c8d7cf1be8d990d6417b0020003`
index 8bc0bb320bb675d9e6f33e0009aa6bf469c7935e..39479a1f1c6c10b887ef64a7e9b8112b8cd7b83e 100644 (file)
@@ -8,7 +8,7 @@ package s2k
 
 import (
        "crypto"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "hash"
        "io"
        "strconv"
@@ -89,11 +89,11 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
 
        hash, ok := HashIdToHash(buf[1])
        if !ok {
-               return nil, error_.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
+               return nil, errors.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
        }
        h := hash.New()
        if h == nil {
-               return nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
+               return nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
        }
 
        switch buf[0] {
@@ -123,7 +123,7 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
                return f, nil
        }
 
-       return nil, error_.UnsupportedError("S2K function")
+       return nil, errors.UnsupportedError("S2K function")
 }
 
 // Serialize salts and stretches the given passphrase and writes the resulting
index bdee57d767c4977e225ae960cdd98b43530bdd42..73daa11312119cd2f8be27a1dcc083ef6bcf29bc 100644 (file)
@@ -7,7 +7,7 @@ package openpgp
 import (
        "crypto"
        "crypto/openpgp/armor"
-       error_ "crypto/openpgp/error"
+       "crypto/openpgp/errors"
        "crypto/openpgp/packet"
        "crypto/openpgp/s2k"
        "crypto/rand"
@@ -58,10 +58,10 @@ func armoredDetachSign(w io.Writer, signer *Entity, message io.Reader, sigType p
 
 func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType) (err error) {
        if signer.PrivateKey == nil {
-               return error_.InvalidArgumentError("signing key doesn't have a private key")
+               return errors.InvalidArgumentError("signing key doesn't have a private key")
        }
        if signer.PrivateKey.Encrypted {
-               return error_.InvalidArgumentError("signing key is encrypted")
+               return errors.InvalidArgumentError("signing key is encrypted")
        }
 
        sig := new(packet.Signature)
@@ -77,7 +77,7 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
        }
        io.Copy(wrappedHash, message)
 
-       err = sig.Sign(h, signer.PrivateKey)
+       err = sig.Sign(rand.Reader, h, signer.PrivateKey)
        if err != nil {
                return
        }
@@ -111,7 +111,7 @@ func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHi
        if err != nil {
                return
        }
-       w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, packet.CipherAES128, key)
+       w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, packet.CipherAES128, key)
        if err != nil {
                return
        }
@@ -156,7 +156,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
        if signed != nil {
                signer = signed.signingKey().PrivateKey
                if signer == nil || signer.Encrypted {
-                       return nil, error_.InvalidArgumentError("signing key must be decrypted")
+                       return nil, errors.InvalidArgumentError("signing key must be decrypted")
                }
        }
 
@@ -183,7 +183,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
        for i := range to {
                encryptKeys[i] = to[i].encryptionKey()
                if encryptKeys[i].PublicKey == nil {
-                       return nil, error_.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
+                       return nil, errors.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
                }
 
                sig := to[i].primaryIdentity().SelfSignature
@@ -201,7 +201,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
        }
 
        if len(candidateCiphers) == 0 || len(candidateHashes) == 0 {
-               return nil, error_.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
+               return nil, errors.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
        }
 
        cipher := packet.CipherFunction(candidateCiphers[0])
@@ -217,7 +217,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
                }
        }
 
-       encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, cipher, symKey)
+       encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, cipher, symKey)
        if err != nil {
                return
        }
@@ -287,7 +287,7 @@ func (s signatureWriter) Close() error {
                IssuerKeyId:  &s.signer.KeyId,
        }
 
-       if err := sig.Sign(s.h, s.signer); err != nil {
+       if err := sig.Sign(rand.Reader, s.h, s.signer); err != nil {
                return err
        }
        if err := s.literalData.Close(); err != nil {
index 02fa5b75bff626300b2476ab5c1ab8cd1d83dd30..7df02e7bd134198d3c7d829387c7e2561813b21e 100644 (file)
@@ -222,7 +222,7 @@ func TestEncryption(t *testing.T) {
 
                if test.isSigned {
                        if md.SignatureError != nil {
-                               t.Errorf("#%d: signature error: %s", i, err)
+                               t.Errorf("#%d: signature error: %s", i, md.SignatureError)
                        }
                        if md.Signature == nil {
                                t.Error("signature missing")
index a461ad951b011447f6bf4ed43f88a5dc1ad8548a..25f7a920cd380e066ef84a98456113477e964ed5 100644 (file)
@@ -111,6 +111,18 @@ type ConnectionState struct {
        VerifiedChains [][]*x509.Certificate
 }
 
+// ClientAuthType declares the policy the server will follow for
+// TLS Client Authentication.
+type ClientAuthType int
+
+const (
+       NoClientCert ClientAuthType = iota
+       RequestClientCert
+       RequireAnyClientCert
+       VerifyClientCertIfGiven
+       RequireAndVerifyClientCert
+)
+
 // A Config structure is used to configure a TLS client or server. After one
 // has been passed to a TLS function it must not be modified.
 type Config struct {
@@ -120,7 +132,7 @@ type Config struct {
        Rand io.Reader
 
        // Time returns the current time as the number of seconds since the epoch.
-       // If Time is nil, TLS uses the system time.Seconds.
+       // If Time is nil, TLS uses time.Now.
        Time func() time.Time
 
        // Certificates contains one or more certificate chains
@@ -148,11 +160,14 @@ type Config struct {
        // hosting.
        ServerName string
 
-       // AuthenticateClient controls whether a server will request a certificate
-       // from the client. It does not require that the client send a
-       // certificate nor does it require that the certificate sent be
-       // anything more than self-signed.
-       AuthenticateClient bool
+       // ClientAuth determines the server's policy for
+       // TLS Client Authentication. The default is NoClientCert.
+       ClientAuth ClientAuthType
+
+       // ClientCAs defines the set of root certificate authorities
+       // that servers use if required to verify a client certificate
+       // by the policy in ClientAuth.
+       ClientCAs *x509.CertPool
 
        // InsecureSkipVerify controls whether a client verifies the
        // server's certificate chain and host name.
@@ -259,6 +274,11 @@ type Certificate struct {
        // OCSPStaple contains an optional OCSP response which will be served
        // to clients that request it.
        OCSPStaple []byte
+       // Leaf is the parsed form of the leaf certificate, which may be
+       // initialized using x509.ParseCertificate to reduce per-handshake
+       // processing for TLS clients doing client authentication. If nil, the
+       // leaf certificate will be parsed as needed.
+       Leaf *x509.Certificate
 }
 
 // A TLS record.
index c4463ff48f8bd096fa7736dd8341969a5d4be817..7c0718b82ab5765576d7d4778839afa1e964417a 100644 (file)
@@ -31,7 +31,7 @@ func main() {
                return
        }
 
-       now := time.Seconds()
+       now := time.Now()
 
        template := x509.Certificate{
                SerialNumber: new(big.Int).SetInt64(0),
@@ -39,8 +39,8 @@ func main() {
                        CommonName:   *hostName,
                        Organization: []string{"Acme Co"},
                },
-               NotBefore: time.SecondsToUTC(now - 300),
-               NotAfter:  time.SecondsToUTC(now + 60*60*24*365), // valid for 1 year.
+               NotBefore: now.Add(-5 * time.Minute).UTC(),
+               NotAfter:  now.AddDate(1, 0, 0).UTC(), // valid for 1 year.
 
                SubjectKeyId: []byte{1, 2, 3, 4},
                KeyUsage:     x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
index 73648002bd58bd74db619cb74a2fd89ac7cc322e..632ceea9c1a2adbdfb7261cd373666770d519c6c 100644 (file)
@@ -5,12 +5,14 @@
 package tls
 
 import (
+       "bytes"
        "crypto"
        "crypto/rsa"
        "crypto/subtle"
        "crypto/x509"
        "errors"
        "io"
+       "strconv"
 )
 
 func (c *Conn) clientHandshake() error {
@@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error {
                }
        }
 
-       transmitCert := false
+       var certToSend *Certificate
        certReq, ok := msg.(*certificateRequestMsg)
        if ok {
-               // We only accept certificates with RSA keys.
+               // RFC 4346 on the certificateAuthorities field:
+               // A list of the distinguished names of acceptable certificate
+               // authorities. These distinguished names may specify a desired
+               // distinguished name for a root CA or for a subordinate CA;
+               // thus, this message can be used to describe both known roots
+               // and a desired authorization space. If the
+               // certificate_authorities list is empty then the client MAY
+               // send any certificate of the appropriate
+               // ClientCertificateType, unless there is some external
+               // arrangement to the contrary.
+
+               finishedHash.Write(certReq.marshal())
+
+               // For now, we only know how to sign challenges with RSA
                rsaAvail := false
                for _, certType := range certReq.certificateTypes {
                        if certType == certTypeRSASign {
@@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error {
                        }
                }
 
-               // For now, only send a certificate back if the server gives us an
-               // empty list of certificateAuthorities.
-               //
-               // RFC 4346 on the certificateAuthorities field:
-               // A list of the distinguished names of acceptable certificate
-               // authorities.  These distinguished names may specify a desired
-               // distinguished name for a root CA or for a subordinate CA; thus,
-               // this message can be used to describe both known roots and a
-               // desired authorization space.  If the certificate_authorities
-               // list is empty then the client MAY send any certificate of the
-               // appropriate ClientCertificateType, unless there is some
-               // external arrangement to the contrary.
-               if rsaAvail && len(certReq.certificateAuthorities) == 0 {
-                       transmitCert = true
-               }
+               // We need to search our list of client certs for one
+               // where SignatureAlgorithm is RSA and the Issuer is in
+               // certReq.certificateAuthorities
+       findCert:
+               for i, cert := range c.config.Certificates {
+                       if !rsaAvail {
+                               continue
+                       }
 
-               finishedHash.Write(certReq.marshal())
+                       leaf := cert.Leaf
+                       if leaf == nil {
+                               if leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
+                                       c.sendAlert(alertInternalError)
+                                       return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
+                               }
+                       }
+
+                       if leaf.PublicKeyAlgorithm != x509.RSA {
+                               continue
+                       }
+
+                       if len(certReq.certificateAuthorities) == 0 {
+                               // they gave us an empty list, so just take the
+                               // first RSA cert from c.config.Certificates
+                               certToSend = &cert
+                               break
+                       }
+
+                       for _, ca := range certReq.certificateAuthorities {
+                               if bytes.Equal(leaf.RawIssuer, ca) {
+                                       certToSend = &cert
+                                       break findCert
+                               }
+                       }
+               }
 
                msg, err = c.readHandshake()
                if err != nil {
@@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error {
        }
        finishedHash.Write(shd.marshal())
 
-       var cert *x509.Certificate
-       if transmitCert {
+       if certToSend != nil {
                certMsg = new(certificateMsg)
-               if len(c.config.Certificates) > 0 {
-                       cert, err = x509.ParseCertificate(c.config.Certificates[0].Certificate[0])
-                       if err == nil && cert.PublicKeyAlgorithm == x509.RSA {
-                               certMsg.certificates = c.config.Certificates[0].Certificate
-                       } else {
-                               cert = nil
-                       }
-               }
+               certMsg.certificates = certToSend.Certificate
                finishedHash.Write(certMsg.marshal())
                c.writeRecord(recordTypeHandshake, certMsg.marshal())
        }
@@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error {
                c.writeRecord(recordTypeHandshake, ckx.marshal())
        }
 
-       if cert != nil {
+       if certToSend != nil {
                certVerify := new(certificateVerifyMsg)
                digest := make([]byte, 0, 36)
                digest = finishedHash.serverMD5.Sum(digest)
index 5438e749ce8b8713a7955617a4444e3947835e83..e1517cc794ff21c8eb474ab094a7e9d8e1f5b718 100644 (file)
@@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
 
        // See http://tools.ietf.org/html/rfc4346#section-7.4.4
        length := 1 + len(m.certificateTypes) + 2
+       casLength := 0
        for _, ca := range m.certificateAuthorities {
-               length += 2 + len(ca)
+               casLength += 2 + len(ca)
        }
+       length += casLength
 
        x = make([]byte, 4+length)
        x[0] = typeCertificateRequest
@@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
 
        copy(x[5:], m.certificateTypes)
        y := x[5+len(m.certificateTypes):]
-
-       numCA := len(m.certificateAuthorities)
-       y[0] = uint8(numCA >> 8)
-       y[1] = uint8(numCA)
+       y[0] = uint8(casLength >> 8)
+       y[1] = uint8(casLength)
        y = y[2:]
        for _, ca := range m.certificateAuthorities {
                y[0] = uint8(len(ca) >> 8)
@@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
        }
 
        m.raw = x
-
        return
 }
 
@@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
        }
 
        data = data[numCertTypes:]
+
        if len(data) < 2 {
                return false
        }
-
-       numCAs := uint16(data[0])<<16 | uint16(data[1])
+       casLength := uint16(data[0])<<8 | uint16(data[1])
        data = data[2:]
+       if len(data) < int(casLength) {
+               return false
+       }
+       cas := make([]byte, casLength)
+       copy(cas, data)
+       data = data[casLength:]
 
-       m.certificateAuthorities = make([][]byte, numCAs)
-       for i := uint16(0); i < numCAs; i++ {
-               if len(data) < 2 {
+       m.certificateAuthorities = nil
+       for len(cas) > 0 {
+               if len(cas) < 2 {
                        return false
                }
-               caLen := uint16(data[0])<<16 | uint16(data[1])
+               caLen := uint16(cas[0])<<8 | uint16(cas[1])
+               cas = cas[2:]
 
-               data = data[2:]
-               if len(data) < int(caLen) {
+               if len(cas) < int(caLen) {
                        return false
                }
 
-               ca := make([]byte, caLen)
-               copy(ca, data)
-               m.certificateAuthorities[i] = ca
-               data = data[caLen:]
+               m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
+               cas = cas[caLen:]
        }
-
        if len(data) > 0 {
                return false
        }
index 89c000dd6e9b04fed616564dc864843b842d9e92..fb53767f3e08c50e6cb857e3a6a60387a30a4a3d 100644 (file)
@@ -150,14 +150,19 @@ FindCipherSuite:
                c.writeRecord(recordTypeHandshake, skx.marshal())
        }
 
-       if config.AuthenticateClient {
+       if config.ClientAuth >= RequestClientCert {
                // Request a client certificate
                certReq := new(certificateRequestMsg)
                certReq.certificateTypes = []byte{certTypeRSASign}
+
                // An empty list of certificateAuthorities signals to
                // the client that it may send any certificate in response
-               // to our request.
-
+               // to our request. When we know the CAs we trust, then
+               // we can send them down, so that the client can choose
+               // an appropriate certificate to give to us.
+               if config.ClientCAs != nil {
+                       certReq.certificateAuthorities = config.ClientCAs.Subjects()
+               }
                finishedHash.Write(certReq.marshal())
                c.writeRecord(recordTypeHandshake, certReq.marshal())
        }
@@ -166,52 +171,87 @@ FindCipherSuite:
        finishedHash.Write(helloDone.marshal())
        c.writeRecord(recordTypeHandshake, helloDone.marshal())
 
-       var pub *rsa.PublicKey
-       if config.AuthenticateClient {
-               // Get client certificate
-               msg, err = c.readHandshake()
-               if err != nil {
-                       return err
-               }
-               certMsg, ok = msg.(*certificateMsg)
-               if !ok {
-                       return c.sendAlert(alertUnexpectedMessage)
+       var pub *rsa.PublicKey // public key for client auth, if any
+
+       msg, err = c.readHandshake()
+       if err != nil {
+               return err
+       }
+
+       // If we requested a client certificate, then the client must send a
+       // certificate message, even if it's empty.
+       if config.ClientAuth >= RequestClientCert {
+               if certMsg, ok = msg.(*certificateMsg); !ok {
+                       return c.sendAlert(alertHandshakeFailure)
                }
                finishedHash.Write(certMsg.marshal())
 
+               if len(certMsg.certificates) == 0 {
+                       // The client didn't actually send a certificate
+                       switch config.ClientAuth {
+                       case RequireAnyClientCert, RequireAndVerifyClientCert:
+                               c.sendAlert(alertBadCertificate)
+                               return errors.New("tls: client didn't provide a certificate")
+                       }
+               }
+
                certs := make([]*x509.Certificate, len(certMsg.certificates))
                for i, asn1Data := range certMsg.certificates {
-                       cert, err := x509.ParseCertificate(asn1Data)
-                       if err != nil {
+                       if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
                                c.sendAlert(alertBadCertificate)
-                               return errors.New("could not parse client's certificate: " + err.Error())
+                               return errors.New("tls: failed to parse client certificate: " + err.Error())
                        }
-                       certs[i] = cert
                }
 
-               // TODO(agl): do better validation of certs: max path length, name restrictions etc.
-               for i := 1; i < len(certs); i++ {
-                       if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
+               if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
+                       opts := x509.VerifyOptions{
+                               Roots:         c.config.ClientCAs,
+                               CurrentTime:   c.config.time(),
+                               Intermediates: x509.NewCertPool(),
+                       }
+
+                       for i, cert := range certs {
+                               if i == 0 {
+                                       continue
+                               }
+                               opts.Intermediates.AddCert(cert)
+                       }
+
+                       chains, err := certs[0].Verify(opts)
+                       if err != nil {
                                c.sendAlert(alertBadCertificate)
-                               return errors.New("could not validate certificate signature: " + err.Error())
+                               return errors.New("tls: failed to verify client's certificate: " + err.Error())
                        }
+
+                       ok := false
+                       for _, ku := range certs[0].ExtKeyUsage {
+                               if ku == x509.ExtKeyUsageClientAuth {
+                                       ok = true
+                                       break
+                               }
+                       }
+                       if !ok {
+                               c.sendAlert(alertHandshakeFailure)
+                               return errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication")
+                       }
+
+                       c.verifiedChains = chains
                }
 
                if len(certs) > 0 {
-                       key, ok := certs[0].PublicKey.(*rsa.PublicKey)
-                       if !ok {
+                       if pub, ok = certs[0].PublicKey.(*rsa.PublicKey); !ok {
                                return c.sendAlert(alertUnsupportedCertificate)
                        }
-                       pub = key
                        c.peerCertificates = certs
                }
+
+               msg, err = c.readHandshake()
+               if err != nil {
+                       return err
+               }
        }
 
        // Get client key exchange
-       msg, err = c.readHandshake()
-       if err != nil {
-               return err
-       }
        ckx, ok := msg.(*clientKeyExchangeMsg)
        if !ok {
                return c.sendAlert(alertUnexpectedMessage)
index d98e13decf677596ee8bcf77c2a75ad533ef9922..4bff5327e2c7e242e558fd16655bde4189efa30b 100644 (file)
@@ -7,9 +7,12 @@ package tls
 import (
        "bytes"
        "crypto/rsa"
+       "crypto/x509"
        "encoding/hex"
+       "encoding/pem"
        "flag"
        "io"
+       "log"
        "math/big"
        "net"
        "strconv"
@@ -109,16 +112,18 @@ func TestClose(t *testing.T) {
        }
 }
 
-func testServerScript(t *testing.T, name string, serverScript [][]byte, config *Config) {
+func testServerScript(t *testing.T, name string, serverScript [][]byte, config *Config, peers []*x509.Certificate) {
        c, s := net.Pipe()
        srv := Server(s, config)
+       pchan := make(chan []*x509.Certificate, 1)
        go func() {
                srv.Write([]byte("hello, world\n"))
                srv.Close()
                s.Close()
+               st := srv.ConnectionState()
+               pchan <- st.PeerCertificates
        }()
 
-       defer c.Close()
        for i, b := range serverScript {
                if i%2 == 0 {
                        c.Write(b)
@@ -133,34 +138,66 @@ func testServerScript(t *testing.T, name string, serverScript [][]byte, config *
                        t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", name, i, bb, b)
                }
        }
+       c.Close()
+
+       if peers != nil {
+               gotpeers := <-pchan
+               if len(peers) == len(gotpeers) {
+                       for i, _ := range peers {
+                               if !peers[i].Equal(gotpeers[i]) {
+                                       t.Fatalf("%s: mismatch on peer cert %d", name, i)
+                               }
+                       }
+               } else {
+                       t.Fatalf("%s: mismatch on peer list length: %d (wanted) != %d (got)", name, len(peers), len(gotpeers))
+               }
+       }
 }
 
 func TestHandshakeServerRC4(t *testing.T) {
-       testServerScript(t, "RC4", rc4ServerScript, testConfig)
+       testServerScript(t, "RC4", rc4ServerScript, testConfig, nil)
 }
 
 func TestHandshakeServer3DES(t *testing.T) {
        des3Config := new(Config)
        *des3Config = *testConfig
        des3Config.CipherSuites = []uint16{TLS_RSA_WITH_3DES_EDE_CBC_SHA}
-       testServerScript(t, "3DES", des3ServerScript, des3Config)
+       testServerScript(t, "3DES", des3ServerScript, des3Config, nil)
 }
 
 func TestHandshakeServerAES(t *testing.T) {
        aesConfig := new(Config)
        *aesConfig = *testConfig
        aesConfig.CipherSuites = []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}
-       testServerScript(t, "AES", aesServerScript, aesConfig)
+       testServerScript(t, "AES", aesServerScript, aesConfig, nil)
 }
 
 func TestHandshakeServerSSLv3(t *testing.T) {
-       testServerScript(t, "SSLv3", sslv3ServerScript, testConfig)
+       testServerScript(t, "SSLv3", sslv3ServerScript, testConfig, nil)
+}
+
+type clientauthTest struct {
+       name       string
+       clientauth ClientAuthType
+       peers      []*x509.Certificate
+       script     [][]byte
+}
+
+func TestClientAuth(t *testing.T) {
+       for _, cat := range clientauthTests {
+               t.Log("running", cat.name)
+               cfg := new(Config)
+               *cfg = *testConfig
+               cfg.ClientAuth = cat.clientauth
+               testServerScript(t, cat.name, cat.script, cfg, cat.peers)
+       }
 }
 
 var serve = flag.Bool("serve", false, "run a TLS server on :10443")
 var testCipherSuites = flag.String("ciphersuites",
        "0x"+strconv.FormatInt(int64(TLS_RSA_WITH_RC4_128_SHA), 16),
        "cipher suites to accept in serving mode")
+var testClientAuth = flag.Int("clientauth", 0, "value for tls.Config.ClientAuth")
 
 func TestRunServer(t *testing.T) {
        if !*serve {
@@ -177,6 +214,8 @@ func TestRunServer(t *testing.T) {
                testConfig.CipherSuites[i] = uint16(suite)
        }
 
+       testConfig.ClientAuth = ClientAuthType(*testClientAuth)
+
        l, err := Listen("tcp", ":10443", testConfig)
        if err != nil {
                t.Fatal(err)
@@ -185,13 +224,23 @@ func TestRunServer(t *testing.T) {
        for {
                c, err := l.Accept()
                if err != nil {
+                       log.Printf("error from TLS handshake: %s", err)
                        break
                }
+
                _, err = c.Write([]byte("hello, world\n"))
                if err != nil {
-                       t.Errorf("error from TLS: %s", err)
-                       break
+                       log.Printf("error from TLS: %s", err)
+                       continue
                }
+
+               st := c.(*Conn).ConnectionState()
+               if len(st.PeerCertificates) > 0 {
+                       log.Print("Handling request from client ", st.PeerCertificates[0].Subject.CommonName)
+               } else {
+                       log.Print("Handling request from anon client")
+               }
+
                c.Close()
        }
 }
@@ -221,6 +270,18 @@ var testPrivateKey = &rsa.PrivateKey{
        },
 }
 
+func loadPEMCert(in string) *x509.Certificate {
+       block, _ := pem.Decode([]byte(in))
+       if block.Type == "CERTIFICATE" && len(block.Headers) == 0 {
+               cert, err := x509.ParseCertificate(block.Bytes)
+               if err == nil {
+                       return cert
+               }
+               panic("error parsing cert")
+       }
+       panic("error parsing PEM")
+}
+
 // Script of interaction with gnutls implementation.
 // The values for this test are obtained by building and running in server mode:
 //   % gotest -test.run "TestRunServer" -serve
@@ -229,23 +290,22 @@ var testPrivateKey = &rsa.PrivateKey{
 //   % python parse-gnutls-cli-debug-log.py < /tmp/log
 var rc4ServerScript = [][]byte{
        {
-               0x16, 0x03, 0x02, 0x00, 0x7f, 0x01, 0x00, 0x00,
-               0x7b, 0x03, 0x02, 0x4d, 0x08, 0x1f, 0x5a, 0x7a,
-               0x0a, 0x92, 0x2f, 0xf0, 0x73, 0x16, 0x3a, 0x88,
-               0x14, 0x85, 0x4c, 0x98, 0x15, 0x7b, 0x65, 0xe0,
-               0x78, 0xd0, 0xed, 0xd0, 0xf3, 0x65, 0x20, 0xeb,
-               0x80, 0xd1, 0x0b, 0x00, 0x00, 0x34, 0x00, 0x33,
+               0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
+               0x76, 0x03, 0x02, 0x4e, 0xdd, 0xe6, 0xa5, 0xf7,
+               0x00, 0x36, 0xf7, 0x83, 0xec, 0x93, 0x7c, 0xd2,
+               0x4d, 0xe7, 0x7b, 0xf5, 0x4c, 0xf7, 0xe3, 0x86,
+               0xe8, 0xec, 0x3b, 0xbd, 0x2c, 0x9a, 0x3f, 0x57,
+               0xf0, 0xa4, 0xd4, 0x00, 0x00, 0x34, 0x00, 0x33,
                0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
                0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
                0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
                0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
                0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
                0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
-               0x00, 0x8a, 0x01, 0x00, 0x00, 0x1e, 0x00, 0x09,
+               0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
                0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
                0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
-               0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, 0xff,
-               0x01, 0x00, 0x01, 0x00,
+               0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
        },
 
        {
@@ -349,38 +409,46 @@ var rc4ServerScript = [][]byte{
 
        {
                0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00,
-               0x82, 0x00, 0x80, 0x3c, 0x13, 0xd7, 0x12, 0xc1,
-               0x6a, 0xf0, 0x3f, 0x8c, 0xa1, 0x35, 0x5d, 0xc5,
-               0x89, 0x1e, 0x9e, 0xcd, 0x32, 0xc7, 0x9e, 0xe6,
-               0xae, 0xd5, 0xf1, 0xbf, 0x70, 0xd7, 0xa9, 0xef,
-               0x2c, 0x4c, 0xf4, 0x22, 0xbc, 0x17, 0x17, 0xaa,
-               0x05, 0xf3, 0x9f, 0x80, 0xf2, 0xe9, 0x82, 0x2f,
-               0x2a, 0x15, 0x54, 0x0d, 0x16, 0x0e, 0x77, 0x4c,
-               0x28, 0x3c, 0x03, 0x2d, 0x2d, 0xd7, 0xc8, 0x64,
-               0xd9, 0x59, 0x4b, 0x1c, 0xf4, 0xde, 0xff, 0x2f,
-               0xbc, 0x94, 0xaf, 0x18, 0x26, 0x37, 0xce, 0x4f,
-               0x84, 0x74, 0x2e, 0x45, 0x66, 0x7c, 0x0c, 0x54,
-               0x46, 0x36, 0x5f, 0x65, 0x21, 0x7b, 0x83, 0x8c,
-               0x6d, 0x76, 0xcd, 0x0d, 0x9f, 0xda, 0x1c, 0xa4,
-               0x6e, 0xfe, 0xb1, 0xf7, 0x09, 0x0d, 0xfb, 0x74,
-               0x66, 0x34, 0x99, 0x89, 0x7f, 0x5f, 0x77, 0x87,
-               0x4a, 0x66, 0x4b, 0xa9, 0x59, 0x57, 0xe3, 0x56,
-               0x0d, 0xdd, 0xd8, 0x14, 0x03, 0x01, 0x00, 0x01,
-               0x01, 0x16, 0x03, 0x01, 0x00, 0x24, 0xc0, 0x4e,
-               0xd3, 0x0f, 0xb5, 0xc0, 0x57, 0xa6, 0x18, 0x80,
-               0x80, 0x6b, 0x49, 0xfe, 0xbd, 0x3a, 0x7a, 0x2c,
-               0xef, 0x70, 0xb5, 0x1c, 0xd2, 0xdf, 0x5f, 0x78,
-               0x5a, 0xd8, 0x4f, 0xa0, 0x95, 0xb4, 0xb3, 0xb5,
-               0xaa, 0x3b,
+               0x82, 0x00, 0x80, 0x39, 0xe2, 0x0f, 0x49, 0xa0,
+               0xe6, 0xe4, 0x3b, 0x0c, 0x5f, 0xce, 0x39, 0x97,
+               0x6c, 0xb6, 0x41, 0xd9, 0xe1, 0x52, 0x8f, 0x43,
+               0xb3, 0xc6, 0x4f, 0x9a, 0xe2, 0x1e, 0xb9, 0x3b,
+               0xe3, 0x72, 0x17, 0x68, 0xb2, 0x0d, 0x7b, 0x71,
+               0x33, 0x96, 0x5c, 0xf9, 0xfe, 0x18, 0x8f, 0x2f,
+               0x2b, 0x82, 0xec, 0x03, 0xf2, 0x16, 0xa8, 0xf8,
+               0x39, 0xf9, 0xbb, 0x5a, 0xd3, 0x0c, 0xc1, 0x2a,
+               0x52, 0xa1, 0x90, 0x20, 0x6b, 0x24, 0xc9, 0x55,
+               0xee, 0x05, 0xd8, 0xb3, 0x43, 0x58, 0xf6, 0x7f,
+               0x68, 0x2d, 0xb3, 0xd1, 0x1b, 0x30, 0xaa, 0xdf,
+               0xfc, 0x85, 0xf1, 0xab, 0x14, 0x51, 0x91, 0x78,
+               0x29, 0x35, 0x65, 0xe0, 0x9c, 0xf6, 0xb7, 0x35,
+               0x33, 0xdb, 0x28, 0x93, 0x4d, 0x86, 0xbc, 0xfe,
+               0xaa, 0xd1, 0xc0, 0x2e, 0x4d, 0xec, 0xa2, 0x98,
+               0xca, 0x08, 0xb2, 0x91, 0x14, 0xde, 0x97, 0x3a,
+               0xc4, 0x6b, 0x49, 0x14, 0x03, 0x01, 0x00, 0x01,
+               0x01, 0x16, 0x03, 0x01, 0x00, 0x24, 0x7a, 0xcb,
+               0x3b, 0x0e, 0xbb, 0x7a, 0x56, 0x39, 0xaf, 0x83,
+               0xae, 0xfd, 0x25, 0xfd, 0x64, 0xb4, 0x0c, 0x0c,
+               0x17, 0x46, 0x54, 0x2c, 0x6a, 0x07, 0x83, 0xc6,
+               0x46, 0x08, 0x0b, 0xcd, 0x15, 0x53, 0xef, 0x40,
+               0x4e, 0x56,
        },
 
        {
                0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
-               0x01, 0x00, 0x24, 0x9d, 0xc9, 0xda, 0xdf, 0xeb,
-               0xc8, 0xdb, 0xf8, 0x94, 0xa5, 0xef, 0xd5, 0xfc,
-               0x89, 0x01, 0x64, 0x30, 0x77, 0x5a, 0x18, 0x4b,
-               0x16, 0x79, 0x9c, 0xf6, 0xf5, 0x09, 0x22, 0x12,
-               0x4c, 0x3e, 0xa8, 0x8e, 0x91, 0xa5, 0x24,
+               0x01, 0x00, 0x24, 0xd3, 0x72, 0xeb, 0x29, 0xb9,
+               0x15, 0x29, 0xb5, 0xe5, 0xb7, 0xef, 0x5c, 0xb2,
+               0x9d, 0xf6, 0xc8, 0x47, 0xd6, 0xa0, 0x84, 0xf0,
+               0x8c, 0xcb, 0xe6, 0xbe, 0xbc, 0xfb, 0x38, 0x90,
+               0x89, 0x60, 0xa2, 0xe8, 0xaa, 0xb3, 0x12, 0x17,
+               0x03, 0x01, 0x00, 0x21, 0x67, 0x4a, 0x3d, 0x31,
+               0x6c, 0x5a, 0x1c, 0xf9, 0x6e, 0xf1, 0xd8, 0x12,
+               0x0e, 0xb9, 0xfd, 0xfc, 0x66, 0x91, 0xd1, 0x1d,
+               0x6e, 0xe4, 0x55, 0xdd, 0x11, 0xb9, 0xb8, 0xa2,
+               0x65, 0xa1, 0x95, 0x64, 0x1c, 0x15, 0x03, 0x01,
+               0x00, 0x16, 0x9b, 0xa0, 0x24, 0xe3, 0xcb, 0xae,
+               0xad, 0x51, 0xb3, 0x63, 0x59, 0x78, 0x49, 0x24,
+               0x06, 0x6e, 0xee, 0x7a, 0xd7, 0x74, 0x53, 0x04,
        },
 }
 
@@ -878,3 +946,625 @@ var sslv3ServerScript = [][]byte{
                0xaf, 0xd3, 0xb7, 0xa3, 0xcc, 0x4a, 0x1d, 0x2e,
        },
 }
+
+var clientauthTests = []clientauthTest{
+       // Server doesn't asks for cert
+       // gotest -test.run "TestRunServer" -serve -clientauth 0
+       // gnutls-cli --insecure --debug 100 -p 10443 localhost 2>&1 |
+       //   python parse-gnutls-cli-debug-log.py
+       {"NoClientCert", NoClientCert, nil,
+               [][]byte{{
+                       0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
+                       0x76, 0x03, 0x02, 0x4e, 0xe0, 0x92, 0x5d, 0xcd,
+                       0xfe, 0x0c, 0x69, 0xd4, 0x7d, 0x8e, 0xa6, 0x88,
+                       0xde, 0x72, 0x04, 0x29, 0x6a, 0x4a, 0x16, 0x23,
+                       0xd7, 0x8f, 0xbc, 0xfa, 0x80, 0x73, 0x2e, 0x12,
+                       0xb7, 0x0b, 0x39, 0x00, 0x00, 0x34, 0x00, 0x33,
+                       0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
+                       0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
+                       0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
+                       0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
+                       0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
+                       0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
+                       0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
+                       0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
+                       0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
+                       0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
+               },
+
+                       {
+                               0x16, 0x03, 0x01, 0x00, 0x2a, 0x02, 0x00, 0x00,
+                               0x26, 0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x16,
+                               0x03, 0x01, 0x02, 0xbe, 0x0b, 0x00, 0x02, 0xba,
+                               0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82,
+                               0x02, 0xb0, 0x30, 0x82, 0x02, 0x19, 0xa0, 0x03,
+                               0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0,
+                               0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0d,
+                               0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
+                               0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x45, 0x31,
+                               0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06,
+                               0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11,
+                               0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53,
+                               0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74,
+                               0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55,
+                               0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65,
+                               0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64,
+                               0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79,
+                               0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d,
+                               0x31, 0x30, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39,
+                               0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31,
+                               0x31, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39, 0x30,
+                               0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b,
+                               0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+                               0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06,
+                               0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f,
+                               0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65,
+                               0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04,
+                               0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72,
+                               0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
+                               0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20,
+                               0x4c, 0x74, 0x64, 0x30, 0x81, 0x9f, 0x30, 0x0d,
+                               0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
+                               0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x81, 0x8d,
+                               0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00,
+                               0xbb, 0x79, 0xd6, 0xf5, 0x17, 0xb5, 0xe5, 0xbf,
+                               0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b,
+                               0x07, 0x43, 0x5a, 0xd0, 0x03, 0x2d, 0x8a, 0x7a,
+                               0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65,
+                               0x4c, 0x2c, 0x78, 0xb8, 0x23, 0x8c, 0xb5, 0xb4,
+                               0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62,
+                               0xa5, 0x2c, 0xa5, 0x33, 0xd6, 0xfe, 0x12, 0x5c,
+                               0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58,
+                               0x7b, 0x26, 0x3f, 0xb5, 0xcd, 0x04, 0xd3, 0xd0,
+                               0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f,
+                               0x5a, 0xbf, 0xef, 0x42, 0x71, 0x00, 0xfe, 0x18,
+                               0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1,
+                               0x04, 0x39, 0xc4, 0xa2, 0x2e, 0xdb, 0x51, 0xc9,
+                               0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01,
+                               0xcf, 0xaf, 0xb1, 0x1d, 0xb8, 0x71, 0x9a, 0x1d,
+                               0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79,
+                               0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x81, 0xa7,
+                               0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55,
+                               0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0xad,
+                               0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69,
+                               0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e, 0x18,
+                               0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d,
+                               0x23, 0x04, 0x6e, 0x30, 0x6c, 0x80, 0x14, 0xb1,
+                               0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb,
+                               0x69, 0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e,
+                               0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30,
+                               0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
+                               0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13,
+                               0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13,
+                               0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
+                               0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
+                               0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e,
+                               0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
+                               0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50,
+                               0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x82, 0x09,
+                               0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8,
+                               0xca, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13,
+                               0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30,
+                               0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
+                               0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81,
+                               0x81, 0x00, 0x08, 0x6c, 0x45, 0x24, 0xc7, 0x6b,
+                               0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0,
+                               0x14, 0xd7, 0x87, 0x9d, 0x7a, 0x64, 0x75, 0xb5,
+                               0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae,
+                               0x12, 0x66, 0x1f, 0xeb, 0x4f, 0x38, 0xb3, 0x6e,
+                               0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5,
+                               0x25, 0x13, 0xb1, 0x18, 0x7a, 0x24, 0xfb, 0x30,
+                               0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7,
+                               0xd7, 0x31, 0x59, 0xdb, 0x95, 0xd3, 0x1d, 0x78,
+                               0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d,
+                               0x5a, 0x5f, 0x33, 0xc4, 0xb6, 0xd8, 0xc9, 0x75,
+                               0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd,
+                               0x98, 0x1f, 0x89, 0x20, 0x5f, 0xf2, 0xa0, 0x1c,
+                               0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57,
+                               0xe9, 0x70, 0xe8, 0x26, 0x6d, 0x71, 0x99, 0x9b,
+                               0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7,
+                               0xbd, 0xd9, 0x16, 0x03, 0x01, 0x00, 0x04, 0x0e,
+                               0x00, 0x00, 0x00,
+                       },
+
+                       {
+                               0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00,
+                               0x82, 0x00, 0x80, 0x10, 0xe1, 0x00, 0x3d, 0x0a,
+                               0x6b, 0x02, 0x7f, 0x97, 0xde, 0xfb, 0x65, 0x46,
+                               0x1a, 0x50, 0x4e, 0x34, 0x9a, 0xae, 0x14, 0x7e,
+                               0xec, 0xef, 0x85, 0x15, 0x3b, 0x39, 0xc2, 0x45,
+                               0x04, 0x40, 0x92, 0x71, 0xd6, 0x7e, 0xf6, 0xfd,
+                               0x4d, 0x84, 0xf7, 0xc4, 0x77, 0x99, 0x3d, 0xe2,
+                               0xc3, 0x8d, 0xb0, 0x4c, 0x74, 0xc8, 0x51, 0xec,
+                               0xb2, 0xe8, 0x6b, 0xa1, 0xd2, 0x4d, 0xd8, 0x61,
+                               0x92, 0x7a, 0x24, 0x57, 0x44, 0x4f, 0xa2, 0x1e,
+                               0x74, 0x0b, 0x06, 0x4b, 0x80, 0x34, 0x8b, 0xfe,
+                               0xc2, 0x0e, 0xc1, 0xcd, 0xab, 0x0c, 0x3f, 0x54,
+                               0xe2, 0x44, 0xe9, 0x6c, 0x2b, 0xba, 0x7b, 0x64,
+                               0xf1, 0x93, 0x65, 0x75, 0xf2, 0x35, 0xff, 0x27,
+                               0x03, 0xd5, 0x64, 0xe6, 0x8e, 0xe7, 0x7b, 0x56,
+                               0xb6, 0x61, 0x73, 0xeb, 0xa2, 0xdc, 0xa4, 0x6e,
+                               0x52, 0xac, 0xbc, 0xba, 0x11, 0xa3, 0xd2, 0x61,
+                               0x4a, 0xe0, 0xbb, 0x14, 0x03, 0x01, 0x00, 0x01,
+                               0x01, 0x16, 0x03, 0x01, 0x00, 0x24, 0xd2, 0x5a,
+                               0x0c, 0x2a, 0x27, 0x96, 0xba, 0xa9, 0x67, 0xd2,
+                               0x51, 0x68, 0x32, 0x68, 0x22, 0x1f, 0xb9, 0x27,
+                               0x79, 0x59, 0x28, 0xdf, 0x38, 0x1f, 0x92, 0x21,
+                               0x5d, 0x0f, 0xf4, 0xc0, 0xee, 0xb7, 0x10, 0x5a,
+                               0xa9, 0x45,
+                       },
+
+                       {
+                               0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
+                               0x01, 0x00, 0x24, 0x13, 0x6f, 0x6c, 0x71, 0x83,
+                               0x59, 0xcf, 0x32, 0x72, 0xe9, 0xce, 0xcc, 0x7a,
+                               0x6c, 0xf0, 0x72, 0x39, 0x16, 0xae, 0x40, 0x61,
+                               0xfa, 0x92, 0x4c, 0xe7, 0xf2, 0x1a, 0xd7, 0x0c,
+                               0x84, 0x76, 0x6c, 0xe9, 0x11, 0x43, 0x19, 0x17,
+                               0x03, 0x01, 0x00, 0x21, 0xc0, 0xa2, 0x13, 0x28,
+                               0x94, 0x8c, 0x5c, 0xd6, 0x79, 0xb9, 0xfe, 0xae,
+                               0x45, 0x4b, 0xc0, 0x7c, 0xae, 0x2d, 0xb4, 0x0d,
+                               0x31, 0xc4, 0xad, 0x22, 0xd7, 0x1e, 0x99, 0x1c,
+                               0x4c, 0x69, 0xab, 0x42, 0x61, 0x15, 0x03, 0x01,
+                               0x00, 0x16, 0xe1, 0x0c, 0x67, 0xf3, 0xf4, 0xb9,
+                               0x8e, 0x81, 0x8e, 0x01, 0xb8, 0xa0, 0x69, 0x8c,
+                               0x03, 0x11, 0x43, 0x3e, 0xee, 0xb7, 0x4d, 0x69,
+                       }}},
+       // Server asks for cert with empty CA list, client doesn't give it.
+       // gotest -test.run "TestRunServer" -serve -clientauth 1
+       // gnutls-cli --insecure --debug 100 -p 10443 localhost
+       {"RequestClientCert, none given", RequestClientCert, nil,
+               [][]byte{{
+                       0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
+                       0x76, 0x03, 0x02, 0x4e, 0xe0, 0x93, 0xe2, 0x47,
+                       0x06, 0xa0, 0x61, 0x0c, 0x51, 0xdd, 0xf0, 0xef,
+                       0xf4, 0x30, 0x72, 0xe1, 0xa6, 0x50, 0x68, 0x82,
+                       0x3c, 0xfb, 0xcb, 0x72, 0x5e, 0x73, 0x9d, 0xda,
+                       0x27, 0x35, 0x72, 0x00, 0x00, 0x34, 0x00, 0x33,
+                       0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
+                       0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
+                       0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
+                       0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
+                       0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
+                       0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
+                       0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
+                       0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
+                       0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
+                       0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
+               },
+
+                       {
+                               0x16, 0x03, 0x01, 0x00, 0x2a, 0x02, 0x00, 0x00,
+                               0x26, 0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x16,
+                               0x03, 0x01, 0x02, 0xbe, 0x0b, 0x00, 0x02, 0xba,
+                               0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82,
+                               0x02, 0xb0, 0x30, 0x82, 0x02, 0x19, 0xa0, 0x03,
+                               0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0,
+                               0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0d,
+                               0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
+                               0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x45, 0x31,
+                               0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06,
+                               0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11,
+                               0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53,
+                               0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74,
+                               0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55,
+                               0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65,
+                               0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64,
+                               0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79,
+                               0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d,
+                               0x31, 0x30, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39,
+                               0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31,
+                               0x31, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39, 0x30,
+                               0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b,
+                               0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+                               0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06,
+                               0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f,
+                               0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65,
+                               0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04,
+                               0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72,
+                               0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
+                               0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20,
+                               0x4c, 0x74, 0x64, 0x30, 0x81, 0x9f, 0x30, 0x0d,
+                               0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
+                               0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x81, 0x8d,
+                               0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00,
+                               0xbb, 0x79, 0xd6, 0xf5, 0x17, 0xb5, 0xe5, 0xbf,
+                               0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b,
+                               0x07, 0x43, 0x5a, 0xd0, 0x03, 0x2d, 0x8a, 0x7a,
+                               0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65,
+                               0x4c, 0x2c, 0x78, 0xb8, 0x23, 0x8c, 0xb5, 0xb4,
+                               0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62,
+                               0xa5, 0x2c, 0xa5, 0x33, 0xd6, 0xfe, 0x12, 0x5c,
+                               0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58,
+                               0x7b, 0x26, 0x3f, 0xb5, 0xcd, 0x04, 0xd3, 0xd0,
+                               0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f,
+                               0x5a, 0xbf, 0xef, 0x42, 0x71, 0x00, 0xfe, 0x18,
+                               0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1,
+                               0x04, 0x39, 0xc4, 0xa2, 0x2e, 0xdb, 0x51, 0xc9,
+                               0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01,
+                               0xcf, 0xaf, 0xb1, 0x1d, 0xb8, 0x71, 0x9a, 0x1d,
+                               0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79,
+                               0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x81, 0xa7,
+                               0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55,
+                               0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0xad,
+                               0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69,
+                               0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e, 0x18,
+                               0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d,
+                               0x23, 0x04, 0x6e, 0x30, 0x6c, 0x80, 0x14, 0xb1,
+                               0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb,
+                               0x69, 0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e,
+                               0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30,
+                               0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
+                               0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13,
+                               0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13,
+                               0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
+                               0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
+                               0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e,
+                               0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
+                               0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50,
+                               0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x82, 0x09,
+                               0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8,
+                               0xca, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13,
+                               0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30,
+                               0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
+                               0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81,
+                               0x81, 0x00, 0x08, 0x6c, 0x45, 0x24, 0xc7, 0x6b,
+                               0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0,
+                               0x14, 0xd7, 0x87, 0x9d, 0x7a, 0x64, 0x75, 0xb5,
+                               0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae,
+                               0x12, 0x66, 0x1f, 0xeb, 0x4f, 0x38, 0xb3, 0x6e,
+                               0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5,
+                               0x25, 0x13, 0xb1, 0x18, 0x7a, 0x24, 0xfb, 0x30,
+                               0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7,
+                               0xd7, 0x31, 0x59, 0xdb, 0x95, 0xd3, 0x1d, 0x78,
+                               0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d,
+                               0x5a, 0x5f, 0x33, 0xc4, 0xb6, 0xd8, 0xc9, 0x75,
+                               0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd,
+                               0x98, 0x1f, 0x89, 0x20, 0x5f, 0xf2, 0xa0, 0x1c,
+                               0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57,
+                               0xe9, 0x70, 0xe8, 0x26, 0x6d, 0x71, 0x99, 0x9b,
+                               0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7,
+                               0xbd, 0xd9, 0x16, 0x03, 0x01, 0x00, 0x08, 0x0d,
+                               0x00, 0x00, 0x04, 0x01, 0x01, 0x00, 0x00, 0x16,
+                               0x03, 0x01, 0x00, 0x04, 0x0e, 0x00, 0x00, 0x00,
+                       },
+
+                       {
+                               0x16, 0x03, 0x01, 0x00, 0x07, 0x0b, 0x00, 0x00,
+                               0x03, 0x00, 0x00, 0x00, 0x16, 0x03, 0x01, 0x00,
+                               0x86, 0x10, 0x00, 0x00, 0x82, 0x00, 0x80, 0x64,
+                               0x28, 0xb9, 0x3f, 0x48, 0xaf, 0x06, 0x22, 0x39,
+                               0x56, 0xd8, 0x6f, 0x63, 0x5d, 0x03, 0x48, 0x63,
+                               0x01, 0x13, 0xa2, 0xd6, 0x76, 0xc0, 0xab, 0xda,
+                               0x25, 0x30, 0x75, 0x6c, 0xaa, 0xb4, 0xdc, 0x35,
+                               0x72, 0xdc, 0xf2, 0x43, 0xe4, 0x1d, 0x82, 0xfb,
+                               0x6c, 0x64, 0xe2, 0xa7, 0x8f, 0x32, 0x67, 0x6b,
+                               0xcd, 0xd2, 0xb2, 0x36, 0x94, 0xbc, 0x6f, 0x46,
+                               0x79, 0x29, 0x42, 0xe3, 0x1a, 0xbf, 0xfb, 0x41,
+                               0xd5, 0xe3, 0xb4, 0x2a, 0xf6, 0x95, 0x6f, 0x0c,
+                               0x87, 0xb9, 0x03, 0x18, 0xa1, 0xea, 0x4a, 0xe2,
+                               0x2e, 0x0f, 0x50, 0x00, 0xc1, 0xe8, 0x8c, 0xc8,
+                               0xa2, 0xf6, 0xa4, 0x05, 0xf4, 0x38, 0x3e, 0xd9,
+                               0x6e, 0x63, 0x96, 0x0c, 0x34, 0x73, 0x90, 0x03,
+                               0x55, 0xa6, 0x34, 0xb0, 0x5e, 0x8c, 0x48, 0x40,
+                               0x25, 0x45, 0x84, 0xa6, 0x21, 0x3f, 0x81, 0x97,
+                               0xa7, 0x11, 0x09, 0x14, 0x95, 0xa5, 0xe5, 0x14,
+                               0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03, 0x01,
+                               0x00, 0x24, 0x16, 0xaa, 0x01, 0x2c, 0xa8, 0xc1,
+                               0x28, 0xaf, 0x35, 0xc1, 0xc1, 0xf3, 0x0a, 0x25,
+                               0x66, 0x6e, 0x27, 0x11, 0xa3, 0xa4, 0xd9, 0xe9,
+                               0xea, 0x15, 0x09, 0x9d, 0x28, 0xe3, 0x5b, 0x2b,
+                               0xa6, 0x25, 0xa7, 0x14, 0x24, 0x3a,
+                       },
+
+                       {
+                               0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
+                               0x01, 0x00, 0x24, 0x9a, 0xa8, 0xd6, 0x77, 0x46,
+                               0x45, 0x68, 0x9d, 0x5d, 0xa9, 0x68, 0x03, 0xe5,
+                               0xaf, 0xe8, 0xc8, 0x21, 0xc5, 0xc6, 0xc1, 0x50,
+                               0xe0, 0xd8, 0x52, 0xce, 0xa3, 0x4f, 0x2d, 0xf4,
+                               0xe3, 0xa7, 0x7d, 0x35, 0x80, 0x84, 0x12, 0x17,
+                               0x03, 0x01, 0x00, 0x21, 0x8a, 0x82, 0x0c, 0x54,
+                               0x1b, 0xeb, 0x77, 0x90, 0x2c, 0x3e, 0xbc, 0xf0,
+                               0x23, 0xcc, 0xa8, 0x9f, 0x25, 0x08, 0x12, 0xed,
+                               0x43, 0xf1, 0xf9, 0x06, 0xad, 0xa9, 0x4b, 0x97,
+                               0x82, 0xb7, 0xc4, 0x0b, 0x4c, 0x15, 0x03, 0x01,
+                               0x00, 0x16, 0x05, 0x2d, 0x9d, 0x45, 0x03, 0xb7,
+                               0xc2, 0xd1, 0xb5, 0x1a, 0x43, 0xcf, 0x1a, 0x37,
+                               0xf4, 0x70, 0xcc, 0xb4, 0xed, 0x07, 0x76, 0x3a,
+                       }}},
+       // Server asks for cert with empty CA list, client gives one
+       // gotest -test.run "TestRunServer" -serve -clientauth 1
+       // gnutls-cli --insecure --debug 100 -p 10443 localhost
+       {"RequestClientCert, client gives it", RequestClientCert,
+               []*x509.Certificate{clicert},
+               [][]byte{{
+                       0x16, 0x03, 0x02, 0x00, 0x7a, 0x01, 0x00, 0x00,
+                       0x76, 0x03, 0x02, 0x4e, 0xe7, 0x44, 0xda, 0x58,
+                       0x7d, 0x46, 0x4a, 0x48, 0x97, 0x9f, 0xe5, 0x91,
+                       0x11, 0x64, 0xa7, 0x1e, 0x4d, 0xb7, 0xfe, 0x9b,
+                       0xc6, 0x63, 0xf8, 0xa4, 0xb5, 0x0b, 0x18, 0xb5,
+                       0xbd, 0x19, 0xb3, 0x00, 0x00, 0x34, 0x00, 0x33,
+                       0x00, 0x45, 0x00, 0x39, 0x00, 0x88, 0x00, 0x16,
+                       0x00, 0x32, 0x00, 0x44, 0x00, 0x38, 0x00, 0x87,
+                       0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
+                       0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x41,
+                       0x00, 0x35, 0x00, 0x84, 0x00, 0x0a, 0x00, 0x05,
+                       0x00, 0x04, 0x00, 0x8c, 0x00, 0x8d, 0x00, 0x8b,
+                       0x00, 0x8a, 0x01, 0x00, 0x00, 0x19, 0x00, 0x09,
+                       0x00, 0x03, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00,
+                       0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, 0x6c, 0x6f,
+                       0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
+               },
+
+                       {
+                               0x16, 0x03, 0x01, 0x00, 0x2a, 0x02, 0x00, 0x00,
+                               0x26, 0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x16,
+                               0x03, 0x01, 0x02, 0xbe, 0x0b, 0x00, 0x02, 0xba,
+                               0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82,
+                               0x02, 0xb0, 0x30, 0x82, 0x02, 0x19, 0xa0, 0x03,
+                               0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0,
+                               0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0d,
+                               0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
+                               0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x45, 0x31,
+                               0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06,
+                               0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11,
+                               0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53,
+                               0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74,
+                               0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55,
+                               0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65,
+                               0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64,
+                               0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79,
+                               0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d,
+                               0x31, 0x30, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39,
+                               0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31,
+                               0x31, 0x30, 0x34, 0x32, 0x34, 0x30, 0x39, 0x30,
+                               0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b,
+                               0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+                               0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06,
+                               0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f,
+                               0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65,
+                               0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04,
+                               0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72,
+                               0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
+                               0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20,
+                               0x4c, 0x74, 0x64, 0x30, 0x81, 0x9f, 0x30, 0x0d,
+                               0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
+                               0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x81, 0x8d,
+                               0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00,
+                               0xbb, 0x79, 0xd6, 0xf5, 0x17, 0xb5, 0xe5, 0xbf,
+                               0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b,
+                               0x07, 0x43, 0x5a, 0xd0, 0x03, 0x2d, 0x8a, 0x7a,
+                               0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65,
+                               0x4c, 0x2c, 0x78, 0xb8, 0x23, 0x8c, 0xb5, 0xb4,
+                               0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62,
+                               0xa5, 0x2c, 0xa5, 0x33, 0xd6, 0xfe, 0x12, 0x5c,
+                               0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58,
+                               0x7b, 0x26, 0x3f, 0xb5, 0xcd, 0x04, 0xd3, 0xd0,
+                               0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f,
+                               0x5a, 0xbf, 0xef, 0x42, 0x71, 0x00, 0xfe, 0x18,
+                               0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1,
+                               0x04, 0x39, 0xc4, 0xa2, 0x2e, 0xdb, 0x51, 0xc9,
+                               0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01,
+                               0xcf, 0xaf, 0xb1, 0x1d, 0xb8, 0x71, 0x9a, 0x1d,
+                               0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79,
+                               0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x81, 0xa7,
+                               0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55,
+                               0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0xad,
+                               0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69,
+                               0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e, 0x18,
+                               0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d,
+                               0x23, 0x04, 0x6e, 0x30, 0x6c, 0x80, 0x14, 0xb1,
+                               0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb,
+                               0x69, 0xce, 0x23, 0x69, 0xde, 0xd3, 0x26, 0x8e,
+                               0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30,
+                               0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
+                               0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13,
+                               0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13,
+                               0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
+                               0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06,
+                               0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e,
+                               0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
+                               0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50,
+                               0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x82, 0x09,
+                               0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8,
+                               0xca, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13,
+                               0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30,
+                               0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
+                               0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81,
+                               0x81, 0x00, 0x08, 0x6c, 0x45, 0x24, 0xc7, 0x6b,
+                               0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0,
+                               0x14, 0xd7, 0x87, 0x9d, 0x7a, 0x64, 0x75, 0xb5,
+                               0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae,
+                               0x12, 0x66, 0x1f, 0xeb, 0x4f, 0x38, 0xb3, 0x6e,
+                               0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5,
+                               0x25, 0x13, 0xb1, 0x18, 0x7a, 0x24, 0xfb, 0x30,
+                               0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7,
+                               0xd7, 0x31, 0x59, 0xdb, 0x95, 0xd3, 0x1d, 0x78,
+                               0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d,
+                               0x5a, 0x5f, 0x33, 0xc4, 0xb6, 0xd8, 0xc9, 0x75,
+                               0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd,
+                               0x98, 0x1f, 0x89, 0x20, 0x5f, 0xf2, 0xa0, 0x1c,
+                               0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57,
+                               0xe9, 0x70, 0xe8, 0x26, 0x6d, 0x71, 0x99, 0x9b,
+                               0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7,
+                               0xbd, 0xd9, 0x16, 0x03, 0x01, 0x00, 0x08, 0x0d,
+                               0x00, 0x00, 0x04, 0x01, 0x01, 0x00, 0x00, 0x16,
+                               0x03, 0x01, 0x00, 0x04, 0x0e, 0x00, 0x00, 0x00,
+                       },
+
+                       {
+                               0x16, 0x03, 0x01, 0x01, 0xfb, 0x0b, 0x00, 0x01,
+                               0xf7, 0x00, 0x01, 0xf4, 0x00, 0x01, 0xf1, 0x30,
+                               0x82, 0x01, 0xed, 0x30, 0x82, 0x01, 0x58, 0xa0,
+                               0x03, 0x02, 0x01, 0x02, 0x02, 0x01, 0x00, 0x30,
+                               0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
+                               0x0d, 0x01, 0x01, 0x05, 0x30, 0x26, 0x31, 0x10,
+                               0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13,
+                               0x07, 0x41, 0x63, 0x6d, 0x65, 0x20, 0x43, 0x6f,
+                               0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04,
+                               0x03, 0x13, 0x09, 0x31, 0x32, 0x37, 0x2e, 0x30,
+                               0x2e, 0x30, 0x2e, 0x31, 0x30, 0x1e, 0x17, 0x0d,
+                               0x31, 0x31, 0x31, 0x32, 0x30, 0x38, 0x30, 0x37,
+                               0x35, 0x35, 0x31, 0x32, 0x5a, 0x17, 0x0d, 0x31,
+                               0x32, 0x31, 0x32, 0x30, 0x37, 0x30, 0x38, 0x30,
+                               0x30, 0x31, 0x32, 0x5a, 0x30, 0x26, 0x31, 0x10,
+                               0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13,
+                               0x07, 0x41, 0x63, 0x6d, 0x65, 0x20, 0x43, 0x6f,
+                               0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04,
+                               0x03, 0x13, 0x09, 0x31, 0x32, 0x37, 0x2e, 0x30,
+                               0x2e, 0x30, 0x2e, 0x31, 0x30, 0x81, 0x9c, 0x30,
+                               0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
+                               0x0d, 0x01, 0x01, 0x01, 0x03, 0x81, 0x8c, 0x00,
+                               0x30, 0x81, 0x88, 0x02, 0x81, 0x80, 0x4e, 0xd0,
+                               0x7b, 0x31, 0xe3, 0x82, 0x64, 0xd9, 0x59, 0xc0,
+                               0xc2, 0x87, 0xa4, 0x5e, 0x1e, 0x8b, 0x73, 0x33,
+                               0xc7, 0x63, 0x53, 0xdf, 0x66, 0x92, 0x06, 0x84,
+                               0xf6, 0x64, 0xd5, 0x8f, 0xe4, 0x36, 0xa7, 0x1d,
+                               0x2b, 0xe8, 0xb3, 0x20, 0x36, 0x45, 0x23, 0xb5,
+                               0xe3, 0x95, 0xae, 0xed, 0xe0, 0xf5, 0x20, 0x9c,
+                               0x8d, 0x95, 0xdf, 0x7f, 0x5a, 0x12, 0xef, 0x87,
+                               0xe4, 0x5b, 0x68, 0xe4, 0xe9, 0x0e, 0x74, 0xec,
+                               0x04, 0x8a, 0x7f, 0xde, 0x93, 0x27, 0xc4, 0x01,
+                               0x19, 0x7a, 0xbd, 0xf2, 0xdc, 0x3d, 0x14, 0xab,
+                               0xd0, 0x54, 0xca, 0x21, 0x0c, 0xd0, 0x4d, 0x6e,
+                               0x87, 0x2e, 0x5c, 0xc5, 0xd2, 0xbb, 0x4d, 0x4b,
+                               0x4f, 0xce, 0xb6, 0x2c, 0xf7, 0x7e, 0x88, 0xec,
+                               0x7c, 0xd7, 0x02, 0x91, 0x74, 0xa6, 0x1e, 0x0c,
+                               0x1a, 0xda, 0xe3, 0x4a, 0x5a, 0x2e, 0xde, 0x13,
+                               0x9c, 0x4c, 0x40, 0x88, 0x59, 0x93, 0x02, 0x03,
+                               0x01, 0x00, 0x01, 0xa3, 0x32, 0x30, 0x30, 0x30,
+                               0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01,
+                               0xff, 0x04, 0x04, 0x03, 0x02, 0x00, 0xa0, 0x30,
+                               0x0d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x06,
+                               0x04, 0x04, 0x01, 0x02, 0x03, 0x04, 0x30, 0x0f,
+                               0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x08, 0x30,
+                               0x06, 0x80, 0x04, 0x01, 0x02, 0x03, 0x04, 0x30,
+                               0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
+                               0x0d, 0x01, 0x01, 0x05, 0x03, 0x81, 0x81, 0x00,
+                               0x36, 0x1f, 0xb3, 0x7a, 0x0c, 0x75, 0xc9, 0x6e,
+                               0x37, 0x46, 0x61, 0x2b, 0xd5, 0xbd, 0xc0, 0xa7,
+                               0x4b, 0xcc, 0x46, 0x9a, 0x81, 0x58, 0x7c, 0x85,
+                               0x79, 0x29, 0xc8, 0xc8, 0xc6, 0x67, 0xdd, 0x32,
+                               0x56, 0x45, 0x2b, 0x75, 0xb6, 0xe9, 0x24, 0xa9,
+                               0x50, 0x9a, 0xbe, 0x1f, 0x5a, 0xfa, 0x1a, 0x15,
+                               0xd9, 0xcc, 0x55, 0x95, 0x72, 0x16, 0x83, 0xb9,
+                               0xc2, 0xb6, 0x8f, 0xfd, 0x88, 0x8c, 0x38, 0x84,
+                               0x1d, 0xab, 0x5d, 0x92, 0x31, 0x13, 0x4f, 0xfd,
+                               0x83, 0x3b, 0xc6, 0x9d, 0xf1, 0x11, 0x62, 0xb6,
+                               0x8b, 0xec, 0xab, 0x67, 0xbe, 0xc8, 0x64, 0xb0,
+                               0x11, 0x50, 0x46, 0x58, 0x17, 0x6b, 0x99, 0x1c,
+                               0xd3, 0x1d, 0xfc, 0x06, 0xf1, 0x0e, 0xe5, 0x96,
+                               0xa8, 0x0c, 0xf9, 0x78, 0x20, 0xb7, 0x44, 0x18,
+                               0x51, 0x8d, 0x10, 0x7e, 0x4f, 0x94, 0x67, 0xdf,
+                               0xa3, 0x4e, 0x70, 0x73, 0x8e, 0x90, 0x91, 0x85,
+                               0x16, 0x03, 0x01, 0x00, 0x86, 0x10, 0x00, 0x00,
+                               0x82, 0x00, 0x80, 0xa7, 0x2f, 0xed, 0xfa, 0xc2,
+                               0xbd, 0x46, 0xa1, 0xf2, 0x69, 0xc5, 0x1d, 0xa1,
+                               0x34, 0xd6, 0xd0, 0x84, 0xf5, 0x5d, 0x8c, 0x82,
+                               0x8d, 0x98, 0x82, 0x9c, 0xd9, 0x07, 0xe0, 0xf7,
+                               0x55, 0x49, 0x4d, 0xa1, 0x48, 0x59, 0x02, 0xd3,
+                               0x84, 0x37, 0xaf, 0x01, 0xb3, 0x3a, 0xf4, 0xed,
+                               0x99, 0xbe, 0x67, 0x36, 0x19, 0x55, 0xf3, 0xf9,
+                               0xcb, 0x94, 0xe5, 0x7b, 0x8b, 0x77, 0xf2, 0x5f,
+                               0x4c, 0xfe, 0x01, 0x1f, 0x7b, 0xd7, 0x23, 0x49,
+                               0x0c, 0xcb, 0x6c, 0xb0, 0xe7, 0x77, 0xd6, 0xcf,
+                               0xa8, 0x7d, 0xdb, 0xa7, 0x14, 0xe2, 0xf5, 0xf3,
+                               0xff, 0xba, 0x23, 0xd2, 0x9a, 0x36, 0x14, 0x60,
+                               0x2a, 0x91, 0x5d, 0x2b, 0x35, 0x3b, 0xb6, 0xdd,
+                               0xcb, 0x6b, 0xdc, 0x18, 0xdc, 0x33, 0xb8, 0xb3,
+                               0xc7, 0x27, 0x7e, 0xfc, 0xd2, 0xf7, 0x97, 0x90,
+                               0x5e, 0x17, 0xac, 0x14, 0x8e, 0x0f, 0xca, 0xb5,
+                               0x6f, 0xc9, 0x2d, 0x16, 0x03, 0x01, 0x00, 0x86,
+                               0x0f, 0x00, 0x00, 0x82, 0x00, 0x80, 0x44, 0x7f,
+                               0xa2, 0x59, 0x60, 0x0b, 0x5a, 0xc4, 0xaf, 0x1e,
+                               0x60, 0xa5, 0x24, 0xea, 0xc1, 0xc3, 0x22, 0x21,
+                               0x6b, 0x22, 0x8b, 0x2a, 0x11, 0x82, 0x68, 0x7d,
+                               0xb9, 0xdd, 0x9c, 0x27, 0x4c, 0xc2, 0xc8, 0xa2,
+                               0x8b, 0x6b, 0x77, 0x8d, 0x3a, 0x2b, 0x8d, 0x2f,
+                               0x6a, 0x2b, 0x43, 0xd2, 0xd1, 0xc6, 0x41, 0x79,
+                               0xa2, 0x4f, 0x2b, 0xc2, 0xf7, 0xb2, 0x10, 0xad,
+                               0xa6, 0x01, 0x51, 0x51, 0x25, 0xe7, 0x58, 0x7a,
+                               0xcf, 0x3b, 0xc4, 0x29, 0xb5, 0xe5, 0xa7, 0x83,
+                               0xe6, 0xcb, 0x1e, 0xf3, 0x02, 0x0f, 0x53, 0x3b,
+                               0xb5, 0x39, 0xef, 0x9c, 0x42, 0xe0, 0xa6, 0x9b,
+                               0x2b, 0xdd, 0x60, 0xae, 0x0a, 0x73, 0x35, 0xbe,
+                               0x26, 0x10, 0x1b, 0xe9, 0xe9, 0x61, 0xab, 0x20,
+                               0xa5, 0x48, 0xc6, 0x60, 0xa6, 0x50, 0x3c, 0xfb,
+                               0xa7, 0xca, 0xb0, 0x80, 0x95, 0x1e, 0xce, 0xc7,
+                               0xbb, 0x68, 0x44, 0xdc, 0x0e, 0x0e, 0x14, 0x03,
+                               0x01, 0x00, 0x01, 0x01, 0x16, 0x03, 0x01, 0x00,
+                               0x24, 0xb6, 0xcd, 0x0c, 0x78, 0xfd, 0xd6, 0xff,
+                               0xbe, 0x97, 0xd5, 0x0a, 0x7d, 0x4f, 0xa1, 0x03,
+                               0x78, 0xc8, 0x61, 0x6f, 0xf2, 0x4b, 0xa8, 0x56,
+                               0x4f, 0x3c, 0xa2, 0xd9, 0xd0, 0x20, 0x13, 0x1b,
+                               0x8b, 0x36, 0xb7, 0x33, 0x9c,
+                       },
+
+                       {
+                               0x14, 0x03, 0x01, 0x00, 0x01, 0x01, 0x16, 0x03,
+                               0x01, 0x00, 0x24, 0xa3, 0x43, 0x94, 0xe7, 0xdf,
+                               0xb6, 0xc3, 0x03, 0x9f, 0xc1, 0x59, 0x0c, 0xc3,
+                               0x13, 0xae, 0xed, 0xcf, 0xff, 0xf1, 0x80, 0xf3,
+                               0x13, 0x63, 0x1c, 0xf0, 0xca, 0xad, 0x9e, 0x71,
+                               0x46, 0x5f, 0x6b, 0xeb, 0x10, 0x3f, 0xe3, 0x17,
+                               0x03, 0x01, 0x00, 0x21, 0xe9, 0x80, 0x95, 0x6e,
+                               0x05, 0x55, 0x2f, 0xed, 0x4d, 0xde, 0x17, 0x3a,
+                               0x32, 0x9b, 0x2a, 0x74, 0x30, 0x4f, 0xe0, 0x9f,
+                               0x4e, 0xd3, 0x06, 0xbd, 0x3a, 0x43, 0x75, 0x8b,
+                               0x5b, 0x9a, 0xd8, 0x2e, 0x56, 0x15, 0x03, 0x01,
+                               0x00, 0x16, 0x53, 0xf5, 0xff, 0xe0, 0xa1, 0x6c,
+                               0x33, 0xf4, 0x4e, 0x89, 0x68, 0xe1, 0xf7, 0x61,
+                               0x13, 0xb3, 0x12, 0xa1, 0x8e, 0x5a, 0x7a, 0x02,
+                       }}},
+}
+
+// cert.pem and key.pem were generated with generate_cert.go
+// Thus, they have no ExtKeyUsage fields and trigger an error
+// when verification is turned on.
+
+var clicert = loadPEMCert(`
+-----BEGIN CERTIFICATE-----
+MIIB7TCCAVigAwIBAgIBADALBgkqhkiG9w0BAQUwJjEQMA4GA1UEChMHQWNtZSBD
+bzESMBAGA1UEAxMJMTI3LjAuMC4xMB4XDTExMTIwODA3NTUxMloXDTEyMTIwNzA4
+MDAxMlowJjEQMA4GA1UEChMHQWNtZSBDbzESMBAGA1UEAxMJMTI3LjAuMC4xMIGc
+MAsGCSqGSIb3DQEBAQOBjAAwgYgCgYBO0Hsx44Jk2VnAwoekXh6LczPHY1PfZpIG
+hPZk1Y/kNqcdK+izIDZFI7Xjla7t4PUgnI2V339aEu+H5Fto5OkOdOwEin/ekyfE
+ARl6vfLcPRSr0FTKIQzQTW6HLlzF0rtNS0/Otiz3fojsfNcCkXSmHgwa2uNKWi7e
+E5xMQIhZkwIDAQABozIwMDAOBgNVHQ8BAf8EBAMCAKAwDQYDVR0OBAYEBAECAwQw
+DwYDVR0jBAgwBoAEAQIDBDALBgkqhkiG9w0BAQUDgYEANh+zegx1yW43RmEr1b3A
+p0vMRpqBWHyFeSnIyMZn3TJWRSt1tukkqVCavh9a+hoV2cxVlXIWg7nCto/9iIw4
+hB2rXZIxE0/9gzvGnfERYraL7KtnvshksBFQRlgXa5kc0x38BvEO5ZaoDPl4ILdE
+GFGNEH5PlGffo05wc46QkYU=
+-----END CERTIFICATE-----
+`)
+
+/* corresponding key.pem for cert.pem is:
+-----BEGIN RSA PRIVATE KEY-----
+MIICXAIBAAKBgE7QezHjgmTZWcDCh6ReHotzM8djU99mkgaE9mTVj+Q2px0r6LMg
+NkUjteOVru3g9SCcjZXff1oS74fkW2jk6Q507ASKf96TJ8QBGXq98tw9FKvQVMoh
+DNBNbocuXMXSu01LT862LPd+iOx81wKRdKYeDBra40paLt4TnExAiFmTAgMBAAEC
+gYBxvXd8yNteFTns8A/2yomEMC4yeosJJSpp1CsN3BJ7g8/qTnrVPxBy+RU+qr63
+t2WquaOu/cr5P8iEsa6lk20tf8pjKLNXeX0b1RTzK8rJLbS7nGzP3tvOhL096VtQ
+dAo4ROEaro0TzYpHmpciSvxVIeEIAAdFDObDJPKqcJAxyQJBAJizfYgK8Gzx9fsx
+hxp+VteCbVPg2euASH5Yv3K5LukRdKoSzHE2grUVQgN/LafC0eZibRanxHegYSr7
+7qaswKUCQQCEIWor/X4XTMdVj3Oj+vpiw75y/S9gh682+myZL+d/02IEkwnB098P
+RkKVpenBHyrGg0oeN5La7URILWKj7CPXAkBKo6F+d+phNjwIFoN1Xb/RA32w/D1I
+saG9sF+UEhRt9AxUfW/U/tIQ9V0ZHHcSg1XaCM5Nvp934brdKdvTOKnJAkBD5h/3
+Rybatlvg/fzBEaJFyq09zhngkxlZOUtBVTqzl17RVvY2orgH02U4HbCHy4phxOn7
+qTdQRYlHRftgnWK1AkANibn9PRYJ7mJyJ9Dyj2QeNcSkSTzrt0tPvUMf4+meJymN
+1Ntu5+S1DLLzfxlaljWG6ylW6DNxujCyuXIV2rvAMAA=
+-----END RSA PRIVATE KEY-----
+*/
index 79ab50231293045e6b0eed4f0efd88818163fb24..28e93a0be69c5165006c89be8f5058004107b43e 100644 (file)
@@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
 
 // LoadX509KeyPair reads and parses a public/private key pair from a pair of
 // files. The files must contain PEM encoded data.
-func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err error) {
+func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
        certPEMBlock, err := ioutil.ReadFile(certFile)
        if err != nil {
                return
index 5a0a87678e37336d238c23f222e89e1bbba05b45..616a0b3c1e8570b5d2b236f0ddf32f5ffb8d34ba 100644 (file)
@@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
 
        return
 }
+
+// Subjects returns a list of the DER-encoded subjects of
+// all of the certificates in the pool. 
+func (s *CertPool) Subjects() (res [][]byte) {
+       res = make([][]byte, len(s.certs))
+       for i, c := range s.certs {
+               res[i] = c.RawSubject
+       }
+       return
+}
index e5c29889b7f8d5510d658cb5a761ecff70a412e0..b90181bdc64103ba86d4c1961a6acc412067fbb8 100644 (file)
@@ -7,14 +7,14 @@ package gosym
 import (
        "debug/elf"
        "os"
-       "syscall"
+       "runtime"
        "testing"
 )
 
 func dotest() bool {
        // For now, only works on ELF platforms.
        // TODO: convert to work with new go tool
-       return false && syscall.OS == "linux" && os.Getenv("GOARCH") == "amd64"
+       return false && runtime.GOOS == "linux" && runtime.GOARCH == "amd64"
 }
 
 func getTable(t *testing.T) *Table {
index 22a0dde0da43d8e1d28a575d8429d3dbeaecbe62..4d1ae38c4edaa1ce4203da112dd81aa287d92b71 100644 (file)
@@ -786,7 +786,8 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
 // Because Unmarshal uses the reflect package, the structs
 // being written to must use upper case field names.
 //
-// An ASN.1 INTEGER can be written to an int, int32 or int64.
+// An ASN.1 INTEGER can be written to an int, int32, int64,
+// or *big.Int (from the math/big package).
 // If the encoded value does not fit in the Go type,
 // Unmarshal returns a parse error.
 //
index 09f94139f906218200b69f4e6cceef1c0c5647bf..92c9eb62d2c6b373bbdb94e233cf07254e7abc6d 100644 (file)
@@ -6,6 +6,7 @@ package asn1
 
 import (
        "bytes"
+       "math/big"
        "reflect"
        "testing"
        "time"
@@ -351,6 +352,10 @@ type TestElementsAfterString struct {
        A, B int
 }
 
+type TestBigInt struct {
+       X *big.Int
+}
+
 var unmarshalTestData = []struct {
        in  []byte
        out interface{}
@@ -369,6 +374,7 @@ var unmarshalTestData = []struct {
        {[]byte{0x01, 0x01, 0x00}, newBool(false)},
        {[]byte{0x01, 0x01, 0x01}, newBool(true)},
        {[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}},
+       {[]byte{0x30, 0x05, 0x02, 0x03, 0x12, 0x34, 0x56}, &TestBigInt{big.NewInt(0x123456)}},
 }
 
 func TestUnmarshal(t *testing.T) {
index d05b5d8d4e92e398bd359b4258410d8f00a56abf..a7447f978127c3362c1cf4691f63d6a63abf84d6 100644 (file)
@@ -7,6 +7,7 @@ package asn1
 import (
        "bytes"
        "encoding/hex"
+       "math/big"
        "testing"
        "time"
 )
@@ -20,6 +21,10 @@ type twoIntStruct struct {
        B int
 }
 
+type bigIntStruct struct {
+       A *big.Int
+}
+
 type nestedStruct struct {
        A intStruct
 }
@@ -65,6 +70,7 @@ var marshalTests = []marshalTest{
        {-128, "020180"},
        {-129, "0202ff7f"},
        {intStruct{64}, "3003020140"},
+       {bigIntStruct{big.NewInt(0x123456)}, "30050203123456"},
        {twoIntStruct{64, 65}, "3006020140020141"},
        {nestedStruct{intStruct{127}}, "3005300302017f"},
        {[]byte{1, 2, 3}, "0403010203"},
index ba1f2eb813078b823926d82a17d01e37dda492e9..4d1325d176c61cc0455047559e36725c17cfb1ed 100644 (file)
@@ -1039,9 +1039,9 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
                // Extract and compare element types.
                var sw *sliceType
                if tt, ok := builtinIdToType[fw]; ok {
-                       sw = tt.(*sliceType)
-               } else {
-                       sw = dec.wireType[fw].SliceT
+                       sw, _ = tt.(*sliceType)
+               } else if wire != nil {
+                       sw = wire.SliceT
                }
                elem := userType(t.Elem()).base
                return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
index cd1500d0772546afdad8427c3bb9b3cf5807a0b3..7a30f9107e636638e6d666fce46792d5bb639be2 100644 (file)
@@ -678,3 +678,11 @@ func TestUnexportedChan(t *testing.T) {
                t.Fatalf("error encoding unexported channel: %s", err)
        }
 }
+
+func TestSliceIncompatibility(t *testing.T) {
+       var in = []byte{1, 2, 3}
+       var out []int
+       if err := encAndDec(in, &out); err == nil {
+               t.Error("expected compatibility error")
+       }
+}
index 8287b330034ab01da4768e63c02e2672978daf09..87076b53dc06578d4b21b24293a957f1eb8d9566 100644 (file)
@@ -10,6 +10,7 @@ package json
 import (
        "encoding/base64"
        "errors"
+       "fmt"
        "reflect"
        "runtime"
        "strconv"
@@ -538,7 +539,7 @@ func (d *decodeState) object(v reflect.Value) {
                // Read value.
                if destring {
                        d.value(reflect.ValueOf(&d.tempstr))
-                       d.literalStore([]byte(d.tempstr), subv)
+                       d.literalStore([]byte(d.tempstr), subv, true)
                } else {
                        d.value(subv)
                }
@@ -571,11 +572,15 @@ func (d *decodeState) literal(v reflect.Value) {
        d.off--
        d.scan.undo(op)
 
-       d.literalStore(d.data[start:d.off], v)
+       d.literalStore(d.data[start:d.off], v, false)
 }
 
 // literalStore decodes a literal stored in item into v.
-func (d *decodeState) literalStore(item []byte, v reflect.Value) {
+//
+// fromQuoted indicates whether this literal came from unwrapping a
+// string from the ",string" struct tag option. this is used only to
+// produce more helpful error messages.
+func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) {
        // Check for unmarshaler.
        wantptr := item[0] == 'n' // null
        unmarshaler, pv := d.indirect(v, wantptr)
@@ -601,7 +606,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
                value := c == 't'
                switch v.Kind() {
                default:
-                       d.saveError(&UnmarshalTypeError{"bool", v.Type()})
+                       if fromQuoted {
+                               d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+                       } else {
+                               d.saveError(&UnmarshalTypeError{"bool", v.Type()})
+                       }
                case reflect.Bool:
                        v.SetBool(value)
                case reflect.Interface:
@@ -611,7 +620,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
        case '"': // string
                s, ok := unquoteBytes(item)
                if !ok {
-                       d.error(errPhase)
+                       if fromQuoted {
+                               d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+                       } else {
+                               d.error(errPhase)
+                       }
                }
                switch v.Kind() {
                default:
@@ -636,12 +649,20 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
 
        default: // number
                if c != '-' && (c < '0' || c > '9') {
-                       d.error(errPhase)
+                       if fromQuoted {
+                               d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+                       } else {
+                               d.error(errPhase)
+                       }
                }
                s := string(item)
                switch v.Kind() {
                default:
-                       d.error(&UnmarshalTypeError{"number", v.Type()})
+                       if fromQuoted {
+                               d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
+                       } else {
+                               d.error(&UnmarshalTypeError{"number", v.Type()})
+                       }
                case reflect.Interface:
                        n, err := strconv.ParseFloat(s, 64)
                        if err != nil {
index 05c8a064a422c8bba383aa1683ebaa72a0ea4793..cc3103f032fb26c1c23cdfac08cbd7ca3c026eeb 100644 (file)
@@ -258,13 +258,10 @@ type wrongStringTest struct {
        in, err string
 }
 
-// TODO(bradfitz): as part of Issue 2331, fix these tests' expected
-// error values to be helpful, rather than the confusing messages they
-// are now.
 var wrongStringTests = []wrongStringTest{
-       {`{"result":"x"}`, "JSON decoder out of sync - data changing underfoot?"},
-       {`{"result":"foo"}`, "json: cannot unmarshal bool into Go value of type string"},
-       {`{"result":"123"}`, "json: cannot unmarshal number into Go value of type string"},
+       {`{"result":"x"}`, `json: invalid use of ,string struct tag, trying to unmarshal "x" into string`},
+       {`{"result":"foo"}`, `json: invalid use of ,string struct tag, trying to unmarshal "foo" into string`},
+       {`{"result":"123"}`, `json: invalid use of ,string struct tag, trying to unmarshal "123" into string`},
 }
 
 // If people misuse the ,string modifier, the error message should be
index 3d2f4fc316ea8e837995399d637851d3832ac7c3..033da2d0ade140aff6389f0e886c99645f4cc203 100644 (file)
@@ -12,6 +12,7 @@ package json
 import (
        "bytes"
        "encoding/base64"
+       "math"
        "reflect"
        "runtime"
        "sort"
@@ -170,6 +171,15 @@ func (e *UnsupportedTypeError) Error() string {
        return "json: unsupported type: " + e.Type.String()
 }
 
+type UnsupportedValueError struct {
+       Value reflect.Value
+       Str   string
+}
+
+func (e *UnsupportedValueError) Error() string {
+       return "json: unsupported value: " + e.Str
+}
+
 type InvalidUTF8Error struct {
        S string
 }
@@ -290,7 +300,11 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
                        e.Write(b)
                }
        case reflect.Float32, reflect.Float64:
-               b := strconv.AppendFloat(e.scratch[:0], v.Float(), 'g', -1, v.Type().Bits())
+               f := v.Float()
+               if math.IsInf(f, 0) || math.IsNaN(f) {
+                       e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits())})
+               }
+               b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, v.Type().Bits())
                if quoted {
                        writeString(e, string(b))
                } else {
index 9366589f252e71b710cd5a4e2bab7cca8a2c733c..0e39559a463405ccbe164cd24794737dc2f37481 100644 (file)
@@ -6,6 +6,7 @@ package json
 
 import (
        "bytes"
+       "math"
        "reflect"
        "testing"
 )
@@ -107,3 +108,21 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
                t.Errorf(" got %s want %s", result, expect)
        }
 }
+
+var unsupportedValues = []interface{}{
+       math.NaN(),
+       math.Inf(-1),
+       math.Inf(1),
+}
+
+func TestUnsupportedValues(t *testing.T) {
+       for _, v := range unsupportedValues {
+               if _, err := Marshal(v); err != nil {
+                       if _, ok := err.(*UnsupportedValueError); !ok {
+                               t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
+                       }
+               } else {
+                       t.Errorf("for %v, expected error", v)
+               }
+       }
+}
index d365510bf583e1e20376a43ea38ecbb3dadadf59..8d003aade0770fb85fcbd2757d0791910fce217a 100644 (file)
@@ -5,6 +5,7 @@
 package xml
 
 var atomValue = &Feed{
+       XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
        Title:   "Example Feed",
        Link:    []Link{{Href: "http://example.org/"}},
        Updated: ParseTime("2003-12-13T18:30:02Z"),
@@ -24,19 +25,19 @@ var atomValue = &Feed{
 
 var atomXml = `` +
        `<feed xmlns="http://www.w3.org/2005/Atom">` +
-       `<Title>Example Feed</Title>` +
-       `<Id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</Id>` +
-       `<Link href="http://example.org/"></Link>` +
-       `<Updated>2003-12-13T18:30:02Z</Updated>` +
-       `<Author><Name>John Doe</Name><URI></URI><Email></Email></Author>` +
-       `<Entry>` +
-       `<Title>Atom-Powered Robots Run Amok</Title>` +
-       `<Id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</Id>` +
-       `<Link href="http://example.org/2003/12/13/atom03"></Link>` +
-       `<Updated>2003-12-13T18:30:02Z</Updated>` +
-       `<Author><Name></Name><URI></URI><Email></Email></Author>` +
-       `<Summary>Some text.</Summary>` +
-       `</Entry>` +
+       `<title>Example Feed</title>` +
+       `<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` +
+       `<link href="http://example.org/"></link>` +
+       `<updated>2003-12-13T18:30:02Z</updated>` +
+       `<author><name>John Doe</name><uri></uri><email></email></author>` +
+       `<entry>` +
+       `<title>Atom-Powered Robots Run Amok</title>` +
+       `<id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</id>` +
+       `<link href="http://example.org/2003/12/13/atom03"></link>` +
+       `<updated>2003-12-13T18:30:02Z</updated>` +
+       `<author><name></name><uri></uri><email></email></author>` +
+       `<summary>Some text.</summary>` +
+       `</entry>` +
        `</feed>`
 
 func ParseTime(str string) Time {
diff --git a/libgo/go/encoding/xml/embed_test.go b/libgo/go/encoding/xml/embed_test.go
deleted file mode 100644 (file)
index ec7f478..0000000
+++ /dev/null
@@ -1,124 +0,0 @@
-// Copyright 2010 The Go Authors.  All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package xml
-
-import "testing"
-
-type C struct {
-       Name string
-       Open bool
-}
-
-type A struct {
-       XMLName Name `xml:"http://domain a"`
-       C
-       B      B
-       FieldA string
-}
-
-type B struct {
-       XMLName Name `xml:"b"`
-       C
-       FieldB string
-}
-
-const _1a = `
-<?xml version="1.0" encoding="UTF-8"?>
-<a xmlns="http://domain">
-  <name>KmlFile</name>
-  <open>1</open>
-  <b>
-    <name>Absolute</name>
-    <open>0</open>
-    <fieldb>bar</fieldb>
-  </b>
-  <fielda>foo</fielda>
-</a>
-`
-
-// Tests that embedded structs are marshalled.
-func TestEmbedded1(t *testing.T) {
-       var a A
-       if e := Unmarshal(StringReader(_1a), &a); e != nil {
-               t.Fatalf("Unmarshal: %s", e)
-       }
-       if a.FieldA != "foo" {
-               t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.FieldA)
-       }
-       if a.Name != "KmlFile" {
-               t.Fatalf("Unmarshal: expected 'KmlFile' but found '%s'", a.Name)
-       }
-       if !a.Open {
-               t.Fatal("Unmarshal: expected 'true' but found otherwise")
-       }
-       if a.B.FieldB != "bar" {
-               t.Fatalf("Unmarshal: expected 'bar' but found '%s'", a.B.FieldB)
-       }
-       if a.B.Name != "Absolute" {
-               t.Fatalf("Unmarshal: expected 'Absolute' but found '%s'", a.B.Name)
-       }
-       if a.B.Open {
-               t.Fatal("Unmarshal: expected 'false' but found otherwise")
-       }
-}
-
-type A2 struct {
-       XMLName Name `xml:"http://domain a"`
-       XY      string
-       Xy      string
-}
-
-const _2a = `
-<?xml version="1.0" encoding="UTF-8"?>
-<a xmlns="http://domain">
-  <xy>foo</xy>
-</a>
-`
-
-// Tests that conflicting field names get excluded.
-func TestEmbedded2(t *testing.T) {
-       var a A2
-       if e := Unmarshal(StringReader(_2a), &a); e != nil {
-               t.Fatalf("Unmarshal: %s", e)
-       }
-       if a.XY != "" {
-               t.Fatalf("Unmarshal: expected empty string but found '%s'", a.XY)
-       }
-       if a.Xy != "" {
-               t.Fatalf("Unmarshal: expected empty string but found '%s'", a.Xy)
-       }
-}
-
-type A3 struct {
-       XMLName Name `xml:"http://domain a"`
-       xy      string
-}
-
-// Tests that private fields are not set.
-func TestEmbedded3(t *testing.T) {
-       var a A3
-       if e := Unmarshal(StringReader(_2a), &a); e != nil {
-               t.Fatalf("Unmarshal: %s", e)
-       }
-       if a.xy != "" {
-               t.Fatalf("Unmarshal: expected empty string but found '%s'", a.xy)
-       }
-}
-
-type A4 struct {
-       XMLName Name `xml:"http://domain a"`
-       Any     string
-}
-
-// Tests that private fields are not set.
-func TestEmbedded4(t *testing.T) {
-       var a A4
-       if e := Unmarshal(StringReader(_2a), &a); e != nil {
-               t.Fatalf("Unmarshal: %s", e)
-       }
-       if a.Any != "foo" {
-               t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.Any)
-       }
-}
index e94fdbc531f29f0f43025279df7e34175cd45e85..d25ee30a72b0b82387313cdcb41ac779615ba61d 100644 (file)
@@ -6,6 +6,8 @@ package xml
 
 import (
        "bufio"
+       "bytes"
+       "fmt"
        "io"
        "reflect"
        "strconv"
@@ -42,20 +44,26 @@ type printer struct {
 // elements containing the data.
 //
 // The name for the XML elements is taken from, in order of preference:
-//     - the tag on an XMLName field, if the data is a struct
-//     - the value of an XMLName field of type xml.Name
+//     - the tag on the XMLName field, if the data is a struct
+//     - the value of the XMLName field of type xml.Name
 //     - the tag of the struct field used to obtain the data
 //     - the name of the struct field used to obtain the data
-//     - the name '???'.
+//     - the name of the marshalled type
 //
 // The XML element for a struct contains marshalled elements for each of the
 // exported fields of the struct, with these exceptions:
 //     - the XMLName field, described above, is omitted.
-//     - a field with tag "attr" becomes an attribute in the XML element.
-//     - a field with tag "chardata" is written as character data,
-//        not as an XML element.
-//     - a field with tag "innerxml" is written verbatim,
-//        not subject to the usual marshalling procedure.
+//     - a field with tag "name,attr" becomes an attribute with
+//       the given name in the XML element.
+//     - a field with tag ",attr" becomes an attribute with the
+//       field name in the in the XML element.
+//     - a field with tag ",chardata" is written as character data,
+//       not as an XML element.
+//     - a field with tag ",innerxml" is written verbatim, not subject
+//       to the usual marshalling procedure.
+//     - a field with tag ",comment" is written as an XML comment, not
+//       subject to the usual marshalling procedure. It must not contain
+//       the "--" string within it.
 //
 // If a field uses a tag "a>b>c", then the element c will be nested inside
 // parent elements a and b.  Fields that appear next to each other that name
@@ -63,17 +71,18 @@ type printer struct {
 //
 //     type Result struct {
 //             XMLName   xml.Name `xml:"result"`
+//             Id        int      `xml:"id,attr"`
 //             FirstName string   `xml:"person>name>first"`
 //             LastName  string   `xml:"person>name>last"`
 //             Age       int      `xml:"person>age"`
 //     }
 //
-//     xml.Marshal(w, &Result{FirstName: "John", LastName: "Doe", Age: 42})
+//     xml.Marshal(w, &Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
 //
 // would be marshalled as:
 //
 //     <result>
-//             <person>
+//             <person id="13">
 //                     <name>
 //                             <first>John</first>
 //                             <last>Doe</last>
@@ -85,12 +94,12 @@ type printer struct {
 // Marshal will return an error if asked to marshal a channel, function, or map.
 func Marshal(w io.Writer, v interface{}) (err error) {
        p := &printer{bufio.NewWriter(w)}
-       err = p.marshalValue(reflect.ValueOf(v), "???")
+       err = p.marshalValue(reflect.ValueOf(v), nil)
        p.Flush()
        return err
 }
 
-func (p *printer) marshalValue(val reflect.Value, name string) error {
+func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
        if !val.IsValid() {
                return nil
        }
@@ -115,58 +124,75 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
                if val.IsNil() {
                        return nil
                }
-               return p.marshalValue(val.Elem(), name)
+               return p.marshalValue(val.Elem(), finfo)
        }
 
        // Slices and arrays iterate over the elements. They do not have an enclosing tag.
        if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 {
                for i, n := 0, val.Len(); i < n; i++ {
-                       if err := p.marshalValue(val.Index(i), name); err != nil {
+                       if err := p.marshalValue(val.Index(i), finfo); err != nil {
                                return err
                        }
                }
                return nil
        }
 
-       // Find XML name
-       xmlns := ""
-       if kind == reflect.Struct {
-               if f, ok := typ.FieldByName("XMLName"); ok {
-                       if tag := f.Tag.Get("xml"); tag != "" {
-                               if i := strings.Index(tag, " "); i >= 0 {
-                                       xmlns, name = tag[:i], tag[i+1:]
-                               } else {
-                                       name = tag
-                               }
-                       } else if v, ok := val.FieldByIndex(f.Index).Interface().(Name); ok && v.Local != "" {
-                               xmlns, name = v.Space, v.Local
-                       }
+       tinfo, err := getTypeInfo(typ)
+       if err != nil {
+               return err
+       }
+
+       // Precedence for the XML element name is:
+       // 1. XMLName field in underlying struct;
+       // 2. field name/tag in the struct field; and
+       // 3. type name
+       var xmlns, name string
+       if tinfo.xmlname != nil {
+               xmlname := tinfo.xmlname
+               if xmlname.name != "" {
+                       xmlns, name = xmlname.xmlns, xmlname.name
+               } else if v, ok := val.FieldByIndex(xmlname.idx).Interface().(Name); ok && v.Local != "" {
+                       xmlns, name = v.Space, v.Local
+               }
+       }
+       if name == "" && finfo != nil {
+               xmlns, name = finfo.xmlns, finfo.name
+       }
+       if name == "" {
+               name = typ.Name()
+               if name == "" {
+                       return &UnsupportedTypeError{typ}
                }
        }
 
        p.WriteByte('<')
        p.WriteString(name)
 
+       if xmlns != "" {
+               p.WriteString(` xmlns="`)
+               // TODO: EscapeString, to avoid the allocation.
+               Escape(p, []byte(xmlns))
+               p.WriteByte('"')
+       }
+
        // Attributes
-       if kind == reflect.Struct {
-               if len(xmlns) > 0 {
-                       p.WriteString(` xmlns="`)
-                       Escape(p, []byte(xmlns))
-                       p.WriteByte('"')
+       for i := range tinfo.fields {
+               finfo := &tinfo.fields[i]
+               if finfo.flags&fAttr == 0 {
+                       continue
                }
-
-               for i, n := 0, typ.NumField(); i < n; i++ {
-                       if f := typ.Field(i); f.PkgPath == "" && f.Tag.Get("xml") == "attr" {
-                               if f.Type.Kind() == reflect.String {
-                                       if str := val.Field(i).String(); str != "" {
-                                               p.WriteByte(' ')
-                                               p.WriteString(strings.ToLower(f.Name))
-                                               p.WriteString(`="`)
-                                               Escape(p, []byte(str))
-                                               p.WriteByte('"')
-                                       }
-                               }
-                       }
+               var str string
+               if fv := val.FieldByIndex(finfo.idx); fv.Kind() == reflect.String {
+                       str = fv.String()
+               } else {
+                       str = fmt.Sprint(fv.Interface())
+               }
+               if str != "" {
+                       p.WriteByte(' ')
+                       p.WriteString(finfo.name)
+                       p.WriteString(`="`)
+                       Escape(p, []byte(str))
+                       p.WriteByte('"')
                }
        }
        p.WriteByte('>')
@@ -194,58 +220,9 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
                bytes := val.Interface().([]byte)
                Escape(p, bytes)
        case reflect.Struct:
-               s := parentStack{printer: p}
-               for i, n := 0, val.NumField(); i < n; i++ {
-                       if f := typ.Field(i); f.Name != "XMLName" && f.PkgPath == "" {
-                               name := f.Name
-                               vf := val.Field(i)
-                               switch tag := f.Tag.Get("xml"); tag {
-                               case "":
-                                       s.trim(nil)
-                               case "chardata":
-                                       if tk := f.Type.Kind(); tk == reflect.String {
-                                               Escape(p, []byte(vf.String()))
-                                       } else if tk == reflect.Slice {
-                                               if elem, ok := vf.Interface().([]byte); ok {
-                                                       Escape(p, elem)
-                                               }
-                                       }
-                                       continue
-                               case "innerxml":
-                                       iface := vf.Interface()
-                                       switch raw := iface.(type) {
-                                       case []byte:
-                                               p.Write(raw)
-                                               continue
-                                       case string:
-                                               p.WriteString(raw)
-                                               continue
-                                       }
-                               case "attr":
-                                       continue
-                               default:
-                                       parents := strings.Split(tag, ">")
-                                       if len(parents) == 1 {
-                                               parents, name = nil, tag
-                                       } else {
-                                               parents, name = parents[:len(parents)-1], parents[len(parents)-1]
-                                               if parents[0] == "" {
-                                                       parents[0] = f.Name
-                                               }
-                                       }
-
-                                       s.trim(parents)
-                                       if !(vf.Kind() == reflect.Ptr || vf.Kind() == reflect.Interface) || !vf.IsNil() {
-                                               s.push(parents[len(s.stack):])
-                                       }
-                               }
-
-                               if err := p.marshalValue(vf, name); err != nil {
-                                       return err
-                               }
-                       }
+               if err := p.marshalStruct(tinfo, val); err != nil {
+                       return err
                }
-               s.trim(nil)
        default:
                return &UnsupportedTypeError{typ}
        }
@@ -258,6 +235,94 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
        return nil
 }
 
+var ddBytes = []byte("--")
+
+func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
+       s := parentStack{printer: p}
+       for i := range tinfo.fields {
+               finfo := &tinfo.fields[i]
+               if finfo.flags&(fAttr|fAny) != 0 {
+                       continue
+               }
+               vf := val.FieldByIndex(finfo.idx)
+               switch finfo.flags & fMode {
+               case fCharData:
+                       switch vf.Kind() {
+                       case reflect.String:
+                               Escape(p, []byte(vf.String()))
+                       case reflect.Slice:
+                               if elem, ok := vf.Interface().([]byte); ok {
+                                       Escape(p, elem)
+                               }
+                       }
+                       continue
+
+               case fComment:
+                       k := vf.Kind()
+                       if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) {
+                               return fmt.Errorf("xml: bad type for comment field of %s", val.Type())
+                       }
+                       if vf.Len() == 0 {
+                               continue
+                       }
+                       p.WriteString("<!--")
+                       dashDash := false
+                       dashLast := false
+                       switch k {
+                       case reflect.String:
+                               s := vf.String()
+                               dashDash = strings.Index(s, "--") >= 0
+                               dashLast = s[len(s)-1] == '-'
+                               if !dashDash {
+                                       p.WriteString(s)
+                               }
+                       case reflect.Slice:
+                               b := vf.Bytes()
+                               dashDash = bytes.Index(b, ddBytes) >= 0
+                               dashLast = b[len(b)-1] == '-'
+                               if !dashDash {
+                                       p.Write(b)
+                               }
+                       default:
+                               panic("can't happen")
+                       }
+                       if dashDash {
+                               return fmt.Errorf(`xml: comments must not contain "--"`)
+                       }
+                       if dashLast {
+                               // "--->" is invalid grammar. Make it "- -->"
+                               p.WriteByte(' ')
+                       }
+                       p.WriteString("-->")
+                       continue
+
+               case fInnerXml:
+                       iface := vf.Interface()
+                       switch raw := iface.(type) {
+                       case []byte:
+                               p.Write(raw)
+                               continue
+                       case string:
+                               p.WriteString(raw)
+                               continue
+                       }
+
+               case fElement:
+                       s.trim(finfo.parents)
+                       if len(finfo.parents) > len(s.stack) {
+                               if vf.Kind() != reflect.Ptr && vf.Kind() != reflect.Interface || !vf.IsNil() {
+                                       s.push(finfo.parents[len(s.stack):])
+                               }
+                       }
+               }
+               if err := p.marshalValue(vf, finfo); err != nil {
+                       return err
+               }
+       }
+       s.trim(nil)
+       return nil
+}
+
 type parentStack struct {
        *printer
        stack []string
index 6a241694baf85651ad1082c7f354975fa599e4d0..bec53761e1a31ea25dedaddc01627354065e1386 100644 (file)
@@ -25,10 +25,10 @@ type Passenger struct {
 }
 
 type Ship struct {
-       XMLName Name `xml:"spaceship"`
+       XMLName struct{} `xml:"spaceship"`
 
-       Name      string       `xml:"attr"`
-       Pilot     string       `xml:"attr"`
+       Name      string       `xml:"name,attr"`
+       Pilot     string       `xml:"pilot,attr"`
        Drive     DriveType    `xml:"drive"`
        Age       uint         `xml:"age"`
        Passenger []*Passenger `xml:"passenger"`
@@ -44,48 +44,50 @@ func (rx RawXML) MarshalXML() ([]byte, error) {
 type NamedType string
 
 type Port struct {
-       XMLName Name   `xml:"port"`
-       Type    string `xml:"attr"`
-       Number  string `xml:"chardata"`
+       XMLName struct{} `xml:"port"`
+       Type    string   `xml:"type,attr"`
+       Comment string   `xml:",comment"`
+       Number  string   `xml:",chardata"`
 }
 
 type Domain struct {
-       XMLName Name   `xml:"domain"`
-       Country string `xml:"attr"`
-       Name    []byte `xml:"chardata"`
+       XMLName struct{} `xml:"domain"`
+       Country string   `xml:",attr"`
+       Name    []byte   `xml:",chardata"`
+       Comment []byte   `xml:",comment"`
 }
 
 type Book struct {
-       XMLName Name   `xml:"book"`
-       Title   string `xml:"chardata"`
+       XMLName struct{} `xml:"book"`
+       Title   string   `xml:",chardata"`
 }
 
 type SecretAgent struct {
-       XMLName   Name   `xml:"agent"`
-       Handle    string `xml:"attr"`
+       XMLName   struct{} `xml:"agent"`
+       Handle    string   `xml:"handle,attr"`
        Identity  string
-       Obfuscate string `xml:"innerxml"`
+       Obfuscate string `xml:",innerxml"`
 }
 
 type NestedItems struct {
-       XMLName Name     `xml:"result"`
+       XMLName struct{} `xml:"result"`
        Items   []string `xml:">item"`
        Item1   []string `xml:"Items>item1"`
 }
 
 type NestedOrder struct {
-       XMLName Name   `xml:"result"`
-       Field1  string `xml:"parent>c"`
-       Field2  string `xml:"parent>b"`
-       Field3  string `xml:"parent>a"`
+       XMLName struct{} `xml:"result"`
+       Field1  string   `xml:"parent>c"`
+       Field2  string   `xml:"parent>b"`
+       Field3  string   `xml:"parent>a"`
 }
 
 type MixedNested struct {
-       XMLName Name   `xml:"result"`
-       A       string `xml:"parent1>a"`
-       B       string `xml:"b"`
-       C       string `xml:"parent1>parent2>c"`
-       D       string `xml:"parent1>d"`
+       XMLName struct{} `xml:"result"`
+       A       string   `xml:"parent1>a"`
+       B       string   `xml:"b"`
+       C       string   `xml:"parent1>parent2>c"`
+       D       string   `xml:"parent1>d"`
 }
 
 type NilTest struct {
@@ -95,62 +97,165 @@ type NilTest struct {
 }
 
 type Service struct {
-       XMLName Name    `xml:"service"`
-       Domain  *Domain `xml:"host>domain"`
-       Port    *Port   `xml:"host>port"`
+       XMLName struct{} `xml:"service"`
+       Domain  *Domain  `xml:"host>domain"`
+       Port    *Port    `xml:"host>port"`
        Extra1  interface{}
        Extra2  interface{} `xml:"host>extra2"`
 }
 
 var nilStruct *Ship
 
+type EmbedA struct {
+       EmbedC
+       EmbedB EmbedB
+       FieldA string
+}
+
+type EmbedB struct {
+       FieldB string
+       EmbedC
+}
+
+type EmbedC struct {
+       FieldA1 string `xml:"FieldA>A1"`
+       FieldA2 string `xml:"FieldA>A2"`
+       FieldB  string
+       FieldC  string
+}
+
+type NameCasing struct {
+       XMLName struct{} `xml:"casing"`
+       Xy      string
+       XY      string
+       XyA     string `xml:"Xy,attr"`
+       XYA     string `xml:"XY,attr"`
+}
+
+type NamePrecedence struct {
+       XMLName     Name              `xml:"Parent"`
+       FromTag     XMLNameWithoutTag `xml:"InTag"`
+       FromNameVal XMLNameWithoutTag
+       FromNameTag XMLNameWithTag
+       InFieldName string
+}
+
+type XMLNameWithTag struct {
+       XMLName Name   `xml:"InXMLNameTag"`
+       Value   string ",chardata"
+}
+
+type XMLNameWithoutTag struct {
+       XMLName Name
+       Value   string ",chardata"
+}
+
+type AttrTest struct {
+       Int   int     `xml:",attr"`
+       Lower int     `xml:"int,attr"`
+       Float float64 `xml:",attr"`
+       Uint8 uint8   `xml:",attr"`
+       Bool  bool    `xml:",attr"`
+       Str   string  `xml:",attr"`
+}
+
+type AnyTest struct {
+       XMLName  struct{}  `xml:"a"`
+       Nested   string    `xml:"nested>value"`
+       AnyField AnyHolder `xml:",any"`
+}
+
+type AnyHolder struct {
+       XMLName Name
+       XML     string `xml:",innerxml"`
+}
+
+type RecurseA struct {
+       A string
+       B *RecurseB
+}
+
+type RecurseB struct {
+       A *RecurseA
+       B string
+}
+
+type Plain struct {
+       V interface{}
+}
+
+// Unless explicitly stated as such (or *Plain), all of the
+// tests below are two-way tests. When introducing new tests,
+// please try to make them two-way as well to ensure that
+// marshalling and unmarshalling are as symmetrical as feasible.
 var marshalTests = []struct {
-       Value     interface{}
-       ExpectXML string
+       Value         interface{}
+       ExpectXML     string
+       MarshalOnly   bool
+       UnmarshalOnly bool
 }{
        // Test nil marshals to nothing
-       {Value: nil, ExpectXML: ``},
-       {Value: nilStruct, ExpectXML: ``},
-
-       // Test value types (no tag name, so ???)
-       {Value: true, ExpectXML: `<???>true</???>`},
-       {Value: int(42), ExpectXML: `<???>42</???>`},
-       {Value: int8(42), ExpectXML: `<???>42</???>`},
-       {Value: int16(42), ExpectXML: `<???>42</???>`},
-       {Value: int32(42), ExpectXML: `<???>42</???>`},
-       {Value: uint(42), ExpectXML: `<???>42</???>`},
-       {Value: uint8(42), ExpectXML: `<???>42</???>`},
-       {Value: uint16(42), ExpectXML: `<???>42</???>`},
-       {Value: uint32(42), ExpectXML: `<???>42</???>`},
-       {Value: float32(1.25), ExpectXML: `<???>1.25</???>`},
-       {Value: float64(1.25), ExpectXML: `<???>1.25</???>`},
-       {Value: uintptr(0xFFDD), ExpectXML: `<???>65501</???>`},
-       {Value: "gopher", ExpectXML: `<???>gopher</???>`},
-       {Value: []byte("gopher"), ExpectXML: `<???>gopher</???>`},
-       {Value: "</>", ExpectXML: `<???>&lt;/&gt;</???>`},
-       {Value: []byte("</>"), ExpectXML: `<???>&lt;/&gt;</???>`},
-       {Value: [3]byte{'<', '/', '>'}, ExpectXML: `<???>&lt;/&gt;</???>`},
-       {Value: NamedType("potato"), ExpectXML: `<???>potato</???>`},
-       {Value: []int{1, 2, 3}, ExpectXML: `<???>1</???><???>2</???><???>3</???>`},
-       {Value: [3]int{1, 2, 3}, ExpectXML: `<???>1</???><???>2</???><???>3</???>`},
+       {Value: nil, ExpectXML: ``, MarshalOnly: true},
+       {Value: nilStruct, ExpectXML: ``, MarshalOnly: true},
+
+       // Test value types
+       {Value: &Plain{true}, ExpectXML: `<Plain><V>true</V></Plain>`},
+       {Value: &Plain{false}, ExpectXML: `<Plain><V>false</V></Plain>`},
+       {Value: &Plain{int(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{int8(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{int16(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{int32(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{uint(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{uint8(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{uint16(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{uint32(42)}, ExpectXML: `<Plain><V>42</V></Plain>`},
+       {Value: &Plain{float32(1.25)}, ExpectXML: `<Plain><V>1.25</V></Plain>`},
+       {Value: &Plain{float64(1.25)}, ExpectXML: `<Plain><V>1.25</V></Plain>`},
+       {Value: &Plain{uintptr(0xFFDD)}, ExpectXML: `<Plain><V>65501</V></Plain>`},
+       {Value: &Plain{"gopher"}, ExpectXML: `<Plain><V>gopher</V></Plain>`},
+       {Value: &Plain{[]byte("gopher")}, ExpectXML: `<Plain><V>gopher</V></Plain>`},
+       {Value: &Plain{"</>"}, ExpectXML: `<Plain><V>&lt;/&gt;</V></Plain>`},
+       {Value: &Plain{[]byte("</>")}, ExpectXML: `<Plain><V>&lt;/&gt;</V></Plain>`},
+       {Value: &Plain{[3]byte{'<', '/', '>'}}, ExpectXML: `<Plain><V>&lt;/&gt;</V></Plain>`},
+       {Value: &Plain{NamedType("potato")}, ExpectXML: `<Plain><V>potato</V></Plain>`},
+       {Value: &Plain{[]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`},
+       {Value: &Plain{[3]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`},
 
        // Test innerxml
-       {Value: RawXML("</>"), ExpectXML: `</>`},
        {
                Value: &SecretAgent{
                        Handle:    "007",
                        Identity:  "James Bond",
                        Obfuscate: "<redacted/>",
                },
-               //ExpectXML: `<agent handle="007"><redacted/></agent>`,
-               ExpectXML: `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`,
+               ExpectXML:   `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`,
+               MarshalOnly: true,
+       },
+       {
+               Value: &SecretAgent{
+                       Handle:    "007",
+                       Identity:  "James Bond",
+                       Obfuscate: "<Identity>James Bond</Identity><redacted/>",
+               },
+               ExpectXML:     `<agent handle="007"><Identity>James Bond</Identity><redacted/></agent>`,
+               UnmarshalOnly: true,
+       },
+
+       // Test marshaller interface
+       {
+               Value:       RawXML("</>"),
+               ExpectXML:   `</>`,
+               MarshalOnly: true,
        },
 
        // Test structs
        {Value: &Port{Type: "ssl", Number: "443"}, ExpectXML: `<port type="ssl">443</port>`},
        {Value: &Port{Number: "443"}, ExpectXML: `<port>443</port>`},
        {Value: &Port{Type: "<unix>"}, ExpectXML: `<port type="&lt;unix&gt;"></port>`},
+       {Value: &Port{Number: "443", Comment: "https"}, ExpectXML: `<port><!--https-->443</port>`},
+       {Value: &Port{Number: "443", Comment: "add space-"}, ExpectXML: `<port><!--add space- -->443</port>`, MarshalOnly: true},
        {Value: &Domain{Name: []byte("google.com&friends")}, ExpectXML: `<domain>google.com&amp;friends</domain>`},
+       {Value: &Domain{Name: []byte("google.com"), Comment: []byte(" &friends ")}, ExpectXML: `<domain>google.com<!-- &friends --></domain>`},
        {Value: &Book{Title: "Pride & Prejudice"}, ExpectXML: `<book>Pride &amp; Prejudice</book>`},
        {Value: atomValue, ExpectXML: atomXml},
        {
@@ -203,16 +308,25 @@ var marshalTests = []struct {
                        `</passenger>` +
                        `</spaceship>`,
        },
+
        // Test a>b
        {
-               Value: NestedItems{Items: []string{}, Item1: []string{}},
+               Value: &NestedItems{Items: nil, Item1: nil},
+               ExpectXML: `<result>` +
+                       `<Items>` +
+                       `</Items>` +
+                       `</result>`,
+       },
+       {
+               Value: &NestedItems{Items: []string{}, Item1: []string{}},
                ExpectXML: `<result>` +
                        `<Items>` +
                        `</Items>` +
                        `</result>`,
+               MarshalOnly: true,
        },
        {
-               Value: NestedItems{Items: []string{}, Item1: []string{"A"}},
+               Value: &NestedItems{Items: nil, Item1: []string{"A"}},
                ExpectXML: `<result>` +
                        `<Items>` +
                        `<item1>A</item1>` +
@@ -220,7 +334,7 @@ var marshalTests = []struct {
                        `</result>`,
        },
        {
-               Value: NestedItems{Items: []string{"A", "B"}, Item1: []string{}},
+               Value: &NestedItems{Items: []string{"A", "B"}, Item1: nil},
                ExpectXML: `<result>` +
                        `<Items>` +
                        `<item>A</item>` +
@@ -229,7 +343,7 @@ var marshalTests = []struct {
                        `</result>`,
        },
        {
-               Value: NestedItems{Items: []string{"A", "B"}, Item1: []string{"C"}},
+               Value: &NestedItems{Items: []string{"A", "B"}, Item1: []string{"C"}},
                ExpectXML: `<result>` +
                        `<Items>` +
                        `<item>A</item>` +
@@ -239,7 +353,7 @@ var marshalTests = []struct {
                        `</result>`,
        },
        {
-               Value: NestedOrder{Field1: "C", Field2: "B", Field3: "A"},
+               Value: &NestedOrder{Field1: "C", Field2: "B", Field3: "A"},
                ExpectXML: `<result>` +
                        `<parent>` +
                        `<c>C</c>` +
@@ -249,16 +363,17 @@ var marshalTests = []struct {
                        `</result>`,
        },
        {
-               Value: NilTest{A: "A", B: nil, C: "C"},
-               ExpectXML: `<???>` +
+               Value: &NilTest{A: "A", B: nil, C: "C"},
+               ExpectXML: `<NilTest>` +
                        `<parent1>` +
                        `<parent2><a>A</a></parent2>` +
                        `<parent2><c>C</c></parent2>` +
                        `</parent1>` +
-                       `</???>`,
+                       `</NilTest>`,
+               MarshalOnly: true, // Uses interface{}
        },
        {
-               Value: MixedNested{A: "A", B: "B", C: "C", D: "D"},
+               Value: &MixedNested{A: "A", B: "B", C: "C", D: "D"},
                ExpectXML: `<result>` +
                        `<parent1><a>A</a></parent1>` +
                        `<b>B</b>` +
@@ -269,32 +384,154 @@ var marshalTests = []struct {
                        `</result>`,
        },
        {
-               Value:     Service{Port: &Port{Number: "80"}},
+               Value:     &Service{Port: &Port{Number: "80"}},
                ExpectXML: `<service><host><port>80</port></host></service>`,
        },
        {
-               Value:     Service{},
+               Value:     &Service{},
                ExpectXML: `<service></service>`,
        },
        {
-               Value: Service{Port: &Port{Number: "80"}, Extra1: "A", Extra2: "B"},
+               Value: &Service{Port: &Port{Number: "80"}, Extra1: "A", Extra2: "B"},
                ExpectXML: `<service>` +
                        `<host><port>80</port></host>` +
                        `<Extra1>A</Extra1>` +
                        `<host><extra2>B</extra2></host>` +
                        `</service>`,
+               MarshalOnly: true,
        },
        {
-               Value: Service{Port: &Port{Number: "80"}, Extra2: "example"},
+               Value: &Service{Port: &Port{Number: "80"}, Extra2: "example"},
                ExpectXML: `<service>` +
                        `<host><port>80</port></host>` +
                        `<host><extra2>example</extra2></host>` +
                        `</service>`,
+               MarshalOnly: true,
+       },
+
+       // Test struct embedding
+       {
+               Value: &EmbedA{
+                       EmbedC: EmbedC{
+                               FieldA1: "", // Shadowed by A.A
+                               FieldA2: "", // Shadowed by A.A
+                               FieldB:  "A.C.B",
+                               FieldC:  "A.C.C",
+                       },
+                       EmbedB: EmbedB{
+                               FieldB: "A.B.B",
+                               EmbedC: EmbedC{
+                                       FieldA1: "A.B.C.A1",
+                                       FieldA2: "A.B.C.A2",
+                                       FieldB:  "", // Shadowed by A.B.B
+                                       FieldC:  "A.B.C.C",
+                               },
+                       },
+                       FieldA: "A.A",
+               },
+               ExpectXML: `<EmbedA>` +
+                       `<FieldB>A.C.B</FieldB>` +
+                       `<FieldC>A.C.C</FieldC>` +
+                       `<EmbedB>` +
+                       `<FieldB>A.B.B</FieldB>` +
+                       `<FieldA>` +
+                       `<A1>A.B.C.A1</A1>` +
+                       `<A2>A.B.C.A2</A2>` +
+                       `</FieldA>` +
+                       `<FieldC>A.B.C.C</FieldC>` +
+                       `</EmbedB>` +
+                       `<FieldA>A.A</FieldA>` +
+                       `</EmbedA>`,
+       },
+
+       // Test that name casing matters
+       {
+               Value:     &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"},
+               ExpectXML: `<casing Xy="mixedA" XY="upperA"><Xy>mixed</Xy><XY>upper</XY></casing>`,
+       },
+
+       // Test the order in which the XML element name is chosen
+       {
+               Value: &NamePrecedence{
+                       FromTag:     XMLNameWithoutTag{Value: "A"},
+                       FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "InXMLName"}, Value: "B"},
+                       FromNameTag: XMLNameWithTag{Value: "C"},
+                       InFieldName: "D",
+               },
+               ExpectXML: `<Parent>` +
+                       `<InTag><Value>A</Value></InTag>` +
+                       `<InXMLName><Value>B</Value></InXMLName>` +
+                       `<InXMLNameTag><Value>C</Value></InXMLNameTag>` +
+                       `<InFieldName>D</InFieldName>` +
+                       `</Parent>`,
+               MarshalOnly: true,
+       },
+       {
+               Value: &NamePrecedence{
+                       XMLName:     Name{Local: "Parent"},
+                       FromTag:     XMLNameWithoutTag{XMLName: Name{Local: "InTag"}, Value: "A"},
+                       FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "FromNameVal"}, Value: "B"},
+                       FromNameTag: XMLNameWithTag{XMLName: Name{Local: "InXMLNameTag"}, Value: "C"},
+                       InFieldName: "D",
+               },
+               ExpectXML: `<Parent>` +
+                       `<InTag><Value>A</Value></InTag>` +
+                       `<FromNameVal><Value>B</Value></FromNameVal>` +
+                       `<InXMLNameTag><Value>C</Value></InXMLNameTag>` +
+                       `<InFieldName>D</InFieldName>` +
+                       `</Parent>`,
+               UnmarshalOnly: true,
+       },
+
+       // Test attributes
+       {
+               Value: &AttrTest{
+                       Int:   8,
+                       Lower: 9,
+                       Float: 23.5,
+                       Uint8: 255,
+                       Bool:  true,
+                       Str:   "s",
+               },
+               ExpectXML: `<AttrTest Int="8" int="9" Float="23.5" Uint8="255" Bool="true" Str="s"></AttrTest>`,
+       },
+
+       // Test ",any"
+       {
+               ExpectXML: `<a><nested><value>known</value></nested><other><sub>unknown</sub></other></a>`,
+               Value: &AnyTest{
+                       Nested: "known",
+                       AnyField: AnyHolder{
+                               XMLName: Name{Local: "other"},
+                               XML:     "<sub>unknown</sub>",
+                       },
+               },
+               UnmarshalOnly: true,
+       },
+       {
+               Value:       &AnyTest{Nested: "known", AnyField: AnyHolder{XML: "<unknown/>"}},
+               ExpectXML:   `<a><nested><value>known</value></nested></a>`,
+               MarshalOnly: true,
+       },
+
+       // Test recursive types.
+       {
+               Value: &RecurseA{
+                       A: "a1",
+                       B: &RecurseB{
+                               A: &RecurseA{"a2", nil},
+                               B: "b1",
+                       },
+               },
+               ExpectXML: `<RecurseA><A>a1</A><B><A><A>a2</A></A><B>b1</B></B></RecurseA>`,
        },
 }
 
 func TestMarshal(t *testing.T) {
        for idx, test := range marshalTests {
+               if test.UnmarshalOnly {
+                       continue
+               }
                buf := bytes.NewBuffer(nil)
                err := Marshal(buf, test.Value)
                if err != nil {
@@ -303,9 +540,9 @@ func TestMarshal(t *testing.T) {
                }
                if got, want := buf.String(), test.ExpectXML; got != want {
                        if strings.Contains(want, "\n") {
-                               t.Errorf("#%d: marshal(%#v) - GOT:\n%s\nWANT:\n%s", idx, test.Value, got, want)
+                               t.Errorf("#%d: marshal(%#v):\nHAVE:\n%s\nWANT:\n%s", idx, test.Value, got, want)
                        } else {
-                               t.Errorf("#%d: marshal(%#v) = %#q want %#q", idx, test.Value, got, want)
+                               t.Errorf("#%d: marshal(%#v):\nhave %#q\nwant %#q", idx, test.Value, got, want)
                        }
                }
        }
@@ -334,6 +571,10 @@ var marshalErrorTests = []struct {
                Err:   "xml: unsupported type: map[*xml.Ship]bool",
                Kind:  reflect.Map,
        },
+       {
+               Value: &Domain{Comment: []byte("f--bar")},
+               Err:   `xml: comments must not contain "--"`,
+       },
 }
 
 func TestMarshalErrors(t *testing.T) {
@@ -341,10 +582,12 @@ func TestMarshalErrors(t *testing.T) {
                buf := bytes.NewBuffer(nil)
                err := Marshal(buf, test.Value)
                if err == nil || err.Error() != test.Err {
-                       t.Errorf("#%d: marshal(%#v) = [error] %q, want %q", idx, test.Value, err, test.Err)
+                       t.Errorf("#%d: marshal(%#v) = [error] %v, want %v", idx, test.Value, err, test.Err)
                }
-               if kind := err.(*UnsupportedTypeError).Type.Kind(); kind != test.Kind {
-                       t.Errorf("#%d: marshal(%#v) = [error kind] %s, want %s", idx, test.Value, kind, test.Kind)
+               if test.Kind != reflect.Invalid {
+                       if kind := err.(*UnsupportedTypeError).Type.Kind(); kind != test.Kind {
+                               t.Errorf("#%d: marshal(%#v) = [error kind] %s, want %s", idx, test.Value, kind, test.Kind)
+                       }
                }
        }
 }
@@ -352,39 +595,20 @@ func TestMarshalErrors(t *testing.T) {
 // Do invertibility testing on the various structures that we test
 func TestUnmarshal(t *testing.T) {
        for i, test := range marshalTests {
-               // Skip the nil pointers
-               if i <= 1 {
+               if test.MarshalOnly {
                        continue
                }
-
-               var dest interface{}
-
-               switch test.Value.(type) {
-               case *Ship, Ship:
-                       dest = &Ship{}
-               case *Port, Port:
-                       dest = &Port{}
-               case *Domain, Domain:
-                       dest = &Domain{}
-               case *Feed, Feed:
-                       dest = &Feed{}
-               default:
+               if _, ok := test.Value.(*Plain); ok {
                        continue
                }
 
+               vt := reflect.TypeOf(test.Value)
+               dest := reflect.New(vt.Elem()).Interface()
                buffer := bytes.NewBufferString(test.ExpectXML)
                err := Unmarshal(buffer, dest)
 
-               // Don't compare XMLNames
                switch fix := dest.(type) {
-               case *Ship:
-                       fix.XMLName = Name{}
-               case *Port:
-                       fix.XMLName = Name{}
-               case *Domain:
-                       fix.XMLName = Name{}
                case *Feed:
-                       fix.XMLName = Name{}
                        fix.Author.InnerXML = ""
                        for i := range fix.Entry {
                                fix.Entry[i].Author.InnerXML = ""
@@ -394,30 +618,23 @@ func TestUnmarshal(t *testing.T) {
                if err != nil {
                        t.Errorf("#%d: unexpected error: %#v", i, err)
                } else if got, want := dest, test.Value; !reflect.DeepEqual(got, want) {
-                       t.Errorf("#%d: unmarshal(%q) = %#v, want %#v", i, test.ExpectXML, got, want)
+                       t.Errorf("#%d: unmarshal(%q):\nhave %#v\nwant %#v", i, test.ExpectXML, got, want)
                }
        }
 }
 
 func BenchmarkMarshal(b *testing.B) {
-       idx := len(marshalTests) - 1
-       test := marshalTests[idx]
-
        buf := bytes.NewBuffer(nil)
        for i := 0; i < b.N; i++ {
-               Marshal(buf, test.Value)
+               Marshal(buf, atomValue)
                buf.Truncate(0)
        }
 }
 
 func BenchmarkUnmarshal(b *testing.B) {
-       idx := len(marshalTests) - 1
-       test := marshalTests[idx]
-       sm := &Ship{}
-       xml := []byte(test.ExpectXML)
-
+       xml := []byte(atomXml)
        for i := 0; i < b.N; i++ {
                buffer := bytes.NewBuffer(xml)
-               Unmarshal(buffer, sm)
+               Unmarshal(buffer, &Feed{})
        }
 }
index 6dd36541000820c0ec16aeec89ac3a21864fc3cd..dde68de3e7839c35fedce20c45a5ca2ad3883c16 100644 (file)
@@ -7,13 +7,10 @@ package xml
 import (
        "bytes"
        "errors"
-       "fmt"
        "io"
        "reflect"
        "strconv"
        "strings"
-       "unicode"
-       "unicode/utf8"
 )
 
 // BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
@@ -31,7 +28,7 @@ import (
 // For example, given these definitions:
 //
 //     type Email struct {
-//             Where string `xml:"attr"`
+//             Where string `xml:",attr"`
 //             Addr  string
 //     }
 //
@@ -64,7 +61,8 @@ import (
 //
 // via Unmarshal(r, &result) is equivalent to assigning
 //
-//     r = Result{xml.Name{"", "result"},
+//     r = Result{
+//             xml.Name{Local: "result"},
 //             "Grace R. Emlin", // name
 //             "phone",          // no phone given
 //             []Email{
@@ -87,9 +85,9 @@ import (
 // In the rules, the tag of a field refers to the value associated with the
 // key 'xml' in the struct field's tag (see the example above).
 //
-//   * If the struct has a field of type []byte or string with tag "innerxml",
-//      Unmarshal accumulates the raw XML nested inside the element
-//      in that field.  The rest of the rules still apply.
+//   * If the struct has a field of type []byte or string with tag
+//      ",innerxml", Unmarshal accumulates the raw XML nested inside the
+//      element in that field.  The rest of the rules still apply.
 //
 //   * If the struct has a field named XMLName of type xml.Name,
 //      Unmarshal records the element name in that field.
@@ -100,8 +98,9 @@ import (
 //      returns an error.
 //
 //   * If the XML element has an attribute whose name matches a
-//      struct field of type string with tag "attr", Unmarshal records
-//      the attribute value in that field.
+//      struct field name with an associated tag containing ",attr" or
+//      the explicit name in a struct field tag of the form "name,attr",
+//      Unmarshal records the attribute value in that field.
 //
 //   * If the XML element contains character data, that data is
 //      accumulated in the first struct field that has tag "chardata".
@@ -109,23 +108,30 @@ import (
 //      If there is no such field, the character data is discarded.
 //
 //   * If the XML element contains comments, they are accumulated in
-//      the first struct field that has tag "comments".  The struct
+//      the first struct field that has tag ",comments".  The struct
 //      field may have type []byte or string.  If there is no such
 //      field, the comments are discarded.
 //
 //   * If the XML element contains a sub-element whose name matches
-//      the prefix of a tag formatted as "a>b>c", unmarshal
+//      the prefix of a tag formatted as "a" or "a>b>c", unmarshal
 //      will descend into the XML structure looking for elements with the
-//      given names, and will map the innermost elements to that struct field.
-//      A tag starting with ">" is equivalent to one starting
+//      given names, and will map the innermost elements to that struct
+//      field. A tag starting with ">" is equivalent to one starting
 //      with the field name followed by ">".
 //
-//   * If the XML element contains a sub-element whose name
-//      matches a field whose tag is neither "attr" nor "chardata",
-//      Unmarshal maps the sub-element to that struct field.
-//      Otherwise, if the struct has a field named Any, unmarshal
+//   * If the XML element contains a sub-element whose name matches
+//      a struct field's XMLName tag and the struct field has no
+//      explicit name tag as per the previous rule, unmarshal maps
+//      the sub-element to that struct field.
+//
+//   * If the XML element contains a sub-element whose name matches a
+//      field without any mode flags (",attr", ",chardata", etc), Unmarshal
 //      maps the sub-element to that struct field.
 //
+//   * If the XML element contains a sub-element that hasn't matched any
+//      of the above rules and the struct has a field with tag ",any",
+//      unmarshal maps the sub-element to that struct field.
+//
 // Unmarshal maps an XML element to a string or []byte by saving the
 // concatenation of that element's character data in the string or
 // []byte.
@@ -169,18 +175,6 @@ type UnmarshalError string
 
 func (e UnmarshalError) Error() string { return string(e) }
 
-// A TagPathError represents an error in the unmarshalling process
-// caused by the use of field tags with conflicting paths.
-type TagPathError struct {
-       Struct       reflect.Type
-       Field1, Tag1 string
-       Field2, Tag2 string
-}
-
-func (e *TagPathError) Error() string {
-       return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
-}
-
 // The Parser's Unmarshal method is like xml.Unmarshal
 // except that it can be passed a pointer to the initial start element,
 // useful when a client reads some raw XML tokens itself
@@ -195,26 +189,6 @@ func (p *Parser) Unmarshal(val interface{}, start *StartElement) error {
        return p.unmarshal(v.Elem(), start)
 }
 
-// fieldName strips invalid characters from an XML name
-// to create a valid Go struct name.  It also converts the
-// name to lower case letters.
-func fieldName(original string) string {
-
-       var i int
-       //remove leading underscores, without exhausting all characters
-       for i = 0; i < len(original)-1 && original[i] == '_'; i++ {
-       }
-
-       return strings.Map(
-               func(x rune) rune {
-                       if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) {
-                               return unicode.ToLower(x)
-                       }
-                       return -1
-               },
-               original[i:])
-}
-
 // Unmarshal a single XML element into val.
 func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
        // Find start element if we need it.
@@ -246,15 +220,22 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
                saveXML      reflect.Value
                saveXMLIndex int
                saveXMLData  []byte
+               saveAny      reflect.Value
                sv           reflect.Value
-               styp         reflect.Type
-               fieldPaths   map[string]pathInfo
+               tinfo        *typeInfo
+               err          error
        )
 
        switch v := val; v.Kind() {
        default:
                return errors.New("unknown type " + v.Type().String())
 
+       case reflect.Interface:
+               // TODO: For now, simply ignore the field. In the near
+               //       future we may choose to unmarshal the start
+               //       element on it, if not nil.
+               return p.Skip()
+
        case reflect.Slice:
                typ := v.Type()
                if typ.Elem().Kind() == reflect.Uint8 {
@@ -288,75 +269,69 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
                saveData = v
 
        case reflect.Struct:
-               if _, ok := v.Interface().(Name); ok {
-                       v.Set(reflect.ValueOf(start.Name))
-                       break
-               }
-
                sv = v
                typ := sv.Type()
-               styp = typ
-               // Assign name.
-               if f, ok := typ.FieldByName("XMLName"); ok {
-                       // Validate element name.
-                       if tag := f.Tag.Get("xml"); tag != "" {
-                               ns := ""
-                               i := strings.LastIndex(tag, " ")
-                               if i >= 0 {
-                                       ns, tag = tag[0:i], tag[i+1:]
-                               }
-                               if tag != start.Name.Local {
-                                       return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">")
-                               }
-                               if ns != "" && ns != start.Name.Space {
-                                       e := "expected element <" + tag + "> in name space " + ns + " but have "
-                                       if start.Name.Space == "" {
-                                               e += "no name space"
-                                       } else {
-                                               e += start.Name.Space
-                                       }
-                                       return UnmarshalError(e)
+               tinfo, err = getTypeInfo(typ)
+               if err != nil {
+                       return err
+               }
+
+               // Validate and assign element name.
+               if tinfo.xmlname != nil {
+                       finfo := tinfo.xmlname
+                       if finfo.name != "" && finfo.name != start.Name.Local {
+                               return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">")
+                       }
+                       if finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
+                               e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have "
+                               if start.Name.Space == "" {
+                                       e += "no name space"
+                               } else {
+                                       e += start.Name.Space
                                }
+                               return UnmarshalError(e)
                        }
-
-                       // Save
-                       v := sv.FieldByIndex(f.Index)
-                       if _, ok := v.Interface().(Name); ok {
-                               v.Set(reflect.ValueOf(start.Name))
+                       fv := sv.FieldByIndex(finfo.idx)
+                       if _, ok := fv.Interface().(Name); ok {
+                               fv.Set(reflect.ValueOf(start.Name))
                        }
                }
 
                // Assign attributes.
                // Also, determine whether we need to save character data or comments.
-               for i, n := 0, typ.NumField(); i < n; i++ {
-                       f := typ.Field(i)
-                       switch f.Tag.Get("xml") {
-                       case "attr":
-                               strv := sv.FieldByIndex(f.Index)
+               for i := range tinfo.fields {
+                       finfo := &tinfo.fields[i]
+                       switch finfo.flags & fMode {
+                       case fAttr:
+                               strv := sv.FieldByIndex(finfo.idx)
                                // Look for attribute.
                                val := ""
-                               k := strings.ToLower(f.Name)
                                for _, a := range start.Attr {
-                                       if fieldName(a.Name.Local) == k {
+                                       if a.Name.Local == finfo.name {
                                                val = a.Value
                                                break
                                        }
                                }
                                copyValue(strv, []byte(val))
 
-                       case "comment":
+                       case fCharData:
+                               if !saveData.IsValid() {
+                                       saveData = sv.FieldByIndex(finfo.idx)
+                               }
+
+                       case fComment:
                                if !saveComment.IsValid() {
-                                       saveComment = sv.FieldByIndex(f.Index)
+                                       saveComment = sv.FieldByIndex(finfo.idx)
                                }
 
-                       case "chardata":
-                               if !saveData.IsValid() {
-                                       saveData = sv.FieldByIndex(f.Index)
+                       case fAny:
+                               if !saveAny.IsValid() {
+                                       saveAny = sv.FieldByIndex(finfo.idx)
                                }
 
-                       case "innerxml":
+                       case fInnerXml:
                                if !saveXML.IsValid() {
-                                       saveXML = sv.FieldByIndex(f.Index)
+                                       saveXML = sv.FieldByIndex(finfo.idx)
                                        if p.saved == nil {
                                                saveXMLIndex = 0
                                                p.saved = new(bytes.Buffer)
@@ -364,24 +339,6 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
                                                saveXMLIndex = p.savedOffset()
                                        }
                                }
-
-                       default:
-                               if tag := f.Tag.Get("xml"); strings.Contains(tag, ">") {
-                                       if fieldPaths == nil {
-                                               fieldPaths = make(map[string]pathInfo)
-                                       }
-                                       path := strings.ToLower(tag)
-                                       if strings.HasPrefix(tag, ">") {
-                                               path = strings.ToLower(f.Name) + path
-                                       }
-                                       if strings.HasSuffix(tag, ">") {
-                                               path = path[:len(path)-1]
-                                       }
-                                       err := addFieldPath(sv, fieldPaths, path, f.Index)
-                                       if err != nil {
-                                               return err
-                                       }
-                               }
                        }
                }
        }
@@ -400,44 +357,23 @@ Loop:
                }
                switch t := tok.(type) {
                case StartElement:
-                       // Sub-element.
-                       // Look up by tag name.
+                       consumed := false
                        if sv.IsValid() {
-                               k := fieldName(t.Name.Local)
-
-                               if fieldPaths != nil {
-                                       if _, found := fieldPaths[k]; found {
-                                               if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil {
-                                                       return err
-                                               }
-                                               continue Loop
-                                       }
-                               }
-
-                               match := func(s string) bool {
-                                       // check if the name matches ignoring case
-                                       if strings.ToLower(s) != k {
-                                               return false
-                                       }
-                                       // now check that it's public
-                                       c, _ := utf8.DecodeRuneInString(s)
-                                       return unicode.IsUpper(c)
-                               }
-
-                               f, found := styp.FieldByNameFunc(match)
-                               if !found { // fall back to mop-up field named "Any"
-                                       f, found = styp.FieldByName("Any")
+                               consumed, err = p.unmarshalPath(tinfo, sv, nil, &t)
+                               if err != nil {
+                                       return err
                                }
-                               if found {
-                                       if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
+                               if !consumed && saveAny.IsValid() {
+                                       consumed = true
+                                       if err := p.unmarshal(saveAny, &t); err != nil {
                                                return err
                                        }
-                                       continue Loop
                                }
                        }
-                       // Not saving sub-element but still have to skip over it.
-                       if err := p.Skip(); err != nil {
-                               return err
+                       if !consumed {
+                               if err := p.Skip(); err != nil {
+                                       return err
+                               }
                        }
 
                case EndElement:
@@ -503,10 +439,10 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
                return err == nil
        }
 
-       // Save accumulated data and comments
+       // Save accumulated data.
        switch t := dst; t.Kind() {
        case reflect.Invalid:
-               // Probably a comment, handled below
+               // Probably a comment.
        default:
                return errors.New("cannot happen: unknown type " + t.Type().String())
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -538,70 +474,66 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
        return nil
 }
 
-type pathInfo struct {
-       fieldIdx []int
-       complete bool
-}
-
-// addFieldPath takes an element path such as "a>b>c" and fills the
-// paths map with all paths leading to it ("a", "a>b", and "a>b>c").
-// It is okay for paths to share a common, shorter prefix but not ok
-// for one path to itself be a prefix of another.
-func addFieldPath(sv reflect.Value, paths map[string]pathInfo, path string, fieldIdx []int) error {
-       if info, found := paths[path]; found {
-               return tagError(sv, info.fieldIdx, fieldIdx)
-       }
-       paths[path] = pathInfo{fieldIdx, true}
-       for {
-               i := strings.LastIndex(path, ">")
-               if i < 0 {
-                       break
+// unmarshalPath walks down an XML structure looking for wanted
+// paths, and calls unmarshal on them.
+// The consumed result tells whether XML elements have been consumed
+// from the Parser until start's matching end element, or if it's
+// still untouched because start is uninteresting for sv's fields.
+func (p *Parser) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) {
+       recurse := false
+Loop:
+       for i := range tinfo.fields {
+               finfo := &tinfo.fields[i]
+               if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) {
+                       continue
                }
-               path = path[:i]
-               if info, found := paths[path]; found {
-                       if info.complete {
-                               return tagError(sv, info.fieldIdx, fieldIdx)
+               for j := range parents {
+                       if parents[j] != finfo.parents[j] {
+                               continue Loop
                        }
-               } else {
-                       paths[path] = pathInfo{fieldIdx, false}
+               }
+               if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
+                       // It's a perfect match, unmarshal the field.
+                       return true, p.unmarshal(sv.FieldByIndex(finfo.idx), start)
+               }
+               if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
+                       // It's a prefix for the field. Break and recurse
+                       // since it's not ok for one field path to be itself
+                       // the prefix for another field path.
+                       recurse = true
+
+                       // We can reuse the same slice as long as we
+                       // don't try to append to it.
+                       parents = finfo.parents[:len(parents)+1]
+                       break
                }
        }
-       return nil
-
-}
-
-func tagError(sv reflect.Value, idx1 []int, idx2 []int) error {
-       t := sv.Type()
-       f1 := t.FieldByIndex(idx1)
-       f2 := t.FieldByIndex(idx2)
-       return &TagPathError{t, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
-}
-
-// unmarshalPaths walks down an XML structure looking for
-// wanted paths, and calls unmarshal on them.
-func (p *Parser) unmarshalPaths(sv reflect.Value, paths map[string]pathInfo, path string, start *StartElement) error {
-       if info, _ := paths[path]; info.complete {
-               return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start)
+       if !recurse {
+               // We have no business with this element.
+               return false, nil
        }
+       // The element is not a perfect match for any field, but one
+       // or more fields have the path to this element as a parent
+       // prefix. Recurse and attempt to match these.
        for {
-               tok, err := p.Token()
+               var tok Token
+               tok, err = p.Token()
                if err != nil {
-                       return err
+                       return true, err
                }
                switch t := tok.(type) {
                case StartElement:
-                       k := path + ">" + fieldName(t.Name.Local)
-                       if _, found := paths[k]; found {
-                               if err := p.unmarshalPaths(sv, paths, k, &t); err != nil {
-                                       return err
-                               }
-                               continue
+                       consumed2, err := p.unmarshalPath(tinfo, sv, parents, &t)
+                       if err != nil {
+                               return true, err
                        }
-                       if err := p.Skip(); err != nil {
-                               return err
+                       if !consumed2 {
+                               if err := p.Skip(); err != nil {
+                                       return true, err
+                               }
                        }
                case EndElement:
-                       return nil
+                       return true, nil
                }
        }
        panic("unreachable")
index fbb7fd5d2f276e2e75868921cc415ea956a3cb69..ff61bd7e1c51eee836fc2c75db1aa79e6d65151a 100644 (file)
@@ -6,6 +6,7 @@ package xml
 
 import (
        "reflect"
+       "strings"
        "testing"
 )
 
@@ -13,7 +14,7 @@ import (
 
 func TestUnmarshalFeed(t *testing.T) {
        var f Feed
-       if err := Unmarshal(StringReader(atomFeedString), &f); err != nil {
+       if err := Unmarshal(strings.NewReader(atomFeedString), &f); err != nil {
                t.Fatalf("Unmarshal: %s", err)
        }
        if !reflect.DeepEqual(f, atomFeed) {
@@ -24,8 +25,8 @@ func TestUnmarshalFeed(t *testing.T) {
 // hget http://codereview.appspot.com/rss/mine/rsc
 const atomFeedString = `
 <?xml version="1.0" encoding="utf-8"?>
-<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><li-nk href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></li-nk><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
-</title><link hre-f="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
+<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
+</title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
   An attempt at adding pubsubhubbub support to Rietveld.
 http://code.google.com/p/pubsubhubbub
 http://code.google.com/p/rietveld/issues/detail?id=155
@@ -78,39 +79,39 @@ not being used from outside intra_region_diff.py.
 </summary></entry></feed>         `
 
 type Feed struct {
-       XMLName Name `xml:"http://www.w3.org/2005/Atom feed"`
-       Title   string
-       Id      string
-       Link    []Link
-       Updated Time
-       Author  Person
-       Entry   []Entry
+       XMLName Name    `xml:"http://www.w3.org/2005/Atom feed"`
+       Title   string  `xml:"title"`
+       Id      string  `xml:"id"`
+       Link    []Link  `xml:"link"`
+       Updated Time    `xml:"updated"`
+       Author  Person  `xml:"author"`
+       Entry   []Entry `xml:"entry"`
 }
 
 type Entry struct {
-       Title   string
-       Id      string
-       Link    []Link
-       Updated Time
-       Author  Person
-       Summary Text
+       Title   string `xml:"title"`
+       Id      string `xml:"id"`
+       Link    []Link `xml:"link"`
+       Updated Time   `xml:"updated"`
+       Author  Person `xml:"author"`
+       Summary Text   `xml:"summary"`
 }
 
 type Link struct {
-       Rel  string `xml:"attr"`
-       Href string `xml:"attr"`
+       Rel  string `xml:"rel,attr"`
+       Href string `xml:"href,attr"`
 }
 
 type Person struct {
-       Name     string
-       URI      string
-       Email    string
-       InnerXML string `xml:"innerxml"`
+       Name     string `xml:"name"`
+       URI      string `xml:"uri"`
+       Email    string `xml:"email"`
+       InnerXML string `xml:",innerxml"`
 }
 
 type Text struct {
-       Type string `xml:"attr"`
-       Body string `xml:"chardata"`
+       Type string `xml:"type,attr"`
+       Body string `xml:",chardata"`
 }
 
 type Time string
@@ -213,44 +214,26 @@ not being used from outside intra_region_diff.py.
        },
 }
 
-type FieldNameTest struct {
-       in, out string
-}
-
-var FieldNameTests = []FieldNameTest{
-       {"Profile-Image", "profileimage"},
-       {"_score", "score"},
-}
-
-func TestFieldName(t *testing.T) {
-       for _, tt := range FieldNameTests {
-               a := fieldName(tt.in)
-               if a != tt.out {
-                       t.Fatalf("have %#v\nwant %#v\n\n", a, tt.out)
-               }
-       }
-}
-
 const pathTestString = `
-<result>
-    <before>1</before>
-    <items>
-        <item1>
-            <value>A</value>
-        </item1>
-        <item2>
-            <value>B</value>
-        </item2>
+<Result>
+    <Before>1</Before>
+    <Items>
+        <Item1>
+            <Value>A</Value>
+        </Item1>
+        <Item2>
+            <Value>B</Value>
+        </Item2>
         <Item1>
             <Value>C</Value>
             <Value>D</Value>
         </Item1>
         <_>
-            <value>E</value>
+            <Value>E</Value>
         </_>
-    </items>
-    <after>2</after>
-</result>
+    </Items>
+    <After>2</After>
+</Result>
 `
 
 type PathTestItem struct {
@@ -258,18 +241,18 @@ type PathTestItem struct {
 }
 
 type PathTestA struct {
-       Items         []PathTestItem `xml:">item1"`
+       Items         []PathTestItem `xml:">Item1"`
        Before, After string
 }
 
 type PathTestB struct {
-       Other         []PathTestItem `xml:"items>Item1"`
+       Other         []PathTestItem `xml:"Items>Item1"`
        Before, After string
 }
 
 type PathTestC struct {
-       Values1       []string `xml:"items>item1>value"`
-       Values2       []string `xml:"items>item2>value"`
+       Values1       []string `xml:"Items>Item1>Value"`
+       Values2       []string `xml:"Items>Item2>Value"`
        Before, After string
 }
 
@@ -278,12 +261,12 @@ type PathTestSet struct {
 }
 
 type PathTestD struct {
-       Other         PathTestSet `xml:"items>"`
+       Other         PathTestSet `xml:"Items"`
        Before, After string
 }
 
 type PathTestE struct {
-       Underline     string `xml:"items>_>value"`
+       Underline     string `xml:"Items>_>Value"`
        Before, After string
 }
 
@@ -298,7 +281,7 @@ var pathTests = []interface{}{
 func TestUnmarshalPaths(t *testing.T) {
        for _, pt := range pathTests {
                v := reflect.New(reflect.TypeOf(pt).Elem()).Interface()
-               if err := Unmarshal(StringReader(pathTestString), v); err != nil {
+               if err := Unmarshal(strings.NewReader(pathTestString), v); err != nil {
                        t.Fatalf("Unmarshal: %s", err)
                }
                if !reflect.DeepEqual(v, pt) {
@@ -310,7 +293,7 @@ func TestUnmarshalPaths(t *testing.T) {
 type BadPathTestA struct {
        First  string `xml:"items>item1"`
        Other  string `xml:"items>item2"`
-       Second string `xml:"items>"`
+       Second string `xml:"items"`
 }
 
 type BadPathTestB struct {
@@ -319,81 +302,55 @@ type BadPathTestB struct {
        Second string `xml:"items>item1>value"`
 }
 
+type BadPathTestC struct {
+       First  string
+       Second string `xml:"First"`
+}
+
+type BadPathTestD struct {
+       BadPathEmbeddedA
+       BadPathEmbeddedB
+}
+
+type BadPathEmbeddedA struct {
+       First string
+}
+
+type BadPathEmbeddedB struct {
+       Second string `xml:"First"`
+}
+
 var badPathTests = []struct {
        v, e interface{}
 }{
-       {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}},
+       {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}},
        {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}},
+       {&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}},
+       {&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}},
 }
 
 func TestUnmarshalBadPaths(t *testing.T) {
        for _, tt := range badPathTests {
-               err := Unmarshal(StringReader(pathTestString), tt.v)
+               err := Unmarshal(strings.NewReader(pathTestString), tt.v)
                if !reflect.DeepEqual(err, tt.e) {
-                       t.Fatalf("Unmarshal with %#v didn't fail properly: %#v", tt.v, err)
+                       t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e)
                }
        }
 }
 
-func TestUnmarshalAttrs(t *testing.T) {
-       var f AttrTest
-       if err := Unmarshal(StringReader(attrString), &f); err != nil {
-               t.Fatalf("Unmarshal: %s", err)
-       }
-       if !reflect.DeepEqual(f, attrStruct) {
-               t.Fatalf("have %#v\nwant %#v", f, attrStruct)
-       }
-}
-
-type AttrTest struct {
-       Test1 Test1
-       Test2 Test2
-}
-
-type Test1 struct {
-       Int   int     `xml:"attr"`
-       Float float64 `xml:"attr"`
-       Uint8 uint8   `xml:"attr"`
-}
-
-type Test2 struct {
-       Bool bool `xml:"attr"`
-}
-
-const attrString = `
-<?xml version="1.0" charset="utf-8"?>
-<attrtest>
-  <test1 int="8" float="23.5" uint8="255"/>
-  <test2 bool="true"/>
-</attrtest>
-`
-
-var attrStruct = AttrTest{
-       Test1: Test1{
-               Int:   8,
-               Float: 23.5,
-               Uint8: 255,
-       },
-       Test2: Test2{
-               Bool: true,
-       },
-}
-
-// test data for TestUnmarshalWithoutNameType
-
 const OK = "OK"
 const withoutNameTypeData = `
 <?xml version="1.0" charset="utf-8"?>
-<Test3 attr="OK" />`
+<Test3 Attr="OK" />`
 
 type TestThree struct {
-       XMLName bool   `xml:"Test3"` // XMLName field without an xml.Name type 
-       Attr    string `xml:"attr"`
+       XMLName Name   `xml:"Test3"`
+       Attr    string `xml:",attr"`
 }
 
 func TestUnmarshalWithoutNameType(t *testing.T) {
        var x TestThree
-       if err := Unmarshal(StringReader(withoutNameTypeData), &x); err != nil {
+       if err := Unmarshal(strings.NewReader(withoutNameTypeData), &x); err != nil {
                t.Fatalf("Unmarshal: %s", err)
        }
        if x.Attr != OK {
diff --git a/libgo/go/encoding/xml/typeinfo.go b/libgo/go/encoding/xml/typeinfo.go
new file mode 100644 (file)
index 0000000..8f79c4e
--- /dev/null
@@ -0,0 +1,321 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xml
+
+import (
+       "fmt"
+       "reflect"
+       "strings"
+       "sync"
+)
+
+// typeInfo holds details for the xml representation of a type.
+type typeInfo struct {
+       xmlname *fieldInfo
+       fields  []fieldInfo
+}
+
+// fieldInfo holds details for the xml representation of a single field.
+type fieldInfo struct {
+       idx     []int
+       name    string
+       xmlns   string
+       flags   fieldFlags
+       parents []string
+}
+
+type fieldFlags int
+
+const (
+       fElement fieldFlags = 1 << iota
+       fAttr
+       fCharData
+       fInnerXml
+       fComment
+       fAny
+
+       // TODO:
+       //fIgnore
+       //fOmitEmpty
+
+       fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny
+)
+
+var tinfoMap = make(map[reflect.Type]*typeInfo)
+var tinfoLock sync.RWMutex
+
+// getTypeInfo returns the typeInfo structure with details necessary
+// for marshalling and unmarshalling typ.
+func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
+       tinfoLock.RLock()
+       tinfo, ok := tinfoMap[typ]
+       tinfoLock.RUnlock()
+       if ok {
+               return tinfo, nil
+       }
+       tinfo = &typeInfo{}
+       if typ.Kind() == reflect.Struct {
+               n := typ.NumField()
+               for i := 0; i < n; i++ {
+                       f := typ.Field(i)
+                       if f.PkgPath != "" {
+                               continue // Private field
+                       }
+
+                       // For embedded structs, embed its fields.
+                       if f.Anonymous {
+                               if f.Type.Kind() != reflect.Struct {
+                                       continue
+                               }
+                               inner, err := getTypeInfo(f.Type)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               for _, finfo := range inner.fields {
+                                       finfo.idx = append([]int{i}, finfo.idx...)
+                                       if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
+                                               return nil, err
+                                       }
+                               }
+                               continue
+                       }
+
+                       finfo, err := structFieldInfo(typ, &f)
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       if f.Name == "XMLName" {
+                               tinfo.xmlname = finfo
+                               continue
+                       }
+
+                       // Add the field if it doesn't conflict with other fields.
+                       if err := addFieldInfo(typ, tinfo, finfo); err != nil {
+                               return nil, err
+                       }
+               }
+       }
+       tinfoLock.Lock()
+       tinfoMap[typ] = tinfo
+       tinfoLock.Unlock()
+       return tinfo, nil
+}
+
+// structFieldInfo builds and returns a fieldInfo for f.
+func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
+       finfo := &fieldInfo{idx: f.Index}
+
+       // Split the tag from the xml namespace if necessary.
+       tag := f.Tag.Get("xml")
+       if i := strings.Index(tag, " "); i >= 0 {
+               finfo.xmlns, tag = tag[:i], tag[i+1:]
+       }
+
+       // Parse flags.
+       tokens := strings.Split(tag, ",")
+       if len(tokens) == 1 {
+               finfo.flags = fElement
+       } else {
+               tag = tokens[0]
+               for _, flag := range tokens[1:] {
+                       switch flag {
+                       case "attr":
+                               finfo.flags |= fAttr
+                       case "chardata":
+                               finfo.flags |= fCharData
+                       case "innerxml":
+                               finfo.flags |= fInnerXml
+                       case "comment":
+                               finfo.flags |= fComment
+                       case "any":
+                               finfo.flags |= fAny
+                       }
+               }
+
+               // Validate the flags used.
+               switch mode := finfo.flags & fMode; mode {
+               case 0:
+                       finfo.flags |= fElement
+               case fAttr, fCharData, fInnerXml, fComment, fAny:
+                       if f.Name != "XMLName" && (tag == "" || mode == fAttr) {
+                               break
+                       }
+                       fallthrough
+               default:
+                       // This will also catch multiple modes in a single field.
+                       return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
+                               f.Name, typ, f.Tag.Get("xml"))
+               }
+       }
+
+       // Use of xmlns without a name is not allowed.
+       if finfo.xmlns != "" && tag == "" {
+               return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
+                       f.Name, typ, f.Tag.Get("xml"))
+       }
+
+       if f.Name == "XMLName" {
+               // The XMLName field records the XML element name. Don't
+               // process it as usual because its name should default to
+               // empty rather than to the field name.
+               finfo.name = tag
+               return finfo, nil
+       }
+
+       if tag == "" {
+               // If the name part of the tag is completely empty, get
+               // default from XMLName of underlying struct if feasible,
+               // or field name otherwise.
+               if xmlname := lookupXMLName(f.Type); xmlname != nil {
+                       finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
+               } else {
+                       finfo.name = f.Name
+               }
+               return finfo, nil
+       }
+
+       // Prepare field name and parents.
+       tokens = strings.Split(tag, ">")
+       if tokens[0] == "" {
+               tokens[0] = f.Name
+       }
+       if tokens[len(tokens)-1] == "" {
+               return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
+       }
+       finfo.name = tokens[len(tokens)-1]
+       if len(tokens) > 1 {
+               finfo.parents = tokens[:len(tokens)-1]
+       }
+
+       // If the field type has an XMLName field, the names must match
+       // so that the behavior of both marshalling and unmarshalling
+       // is straighforward and unambiguous.
+       if finfo.flags&fElement != 0 {
+               ftyp := f.Type
+               xmlname := lookupXMLName(ftyp)
+               if xmlname != nil && xmlname.name != finfo.name {
+                       return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
+                               finfo.name, typ, f.Name, xmlname.name, ftyp)
+               }
+       }
+       return finfo, nil
+}
+
+// lookupXMLName returns the fieldInfo for typ's XMLName field
+// in case it exists and has a valid xml field tag, otherwise
+// it returns nil.
+func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
+       for typ.Kind() == reflect.Ptr {
+               typ = typ.Elem()
+       }
+       if typ.Kind() != reflect.Struct {
+               return nil
+       }
+       for i, n := 0, typ.NumField(); i < n; i++ {
+               f := typ.Field(i)
+               if f.Name != "XMLName" {
+                       continue
+               }
+               finfo, err := structFieldInfo(typ, &f)
+               if finfo.name != "" && err == nil {
+                       return finfo
+               }
+               // Also consider errors as a non-existent field tag
+               // and let getTypeInfo itself report the error.
+               break
+       }
+       return nil
+}
+
+func min(a, b int) int {
+       if a <= b {
+               return a
+       }
+       return b
+}
+
+// addFieldInfo adds finfo to tinfo.fields if there are no
+// conflicts, or if conflicts arise from previous fields that were
+// obtained from deeper embedded structures than finfo. In the latter
+// case, the conflicting entries are dropped.
+// A conflict occurs when the path (parent + name) to a field is
+// itself a prefix of another path, or when two paths match exactly.
+// It is okay for field paths to share a common, shorter prefix.
+func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
+       var conflicts []int
+Loop:
+       // First, figure all conflicts. Most working code will have none.
+       for i := range tinfo.fields {
+               oldf := &tinfo.fields[i]
+               if oldf.flags&fMode != newf.flags&fMode {
+                       continue
+               }
+               minl := min(len(newf.parents), len(oldf.parents))
+               for p := 0; p < minl; p++ {
+                       if oldf.parents[p] != newf.parents[p] {
+                               continue Loop
+                       }
+               }
+               if len(oldf.parents) > len(newf.parents) {
+                       if oldf.parents[len(newf.parents)] == newf.name {
+                               conflicts = append(conflicts, i)
+                       }
+               } else if len(oldf.parents) < len(newf.parents) {
+                       if newf.parents[len(oldf.parents)] == oldf.name {
+                               conflicts = append(conflicts, i)
+                       }
+               } else {
+                       if newf.name == oldf.name {
+                               conflicts = append(conflicts, i)
+                       }
+               }
+       }
+       // Without conflicts, add the new field and return.
+       if conflicts == nil {
+               tinfo.fields = append(tinfo.fields, *newf)
+               return nil
+       }
+
+       // If any conflict is shallower, ignore the new field.
+       // This matches the Go field resolution on embedding.
+       for _, i := range conflicts {
+               if len(tinfo.fields[i].idx) < len(newf.idx) {
+                       return nil
+               }
+       }
+
+       // Otherwise, if any of them is at the same depth level, it's an error.
+       for _, i := range conflicts {
+               oldf := &tinfo.fields[i]
+               if len(oldf.idx) == len(newf.idx) {
+                       f1 := typ.FieldByIndex(oldf.idx)
+                       f2 := typ.FieldByIndex(newf.idx)
+                       return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
+               }
+       }
+
+       // Otherwise, the new field is shallower, and thus takes precedence,
+       // so drop the conflicting fields from tinfo and append the new one.
+       for c := len(conflicts) - 1; c >= 0; c-- {
+               i := conflicts[c]
+               copy(tinfo.fields[i:], tinfo.fields[i+1:])
+               tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
+       }
+       tinfo.fields = append(tinfo.fields, *newf)
+       return nil
+}
+
+// A TagPathError represents an error in the unmarshalling process
+// caused by the use of field tags with conflicting paths.
+type TagPathError struct {
+       Struct       reflect.Type
+       Field1, Tag1 string
+       Field2, Tag2 string
+}
+
+func (e *TagPathError) Error() string {
+       return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
+}
index 25ffc917dcb135c3e8068405af0510d6a2530ee7..524d4dda4f4f8aabfd9c749b711ee34fe4a5cfd5 100644 (file)
@@ -154,36 +154,8 @@ var xmlInput = []string{
        "<t>cdata]]></t>",
 }
 
-type stringReader struct {
-       s   string
-       off int
-}
-
-func (r *stringReader) Read(b []byte) (n int, err error) {
-       if r.off >= len(r.s) {
-               return 0, io.EOF
-       }
-       for r.off < len(r.s) && n < len(b) {
-               b[n] = r.s[r.off]
-               n++
-               r.off++
-       }
-       return
-}
-
-func (r *stringReader) ReadByte() (b byte, err error) {
-       if r.off >= len(r.s) {
-               return 0, io.EOF
-       }
-       b = r.s[r.off]
-       r.off++
-       return
-}
-
-func StringReader(s string) io.Reader { return &stringReader{s, 0} }
-
 func TestRawToken(t *testing.T) {
-       p := NewParser(StringReader(testInput))
+       p := NewParser(strings.NewReader(testInput))
        testRawToken(t, p, rawTokens)
 }
 
@@ -207,7 +179,7 @@ func (d *downCaser) Read(p []byte) (int, error) {
 
 func TestRawTokenAltEncoding(t *testing.T) {
        sawEncoding := ""
-       p := NewParser(StringReader(testInputAltEncoding))
+       p := NewParser(strings.NewReader(testInputAltEncoding))
        p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
                sawEncoding = charset
                if charset != "x-testing-uppercase" {
@@ -219,7 +191,7 @@ func TestRawTokenAltEncoding(t *testing.T) {
 }
 
 func TestRawTokenAltEncodingNoConverter(t *testing.T) {
-       p := NewParser(StringReader(testInputAltEncoding))
+       p := NewParser(strings.NewReader(testInputAltEncoding))
        token, err := p.RawToken()
        if token == nil {
                t.Fatalf("expected a token on first RawToken call")
@@ -286,7 +258,7 @@ var nestedDirectivesTokens = []Token{
 }
 
 func TestNestedDirectives(t *testing.T) {
-       p := NewParser(StringReader(nestedDirectivesInput))
+       p := NewParser(strings.NewReader(nestedDirectivesInput))
 
        for i, want := range nestedDirectivesTokens {
                have, err := p.Token()
@@ -300,7 +272,7 @@ func TestNestedDirectives(t *testing.T) {
 }
 
 func TestToken(t *testing.T) {
-       p := NewParser(StringReader(testInput))
+       p := NewParser(strings.NewReader(testInput))
 
        for i, want := range cookedTokens {
                have, err := p.Token()
@@ -315,7 +287,7 @@ func TestToken(t *testing.T) {
 
 func TestSyntax(t *testing.T) {
        for i := range xmlInput {
-               p := NewParser(StringReader(xmlInput[i]))
+               p := NewParser(strings.NewReader(xmlInput[i]))
                var err error
                for _, err = p.Token(); err == nil; _, err = p.Token() {
                }
@@ -372,26 +344,26 @@ var all = allScalars{
 var sixteen = "16"
 
 const testScalarsInput = `<allscalars>
-       <true1>true</true1>
-       <true2>1</true2>
-       <false1>false</false1>
-       <false2>0</false2>
-       <int>1</int>
-       <int8>-2</int8>
-       <int16>3</int16>
-       <int32>-4</int32>
-       <int64>5</int64>
-       <uint>6</uint>
-       <uint8>7</uint8>
-       <uint16>8</uint16>
-       <uint32>9</uint32>
-       <uint64>10</uint64>
-       <uintptr>11</uintptr>
-       <float>12.0</float>
-       <float32>13.0</float32>
-       <float64>14.0</float64>
-       <string>15</string>
-       <ptrstring>16</ptrstring>
+       <True1>true</True1>
+       <True2>1</True2>
+       <False1>false</False1>
+       <False2>0</False2>
+       <Int>1</Int>
+       <Int8>-2</Int8>
+       <Int16>3</Int16>
+       <Int32>-4</Int32>
+       <Int64>5</Int64>
+       <Uint>6</Uint>
+       <Uint8>7</Uint8>
+       <Uint16>8</Uint16>
+       <Uint32>9</Uint32>
+       <Uint64>10</Uint64>
+       <Uintptr>11</Uintptr>
+       <Float>12.0</Float>
+       <Float32>13.0</Float32>
+       <Float64>14.0</Float64>
+       <String>15</String>
+       <PtrString>16</PtrString>
 </allscalars>`
 
 func TestAllScalars(t *testing.T) {
@@ -412,7 +384,7 @@ type item struct {
 }
 
 func TestIssue569(t *testing.T) {
-       data := `<item><field_a>abcd</field_a></item>`
+       data := `<item><Field_a>abcd</Field_a></item>`
        var i item
        buf := bytes.NewBufferString(data)
        err := Unmarshal(buf, &i)
@@ -424,7 +396,7 @@ func TestIssue569(t *testing.T) {
 
 func TestUnquotedAttrs(t *testing.T) {
        data := "<tag attr=azAZ09:-_\t>"
-       p := NewParser(StringReader(data))
+       p := NewParser(strings.NewReader(data))
        p.Strict = false
        token, err := p.Token()
        if _, ok := err.(*SyntaxError); ok {
@@ -450,7 +422,7 @@ func TestValuelessAttrs(t *testing.T) {
                {"<input checked />", "input", "checked"},
        }
        for _, test := range tests {
-               p := NewParser(StringReader(test[0]))
+               p := NewParser(strings.NewReader(test[0]))
                p.Strict = false
                token, err := p.Token()
                if _, ok := err.(*SyntaxError); ok {
@@ -500,7 +472,7 @@ func TestCopyTokenStartElement(t *testing.T) {
 
 func TestSyntaxErrorLineNum(t *testing.T) {
        testInput := "<P>Foo<P>\n\n<P>Bar</>\n"
-       p := NewParser(StringReader(testInput))
+       p := NewParser(strings.NewReader(testInput))
        var err error
        for _, err = p.Token(); err == nil; _, err = p.Token() {
        }
@@ -515,7 +487,7 @@ func TestSyntaxErrorLineNum(t *testing.T) {
 
 func TestTrailingRawToken(t *testing.T) {
        input := `<FOO></FOO>  `
-       p := NewParser(StringReader(input))
+       p := NewParser(strings.NewReader(input))
        var err error
        for _, err = p.RawToken(); err == nil; _, err = p.RawToken() {
        }
@@ -526,7 +498,7 @@ func TestTrailingRawToken(t *testing.T) {
 
 func TestTrailingToken(t *testing.T) {
        input := `<FOO></FOO>  `
-       p := NewParser(StringReader(input))
+       p := NewParser(strings.NewReader(input))
        var err error
        for _, err = p.Token(); err == nil; _, err = p.Token() {
        }
@@ -537,7 +509,7 @@ func TestTrailingToken(t *testing.T) {
 
 func TestEntityInsideCDATA(t *testing.T) {
        input := `<test><![CDATA[ &val=foo ]]></test>`
-       p := NewParser(StringReader(input))
+       p := NewParser(strings.NewReader(input))
        var err error
        for _, err = p.Token(); err == nil; _, err = p.Token() {
        }
@@ -569,7 +541,7 @@ var characterTests = []struct {
 func TestDisallowedCharacters(t *testing.T) {
 
        for i, tt := range characterTests {
-               p := NewParser(StringReader(tt.in))
+               p := NewParser(strings.NewReader(tt.in))
                var err error
 
                for err == nil {
index ce159e9050cea671ecc1953a8f12207c90623760..42e6f1b79471f8c590edc025e80fd8642f6d9d6f 100644 (file)
@@ -8,7 +8,7 @@ import "unicode/utf8"
 
 type input interface {
        skipASCII(p int) int
-       skipNonStarter() int
+       skipNonStarter(p int) int
        appendSlice(buf []byte, s, e int) []byte
        copySlice(buf []byte, s, e int)
        charinfo(p int) (uint16, int)
@@ -25,8 +25,7 @@ func (s inputString) skipASCII(p int) int {
        return p
 }
 
-func (s inputString) skipNonStarter() int {
-       p := 0
+func (s inputString) skipNonStarter(p int) int {
        for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
        }
        return p
@@ -71,8 +70,7 @@ func (s inputBytes) skipASCII(p int) int {
        return p
 }
 
-func (s inputBytes) skipNonStarter() int {
-       p := 0
+func (s inputBytes) skipNonStarter(p int) int {
        for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
        }
        return p
index 25bb28d517fb6f1d8f5d9e921a066db0ae5f2624..3bd40470d5c351788b89368c8ccb6c3dbe916044 100644 (file)
@@ -34,24 +34,28 @@ const (
 
 // Bytes returns f(b). May return b if f(b) = b.
 func (f Form) Bytes(b []byte) []byte {
-       n := f.QuickSpan(b)
+       rb := reorderBuffer{}
+       rb.init(f, b)
+       n := quickSpan(&rb, 0)
        if n == len(b) {
                return b
        }
        out := make([]byte, n, len(b))
        copy(out, b[0:n])
-       return f.Append(out, b[n:]...)
+       return doAppend(&rb, out, n)
 }
 
 // String returns f(s).
 func (f Form) String(s string) string {
-       n := f.QuickSpanString(s)
+       rb := reorderBuffer{}
+       rb.initString(f, s)
+       n := quickSpan(&rb, 0)
        if n == len(s) {
                return s
        }
-       out := make([]byte, 0, len(s))
+       out := make([]byte, n, len(s))
        copy(out, s[0:n])
-       return string(f.AppendString(out, s[n:]))
+       return string(doAppend(&rb, out, n))
 }
 
 // IsNormal returns true if b == f(b).
@@ -122,23 +126,27 @@ func (f Form) IsNormalString(s string) bool {
 
 // patchTail fixes a case where a rune may be incorrectly normalized
 // if it is followed by illegal continuation bytes. It returns the
-// patched buffer and the number of trailing continuation bytes that
-// have been dropped.
-func patchTail(rb *reorderBuffer, buf []byte) ([]byte, int) {
+// patched buffer and whether there were trailing continuation bytes.
+func patchTail(rb *reorderBuffer, buf []byte) ([]byte, bool) {
        info, p := lastRuneStart(&rb.f, buf)
        if p == -1 || info.size == 0 {
-               return buf, 0
+               return buf, false
        }
        end := p + int(info.size)
        extra := len(buf) - end
        if extra > 0 {
+               // Potentially allocating memory. However, this only
+               // happens with ill-formed UTF-8.
+               x := make([]byte, 0)
+               x = append(x, buf[len(buf)-extra:]...)
                buf = decomposeToLastBoundary(rb, buf[:end])
                if rb.f.composing {
                        rb.compose()
                }
-               return rb.flush(buf), extra
+               buf = rb.flush(buf)
+               return append(buf, x...), true
        }
-       return buf, 0
+       return buf, false
 }
 
 func appendQuick(rb *reorderBuffer, dst []byte, i int) ([]byte, int) {
@@ -157,23 +165,23 @@ func (f Form) Append(out []byte, src ...byte) []byte {
        }
        rb := reorderBuffer{}
        rb.init(f, src)
-       return doAppend(&rb, out)
+       return doAppend(&rb, out, 0)
 }
 
-func doAppend(rb *reorderBuffer, out []byte) []byte {
+func doAppend(rb *reorderBuffer, out []byte, p int) []byte {
        src, n := rb.src, rb.nsrc
        doMerge := len(out) > 0
-       p := 0
-       if p = src.skipNonStarter(); p > 0 {
+       if q := src.skipNonStarter(p); q > p {
                // Move leading non-starters to destination.
-               out = src.appendSlice(out, 0, p)
-               buf, ndropped := patchTail(rb, out)
-               if ndropped > 0 {
-                       out = src.appendSlice(buf, p-ndropped, p)
+               out = src.appendSlice(out, p, q)
+               buf, endsInError := patchTail(rb, out)
+               if endsInError {
+                       out = buf
                        doMerge = false // no need to merge, ends with illegal UTF-8
                } else {
                        out = decomposeToLastBoundary(rb, buf) // force decomposition
                }
+               p = q
        }
        fd := &rb.f
        if doMerge {
@@ -217,7 +225,7 @@ func (f Form) AppendString(out []byte, src string) []byte {
        }
        rb := reorderBuffer{}
        rb.initString(f, src)
-       return doAppend(&rb, out)
+       return doAppend(&rb, out, 0)
 }
 
 // QuickSpan returns a boundary n such that b[0:n] == f(b[0:n]).
@@ -225,7 +233,8 @@ func (f Form) AppendString(out []byte, src string) []byte {
 func (f Form) QuickSpan(b []byte) int {
        rb := reorderBuffer{}
        rb.init(f, b)
-       return quickSpan(&rb, 0)
+       n := quickSpan(&rb, 0)
+       return n
 }
 
 func quickSpan(rb *reorderBuffer, i int) int {
@@ -301,7 +310,7 @@ func (f Form) FirstBoundary(b []byte) int {
 
 func firstBoundary(rb *reorderBuffer) int {
        src, nsrc := rb.src, rb.nsrc
-       i := src.skipNonStarter()
+       i := src.skipNonStarter(0)
        if i >= nsrc {
                return -1
        }
index 6bd5292d3fb81820d1b24e191ac954bd6bbe80ee..2e0c1f17120e7f11a473069362603ee734044feb 100644 (file)
@@ -253,7 +253,7 @@ var quickSpanNFDTests = []PositionTest{
        {"\u0316\u0300cd", 6, ""},
        {"\u043E\u0308b", 5, ""},
        // incorrectly ordered combining characters
-       {"ab\u0300\u0316", 1, ""}, // TODO(mpvl): we could skip 'b' as well.
+       {"ab\u0300\u0316", 1, ""}, // TODO: we could skip 'b' as well.
        {"ab\u0300\u0316cd", 1, ""},
        // Hangul
        {"같은", 0, ""},
@@ -465,6 +465,7 @@ var appendTests = []AppendTest{
        {"\u0300", "\xFC\x80\x80\x80\x80\x80\u0300", "\u0300\xFC\x80\x80\x80\x80\x80\u0300"},
        {"\xF8\x80\x80\x80\x80\u0300", "\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
        {"\xFC\x80\x80\x80\x80\x80\u0300", "\u0300", "\xFC\x80\x80\x80\x80\x80\u0300\u0300"},
+       {"\xF8\x80\x80\x80", "\x80\u0300\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
 }
 
 func appendF(f Form, out []byte, s string) []byte {
@@ -475,9 +476,23 @@ func appendStringF(f Form, out []byte, s string) []byte {
        return f.AppendString(out, s)
 }
 
+func bytesF(f Form, out []byte, s string) []byte {
+       buf := []byte{}
+       buf = append(buf, out...)
+       buf = append(buf, s...)
+       return f.Bytes(buf)
+}
+
+func stringF(f Form, out []byte, s string) []byte {
+       outs := string(out) + s
+       return []byte(f.String(outs))
+}
+
 func TestAppend(t *testing.T) {
        runAppendTests(t, "TestAppend", NFKC, appendF, appendTests)
        runAppendTests(t, "TestAppendString", NFKC, appendStringF, appendTests)
+       runAppendTests(t, "TestBytes", NFKC, bytesF, appendTests)
+       runAppendTests(t, "TestString", NFKC, stringF, appendTests)
 }
 
 func doFormBenchmark(b *testing.B, f Form, s string) {
index ee58abd22de1f9e2798c1ef3ddb9a662e7bed40f..2682894de0b1d19781cdd44184eb2be2c4bb5356 100644 (file)
@@ -27,7 +27,7 @@ func (w *normWriter) Write(data []byte) (n int, err error) {
                }
                w.rb.src = inputBytes(data[:m])
                w.rb.nsrc = m
-               w.buf = doAppend(&w.rb, w.buf)
+               w.buf = doAppend(&w.rb, w.buf, 0)
                data = data[m:]
                n += m
 
@@ -101,7 +101,7 @@ func (r *normReader) Read(p []byte) (int, error) {
                r.rb.src = inputBytes(r.inbuf[0:n])
                r.rb.nsrc, r.err = n, err
                if n > 0 {
-                       r.outbuf = doAppend(&r.rb, r.outbuf)
+                       r.outbuf = doAppend(&r.rb, r.outbuf, 0)
                }
                if err == io.EOF {
                        r.lastBoundary = len(r.outbuf)
diff --git a/libgo/go/exp/proxy/direct.go b/libgo/go/exp/proxy/direct.go
new file mode 100644 (file)
index 0000000..4c5ad88
--- /dev/null
@@ -0,0 +1,18 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+       "net"
+)
+
+type direct struct{}
+
+// Direct is a direct proxy: one that makes network connections directly.
+var Direct = direct{}
+
+func (direct) Dial(network, addr string) (net.Conn, error) {
+       return net.Dial(network, addr)
+}
diff --git a/libgo/go/exp/proxy/per_host.go b/libgo/go/exp/proxy/per_host.go
new file mode 100644 (file)
index 0000000..397ef57
--- /dev/null
@@ -0,0 +1,140 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+       "net"
+       "strings"
+)
+
+// A PerHost directs connections to a default Dailer unless the hostname
+// requested matches one of a number of exceptions.
+type PerHost struct {
+       def, bypass Dialer
+
+       bypassNetworks []*net.IPNet
+       bypassIPs      []net.IP
+       bypassZones    []string
+       bypassHosts    []string
+}
+
+// NewPerHost returns a PerHost Dialer that directs connections to either
+// defaultDialer or bypass, depending on whether the connection matches one of
+// the configured rules.
+func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
+       return &PerHost{
+               def:    defaultDialer,
+               bypass: bypass,
+       }
+}
+
+// Dial connects to the address addr on the network net through either
+// defaultDialer or bypass.
+func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
+       host, _, err := net.SplitHostPort(addr)
+       if err != nil {
+               return nil, err
+       }
+
+       return p.dialerForRequest(host).Dial(network, addr)
+}
+
+func (p *PerHost) dialerForRequest(host string) Dialer {
+       if ip := net.ParseIP(host); ip != nil {
+               for _, net := range p.bypassNetworks {
+                       if net.Contains(ip) {
+                               return p.bypass
+                       }
+               }
+               for _, bypassIP := range p.bypassIPs {
+                       if bypassIP.Equal(ip) {
+                               return p.bypass
+                       }
+               }
+               return p.def
+       }
+
+       for _, zone := range p.bypassZones {
+               if strings.HasSuffix(host, zone) {
+                       return p.bypass
+               }
+               if host == zone[1:] {
+                       // For a zone "example.com", we match "example.com"
+                       // too.
+                       return p.bypass
+               }
+       }
+       for _, bypassHost := range p.bypassHosts {
+               if bypassHost == host {
+                       return p.bypass
+               }
+       }
+       return p.def
+}
+
+// AddFromString parses a string that contains comma-separated values
+// specifing hosts that should use the bypass proxy. Each value is either an
+// IP address, a CIDR range, a zone (*.example.com) or a hostname
+// (localhost). A best effort is made to parse the string and errors are
+// ignored.
+func (p *PerHost) AddFromString(s string) {
+       hosts := strings.Split(s, ",")
+       for _, host := range hosts {
+               host = strings.TrimSpace(host)
+               if len(host) == 0 {
+                       continue
+               }
+               if strings.Contains(host, "/") {
+                       // We assume that it's a CIDR address like 127.0.0.0/8
+                       if _, net, err := net.ParseCIDR(host); err == nil {
+                               p.AddNetwork(net)
+                       }
+                       continue
+               }
+               if ip := net.ParseIP(host); ip != nil {
+                       p.AddIP(ip)
+                       continue
+               }
+               if strings.HasPrefix(host, "*.") {
+                       p.AddZone(host[1:])
+                       continue
+               }
+               p.AddHost(host)
+       }
+}
+
+// AddIP specifies an IP address that will use the bypass proxy. Note that
+// this will only take effect if a literal IP address is dialed. A connection
+// to a named host will never match an IP.
+func (p *PerHost) AddIP(ip net.IP) {
+       p.bypassIPs = append(p.bypassIPs, ip)
+}
+
+// AddIP specifies an IP range that will use the bypass proxy. Note that this
+// will only take effect if a literal IP address is dialed. A connection to a
+// named host will never match.
+func (p *PerHost) AddNetwork(net *net.IPNet) {
+       p.bypassNetworks = append(p.bypassNetworks, net)
+}
+
+// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
+// "example.com" matches "example.com" and all of its subdomains.
+func (p *PerHost) AddZone(zone string) {
+       if strings.HasSuffix(zone, ".") {
+               zone = zone[:len(zone)-1]
+       }
+       if !strings.HasPrefix(zone, ".") {
+               zone = "." + zone
+       }
+       p.bypassZones = append(p.bypassZones, zone)
+}
+
+// AddHost specifies a hostname that will use the bypass proxy.
+func (p *PerHost) AddHost(host string) {
+       if strings.HasSuffix(host, ".") {
+               host = host[:len(host)-1]
+       }
+       p.bypassHosts = append(p.bypassHosts, host)
+}
diff --git a/libgo/go/exp/proxy/per_host_test.go b/libgo/go/exp/proxy/per_host_test.go
new file mode 100644 (file)
index 0000000..a7d8095
--- /dev/null
@@ -0,0 +1,55 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+       "errors"
+       "net"
+       "reflect"
+       "testing"
+)
+
+type recordingProxy struct {
+       addrs []string
+}
+
+func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
+       r.addrs = append(r.addrs, addr)
+       return nil, errors.New("recordingProxy")
+}
+
+func TestPerHost(t *testing.T) {
+       var def, bypass recordingProxy
+       perHost := NewPerHost(&def, &bypass)
+       perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
+
+       expectedDef := []string{
+               "example.com:123",
+               "1.2.3.4:123",
+               "[1001::]:123",
+       }
+       expectedBypass := []string{
+               "localhost:123",
+               "zone:123",
+               "foo.zone:123",
+               "127.0.0.1:123",
+               "10.1.2.3:123",
+               "[1000::]:123",
+       }
+
+       for _, addr := range expectedDef {
+               perHost.Dial("tcp", addr)
+       }
+       for _, addr := range expectedBypass {
+               perHost.Dial("tcp", addr)
+       }
+
+       if !reflect.DeepEqual(expectedDef, def.addrs) {
+               t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
+       }
+       if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
+               t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
+       }
+}
diff --git a/libgo/go/exp/proxy/proxy.go b/libgo/go/exp/proxy/proxy.go
new file mode 100644 (file)
index 0000000..ccd3d1d
--- /dev/null
@@ -0,0 +1,98 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package proxy provides support for a variety of protocols to proxy network
+// data.
+package proxy
+
+import (
+       "errors"
+       "net"
+       "net/url"
+       "os"
+       "strings"
+)
+
+// A Dialer is a means to establish a connection.
+type Dialer interface {
+       // Dial connects to the given address via the proxy.
+       Dial(network, addr string) (c net.Conn, err error)
+}
+
+// Auth contains authentication parameters that specific Dialers may require.
+type Auth struct {
+       User, Password string
+}
+
+// DefaultDialer returns the dialer specified by the proxy related variables in
+// the environment.
+func FromEnvironment() Dialer {
+       allProxy := os.Getenv("all_proxy")
+       if len(allProxy) == 0 {
+               return Direct
+       }
+
+       proxyURL, err := url.Parse(allProxy)
+       if err != nil {
+               return Direct
+       }
+       proxy, err := FromURL(proxyURL, Direct)
+       if err != nil {
+               return Direct
+       }
+
+       noProxy := os.Getenv("no_proxy")
+       if len(noProxy) == 0 {
+               return proxy
+       }
+
+       perHost := NewPerHost(proxy, Direct)
+       perHost.AddFromString(noProxy)
+       return perHost
+}
+
+// proxySchemes is a map from URL schemes to a function that creates a Dialer
+// from a URL with such a scheme.
+var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
+
+// RegisterDialerType takes a URL scheme and a function to generate Dialers from
+// a URL with that scheme and a forwarding Dialer. Registered schemes are used
+// by FromURL.
+func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
+       if proxySchemes == nil {
+               proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
+       }
+       proxySchemes[scheme] = f
+}
+
+// FromURL returns a Dialer given a URL specification and an underlying
+// Dialer for it to make network requests.
+func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
+       var auth *Auth
+       if len(u.RawUserinfo) > 0 {
+               auth = new(Auth)
+               parts := strings.SplitN(u.RawUserinfo, ":", 1)
+               if len(parts) == 1 {
+                       auth.User = parts[0]
+               } else if len(parts) >= 2 {
+                       auth.User = parts[0]
+                       auth.Password = parts[1]
+               }
+       }
+
+       switch u.Scheme {
+       case "socks5":
+               return SOCKS5("tcp", u.Host, auth, forward)
+       }
+
+       // If the scheme doesn't match any of the built-in schemes, see if it
+       // was registered by another package.
+       if proxySchemes != nil {
+               if f, ok := proxySchemes[u.Scheme]; ok {
+                       return f(u, forward)
+               }
+       }
+
+       return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
+}
diff --git a/libgo/go/exp/proxy/proxy_test.go b/libgo/go/exp/proxy/proxy_test.go
new file mode 100644 (file)
index 0000000..4078bc7
--- /dev/null
@@ -0,0 +1,50 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+       "net"
+       "net/url"
+       "testing"
+)
+
+type testDialer struct {
+       network, addr string
+}
+
+func (t *testDialer) Dial(network, addr string) (net.Conn, error) {
+       t.network = network
+       t.addr = addr
+       return nil, t
+}
+
+func (t *testDialer) Error() string {
+       return "testDialer " + t.network + " " + t.addr
+}
+
+func TestFromURL(t *testing.T) {
+       u, err := url.Parse("socks5://user:password@1.2.3.4:5678")
+       if err != nil {
+               t.Fatalf("failed to parse URL: %s", err)
+       }
+
+       tp := &testDialer{}
+       proxy, err := FromURL(u, tp)
+       if err != nil {
+               t.Fatalf("FromURL failed: %s", err)
+       }
+
+       conn, err := proxy.Dial("tcp", "example.com:80")
+       if conn != nil {
+               t.Error("Dial unexpected didn't return an error")
+       }
+       if tp, ok := err.(*testDialer); ok {
+               if tp.network != "tcp" || tp.addr != "1.2.3.4:5678" {
+                       t.Errorf("Dialer connected to wrong host. Wanted 1.2.3.4:5678, got: %v", tp)
+               }
+       } else {
+               t.Errorf("Unexpected error from Dial: %s", err)
+       }
+}
diff --git a/libgo/go/exp/proxy/socks5.go b/libgo/go/exp/proxy/socks5.go
new file mode 100644 (file)
index 0000000..466e135
--- /dev/null
@@ -0,0 +1,207 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+       "errors"
+       "io"
+       "net"
+       "strconv"
+)
+
+// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
+// with an optional username and password. See RFC 1928.
+func SOCKS5(network, addr string, auth *Auth, forward Dialer) (Dialer, error) {
+       s := &socks5{
+               network: network,
+               addr:    addr,
+               forward: forward,
+       }
+       if auth != nil {
+               s.user = auth.User
+               s.password = auth.Password
+       }
+
+       return s, nil
+}
+
+type socks5 struct {
+       user, password string
+       network, addr  string
+       forward        Dialer
+}
+
+const socks5Version = 5
+
+const (
+       socks5AuthNone     = 0
+       socks5AuthPassword = 2
+)
+
+const socks5Connect = 1
+
+const (
+       socks5IP4    = 1
+       socks5Domain = 3
+       socks5IP6    = 4
+)
+
+var socks5Errors = []string{
+       "",
+       "general failure",
+       "connection forbidden",
+       "network unreachable",
+       "host unreachable",
+       "connection refused",
+       "TTL expired",
+       "command not supported",
+       "address type not supported",
+}
+
+// Dial connects to the address addr on the network net via the SOCKS5 proxy.
+func (s *socks5) Dial(network, addr string) (net.Conn, error) {
+       switch network {
+       case "tcp", "tcp6", "tcp4":
+               break
+       default:
+               return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
+       }
+
+       conn, err := s.forward.Dial(s.network, s.addr)
+       if err != nil {
+               return nil, err
+       }
+       closeConn := &conn
+       defer func() {
+               if closeConn != nil {
+                       (*closeConn).Close()
+               }
+       }()
+
+       host, portStr, err := net.SplitHostPort(addr)
+       if err != nil {
+               return nil, err
+       }
+
+       port, err := strconv.Atoi(portStr)
+       if err != nil {
+               return nil, errors.New("proxy: failed to parse port number: " + portStr)
+       }
+       if port < 1 || port > 0xffff {
+               return nil, errors.New("proxy: port number out of range: " + portStr)
+       }
+
+       // the size here is just an estimate
+       buf := make([]byte, 0, 6+len(host))
+
+       buf = append(buf, socks5Version)
+       if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
+               buf = append(buf, 2, /* num auth methods */ socks5AuthNone, socks5AuthPassword)
+       } else {
+               buf = append(buf, 1, /* num auth methods */ socks5AuthNone)
+       }
+
+       if _, err = conn.Write(buf); err != nil {
+               return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+       }
+
+       if _, err = io.ReadFull(conn, buf[:2]); err != nil {
+               return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+       }
+       if buf[0] != 5 {
+               return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
+       }
+       if buf[1] == 0xff {
+               return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
+       }
+
+       if buf[1] == socks5AuthPassword {
+               buf = buf[:0]
+               buf = append(buf, socks5Version)
+               buf = append(buf, uint8(len(s.user)))
+               buf = append(buf, s.user...)
+               buf = append(buf, uint8(len(s.password)))
+               buf = append(buf, s.password...)
+
+               if _, err = conn.Write(buf); err != nil {
+                       return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+               }
+
+               if _, err = io.ReadFull(conn, buf[:2]); err != nil {
+                       return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+               }
+
+               if buf[1] != 0 {
+                       return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
+               }
+       }
+
+       buf = buf[:0]
+       buf = append(buf, socks5Version, socks5Connect, 0 /* reserved */ )
+
+       if ip := net.ParseIP(host); ip != nil {
+               if len(ip) == 4 {
+                       buf = append(buf, socks5IP4)
+               } else {
+                       buf = append(buf, socks5IP6)
+               }
+               buf = append(buf, []byte(ip)...)
+       } else {
+               buf = append(buf, socks5Domain)
+               buf = append(buf, byte(len(host)))
+               buf = append(buf, host...)
+       }
+       buf = append(buf, byte(port>>8), byte(port))
+
+       if _, err = conn.Write(buf); err != nil {
+               return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+       }
+
+       if _, err = io.ReadFull(conn, buf[:4]); err != nil {
+               return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+       }
+
+       failure := "unknown error"
+       if int(buf[1]) < len(socks5Errors) {
+               failure = socks5Errors[buf[1]]
+       }
+
+       if len(failure) > 0 {
+               return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
+       }
+
+       bytesToDiscard := 0
+       switch buf[3] {
+       case socks5IP4:
+               bytesToDiscard = 4
+       case socks5IP6:
+               bytesToDiscard = 16
+       case socks5Domain:
+               _, err := io.ReadFull(conn, buf[:1])
+               if err != nil {
+                       return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+               }
+               bytesToDiscard = int(buf[0])
+       default:
+               return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
+       }
+
+       if cap(buf) < bytesToDiscard {
+               buf = make([]byte, bytesToDiscard)
+       } else {
+               buf = buf[:bytesToDiscard]
+       }
+       if _, err = io.ReadFull(conn, buf); err != nil {
+               return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+       }
+
+       // Also need to discard the port number
+       if _, err = io.ReadFull(conn, buf[:2]); err != nil {
+               return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+       }
+
+       closeConn = nil
+       return conn, nil
+}
index bed09ffb29de944b34f2ba02568aff068f51c935..702ba4399d5445bb87494c8b6cbaa63e94dbc2ec 100644 (file)
@@ -8,8 +8,11 @@ import (
        "fmt"
        "reflect"
        "testing"
+       "time"
 )
 
+var someTime = time.Unix(123, 0)
+
 type conversionTest struct {
        s, d interface{} // source and destination
 
@@ -19,6 +22,7 @@ type conversionTest struct {
        wantstr  string
        wantf32  float32
        wantf64  float64
+       wanttime time.Time
        wantbool bool // used if d is of type *bool
        wanterr  string
 }
@@ -35,12 +39,14 @@ var (
        scanbool   bool
        scanf32    float32
        scanf64    float64
+       scantime   time.Time
 )
 
 var conversionTests = []conversionTest{
        // Exact conversions (destination pointer type matches source type)
        {s: "foo", d: &scanstr, wantstr: "foo"},
        {s: 123, d: &scanint, wantint: 123},
+       {s: someTime, d: &scantime, wanttime: someTime},
 
        // To strings
        {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
@@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 {
        return *(ptr.(*float32))
 }
 
+func timeValue(ptr interface{}) time.Time {
+       return *(ptr.(*time.Time))
+}
+
 func TestConversions(t *testing.T) {
        for n, ct := range conversionTests {
                err := convertAssign(ct.d, ct.s)
@@ -138,6 +148,9 @@ func TestConversions(t *testing.T) {
                if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
                        errf("want bool %v, got %v", ct.wantbool, *bp)
                }
+               if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
+                       errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
+               }
        }
 }
 
index f0bcca291065d17c058b1cca1ba27d2ea5f5ad03..0cd2562d6822d360b58fc663a66da32df04ba861 100644 (file)
@@ -16,6 +16,7 @@
 //   nil
 //   []byte
 //   string   [*] everywhere except from Rows.Next.
+//   time.Time
 //
 package driver
 
index 086b529c84f1d74525ea891e6314f2039e41a5eb..d6ba641cb269e7bce0e79a05f16edd805953e8e5 100644 (file)
@@ -8,6 +8,7 @@ import (
        "fmt"
        "reflect"
        "strconv"
+       "time"
 )
 
 // ValueConverter is the interface providing the ConvertValue method.
@@ -39,7 +40,7 @@ type ValueConverter interface {
 //       1 is true
 //       0 is false,
 //       other integers are an error
-//  - for strings and []byte, same rules as strconv.Atob
+//  - for strings and []byte, same rules as strconv.ParseBool
 //  - all other types are an error
 var Bool boolType
 
@@ -143,9 +144,10 @@ func (stringType) ConvertValue(v interface{}) (interface{}, error) {
 //   bool
 //   nil
 //   []byte
+//   time.Time
 //   string
 //
-// This is the ame list as IsScanSubsetType, with the addition of
+// This is the same list as IsScanSubsetType, with the addition of
 // string.
 func IsParameterSubsetType(v interface{}) bool {
        if IsScanSubsetType(v) {
@@ -165,6 +167,7 @@ func IsParameterSubsetType(v interface{}) bool {
 //   bool
 //   nil
 //   []byte
+//   time.Time
 //
 // This is the same list as IsParameterSubsetType, without string.
 func IsScanSubsetType(v interface{}) bool {
@@ -172,7 +175,7 @@ func IsScanSubsetType(v interface{}) bool {
                return true
        }
        switch v.(type) {
-       case int64, float64, []byte, bool:
+       case int64, float64, []byte, bool, time.Time:
                return true
        }
        return false
index 4b049e26e5131b64870ccc99f7ad551af592ce54..966bc6b45877eaf303e4bc5ec40aaff477637221 100644 (file)
@@ -7,6 +7,7 @@ package driver
 import (
        "reflect"
        "testing"
+       "time"
 )
 
 type valueConverterTest struct {
@@ -16,6 +17,8 @@ type valueConverterTest struct {
        err string
 }
 
+var now = time.Now()
+
 var valueConverterTests = []valueConverterTest{
        {Bool, "true", true, ""},
        {Bool, "True", true, ""},
@@ -33,6 +36,7 @@ var valueConverterTests = []valueConverterTest{
        {Bool, uint16(0), false, ""},
        {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
        {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
+       {DefaultParameterConverter, now, now, ""},
 }
 
 func TestValueConverters(t *testing.T) {
index 2474a86f644c46fc5d5988fad918fd7bcf82c9b3..70aa68c1385cd0fa1b843421d0d5e8aac682e89b 100644 (file)
@@ -12,6 +12,7 @@ import (
        "strconv"
        "strings"
        "sync"
+       "time"
 
        "exp/sql/driver"
 )
@@ -77,6 +78,17 @@ type fakeConn struct {
        db *fakeDB // where to return ourselves to
 
        currTx *fakeTx
+
+       // Stats for tests:
+       mu          sync.Mutex
+       stmtsMade   int
+       stmtsClosed int
+}
+
+func (c *fakeConn) incrStat(v *int) {
+       c.mu.Lock()
+       *v++
+       c.mu.Unlock()
 }
 
 type fakeTx struct {
@@ -110,25 +122,34 @@ func init() {
 
 // Supports dsn forms:
 //    <dbname>
-//    <dbname>;wipe
+//    <dbname>;<opts>  (no currently supported options)
 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
-       d.mu.Lock()
-       defer d.mu.Unlock()
-       d.openCount++
-       if d.dbs == nil {
-               d.dbs = make(map[string]*fakeDB)
-       }
        parts := strings.Split(dsn, ";")
        if len(parts) < 1 {
                return nil, errors.New("fakedb: no database name")
        }
        name := parts[0]
+
+       db := d.getDB(name)
+
+       d.mu.Lock()
+       d.openCount++
+       d.mu.Unlock()
+       return &fakeConn{db: db}, nil
+}
+
+func (d *fakeDriver) getDB(name string) *fakeDB {
+       d.mu.Lock()
+       defer d.mu.Unlock()
+       if d.dbs == nil {
+               d.dbs = make(map[string]*fakeDB)
+       }
        db, ok := d.dbs[name]
        if !ok {
                db = &fakeDB{name: name}
                d.dbs[name] = db
        }
-       return &fakeConn{db: db}, nil
+       return db
 }
 
 func (db *fakeDB) wipe() {
@@ -200,7 +221,7 @@ func (c *fakeConn) Close() error {
 func checkSubsetTypes(args []interface{}) error {
        for n, arg := range args {
                switch arg.(type) {
-               case int64, float64, bool, nil, []byte, string:
+               case int64, float64, bool, nil, []byte, string, time.Time:
                default:
                        return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
                }
@@ -297,6 +318,8 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
                        switch ctype {
                        case "string":
                                subsetVal = []byte(value)
+                       case "blob":
+                               subsetVal = []byte(value)
                        case "int32":
                                i, err := strconv.Atoi(value)
                                if err != nil {
@@ -327,6 +350,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
        cmd := parts[0]
        parts = parts[1:]
        stmt := &fakeStmt{q: query, c: c, cmd: cmd}
+       c.incrStat(&c.stmtsMade)
        switch cmd {
        case "WIPE":
                // Nothing
@@ -347,7 +371,10 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
 }
 
 func (s *fakeStmt) Close() error {
-       s.closed = true
+       if !s.closed {
+               s.c.incrStat(&s.c.stmtsClosed)
+               s.closed = true
+       }
        return nil
 }
 
@@ -501,9 +528,19 @@ type rowsCursor struct {
        pos    int
        rows   []*row
        closed bool
+
+       // a clone of slices to give out to clients, indexed by the
+       // the original slice's first byte address.  we clone them
+       // just so we're able to corrupt them on close.
+       bytesClone map[*byte][]byte
 }
 
 func (rc *rowsCursor) Close() error {
+       if !rc.closed {
+               for _, bs := range rc.bytesClone {
+                       bs[0] = 255 // first byte corrupted
+               }
+       }
        rc.closed = true
        return nil
 }
@@ -528,6 +565,19 @@ func (rc *rowsCursor) Next(dest []interface{}) error {
                // for ease of drivers, and to prevent drivers from
                // messing up conversions or doing them differently.
                dest[i] = v
+
+               if bs, ok := v.([]byte); ok {
+                       if rc.bytesClone == nil {
+                               rc.bytesClone = make(map[*byte][]byte)
+                       }
+                       clone, ok := rc.bytesClone[&bs[0]]
+                       if !ok {
+                               clone = make([]byte, len(bs))
+                               copy(clone, bs)
+                               rc.bytesClone[&bs[0]] = clone
+                       }
+                       dest[i] = clone
+               }
        }
        return nil
 }
@@ -540,6 +590,8 @@ func converterForType(typ string) driver.ValueConverter {
                return driver.Int32
        case "string":
                return driver.String
+       case "datetime":
+               return driver.DefaultParameterConverter
        }
        panic("invalid fakedb column type of " + typ)
 }
index 937982cdbe6f9b3661ea963047f0efe9d92fd828..4e68c3ee0952b6e778a0facf650862212164cbb9 100644 (file)
@@ -243,8 +243,13 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
        if err != nil {
                return nil, err
        }
-       defer stmt.Close()
-       return stmt.Query(args...)
+       rows, err := stmt.Query(args...)
+       if err != nil {
+               stmt.Close()
+               return nil, err
+       }
+       rows.closeStmt = stmt
+       return rows, nil
 }
 
 // QueryRow executes a query that is expected to return at most one row.
@@ -549,8 +554,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
 func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
-       if s.stickyErr != nil {
-               return nil, nil, nil, s.stickyErr
+       if err = s.stickyErr; err != nil {
+               return
        }
        s.mu.Lock()
        if s.closed {
@@ -706,9 +711,10 @@ type Rows struct {
        releaseConn func()
        rowsi       driver.Rows
 
-       closed   bool
-       lastcols []interface{}
-       lasterr  error
+       closed    bool
+       lastcols  []interface{}
+       lasterr   error
+       closeStmt *Stmt // if non-nil, statement to Close on close
 }
 
 // Next prepares the next result row for reading with the Scan method.
@@ -726,6 +732,9 @@ func (rs *Rows) Next() bool {
                rs.lastcols = make([]interface{}, len(rs.rowsi.Columns()))
        }
        rs.lasterr = rs.rowsi.Next(rs.lastcols)
+       if rs.lasterr == io.EOF {
+               rs.Close()
+       }
        return rs.lasterr == nil
 }
 
@@ -786,6 +795,9 @@ func (rs *Rows) Close() error {
        rs.closed = true
        err := rs.rowsi.Close()
        rs.releaseConn()
+       if rs.closeStmt != nil {
+               rs.closeStmt.Close()
+       }
        return err
 }
 
@@ -800,10 +812,6 @@ type Row struct {
 // pointed at by dest.  If more than one row matches the query,
 // Scan uses the first row and discards the rest.  If no row matches
 // the query, Scan returns ErrNoRows.
-//
-// If dest contains pointers to []byte, the slices should not be
-// modified and should only be considered valid until the next call to
-// Next or Scan.
 func (r *Row) Scan(dest ...interface{}) error {
        if r.err != nil {
                return r.err
@@ -812,7 +820,33 @@ func (r *Row) Scan(dest ...interface{}) error {
        if !r.rows.Next() {
                return ErrNoRows
        }
-       return r.rows.Scan(dest...)
+       err := r.rows.Scan(dest...)
+       if err != nil {
+               return err
+       }
+
+       // TODO(bradfitz): for now we need to defensively clone all
+       // []byte that the driver returned, since we're about to close
+       // the Rows in our defer, when we return from this function.
+       // the contract with the driver.Next(...) interface is that it
+       // can return slices into read-only temporary memory that's
+       // only valid until the next Scan/Close.  But the TODO is that
+       // for a lot of drivers, this copy will be unnecessary.  We
+       // should provide an optional interface for drivers to
+       // implement to say, "don't worry, the []bytes that I return
+       // from Next will not be modified again." (for instance, if
+       // they were obtained from the network anyway) But for now we
+       // don't care.
+       for _, dp := range dest {
+               b, ok := dp.(*[]byte)
+               if !ok {
+                       continue
+               }
+               clone := make([]byte, len(*b))
+               copy(clone, *b)
+               *b = clone
+       }
+       return nil
 }
 
 // A Result summarizes an executed SQL command.
index 5307a235ddf13fc0eb84b09eb7da9e3358ed0f3a..3f98a8cd9f288ccd9fdc65da3b33c75419afc7b1 100644 (file)
@@ -8,10 +8,15 @@ import (
        "reflect"
        "strings"
        "testing"
+       "time"
 )
 
+const fakeDBName = "foo"
+
+var chrisBirthday = time.Unix(123456789, 0)
+
 func newTestDB(t *testing.T, name string) *DB {
-       db, err := Open("test", "foo")
+       db, err := Open("test", fakeDBName)
        if err != nil {
                t.Fatalf("Open: %v", err)
        }
@@ -19,10 +24,10 @@ func newTestDB(t *testing.T, name string) *DB {
                t.Fatalf("exec wipe: %v", err)
        }
        if name == "people" {
-               exec(t, db, "CREATE|people|name=string,age=int32,dead=bool")
-               exec(t, db, "INSERT|people|name=Alice,age=?", 1)
-               exec(t, db, "INSERT|people|name=Bob,age=?", 2)
-               exec(t, db, "INSERT|people|name=Chris,age=?", 3)
+               exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
+               exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
+               exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
+               exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
        }
        return db
 }
@@ -73,6 +78,12 @@ func TestQuery(t *testing.T) {
        if !reflect.DeepEqual(got, want) {
                t.Logf(" got: %#v\nwant: %#v", got, want)
        }
+
+       // And verify that the final rows.Next() call, which hit EOF,
+       // also closed the rows connection.
+       if n := len(db.freeConn); n != 1 {
+               t.Errorf("free conns after query hitting EOF = %d; want 1", n)
+       }
 }
 
 func TestRowsColumns(t *testing.T) {
@@ -97,12 +108,18 @@ func TestQueryRow(t *testing.T) {
        defer closeDB(t, db)
        var name string
        var age int
+       var birthday time.Time
 
        err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
        if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
                t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
        }
 
+       err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
+       if err != nil || !birthday.Equal(chrisBirthday) {
+               t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
+       }
+
        err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
        if err != nil {
                t.Fatalf("age QueryRow+Scan: %v", err)
@@ -124,6 +141,16 @@ func TestQueryRow(t *testing.T) {
        if age != 1 {
                t.Errorf("expected age 1, got %d", age)
        }
+
+       var photo []byte
+       err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
+       if err != nil {
+               t.Fatalf("photo QueryRow+Scan: %v", err)
+       }
+       want := []byte("APHOTO")
+       if !reflect.DeepEqual(photo, want) {
+               t.Errorf("photo = %q; want %q", photo, want)
+       }
 }
 
 func TestStatementErrorAfterClose(t *testing.T) {
@@ -258,3 +285,21 @@ func TestIssue2542Deadlock(t *testing.T) {
                }
        }
 }
+
+func TestQueryRowClosingStmt(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+       var name string
+       var age int
+       err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if len(db.freeConn) != 1 {
+               t.Fatalf("expected 1 free conn")
+       }
+       fakeConn := db.freeConn[0].(*fakeConn)
+       if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
+               t.Logf("statement close mismatch: made %d, closed %d", made, closed)
+       }
+}
index 7c862078b7e95a46cdfd19149a6063b31c0c5663..8df81457bf54269bf13b522f1bd85c9de02c5d83 100644 (file)
@@ -420,27 +420,37 @@ type chanWriter struct {
 }
 
 // Write writes data to the remote process's standard input.
-func (w *chanWriter) Write(data []byte) (n int, err error) {
-       for {
-               if w.rwin == 0 {
+func (w *chanWriter) Write(data []byte) (written int, err error) {
+       for len(data) > 0 {
+               for w.rwin < 1 {
                        win, ok := <-w.win
                        if !ok {
                                return 0, io.EOF
                        }
                        w.rwin += win
-                       continue
                }
+               n := min(len(data), w.rwin)
                peersId := w.clientChan.peersId
-               n = len(data)
-               packet := make([]byte, 0, 9+n)
-               packet = append(packet, msgChannelData,
-                       byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId),
-                       byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
-               err = w.clientChan.writePacket(append(packet, data...))
+               packet := []byte{
+                       msgChannelData,
+                       byte(peersId >> 24), byte(peersId >> 16), byte(peersId >> 8), byte(peersId),
+                       byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
+               }
+               if err = w.clientChan.writePacket(append(packet, data[:n]...)); err != nil {
+                       break
+               }
+               data = data[n:]
                w.rwin -= n
-               return
+               written += n
        }
-       panic("unreachable")
+       return
+}
+
+func min(a, b int) int {
+       if a < b {
+               return a
+       }
+       return b
 }
 
 func (w *chanWriter) Close() error {
index 480f877191a1b705115e4097e7aea8764cea2cef..e7deb5ec168831488a3c9600c77cf1b86ad02c41 100644 (file)
@@ -14,7 +14,7 @@ others.
 An SSH server is represented by a ServerConfig, which holds certificate
 details and handles authentication of ServerConns.
 
-       config := new(ServerConfig)
+       config := new(ssh.ServerConfig)
        config.PubKeyCallback = pubKeyAuth
        config.PasswordCallback = passwordAuth
 
@@ -34,8 +34,7 @@ Once a ServerConfig has been configured, connections can be accepted.
        if err != nil {
                panic("failed to accept incoming connection")
        }
-       err = sConn.Handshake(conn)
-       if err != nil {
+       if err := sConn.Handshake(conn); err != nil {
                panic("failed to handshake")
        }
 
@@ -60,16 +59,20 @@ the case of a shell, the type is "session" and ServerShell may be used to
 present a simple terminal interface.
 
        if channel.ChannelType() != "session" {
-               c.Reject(UnknownChannelType, "unknown channel type")
+               channel.Reject(UnknownChannelType, "unknown channel type")
                return
        }
        channel.Accept()
 
-       shell := NewServerShell(channel, "> ")
+       term := terminal.NewTerminal(channel, "> ")
+       serverTerm := &ssh.ServerTerminal{
+               Term: term,
+               Channel: channel,
+       }
        go func() {
                defer channel.Close()
                for {
-                       line, err := shell.ReadLine()
+                       line, err := serverTerm.ReadLine()
                        if err != nil {
                                break
                        }
@@ -78,8 +81,27 @@ present a simple terminal interface.
                return
        }()
 
+To authenticate with the remote server you must pass at least one implementation of 
+ClientAuth via the Auth field in ClientConfig.
+
+       // password implements the ClientPassword interface
+       type password string
+
+       func (p password) Password(user string) (string, error) {
+               return string(p), nil
+       }
+
+       config := &ssh.ClientConfig {
+               User: "username",
+               Auth: []ClientAuth {
+                       // ClientAuthPassword wraps a ClientPassword implementation
+                       // in a type that implements ClientAuth.
+                       ClientAuthPassword(password("yourpassword")),
+               }
+       }
+
 An SSH client is represented with a ClientConn. Currently only the "password"
-authentication method is supported. 
+authentication method is supported.
 
        config := &ClientConfig{
                User: "username",
@@ -87,19 +109,19 @@ authentication method is supported.
        }
        client, err := Dial("yourserver.com:22", config)
 
-Each ClientConn can support multiple interactive sessions, represented by a Session. 
+Each ClientConn can support multiple interactive sessions, represented by a Session.
 
        session, err := client.NewSession()
 
-Once a Session is created, you can execute a single command on the remote side 
-using the Run method.
+Once a Session is created, you can execute a single command on the remote side
+using the Exec method.
 
+       b := bytes.NewBuffer()
+       session.Stdin = b
        if err := session.Run("/usr/bin/whoami"); err != nil {
                panic("Failed to exec: " + err.String())
        }
-       reader := bufio.NewReader(session.Stdin)
-       line, _, _ := reader.ReadLine()
-       fmt.Println(line)
+       fmt.Println(bytes.String())
        session.Close()
 */
 package ssh
diff --git a/libgo/go/exp/ssh/server_shell.go b/libgo/go/exp/ssh/server_shell.go
deleted file mode 100644 (file)
index 5243d0e..0000000
+++ /dev/null
@@ -1,398 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package ssh
-
-import "io"
-
-// ServerShell contains the state for running a VT100 terminal that is capable
-// of reading lines of input.
-type ServerShell struct {
-       c      Channel
-       prompt string
-
-       // line is the current line being entered.
-       line []byte
-       // pos is the logical position of the cursor in line
-       pos int
-
-       // cursorX contains the current X value of the cursor where the left
-       // edge is 0. cursorY contains the row number where the first row of
-       // the current line is 0.
-       cursorX, cursorY int
-       // maxLine is the greatest value of cursorY so far.
-       maxLine int
-
-       termWidth, termHeight int
-
-       // outBuf contains the terminal data to be sent.
-       outBuf []byte
-       // remainder contains the remainder of any partial key sequences after
-       // a read. It aliases into inBuf.
-       remainder []byte
-       inBuf     [256]byte
-}
-
-// NewServerShell runs a VT100 terminal on the given channel. prompt is a
-// string that is written at the start of each input line. For example: "> ".
-func NewServerShell(c Channel, prompt string) *ServerShell {
-       return &ServerShell{
-               c:          c,
-               prompt:     prompt,
-               termWidth:  80,
-               termHeight: 24,
-       }
-}
-
-const (
-       keyCtrlD     = 4
-       keyEnter     = '\r'
-       keyEscape    = 27
-       keyBackspace = 127
-       keyUnknown   = 256 + iota
-       keyUp
-       keyDown
-       keyLeft
-       keyRight
-       keyAltLeft
-       keyAltRight
-)
-
-// bytesToKey tries to parse a key sequence from b. If successful, it returns
-// the key and the remainder of the input. Otherwise it returns -1.
-func bytesToKey(b []byte) (int, []byte) {
-       if len(b) == 0 {
-               return -1, nil
-       }
-
-       if b[0] != keyEscape {
-               return int(b[0]), b[1:]
-       }
-
-       if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
-               switch b[2] {
-               case 'A':
-                       return keyUp, b[3:]
-               case 'B':
-                       return keyDown, b[3:]
-               case 'C':
-                       return keyRight, b[3:]
-               case 'D':
-                       return keyLeft, b[3:]
-               }
-       }
-
-       if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
-               switch b[5] {
-               case 'C':
-                       return keyAltRight, b[6:]
-               case 'D':
-                       return keyAltLeft, b[6:]
-               }
-       }
-
-       // If we get here then we have a key that we don't recognise, or a
-       // partial sequence. It's not clear how one should find the end of a
-       // sequence without knowing them all, but it seems that [a-zA-Z] only
-       // appears at the end of a sequence.
-       for i, c := range b[0:] {
-               if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
-                       return keyUnknown, b[i+1:]
-               }
-       }
-
-       return -1, b
-}
-
-// queue appends data to the end of ss.outBuf
-func (ss *ServerShell) queue(data []byte) {
-       if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
-               newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
-               copy(newOutBuf, ss.outBuf)
-               ss.outBuf = newOutBuf
-       }
-
-       oldLen := len(ss.outBuf)
-       ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
-       copy(ss.outBuf[oldLen:], data)
-}
-
-var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
-
-func isPrintable(key int) bool {
-       return key >= 32 && key < 127
-}
-
-// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
-// given, logical position in the text.
-func (ss *ServerShell) moveCursorToPos(pos int) {
-       x := len(ss.prompt) + pos
-       y := x / ss.termWidth
-       x = x % ss.termWidth
-
-       up := 0
-       if y < ss.cursorY {
-               up = ss.cursorY - y
-       }
-
-       down := 0
-       if y > ss.cursorY {
-               down = y - ss.cursorY
-       }
-
-       left := 0
-       if x < ss.cursorX {
-               left = ss.cursorX - x
-       }
-
-       right := 0
-       if x > ss.cursorX {
-               right = x - ss.cursorX
-       }
-
-       movement := make([]byte, 3*(up+down+left+right))
-       m := movement
-       for i := 0; i < up; i++ {
-               m[0] = keyEscape
-               m[1] = '['
-               m[2] = 'A'
-               m = m[3:]
-       }
-       for i := 0; i < down; i++ {
-               m[0] = keyEscape
-               m[1] = '['
-               m[2] = 'B'
-               m = m[3:]
-       }
-       for i := 0; i < left; i++ {
-               m[0] = keyEscape
-               m[1] = '['
-               m[2] = 'D'
-               m = m[3:]
-       }
-       for i := 0; i < right; i++ {
-               m[0] = keyEscape
-               m[1] = '['
-               m[2] = 'C'
-               m = m[3:]
-       }
-
-       ss.cursorX = x
-       ss.cursorY = y
-       ss.queue(movement)
-}
-
-const maxLineLength = 4096
-
-// handleKey processes the given key and, optionally, returns a line of text
-// that the user has entered.
-func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
-       switch key {
-       case keyBackspace:
-               if ss.pos == 0 {
-                       return
-               }
-               ss.pos--
-
-               copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
-               ss.line = ss.line[:len(ss.line)-1]
-               ss.writeLine(ss.line[ss.pos:])
-               ss.moveCursorToPos(ss.pos)
-               ss.queue(eraseUnderCursor)
-       case keyAltLeft:
-               // move left by a word.
-               if ss.pos == 0 {
-                       return
-               }
-               ss.pos--
-               for ss.pos > 0 {
-                       if ss.line[ss.pos] != ' ' {
-                               break
-                       }
-                       ss.pos--
-               }
-               for ss.pos > 0 {
-                       if ss.line[ss.pos] == ' ' {
-                               ss.pos++
-                               break
-                       }
-                       ss.pos--
-               }
-               ss.moveCursorToPos(ss.pos)
-       case keyAltRight:
-               // move right by a word.
-               for ss.pos < len(ss.line) {
-                       if ss.line[ss.pos] == ' ' {
-                               break
-                       }
-                       ss.pos++
-               }
-               for ss.pos < len(ss.line) {
-                       if ss.line[ss.pos] != ' ' {
-                               break
-                       }
-                       ss.pos++
-               }
-               ss.moveCursorToPos(ss.pos)
-       case keyLeft:
-               if ss.pos == 0 {
-                       return
-               }
-               ss.pos--
-               ss.moveCursorToPos(ss.pos)
-       case keyRight:
-               if ss.pos == len(ss.line) {
-                       return
-               }
-               ss.pos++
-               ss.moveCursorToPos(ss.pos)
-       case keyEnter:
-               ss.moveCursorToPos(len(ss.line))
-               ss.queue([]byte("\r\n"))
-               line = string(ss.line)
-               ok = true
-               ss.line = ss.line[:0]
-               ss.pos = 0
-               ss.cursorX = 0
-               ss.cursorY = 0
-               ss.maxLine = 0
-       default:
-               if !isPrintable(key) {
-                       return
-               }
-               if len(ss.line) == maxLineLength {
-                       return
-               }
-               if len(ss.line) == cap(ss.line) {
-                       newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
-                       copy(newLine, ss.line)
-                       ss.line = newLine
-               }
-               ss.line = ss.line[:len(ss.line)+1]
-               copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
-               ss.line[ss.pos] = byte(key)
-               ss.writeLine(ss.line[ss.pos:])
-               ss.pos++
-               ss.moveCursorToPos(ss.pos)
-       }
-       return
-}
-
-func (ss *ServerShell) writeLine(line []byte) {
-       for len(line) != 0 {
-               if ss.cursorX == ss.termWidth {
-                       ss.queue([]byte("\r\n"))
-                       ss.cursorX = 0
-                       ss.cursorY++
-                       if ss.cursorY > ss.maxLine {
-                               ss.maxLine = ss.cursorY
-                       }
-               }
-
-               remainingOnLine := ss.termWidth - ss.cursorX
-               todo := len(line)
-               if todo > remainingOnLine {
-                       todo = remainingOnLine
-               }
-               ss.queue(line[:todo])
-               ss.cursorX += todo
-               line = line[todo:]
-       }
-}
-
-// parsePtyRequest parses the payload of the pty-req message and extracts the
-// dimensions of the terminal. See RFC 4254, section 6.2.
-func parsePtyRequest(s []byte) (width, height int, ok bool) {
-       _, s, ok = parseString(s)
-       if !ok {
-               return
-       }
-       width32, s, ok := parseUint32(s)
-       if !ok {
-               return
-       }
-       height32, _, ok := parseUint32(s)
-       width = int(width32)
-       height = int(height32)
-       if width < 1 {
-               ok = false
-       }
-       if height < 1 {
-               ok = false
-       }
-       return
-}
-
-func (ss *ServerShell) Write(buf []byte) (n int, err error) {
-       return ss.c.Write(buf)
-}
-
-// ReadLine returns a line of input from the terminal.
-func (ss *ServerShell) ReadLine() (line string, err error) {
-       ss.writeLine([]byte(ss.prompt))
-       ss.c.Write(ss.outBuf)
-       ss.outBuf = ss.outBuf[:0]
-
-       for {
-               // ss.remainder is a slice at the beginning of ss.inBuf
-               // containing a partial key sequence
-               readBuf := ss.inBuf[len(ss.remainder):]
-               var n int
-               n, err = ss.c.Read(readBuf)
-               if err == nil {
-                       ss.remainder = ss.inBuf[:n+len(ss.remainder)]
-                       rest := ss.remainder
-                       lineOk := false
-                       for !lineOk {
-                               var key int
-                               key, rest = bytesToKey(rest)
-                               if key < 0 {
-                                       break
-                               }
-                               if key == keyCtrlD {
-                                       return "", io.EOF
-                               }
-                               line, lineOk = ss.handleKey(key)
-                       }
-                       if len(rest) > 0 {
-                               n := copy(ss.inBuf[:], rest)
-                               ss.remainder = ss.inBuf[:n]
-                       } else {
-                               ss.remainder = nil
-                       }
-                       ss.c.Write(ss.outBuf)
-                       ss.outBuf = ss.outBuf[:0]
-                       if lineOk {
-                               return
-                       }
-                       continue
-               }
-
-               if req, ok := err.(ChannelRequest); ok {
-                       ok := false
-                       switch req.Request {
-                       case "pty-req":
-                               ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
-                               if !ok {
-                                       ss.termWidth = 80
-                                       ss.termHeight = 24
-                               }
-                       case "shell":
-                               ok = true
-                               if len(req.Payload) > 0 {
-                                       // We don't accept any commands, only the default shell.
-                                       ok = false
-                               }
-                       case "env":
-                               ok = true
-                       }
-                       if req.WantReply {
-                               ss.c.AckRequest(ok)
-                       }
-               } else {
-                       return "", err
-               }
-       }
-       panic("unreachable")
-}
diff --git a/libgo/go/exp/ssh/server_shell_test.go b/libgo/go/exp/ssh/server_shell_test.go
deleted file mode 100644 (file)
index aa69ef7..0000000
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package ssh
-
-import (
-       "io"
-       "testing"
-)
-
-type MockChannel struct {
-       toSend       []byte
-       bytesPerRead int
-       received     []byte
-}
-
-func (c *MockChannel) Accept() error {
-       return nil
-}
-
-func (c *MockChannel) Reject(RejectionReason, string) error {
-       return nil
-}
-
-func (c *MockChannel) Read(data []byte) (n int, err error) {
-       n = len(data)
-       if n == 0 {
-               return
-       }
-       if n > len(c.toSend) {
-               n = len(c.toSend)
-       }
-       if n == 0 {
-               return 0, io.EOF
-       }
-       if c.bytesPerRead > 0 && n > c.bytesPerRead {
-               n = c.bytesPerRead
-       }
-       copy(data, c.toSend[:n])
-       c.toSend = c.toSend[n:]
-       return
-}
-
-func (c *MockChannel) Write(data []byte) (n int, err error) {
-       c.received = append(c.received, data...)
-       return len(data), nil
-}
-
-func (c *MockChannel) Close() error {
-       return nil
-}
-
-func (c *MockChannel) AckRequest(ok bool) error {
-       return nil
-}
-
-func (c *MockChannel) ChannelType() string {
-       return ""
-}
-
-func (c *MockChannel) ExtraData() []byte {
-       return nil
-}
-
-func TestClose(t *testing.T) {
-       c := &MockChannel{}
-       ss := NewServerShell(c, "> ")
-       line, err := ss.ReadLine()
-       if line != "" {
-               t.Errorf("Expected empty line but got: %s", line)
-       }
-       if err != io.EOF {
-               t.Errorf("Error should have been EOF but got: %s", err)
-       }
-}
-
-var keyPressTests = []struct {
-       in   string
-       line string
-       err  error
-}{
-       {
-               "",
-               "",
-               io.EOF,
-       },
-       {
-               "\r",
-               "",
-               nil,
-       },
-       {
-               "foo\r",
-               "foo",
-               nil,
-       },
-       {
-               "a\x1b[Cb\r", // right
-               "ab",
-               nil,
-       },
-       {
-               "a\x1b[Db\r", // left
-               "ba",
-               nil,
-       },
-       {
-               "a\177b\r", // backspace
-               "b",
-               nil,
-       },
-}
-
-func TestKeyPresses(t *testing.T) {
-       for i, test := range keyPressTests {
-               for j := 0; j < len(test.in); j++ {
-                       c := &MockChannel{
-                               toSend:       []byte(test.in),
-                               bytesPerRead: j,
-                       }
-                       ss := NewServerShell(c, "> ")
-                       line, err := ss.ReadLine()
-                       if line != test.line {
-                               t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
-                               break
-                       }
-                       if err != test.err {
-                               t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
-                               break
-                       }
-               }
-       }
-}
diff --git a/libgo/go/exp/ssh/server_terminal.go b/libgo/go/exp/ssh/server_terminal.go
new file mode 100644 (file)
index 0000000..708a915
--- /dev/null
@@ -0,0 +1,81 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// A Terminal is capable of parsing and generating virtual terminal
+// data from an SSH client.
+type Terminal interface {
+       ReadLine() (line string, err error)
+       SetSize(x, y int)
+       Write([]byte) (int, error)
+}
+
+// ServerTerminal contains the state for running a terminal that is capable of
+// reading lines of input.
+type ServerTerminal struct {
+       Term    Terminal
+       Channel Channel
+}
+
+// parsePtyRequest parses the payload of the pty-req message and extracts the
+// dimensions of the terminal. See RFC 4254, section 6.2.
+func parsePtyRequest(s []byte) (width, height int, ok bool) {
+       _, s, ok = parseString(s)
+       if !ok {
+               return
+       }
+       width32, s, ok := parseUint32(s)
+       if !ok {
+               return
+       }
+       height32, _, ok := parseUint32(s)
+       width = int(width32)
+       height = int(height32)
+       if width < 1 {
+               ok = false
+       }
+       if height < 1 {
+               ok = false
+       }
+       return
+}
+
+func (ss *ServerTerminal) Write(buf []byte) (n int, err error) {
+       return ss.Term.Write(buf)
+}
+
+// ReadLine returns a line of input from the terminal.
+func (ss *ServerTerminal) ReadLine() (line string, err error) {
+       for {
+               if line, err = ss.Term.ReadLine(); err == nil {
+                       return
+               }
+
+               req, ok := err.(ChannelRequest)
+               if !ok {
+                       return
+               }
+
+               ok = false
+               switch req.Request {
+               case "pty-req":
+                       var width, height int
+                       width, height, ok = parsePtyRequest(req.Payload)
+                       ss.Term.SetSize(width, height)
+               case "shell":
+                       ok = true
+                       if len(req.Payload) > 0 {
+                               // We don't accept any commands, only the default shell.
+                               ok = false
+                       }
+               case "env":
+                       ok = true
+               }
+               if req.WantReply {
+                       ss.Channel.AckRequest(ok)
+               }
+       }
+       panic("unreachable")
+}
index 2882620b0ba3c89c54953429c8e36cb45ca512e9..4a3d22bee04f905be3cdf581f7a8a12955892c8b 100644 (file)
@@ -8,6 +8,7 @@ package ssh
 
 import (
        "bytes"
+       "exp/terminal"
        "io"
        "testing"
 )
@@ -290,24 +291,32 @@ type exitSignalMsg struct {
        Lang       string
 }
 
+func newServerShell(ch *channel, prompt string) *ServerTerminal {
+       term := terminal.NewTerminal(ch, prompt)
+       return &ServerTerminal{
+               Term:    term,
+               Channel: ch,
+       }
+}
+
 func exitStatusZeroHandler(ch *channel) {
        defer ch.Close()
        // this string is returned to stdout
-       shell := NewServerShell(ch, "> ")
+       shell := newServerShell(ch, "> ")
        shell.ReadLine()
        sendStatus(0, ch)
 }
 
 func exitStatusNonZeroHandler(ch *channel) {
        defer ch.Close()
-       shell := NewServerShell(ch, "> ")
+       shell := newServerShell(ch, "> ")
        shell.ReadLine()
        sendStatus(15, ch)
 }
 
 func exitSignalAndStatusHandler(ch *channel) {
        defer ch.Close()
-       shell := NewServerShell(ch, "> ")
+       shell := newServerShell(ch, "> ")
        shell.ReadLine()
        sendStatus(15, ch)
        sendSignal("TERM", ch)
@@ -315,28 +324,28 @@ func exitSignalAndStatusHandler(ch *channel) {
 
 func exitSignalHandler(ch *channel) {
        defer ch.Close()
-       shell := NewServerShell(ch, "> ")
+       shell := newServerShell(ch, "> ")
        shell.ReadLine()
        sendSignal("TERM", ch)
 }
 
 func exitSignalUnknownHandler(ch *channel) {
        defer ch.Close()
-       shell := NewServerShell(ch, "> ")
+       shell := newServerShell(ch, "> ")
        shell.ReadLine()
        sendSignal("SYS", ch)
 }
 
 func exitWithoutSignalOrStatus(ch *channel) {
        defer ch.Close()
-       shell := NewServerShell(ch, "> ")
+       shell := newServerShell(ch, "> ")
        shell.ReadLine()
 }
 
 func shellHandler(ch *channel) {
        defer ch.Close()
        // this string is returned to stdout
-       shell := NewServerShell(ch, "golang")
+       shell := newServerShell(ch, "golang")
        shell.ReadLine()
        sendStatus(0, ch)
 }
index bcd073e7ce6ec11451df55aa479fa316d0e01ee1..2e7c955a12de6d9abb3b6f8c1b0cfa802a1440e3 100644 (file)
@@ -117,9 +117,7 @@ func (r *reader) readOnePacket() ([]byte, error) {
                return nil, err
        }
        mac := packet[length-1:]
-       if r.cipher != nil {
-               r.cipher.XORKeyStream(packet, packet[:length-1])
-       }
+       r.cipher.XORKeyStream(packet, packet[:length-1])
 
        if r.mac != nil {
                r.mac.Write(packet[:length-1])
index 809e88cacfa3f448a391ca857ec18a10fef7733c..c3ba5bde2ee6e69b01b7f9e315c1d17bdef7f3ea 100644 (file)
@@ -2,8 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// +build linux
-
 package terminal
 
 import (
@@ -463,6 +461,31 @@ func (t *Terminal) readLine() (line string, err error) {
        }
 
        for {
+               rest := t.remainder
+               lineOk := false
+               for !lineOk {
+                       var key int
+                       key, rest = bytesToKey(rest)
+                       if key < 0 {
+                               break
+                       }
+                       if key == keyCtrlD {
+                               return "", io.EOF
+                       }
+                       line, lineOk = t.handleKey(key)
+               }
+               if len(rest) > 0 {
+                       n := copy(t.inBuf[:], rest)
+                       t.remainder = t.inBuf[:n]
+               } else {
+                       t.remainder = nil
+               }
+               t.c.Write(t.outBuf)
+               t.outBuf = t.outBuf[:0]
+               if lineOk {
+                       return
+               }
+
                // t.remainder is a slice at the beginning of t.inBuf
                // containing a partial key sequence
                readBuf := t.inBuf[len(t.remainder):]
@@ -476,38 +499,19 @@ func (t *Terminal) readLine() (line string, err error) {
                        return
                }
 
-               if err == nil {
-                       t.remainder = t.inBuf[:n+len(t.remainder)]
-                       rest := t.remainder
-                       lineOk := false
-                       for !lineOk {
-                               var key int
-                               key, rest = bytesToKey(rest)
-                               if key < 0 {
-                                       break
-                               }
-                               if key == keyCtrlD {
-                                       return "", io.EOF
-                               }
-                               line, lineOk = t.handleKey(key)
-                       }
-                       if len(rest) > 0 {
-                               n := copy(t.inBuf[:], rest)
-                               t.remainder = t.inBuf[:n]
-                       } else {
-                               t.remainder = nil
-                       }
-                       t.c.Write(t.outBuf)
-                       t.outBuf = t.outBuf[:0]
-                       if lineOk {
-                               return
-                       }
-                       continue
-               }
+               t.remainder = t.inBuf[:n+len(t.remainder)]
        }
        panic("unreachable")
 }
 
+// SetPrompt sets the prompt to be used when reading subsequent lines.
+func (t *Terminal) SetPrompt(prompt string) {
+       t.lock.Lock()
+       defer t.lock.Unlock()
+
+       t.prompt = prompt
+}
+
 func (t *Terminal) SetSize(width, height int) {
        t.lock.Lock()
        defer t.lock.Unlock()
index 75628f695e989e41fe5f6e8d48b24955f975dcfe..a2197210e2a8d4da5720b2c797afccd21156011d 100644 (file)
@@ -2,8 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// +build linux
-
 package terminal
 
 import (
index 35535ea406f0c40c566bbd04c63ccfdb6fd6f43d..ea9218ff51d67fbbe27f902de61b844cc0669a28 100644 (file)
@@ -111,7 +111,7 @@ func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) m
                // set otherwise the position information returned here will
                // not match the position information collected by the parser
                s.Init(getFile(filename), src, nil, scanner.ScanComments)
-               var prev token.Pos // position of last non-comment token
+               var prev token.Pos // position of last non-comment, non-semicolon token
 
        scanFile:
                for {
@@ -124,6 +124,12 @@ func expectedErrors(t *testing.T, testname string, files map[string]*ast.File) m
                                if len(s) == 2 {
                                        errors[prev] = string(s[1])
                                }
+                       case token.SEMICOLON:
+                               // ignore automatically inserted semicolon
+                               if lit == "\n" {
+                                       break
+                               }
+                               fallthrough
                        default:
                                prev = pos
                        }
index 780b82625f5f8208296f526074e7273a04c83858..46cff31bce84e8e833483d5e7cbbd96d61bf9c20 100644 (file)
@@ -20,6 +20,7 @@ func define(kind ast.ObjKind, name string) *ast.Object {
        if scope.Insert(obj) != nil {
                panic("types internal error: double declaration")
        }
+       obj.Decl = scope
        return obj
 }
 
index 406ea77799d4358f47ac5183a85073b95a78d50f..964f5541b86c2f53d76c36aee21d3e355562b7ee 100644 (file)
@@ -65,12 +65,13 @@ import (
        "os"
        "sort"
        "strconv"
+       "time"
 )
 
 // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined.
 var ErrHelp = errors.New("flag: help requested")
 
-// -- Bool Value
+// -- bool Value
 type boolValue bool
 
 func newBoolValue(val bool, p *bool) *boolValue {
@@ -78,15 +79,15 @@ func newBoolValue(val bool, p *bool) *boolValue {
        return (*boolValue)(p)
 }
 
-func (b *boolValue) Set(s string) bool {
+func (b *boolValue) Set(s string) error {
        v, err := strconv.ParseBool(s)
        *b = boolValue(v)
-       return err == nil
+       return err
 }
 
 func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) }
 
-// -- Int Value
+// -- int Value
 type intValue int
 
 func newIntValue(val int, p *int) *intValue {
@@ -94,15 +95,15 @@ func newIntValue(val int, p *int) *intValue {
        return (*intValue)(p)
 }
 
-func (i *intValue) Set(s string) bool {
+func (i *intValue) Set(s string) error {
        v, err := strconv.ParseInt(s, 0, 64)
        *i = intValue(v)
-       return err == nil
+       return err
 }
 
 func (i *intValue) String() string { return fmt.Sprintf("%v", *i) }
 
-// -- Int64 Value
+// -- int64 Value
 type int64Value int64
 
 func newInt64Value(val int64, p *int64) *int64Value {
@@ -110,15 +111,15 @@ func newInt64Value(val int64, p *int64) *int64Value {
        return (*int64Value)(p)
 }
 
-func (i *int64Value) Set(s string) bool {
+func (i *int64Value) Set(s string) error {
        v, err := strconv.ParseInt(s, 0, 64)
        *i = int64Value(v)
-       return err == nil
+       return err
 }
 
 func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) }
 
-// -- Uint Value
+// -- uint Value
 type uintValue uint
 
 func newUintValue(val uint, p *uint) *uintValue {
@@ -126,10 +127,10 @@ func newUintValue(val uint, p *uint) *uintValue {
        return (*uintValue)(p)
 }
 
-func (i *uintValue) Set(s string) bool {
+func (i *uintValue) Set(s string) error {
        v, err := strconv.ParseUint(s, 0, 64)
        *i = uintValue(v)
-       return err == nil
+       return err
 }
 
 func (i *uintValue) String() string { return fmt.Sprintf("%v", *i) }
@@ -142,10 +143,10 @@ func newUint64Value(val uint64, p *uint64) *uint64Value {
        return (*uint64Value)(p)
 }
 
-func (i *uint64Value) Set(s string) bool {
+func (i *uint64Value) Set(s string) error {
        v, err := strconv.ParseUint(s, 0, 64)
        *i = uint64Value(v)
-       return err == nil
+       return err
 }
 
 func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i) }
@@ -158,14 +159,14 @@ func newStringValue(val string, p *string) *stringValue {
        return (*stringValue)(p)
 }
 
-func (s *stringValue) Set(val string) bool {
+func (s *stringValue) Set(val string) error {
        *s = stringValue(val)
-       return true
+       return nil
 }
 
 func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) }
 
-// -- Float64 Value
+// -- float64 Value
 type float64Value float64
 
 func newFloat64Value(val float64, p *float64) *float64Value {
@@ -173,19 +174,35 @@ func newFloat64Value(val float64, p *float64) *float64Value {
        return (*float64Value)(p)
 }
 
-func (f *float64Value) Set(s string) bool {
+func (f *float64Value) Set(s string) error {
        v, err := strconv.ParseFloat(s, 64)
        *f = float64Value(v)
-       return err == nil
+       return err
 }
 
 func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) }
 
+// -- time.Duration Value
+type durationValue time.Duration
+
+func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
+       *p = val
+       return (*durationValue)(p)
+}
+
+func (d *durationValue) Set(s string) error {
+       v, err := time.ParseDuration(s)
+       *d = durationValue(v)
+       return err
+}
+
+func (d *durationValue) String() string { return (*time.Duration)(d).String() }
+
 // Value is the interface to the dynamic value stored in a flag.
 // (The default value is represented as a string.)
 type Value interface {
        String() string
-       Set(string) bool
+       Set(string) error
 }
 
 // ErrorHandling defines how to handle flag parsing errors.
@@ -276,27 +293,25 @@ func Lookup(name string) *Flag {
        return commandLine.formal[name]
 }
 
-// Set sets the value of the named flag.  It returns true if the set succeeded; false if
-// there is no such flag defined.
-func (f *FlagSet) Set(name, value string) bool {
+// Set sets the value of the named flag.
+func (f *FlagSet) Set(name, value string) error {
        flag, ok := f.formal[name]
        if !ok {
-               return false
+               return fmt.Errorf("no such flag -%v", name)
        }
-       ok = flag.Value.Set(value)
-       if !ok {
-               return false
+       err := flag.Value.Set(value)
+       if err != nil {
+               return err
        }
        if f.actual == nil {
                f.actual = make(map[string]*Flag)
        }
        f.actual[name] = flag
-       return true
+       return nil
 }
 
-// Set sets the value of the named command-line flag. It returns true if the
-// set succeeded; false if there is no such flag defined.
-func Set(name, value string) bool {
+// Set sets the value of the named command-line flag.
+func Set(name, value string) error {
        return commandLine.Set(name, value)
 }
 
@@ -543,12 +558,38 @@ func (f *FlagSet) Float64(name string, value float64, usage string) *float64 {
        return p
 }
 
-// Float64 defines an int flag with specified name, default value, and usage string.
+// Float64 defines a float64 flag with specified name, default value, and usage string.
 // The return value is the address of a float64 variable that stores the value of the flag.
 func Float64(name string, value float64, usage string) *float64 {
        return commandLine.Float64(name, value, usage)
 }
 
+// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
+// The argument p points to a time.Duration variable in which to store the value of the flag.
+func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
+       f.Var(newDurationValue(value, p), name, usage)
+}
+
+// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
+// The argument p points to a time.Duration variable in which to store the value of the flag.
+func DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
+       commandLine.Var(newDurationValue(value, p), name, usage)
+}
+
+// Duration defines a time.Duration flag with specified name, default value, and usage string.
+// The return value is the address of a time.Duration variable that stores the value of the flag.
+func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration {
+       p := new(time.Duration)
+       f.DurationVar(p, name, value, usage)
+       return p
+}
+
+// Duration defines a time.Duration flag with specified name, default value, and usage string.
+// The return value is the address of a time.Duration variable that stores the value of the flag.
+func Duration(name string, value time.Duration, usage string) *time.Duration {
+       return commandLine.Duration(name, value, usage)
+}
+
 // Var defines a flag with the specified name and usage string. The type and
 // value of the flag are represented by the first argument, of type Value, which
 // typically holds a user-defined implementation of Value. For instance, the
@@ -645,8 +686,8 @@ func (f *FlagSet) parseOne() (bool, error) {
        }
        if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
                if has_value {
-                       if !fv.Set(value) {
-                               f.failf("invalid boolean value %q for flag: -%s", value, name)
+                       if err := fv.Set(value); err != nil {
+                               f.failf("invalid boolean value %q for  -%s: %v", value, name, err)
                        }
                } else {
                        fv.Set("true")
@@ -661,9 +702,8 @@ func (f *FlagSet) parseOne() (bool, error) {
                if !has_value {
                        return false, f.failf("flag needs an argument: -%s", name)
                }
-               ok = flag.Value.Set(value)
-               if !ok {
-                       return false, f.failf("invalid value %q for flag: -%s", value, name)
+               if err := flag.Value.Set(value); err != nil {
+                       return false, f.failf("invalid value %q for flag -%s: %v", value, name, err)
                }
        }
        if f.actual == nil {
index f13531669c1914dcd9ae9027c042ac01e4b76df2..698c15f2c58f681c34befbf4481411af0211a8a7 100644 (file)
@@ -10,16 +10,18 @@ import (
        "os"
        "sort"
        "testing"
+       "time"
 )
 
 var (
-       test_bool    = Bool("test_bool", false, "bool value")
-       test_int     = Int("test_int", 0, "int value")
-       test_int64   = Int64("test_int64", 0, "int64 value")
-       test_uint    = Uint("test_uint", 0, "uint value")
-       test_uint64  = Uint64("test_uint64", 0, "uint64 value")
-       test_string  = String("test_string", "0", "string value")
-       test_float64 = Float64("test_float64", 0, "float64 value")
+       test_bool     = Bool("test_bool", false, "bool value")
+       test_int      = Int("test_int", 0, "int value")
+       test_int64    = Int64("test_int64", 0, "int64 value")
+       test_uint     = Uint("test_uint", 0, "uint value")
+       test_uint64   = Uint64("test_uint64", 0, "uint64 value")
+       test_string   = String("test_string", "0", "string value")
+       test_float64  = Float64("test_float64", 0, "float64 value")
+       test_duration = Duration("test_duration", 0, "time.Duration value")
 )
 
 func boolString(s string) string {
@@ -41,6 +43,8 @@ func TestEverything(t *testing.T) {
                                ok = true
                        case f.Name == "test_bool" && f.Value.String() == boolString(desired):
                                ok = true
+                       case f.Name == "test_duration" && f.Value.String() == desired+"s":
+                               ok = true
                        }
                        if !ok {
                                t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
@@ -48,7 +52,7 @@ func TestEverything(t *testing.T) {
                }
        }
        VisitAll(visitor)
-       if len(m) != 7 {
+       if len(m) != 8 {
                t.Error("VisitAll misses some flags")
                for k, v := range m {
                        t.Log(k, *v)
@@ -70,9 +74,10 @@ func TestEverything(t *testing.T) {
        Set("test_uint64", "1")
        Set("test_string", "1")
        Set("test_float64", "1")
+       Set("test_duration", "1s")
        desired = "1"
        Visit(visitor)
-       if len(m) != 7 {
+       if len(m) != 8 {
                t.Error("Visit fails after set")
                for k, v := range m {
                        t.Log(k, *v)
@@ -109,6 +114,7 @@ func testParse(f *FlagSet, t *testing.T) {
        uint64Flag := f.Uint64("uint64", 0, "uint64 value")
        stringFlag := f.String("string", "0", "string value")
        float64Flag := f.Float64("float64", 0, "float64 value")
+       durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value")
        extra := "one-extra-argument"
        args := []string{
                "-bool",
@@ -119,6 +125,7 @@ func testParse(f *FlagSet, t *testing.T) {
                "--uint64", "25",
                "-string", "hello",
                "-float64", "2718e28",
+               "-duration", "2m",
                extra,
        }
        if err := f.Parse(args); err != nil {
@@ -151,6 +158,9 @@ func testParse(f *FlagSet, t *testing.T) {
        if *float64Flag != 2718e28 {
                t.Error("float64 flag should be 2718e28, is ", *float64Flag)
        }
+       if *durationFlag != 2*time.Minute {
+               t.Error("duration flag should be 2m, is ", *durationFlag)
+       }
        if len(f.Args()) != 1 {
                t.Error("expected one argument, got", len(f.Args()))
        } else if f.Args()[0] != extra {
@@ -174,9 +184,9 @@ func (f *flagVar) String() string {
        return fmt.Sprint([]string(*f))
 }
 
-func (f *flagVar) Set(value string) bool {
+func (f *flagVar) Set(value string) error {
        *f = append(*f, value)
-       return true
+       return nil
 }
 
 func TestUserDefined(t *testing.T) {
index 11e9f19f8995be1e5e119bf41107b312c9fad4b4..7d4178da768bc91609e367b77c227e251933bc06 100644 (file)
@@ -30,8 +30,9 @@
                %X      base 16, with upper-case letters for A-F
                %U      Unicode format: U+1234; same as "U+%04X"
        Floating-point and complex constituents:
-               %b      decimalless scientific notation with exponent a power
-                       of two, in the manner of strconv.Ftoa32, e.g. -123456p-78
+               %b      decimalless scientific notation with exponent a power of two, 
+                       in the manner of strconv.FormatFloat with the 'b' format, 
+                       e.g. -123456p-78
                %e      scientific notation, e.g. -1234.456e+78
                %E      scientific notation, e.g. -1234.456E+78
                %f      decimal point but no exponent, e.g. 123.456
index d34a4f8fd2d368c2ce81f4497d26ee69be6c4d45..beb410fa117fd94172a8160e0741eaacb7e64cbc 100644 (file)
@@ -517,7 +517,7 @@ var mallocTest = []struct {
        {1, `Sprintf("xxx")`, func() { Sprintf("xxx") }},
        {1, `Sprintf("%x")`, func() { Sprintf("%x", 7) }},
        {2, `Sprintf("%s")`, func() { Sprintf("%s", "hello") }},
-       {1, `Sprintf("%x %x")`, func() { Sprintf("%x", 7, 112) }},
+       {1, `Sprintf("%x %x")`, func() { Sprintf("%x %x", 7, 112) }},
        {1, `Sprintf("%g")`, func() { Sprintf("%g", 3.14159) }},
        {0, `Fprintf(buf, "%x %x %x")`, func() { mallocBuf.Reset(); Fprintf(&mallocBuf, "%x %x %x", 7, 8, 9) }},
        {1, `Fprintf(buf, "%s")`, func() { mallocBuf.Reset(); Fprintf(&mallocBuf, "%s", "hello") }},
index 1485f351c07e115954f9f127b3868822242519ba..7123fe58f50b40aefa16b4571bb1d809b1d158b1 100644 (file)
@@ -9,6 +9,7 @@ package ast
 
 import (
        "go/token"
+       "strings"
        "unicode"
        "unicode/utf8"
 )
@@ -76,6 +77,74 @@ type CommentGroup struct {
 func (g *CommentGroup) Pos() token.Pos { return g.List[0].Pos() }
 func (g *CommentGroup) End() token.Pos { return g.List[len(g.List)-1].End() }
 
+func isWhitespace(ch byte) bool { return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' }
+
+func stripTrailingWhitespace(s string) string {
+       i := len(s)
+       for i > 0 && isWhitespace(s[i-1]) {
+               i--
+       }
+       return s[0:i]
+}
+
+// Text returns the text of the comment,
+// with the comment markers - //, /*, and */ - removed.
+func (g *CommentGroup) Text() string {
+       if g == nil {
+               return ""
+       }
+       comments := make([]string, len(g.List))
+       for i, c := range g.List {
+               comments[i] = string(c.Text)
+       }
+
+       lines := make([]string, 0, 10) // most comments are less than 10 lines
+       for _, c := range comments {
+               // Remove comment markers.
+               // The parser has given us exactly the comment text.
+               switch c[1] {
+               case '/':
+                       //-style comment
+                       c = c[2:]
+                       // Remove leading space after //, if there is one.
+                       // TODO(gri) This appears to be necessary in isolated
+                       //           cases (bignum.RatFromString) - why?
+                       if len(c) > 0 && c[0] == ' ' {
+                               c = c[1:]
+                       }
+               case '*':
+                       /*-style comment */
+                       c = c[2 : len(c)-2]
+               }
+
+               // Split on newlines.
+               cl := strings.Split(c, "\n")
+
+               // Walk lines, stripping trailing white space and adding to list.
+               for _, l := range cl {
+                       lines = append(lines, stripTrailingWhitespace(l))
+               }
+       }
+
+       // Remove leading blank lines; convert runs of
+       // interior blank lines to a single blank line.
+       n := 0
+       for _, line := range lines {
+               if line != "" || n > 0 && lines[n-1] != "" {
+                       lines[n] = line
+                       n++
+               }
+       }
+       lines = lines[0:n]
+
+       // Add final "" entry to get trailing newline from Join.
+       if n > 0 && lines[n-1] != "" {
+               lines = append(lines, "")
+       }
+
+       return strings.Join(lines, "\n")
+}
+
 // ----------------------------------------------------------------------------
 // Expressions and types
 
index bec235e2f98480c54dc2bf9ae9688b132eb9245e..4a89b89096a04dece08010a39929616f699867ba 100644 (file)
@@ -4,7 +4,10 @@
 
 package ast
 
-import "go/token"
+import (
+       "go/token"
+       "sort"
+)
 
 // ----------------------------------------------------------------------------
 // Export filtering
@@ -20,7 +23,7 @@ func exportFilter(name string) bool {
 // body) are removed. Non-exported fields and methods of exported types are
 // stripped. The File.Comments list is not changed.
 //
-// FileExports returns true if there are exported declarationa;
+// FileExports returns true if there are exported declarations;
 // it returns false otherwise.
 //
 func FileExports(src *File) bool {
@@ -291,29 +294,35 @@ var separator = &Comment{noPos, "//"}
 //
 func MergePackageFiles(pkg *Package, mode MergeMode) *File {
        // Count the number of package docs, comments and declarations across
-       // all package files.
+       // all package files. Also, compute sorted list of filenames, so that
+       // subsequent iterations can always iterate in the same order.
        ndocs := 0
        ncomments := 0
        ndecls := 0
-       for _, f := range pkg.Files {
+       filenames := make([]string, len(pkg.Files))
+       i := 0
+       for filename, f := range pkg.Files {
+               filenames[i] = filename
+               i++
                if f.Doc != nil {
                        ndocs += len(f.Doc.List) + 1 // +1 for separator
                }
                ncomments += len(f.Comments)
                ndecls += len(f.Decls)
        }
+       sort.Strings(filenames)
 
        // Collect package comments from all package files into a single
-       // CommentGroup - the collected package documentation. The order
-       // is unspecified. In general there should be only one file with
-       // a package comment; but it's better to collect extra comments
-       // than drop them on the floor.
+       // CommentGroup - the collected package documentation. In general
+       // there should be only one file with a package comment; but it's
+       // better to collect extra comments than drop them on the floor.
        var doc *CommentGroup
        var pos token.Pos
        if ndocs > 0 {
                list := make([]*Comment, ndocs-1) // -1: no separator before first group
                i := 0
-               for _, f := range pkg.Files {
+               for _, filename := range filenames {
+                       f := pkg.Files[filename]
                        if f.Doc != nil {
                                if i > 0 {
                                        // not the first group - add separator
@@ -342,7 +351,8 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
                funcs := make(map[string]int) // map of global function name -> decls index
                i := 0                        // current index
                n := 0                        // number of filtered entries
-               for _, f := range pkg.Files {
+               for _, filename := range filenames {
+                       f := pkg.Files[filename]
                        for _, d := range f.Decls {
                                if mode&FilterFuncDuplicates != 0 {
                                        // A language entity may be declared multiple
@@ -398,7 +408,8 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
        var imports []*ImportSpec
        if mode&FilterImportDuplicates != 0 {
                seen := make(map[string]bool)
-               for _, f := range pkg.Files {
+               for _, filename := range filenames {
+                       f := pkg.Files[filename]
                        for _, imp := range f.Imports {
                                if path := imp.Path.Value; !seen[path] {
                                        // TODO: consider handling cases where:
index fb3068e1e937af928cb9e704b100acb46096f801..f6c63c0d8895611b1713ca1c306682ea635dfde1 100644 (file)
@@ -36,7 +36,7 @@ func NotNilFilter(_ string, v reflect.Value) bool {
 // struct fields for which f(fieldname, fieldvalue) is true are
 // are printed; all others are filtered from the output.
 //
-func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n int, err error) {
+func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (err error) {
        // setup printer
        p := printer{
                output: w,
@@ -48,7 +48,6 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i
 
        // install error handler
        defer func() {
-               n = p.written
                if e := recover(); e != nil {
                        err = e.(localError).err // re-panics if it's not a localError
                }
@@ -67,19 +66,18 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i
 
 // Print prints x to standard output, skipping nil fields.
 // Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter).
-func Print(fset *token.FileSet, x interface{}) (int, error) {
+func Print(fset *token.FileSet, x interface{}) error {
        return Fprint(os.Stdout, fset, x, NotNilFilter)
 }
 
 type printer struct {
-       output  io.Writer
-       fset    *token.FileSet
-       filter  FieldFilter
-       ptrmap  map[interface{}]int // *T -> line number
-       written int                 // number of bytes written to output
-       indent  int                 // current indentation level
-       last    byte                // the last byte processed by Write
-       line    int                 // current line number
+       output io.Writer
+       fset   *token.FileSet
+       filter FieldFilter
+       ptrmap map[interface{}]int // *T -> line number
+       indent int                 // current indentation level
+       last   byte                // the last byte processed by Write
+       line   int                 // current line number
 }
 
 var indent = []byte(".  ")
@@ -122,9 +120,7 @@ type localError struct {
 
 // printf is a convenience wrapper that takes care of print errors.
 func (p *printer) printf(format string, args ...interface{}) {
-       n, err := fmt.Fprintf(p, format, args...)
-       p.written += n
-       if err != nil {
+       if _, err := fmt.Fprintf(p, format, args...); err != nil {
                panic(localError{err})
        }
 }
index 89d5af1541e23c7906a92050a329bab5f4302d13..71c028e753787e4b322c8ab629cb04b115677623 100644 (file)
@@ -66,7 +66,7 @@ func TestPrint(t *testing.T) {
        var buf bytes.Buffer
        for _, test := range tests {
                buf.Reset()
-               if _, err := Fprint(&buf, nil, test.x, nil); err != nil {
+               if err := Fprint(&buf, nil, test.x, nil); err != nil {
                        t.Errorf("Fprint failed: %s", err)
                }
                if s, ts := trim(buf.String()), trim(test.s); s != ts {
index fbe4779671ec8fc67b857d89eae1c6e6cb01422a..11e6b13f169b807095b0ee40abe76a7b296c66d5 100644 (file)
@@ -80,7 +80,7 @@ func (s *Scope) String() string {
 type Object struct {
        Kind ObjKind
        Name string      // declared name
-       Decl interface{} // corresponding Field, XxxSpec, FuncDecl, LabeledStmt, or AssignStmt; or nil
+       Decl interface{} // corresponding Field, XxxSpec, FuncDecl, LabeledStmt, AssignStmt, Scope; or nil
        Data interface{} // object-specific data; or nil
        Type interface{} // place holder for type information; may be nil
 }
@@ -131,6 +131,8 @@ func (obj *Object) Pos() token.Pos {
                                return ident.Pos()
                        }
                }
+       case *Scope:
+               // predeclared object - nothing to do for now
        }
        return token.NoPos
 }
index 5301ab53e519d43b0bcce89ba652db9f899b6ceb..9515a7e645224f41d032ad227e0ff4e4cf89ad1f 100644 (file)
@@ -396,8 +396,7 @@ func (b *build) cgo(cgofiles, cgocfiles []string) (outGo, outObj []string) {
                Output: output,
        })
        outGo = append(outGo, gofiles...)
-       exportH := filepath.Join(b.path, "_cgo_export.h")
-       b.script.addIntermediate(defunC, exportH, b.obj+"_cgo_flags")
+       b.script.addIntermediate(defunC, b.obj+"_cgo_export.h", b.obj+"_cgo_flags")
        b.script.addIntermediate(cfiles...)
 
        // cc _cgo_defun.c
index 265261f22eaaac2c71c7f43290a343a7b765f423..5ce75fda7e08d38e5f9afdd2359484bef8cfea7d 100644 (file)
@@ -9,7 +9,6 @@ import (
        "errors"
        "fmt"
        "go/ast"
-       "go/doc"
        "go/parser"
        "go/token"
        "io/ioutil"
@@ -412,7 +411,7 @@ func (ctxt *Context) shouldBuild(content []byte) bool {
 // TODO(rsc): This duplicates code in cgo.
 // Once the dust settles, remove this code from cgo.
 func (ctxt *Context) saveCgo(filename string, di *DirInfo, cg *ast.CommentGroup) error {
-       text := doc.CommentText(cg)
+       text := cg.Text()
        for _, line := range strings.Split(text, "\n") {
                orig := line
 
@@ -476,7 +475,7 @@ func (ctxt *Context) saveCgo(filename string, di *DirInfo, cg *ast.CommentGroup)
        return nil
 }
 
-var safeBytes = []byte("+-.,/0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz")
+var safeBytes = []byte("+-.,/0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz:")
 
 func safeName(s string) bool {
        if s == "" {
index 7a281800c2828afe57bdbe3cde1286b49f804d30..bb9b8ca642a123f8a63b2c4d19ada8760d43ddf1 100644 (file)
@@ -157,6 +157,7 @@ func init() {
                Path = []*Tree{t}
        }
 
+Loop:
        for _, p := range filepath.SplitList(os.Getenv("GOPATH")) {
                if p == "" {
                        continue
@@ -166,6 +167,21 @@ func init() {
                        log.Printf("invalid GOPATH %q: %v", p, err)
                        continue
                }
+
+               // Check for dupes.
+               // TODO(alexbrainman): make this correct under windows (case insensitive).
+               for _, t2 := range Path {
+                       if t2.Path != t.Path {
+                               continue
+                       }
+                       if t2.Goroot {
+                               log.Printf("GOPATH is the same as GOROOT: %q", t.Path)
+                       } else {
+                               log.Printf("duplicate GOPATH entry: %q", t.Path)
+                       }
+                       continue Loop
+               }
+
                Path = append(Path, t)
                gcImportArgs = append(gcImportArgs, "-I", t.PkgDir())
                ldImportArgs = append(ldImportArgs, "-L", t.PkgDir())
index 39f34afa10c6cc5b5531e062f2f7ad24240c2e90..060e37bff148d21b2c9c4ab313c4276ee3b80866 100644 (file)
@@ -7,7 +7,6 @@
 package doc
 
 import (
-       "go/ast"
        "io"
        "regexp"
        "strings"
@@ -16,74 +15,6 @@ import (
        "unicode/utf8"
 )
 
-func isWhitespace(ch byte) bool { return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' }
-
-func stripTrailingWhitespace(s string) string {
-       i := len(s)
-       for i > 0 && isWhitespace(s[i-1]) {
-               i--
-       }
-       return s[0:i]
-}
-
-// CommentText returns the text of comment,
-// with the comment markers - //, /*, and */ - removed.
-func CommentText(comment *ast.CommentGroup) string {
-       if comment == nil {
-               return ""
-       }
-       comments := make([]string, len(comment.List))
-       for i, c := range comment.List {
-               comments[i] = string(c.Text)
-       }
-
-       lines := make([]string, 0, 10) // most comments are less than 10 lines
-       for _, c := range comments {
-               // Remove comment markers.
-               // The parser has given us exactly the comment text.
-               switch c[1] {
-               case '/':
-                       //-style comment
-                       c = c[2:]
-                       // Remove leading space after //, if there is one.
-                       // TODO(gri) This appears to be necessary in isolated
-                       //           cases (bignum.RatFromString) - why?
-                       if len(c) > 0 && c[0] == ' ' {
-                               c = c[1:]
-                       }
-               case '*':
-                       /*-style comment */
-                       c = c[2 : len(c)-2]
-               }
-
-               // Split on newlines.
-               cl := strings.Split(c, "\n")
-
-               // Walk lines, stripping trailing white space and adding to list.
-               for _, l := range cl {
-                       lines = append(lines, stripTrailingWhitespace(l))
-               }
-       }
-
-       // Remove leading blank lines; convert runs of
-       // interior blank lines to a single blank line.
-       n := 0
-       for _, line := range lines {
-               if line != "" || n > 0 && lines[n-1] != "" {
-                       lines[n] = line
-                       n++
-               }
-       }
-       lines = lines[0:n]
-
-       // Add final "" entry to get trailing newline from Join.
-       if n > 0 && lines[n-1] != "" {
-               lines = append(lines, "")
-       }
-
-       return strings.Join(lines, "\n")
-}
-
 var (
        ldquo = []byte("&ldquo;")
        rdquo = []byte("&rdquo;")
@@ -422,12 +353,10 @@ func ToText(w io.Writer, text string, indent, preIndent string, width int) {
                width:  width,
                indent: indent,
        }
-       for i, b := range blocks(text) {
+       for _, b := range blocks(text) {
                switch b.op {
                case opPara:
-                       if i > 0 {
-                               w.Write(nl)
-                       }
+                       // l.write will add leading newline if required
                        for _, line := range b.lines {
                                l.write(line)
                        }
index 1bb22416c78ad9f81a5103a50cf145c83cc94a99..66e2937aeb0405107fcde6e190dc335353d9ae13 100644 (file)
@@ -7,673 +7,94 @@ package doc
 
 import (
        "go/ast"
-       "go/token"
-       "regexp"
        "sort"
 )
 
-// ----------------------------------------------------------------------------
-// Collection of documentation info
-
-// embeddedType describes the type of an anonymous field.
-//
-type embeddedType struct {
-       typ *typeInfo // the corresponding base type
-       ptr bool      // if set, the anonymous field type is a pointer
-}
-
-type typeInfo struct {
-       // len(decl.Specs) == 1, and the element type is *ast.TypeSpec
-       // if the type declaration hasn't been seen yet, decl is nil
-       decl     *ast.GenDecl
-       embedded []embeddedType
-       forward  *TypeDoc // forward link to processed type documentation
-
-       // declarations associated with the type
-       values    []*ast.GenDecl // consts and vars
-       factories map[string]*ast.FuncDecl
-       methods   map[string]*ast.FuncDecl
-}
-
-func (info *typeInfo) addEmbeddedType(embedded *typeInfo, isPtr bool) {
-       info.embedded = append(info.embedded, embeddedType{embedded, isPtr})
-}
-
-// docReader accumulates documentation for a single package.
-// It modifies the AST: Comments (declaration documentation)
-// that have been collected by the DocReader are set to nil
-// in the respective AST nodes so that they are not printed
-// twice (once when printing the documentation and once when
-// printing the corresponding AST node).
-//
-type docReader struct {
-       doc      *ast.CommentGroup // package documentation, if any
-       pkgName  string
-       values   []*ast.GenDecl // consts and vars
-       types    map[string]*typeInfo
-       embedded map[string]*typeInfo // embedded types, possibly not exported
-       funcs    map[string]*ast.FuncDecl
-       bugs     []*ast.CommentGroup
-}
-
-func (doc *docReader) init(pkgName string) {
-       doc.pkgName = pkgName
-       doc.types = make(map[string]*typeInfo)
-       doc.embedded = make(map[string]*typeInfo)
-       doc.funcs = make(map[string]*ast.FuncDecl)
-}
-
-func (doc *docReader) addDoc(comments *ast.CommentGroup) {
-       if doc.doc == nil {
-               // common case: just one package comment
-               doc.doc = comments
-               return
-       }
-       // More than one package comment: Usually there will be only
-       // one file with a package comment, but it's better to collect
-       // all comments than drop them on the floor.
-       blankComment := &ast.Comment{token.NoPos, "//"}
-       list := append(doc.doc.List, blankComment)
-       doc.doc.List = append(list, comments.List...)
-}
-
-func (doc *docReader) lookupTypeInfo(name string) *typeInfo {
-       if name == "" || name == "_" {
-               return nil // no type docs for anonymous types
-       }
-       if info, found := doc.types[name]; found {
-               return info
-       }
-       // type wasn't found - add one without declaration
-       info := &typeInfo{
-               factories: make(map[string]*ast.FuncDecl),
-               methods:   make(map[string]*ast.FuncDecl),
-       }
-       doc.types[name] = info
-       return info
-}
-
-func baseTypeName(typ ast.Expr, allTypes bool) string {
-       switch t := typ.(type) {
-       case *ast.Ident:
-               // if the type is not exported, the effect to
-               // a client is as if there were no type name
-               if t.IsExported() || allTypes {
-                       return t.Name
-               }
-       case *ast.StarExpr:
-               return baseTypeName(t.X, allTypes)
-       }
-       return ""
-}
-
-func (doc *docReader) addValue(decl *ast.GenDecl) {
-       // determine if decl should be associated with a type
-       // Heuristic: For each typed entry, determine the type name, if any.
-       //            If there is exactly one type name that is sufficiently
-       //            frequent, associate the decl with the respective type.
-       domName := ""
-       domFreq := 0
-       prev := ""
-       for _, s := range decl.Specs {
-               if v, ok := s.(*ast.ValueSpec); ok {
-                       name := ""
-                       switch {
-                       case v.Type != nil:
-                               // a type is present; determine its name
-                               name = baseTypeName(v.Type, false)
-                       case decl.Tok == token.CONST:
-                               // no type is present but we have a constant declaration;
-                               // use the previous type name (w/o more type information
-                               // we cannot handle the case of unnamed variables with
-                               // initializer expressions except for some trivial cases)
-                               name = prev
-                       }
-                       if name != "" {
-                               // entry has a named type
-                               if domName != "" && domName != name {
-                                       // more than one type name - do not associate
-                                       // with any type
-                                       domName = ""
-                                       break
-                               }
-                               domName = name
-                               domFreq++
-                       }
-                       prev = name
-               }
-       }
-
-       // determine values list
-       const threshold = 0.75
-       values := &doc.values
-       if domName != "" && domFreq >= int(float64(len(decl.Specs))*threshold) {
-               // typed entries are sufficiently frequent
-               typ := doc.lookupTypeInfo(domName)
-               if typ != nil {
-                       values = &typ.values // associate with that type
-               }
-       }
-
-       *values = append(*values, decl)
-}
-
-// Helper function to set the table entry for function f. Makes sure that
-// at least one f with associated documentation is stored in table, if there
-// are multiple f's with the same name.
-func setFunc(table map[string]*ast.FuncDecl, f *ast.FuncDecl) {
-       name := f.Name.Name
-       if g, exists := table[name]; exists && g.Doc != nil {
-               // a function with the same name has already been registered;
-               // since it has documentation, assume f is simply another
-               // implementation and ignore it
-               // TODO(gri) consider collecting all functions, or at least
-               //           all comments
-               return
-       }
-       // function doesn't exist or has no documentation; use f
-       table[name] = f
-}
-
-func (doc *docReader) addFunc(fun *ast.FuncDecl) {
-       // strip function body
-       fun.Body = nil
-
-       // determine if it should be associated with a type
-       if fun.Recv != nil {
-               // method
-               typ := doc.lookupTypeInfo(baseTypeName(fun.Recv.List[0].Type, false))
-               if typ != nil {
-                       // exported receiver type
-                       setFunc(typ.methods, fun)
-               }
-               // otherwise don't show the method
-               // TODO(gri): There may be exported methods of non-exported types
-               // that can be called because of exported values (consts, vars, or
-               // function results) of that type. Could determine if that is the
-               // case and then show those methods in an appropriate section.
-               return
-       }
-
-       // perhaps a factory function
-       // determine result type, if any
-       if fun.Type.Results.NumFields() >= 1 {
-               res := fun.Type.Results.List[0]
-               if len(res.Names) <= 1 {
-                       // exactly one (named or anonymous) result associated
-                       // with the first type in result signature (there may
-                       // be more than one result)
-                       tname := baseTypeName(res.Type, false)
-                       typ := doc.lookupTypeInfo(tname)
-                       if typ != nil {
-                               // named and exported result type
-                               setFunc(typ.factories, fun)
-                               return
-                       }
-               }
-       }
-
-       // ordinary function
-       setFunc(doc.funcs, fun)
-}
-
-func (doc *docReader) addDecl(decl ast.Decl) {
-       switch d := decl.(type) {
-       case *ast.GenDecl:
-               if len(d.Specs) > 0 {
-                       switch d.Tok {
-                       case token.CONST, token.VAR:
-                               // constants and variables are always handled as a group
-                               doc.addValue(d)
-                       case token.TYPE:
-                               // types are handled individually
-                               for _, spec := range d.Specs {
-                                       tspec := spec.(*ast.TypeSpec)
-                                       // add the type to the documentation
-                                       info := doc.lookupTypeInfo(tspec.Name.Name)
-                                       if info == nil {
-                                               continue // no name - ignore the type
-                                       }
-                                       // Make a (fake) GenDecl node for this TypeSpec
-                                       // (we need to do this here - as opposed to just
-                                       // for printing - so we don't lose the GenDecl
-                                       // documentation). Since a new GenDecl node is
-                                       // created, there's no need to nil out d.Doc.
-                                       //
-                                       // TODO(gri): Consider just collecting the TypeSpec
-                                       // node (and copy in the GenDecl.doc if there is no
-                                       // doc in the TypeSpec - this is currently done in
-                                       // makeTypeDocs below). Simpler data structures, but
-                                       // would lose GenDecl documentation if the TypeSpec
-                                       // has documentation as well.
-                                       fake := &ast.GenDecl{d.Doc, d.Pos(), token.TYPE, token.NoPos,
-                                               []ast.Spec{tspec}, token.NoPos}
-                                       // A type should be added at most once, so info.decl
-                                       // should be nil - if it isn't, simply overwrite it.
-                                       info.decl = fake
-                                       // Look for anonymous fields that might contribute methods.
-                                       var fields *ast.FieldList
-                                       switch typ := spec.(*ast.TypeSpec).Type.(type) {
-                                       case *ast.StructType:
-                                               fields = typ.Fields
-                                       case *ast.InterfaceType:
-                                               fields = typ.Methods
-                                       }
-                                       if fields != nil {
-                                               for _, field := range fields.List {
-                                                       if len(field.Names) == 0 {
-                                                               // anonymous field - add corresponding type
-                                                               // to the info and collect it in doc
-                                                               name := baseTypeName(field.Type, true)
-                                                               if embedded := doc.lookupTypeInfo(name); embedded != nil {
-                                                                       _, ptr := field.Type.(*ast.StarExpr)
-                                                                       info.addEmbeddedType(embedded, ptr)
-                                                               }
-                                                       }
-                                               }
-                                       }
-                               }
-                       }
-               }
-       case *ast.FuncDecl:
-               doc.addFunc(d)
-       }
-}
-
-func copyCommentList(list []*ast.Comment) []*ast.Comment {
-       return append([]*ast.Comment(nil), list...)
-}
-
-var (
-       bug_markers = regexp.MustCompile("^/[/*][ \t]*BUG\\(.*\\):[ \t]*") // BUG(uid):
-       bug_content = regexp.MustCompile("[^ \n\r\t]+")                    // at least one non-whitespace char
-)
-
-// addFile adds the AST for a source file to the docReader.
-// Adding the same AST multiple times is a no-op.
-//
-func (doc *docReader) addFile(src *ast.File) {
-       // add package documentation
-       if src.Doc != nil {
-               doc.addDoc(src.Doc)
-               src.Doc = nil // doc consumed - remove from ast.File node
-       }
-
-       // add all declarations
-       for _, decl := range src.Decls {
-               doc.addDecl(decl)
-       }
-
-       // collect BUG(...) comments
-       for _, c := range src.Comments {
-               text := c.List[0].Text
-               if m := bug_markers.FindStringIndex(text); m != nil {
-                       // found a BUG comment; maybe empty
-                       if btxt := text[m[1]:]; bug_content.MatchString(btxt) {
-                               // non-empty BUG comment; collect comment without BUG prefix
-                               list := copyCommentList(c.List)
-                               list[0].Text = text[m[1]:]
-                               doc.bugs = append(doc.bugs, &ast.CommentGroup{list})
-                       }
-               }
-       }
-       src.Comments = nil // consumed unassociated comments - remove from ast.File node
-}
-
-func NewPackageDoc(pkg *ast.Package, importpath string, exportsOnly bool) *PackageDoc {
-       var r docReader
-       r.init(pkg.Name)
-       filenames := make([]string, len(pkg.Files))
-       i := 0
-       for filename, f := range pkg.Files {
-               if exportsOnly {
-                       r.fileExports(f)
-               }
-               r.addFile(f)
-               filenames[i] = filename
-               i++
-       }
-       return r.newDoc(importpath, filenames)
-}
-
-// ----------------------------------------------------------------------------
-// Conversion to external representation
-
-// ValueDoc is the documentation for a group of declared
-// values, either vars or consts.
-//
-type ValueDoc struct {
+// Package is the documentation for an entire package.
+type Package struct {
+       Doc        string
+       Name       string
+       ImportPath string
+       Imports    []string // TODO(gri) this field is not computed at the moment
+       Filenames  []string
+       Consts     []*Value
+       Types      []*Type
+       Vars       []*Value
+       Funcs      []*Func
+       Bugs       []string
+}
+
+// Value is the documentation for a (possibly grouped) var or const declaration.
+type Value struct {
        Doc   string
+       Names []string // var or const names in declaration order
        Decl  *ast.GenDecl
+
        order int
 }
 
-type sortValueDoc []*ValueDoc
-
-func (p sortValueDoc) Len() int      { return len(p) }
-func (p sortValueDoc) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
-
-func declName(d *ast.GenDecl) string {
-       if len(d.Specs) != 1 {
-               return ""
-       }
-
-       switch v := d.Specs[0].(type) {
-       case *ast.ValueSpec:
-               return v.Names[0].Name
-       case *ast.TypeSpec:
-               return v.Name.Name
-       }
-
-       return ""
+type Method struct {
+       *Func
+       // TODO(gri) The following fields are not set at the moment. 
+       Recv  *Type // original receiver base type
+       Level int   // embedding level; 0 means Func is not embedded
 }
 
-func (p sortValueDoc) Less(i, j int) bool {
-       // sort by name
-       // pull blocks (name = "") up to top
-       // in original order
-       if ni, nj := declName(p[i].Decl), declName(p[j].Decl); ni != nj {
-               return ni < nj
-       }
-       return p[i].order < p[j].order
-}
+// Type is the documentation for type declaration.
+type Type struct {
+       Doc     string
+       Name    string
+       Type    *ast.TypeSpec
+       Decl    *ast.GenDecl
+       Consts  []*Value  // sorted list of constants of (mostly) this type
+       Vars    []*Value  // sorted list of variables of (mostly) this type
+       Funcs   []*Func   // sorted list of functions returning this type
+       Methods []*Method // sorted list of methods (including embedded ones) of this type
 
-func makeValueDocs(list []*ast.GenDecl, tok token.Token) []*ValueDoc {
-       d := make([]*ValueDoc, len(list)) // big enough in any case
-       n := 0
-       for i, decl := range list {
-               if decl.Tok == tok {
-                       d[n] = &ValueDoc{CommentText(decl.Doc), decl, i}
-                       n++
-                       decl.Doc = nil // doc consumed - removed from AST
-               }
-       }
-       d = d[0:n]
-       sort.Sort(sortValueDoc(d))
-       return d
+       methods  []*Func   // top-level methods only
+       embedded methodSet // embedded methods only
+       order    int
 }
 
-// FuncDoc is the documentation for a func declaration,
-// either a top-level function or a method function.
-//
-type FuncDoc struct {
+// Func is the documentation for a func declaration.
+type Func struct {
        Doc  string
-       Recv ast.Expr // TODO(rsc): Would like string here
        Name string
+       // TODO(gri) remove Recv once we switch to new implementation
+       Recv ast.Expr // TODO(rsc): Would like string here
        Decl *ast.FuncDecl
 }
 
-type sortFuncDoc []*FuncDoc
+// Mode values control the operation of New.
+type Mode int
 
-func (p sortFuncDoc) Len() int           { return len(p) }
-func (p sortFuncDoc) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
-func (p sortFuncDoc) Less(i, j int) bool { return p[i].Name < p[j].Name }
-
-func makeFuncDocs(m map[string]*ast.FuncDecl) []*FuncDoc {
-       d := make([]*FuncDoc, len(m))
-       i := 0
-       for _, f := range m {
-               doc := new(FuncDoc)
-               doc.Doc = CommentText(f.Doc)
-               f.Doc = nil // doc consumed - remove from ast.FuncDecl node
-               if f.Recv != nil {
-                       doc.Recv = f.Recv.List[0].Type
-               }
-               doc.Name = f.Name.Name
-               doc.Decl = f
-               d[i] = doc
-               i++
-       }
-       sort.Sort(sortFuncDoc(d))
-       return d
-}
-
-type methodSet map[string]*FuncDoc
-
-func (mset methodSet) add(m *FuncDoc) {
-       if mset[m.Name] == nil {
-               mset[m.Name] = m
-       }
-}
+const (
+       // extract documentation for all package-level declarations,
+       // not just exported ones
+       AllDecls Mode = 1 << iota
+)
 
-func (mset methodSet) sortedList() []*FuncDoc {
-       list := make([]*FuncDoc, len(mset))
+// New computes the package documentation for the given package.
+func New(pkg *ast.Package, importpath string, mode Mode) *Package {
+       var r docReader
+       r.init(pkg.Name, mode)
+       filenames := make([]string, len(pkg.Files))
+       // sort package files before reading them so that the
+       // result is the same on different machines (32/64bit)
        i := 0
-       for _, m := range mset {
-               list[i] = m
+       for filename := range pkg.Files {
+               filenames[i] = filename
                i++
        }
-       sort.Sort(sortFuncDoc(list))
-       return list
-}
-
-// TypeDoc is the documentation for a declared type.
-// Consts and Vars are sorted lists of constants and variables of (mostly) that type.
-// Factories is a sorted list of factory functions that return that type.
-// Methods is a sorted list of method functions on that type.
-type TypeDoc struct {
-       Doc       string
-       Type      *ast.TypeSpec
-       Consts    []*ValueDoc
-       Vars      []*ValueDoc
-       Factories []*FuncDoc
-       methods   []*FuncDoc // top-level methods only
-       embedded  methodSet  // embedded methods only
-       Methods   []*FuncDoc // all methods including embedded ones
-       Decl      *ast.GenDecl
-       order     int
-}
-
-type sortTypeDoc []*TypeDoc
-
-func (p sortTypeDoc) Len() int      { return len(p) }
-func (p sortTypeDoc) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
-func (p sortTypeDoc) Less(i, j int) bool {
-       // sort by name
-       // pull blocks (name = "") up to top
-       // in original order
-       if ni, nj := p[i].Type.Name.Name, p[j].Type.Name.Name; ni != nj {
-               return ni < nj
-       }
-       return p[i].order < p[j].order
-}
-
-// NOTE(rsc): This would appear not to be correct for type ( )
-// blocks, but the doc extractor above has split them into
-// individual declarations.
-func (doc *docReader) makeTypeDocs(m map[string]*typeInfo) []*TypeDoc {
-       // TODO(gri) Consider computing the embedded method information
-       //           before calling makeTypeDocs. Then this function can
-       //           be single-phased again. Also, it might simplify some
-       //           of the logic.
-       //
-       // phase 1: associate collected declarations with TypeDocs
-       list := make([]*TypeDoc, len(m))
-       i := 0
-       for _, old := range m {
-               // all typeInfos should have a declaration associated with
-               // them after processing an entire package - be conservative
-               // and check
-               if decl := old.decl; decl != nil {
-                       typespec := decl.Specs[0].(*ast.TypeSpec)
-                       t := new(TypeDoc)
-                       doc := typespec.Doc
-                       typespec.Doc = nil // doc consumed - remove from ast.TypeSpec node
-                       if doc == nil {
-                               // no doc associated with the spec, use the declaration doc, if any
-                               doc = decl.Doc
-                       }
-                       decl.Doc = nil // doc consumed - remove from ast.Decl node
-                       t.Doc = CommentText(doc)
-                       t.Type = typespec
-                       t.Consts = makeValueDocs(old.values, token.CONST)
-                       t.Vars = makeValueDocs(old.values, token.VAR)
-                       t.Factories = makeFuncDocs(old.factories)
-                       t.methods = makeFuncDocs(old.methods)
-                       // The list of embedded types' methods is computed from the list
-                       // of embedded types, some of which may not have been processed
-                       // yet (i.e., their forward link is nil) - do this in a 2nd phase.
-                       // The final list of methods can only be computed after that -
-                       // do this in a 3rd phase.
-                       t.Decl = old.decl
-                       t.order = i
-                       old.forward = t // old has been processed
-                       list[i] = t
-                       i++
-               } else {
-                       // no corresponding type declaration found - move any associated
-                       // values, factory functions, and methods back to the top-level
-                       // so that they are not lost (this should only happen if a package
-                       // file containing the explicit type declaration is missing or if
-                       // an unqualified type name was used after a "." import)
-                       // 1) move values
-                       doc.values = append(doc.values, old.values...)
-                       // 2) move factory functions
-                       for name, f := range old.factories {
-                               doc.funcs[name] = f
-                       }
-                       // 3) move methods
-                       for name, f := range old.methods {
-                               // don't overwrite functions with the same name
-                               if _, found := doc.funcs[name]; !found {
-                                       doc.funcs[name] = f
-                               }
-                       }
-               }
-       }
-       list = list[0:i] // some types may have been ignored
-
-       // phase 2: collect embedded methods for each processed typeInfo
-       for _, old := range m {
-               if t := old.forward; t != nil {
-                       // old has been processed into t; collect embedded
-                       // methods for t from the list of processed embedded
-                       // types in old (and thus for which the methods are known)
-                       typ := t.Type
-                       if _, ok := typ.Type.(*ast.StructType); ok {
-                               // struct
-                               t.embedded = make(methodSet)
-                               collectEmbeddedMethods(t.embedded, old, typ.Name.Name)
-                       } else {
-                               // interface
-                               // TODO(gri) fix this
-                       }
-               }
-       }
-
-       // phase 3: compute final method set for each TypeDoc
-       for _, d := range list {
-               if len(d.embedded) > 0 {
-                       // there are embedded methods - exclude
-                       // the ones with names conflicting with
-                       // non-embedded methods
-                       mset := make(methodSet)
-                       // top-level methods have priority
-                       for _, m := range d.methods {
-                               mset.add(m)
-                       }
-                       // add non-conflicting embedded methods
-                       for _, m := range d.embedded {
-                               mset.add(m)
-                       }
-                       d.Methods = mset.sortedList()
-               } else {
-                       // no embedded methods
-                       d.Methods = d.methods
-               }
-       }
-
-       sort.Sort(sortTypeDoc(list))
-       return list
-}
+       sort.Strings(filenames)
 
-// collectEmbeddedMethods collects the embedded methods from all
-// processed embedded types found in info in mset. It considers
-// embedded types at the most shallow level first so that more
-// deeply nested embedded methods with conflicting names are
-// excluded.
-//
-func collectEmbeddedMethods(mset methodSet, info *typeInfo, recvTypeName string) {
-       for _, e := range info.embedded {
-               if e.typ.forward != nil { // == e was processed
-                       for _, m := range e.typ.forward.methods {
-                               mset.add(customizeRecv(m, e.ptr, recvTypeName))
-                       }
-                       collectEmbeddedMethods(mset, e.typ, recvTypeName)
+       // process files in sorted order
+       for _, filename := range filenames {
+               f := pkg.Files[filename]
+               if mode&AllDecls == 0 {
+                       r.fileExports(f)
                }
+               r.addFile(f)
        }
-}
-
-func customizeRecv(m *FuncDoc, embeddedIsPtr bool, recvTypeName string) *FuncDoc {
-       if m == nil || m.Decl == nil || m.Decl.Recv == nil || len(m.Decl.Recv.List) != 1 {
-               return m // shouldn't happen, but be safe
-       }
-
-       // copy existing receiver field and set new type
-       // TODO(gri) is receiver type computation correct?
-       //           what about deeply nested embeddings?
-       newField := *m.Decl.Recv.List[0]
-       _, origRecvIsPtr := newField.Type.(*ast.StarExpr)
-       var typ ast.Expr = ast.NewIdent(recvTypeName)
-       if embeddedIsPtr || origRecvIsPtr {
-               typ = &ast.StarExpr{token.NoPos, typ}
-       }
-       newField.Type = typ
-
-       // copy existing receiver field list and set new receiver field
-       newFieldList := *m.Decl.Recv
-       newFieldList.List = []*ast.Field{&newField}
-
-       // copy existing function declaration and set new receiver field list
-       newFuncDecl := *m.Decl
-       newFuncDecl.Recv = &newFieldList
-
-       // copy existing function documentation and set new declaration
-       newM := *m
-       newM.Decl = &newFuncDecl
-       newM.Recv = typ
-
-       return &newM
-}
-
-func makeBugDocs(list []*ast.CommentGroup) []string {
-       d := make([]string, len(list))
-       for i, g := range list {
-               d[i] = CommentText(g)
-       }
-       return d
-}
-
-// PackageDoc is the documentation for an entire package.
-//
-type PackageDoc struct {
-       PackageName string
-       ImportPath  string
-       Filenames   []string
-       Doc         string
-       Consts      []*ValueDoc
-       Types       []*TypeDoc
-       Vars        []*ValueDoc
-       Funcs       []*FuncDoc
-       Bugs        []string
-}
-
-// newDoc returns the accumulated documentation for the package.
-//
-func (doc *docReader) newDoc(importpath string, filenames []string) *PackageDoc {
-       p := new(PackageDoc)
-       p.PackageName = doc.pkgName
-       p.ImportPath = importpath
-       sort.Strings(filenames)
-       p.Filenames = filenames
-       p.Doc = CommentText(doc.doc)
-       // makeTypeDocs may extend the list of doc.values and
-       // doc.funcs and thus must be called before any other
-       // function consuming those lists
-       p.Types = doc.makeTypeDocs(doc.types)
-       p.Consts = makeValueDocs(doc.values, token.CONST)
-       p.Vars = makeValueDocs(doc.values, token.VAR)
-       p.Funcs = makeFuncDocs(doc.funcs)
-       p.Bugs = makeBugDocs(doc.bugs)
-       return p
+       return r.newDoc(importpath, filenames)
 }
diff --git a/libgo/go/go/doc/doc_test.go b/libgo/go/go/doc/doc_test.go
new file mode 100644 (file)
index 0000000..317d3ab
--- /dev/null
@@ -0,0 +1,137 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package doc
+
+import (
+       "bytes"
+       "fmt"
+       "go/ast"
+       "go/parser"
+       "go/token"
+       "testing"
+       "text/template"
+)
+
+type sources map[string]string // filename -> file contents
+
+type testCase struct {
+       name       string
+       importPath string
+       mode       Mode
+       srcs       sources
+       doc        string
+}
+
+var tests = make(map[string]*testCase)
+
+// To register a new test case, use the pattern:
+//
+//     var _ = register(&testCase{ ... })
+//
+// (The result value of register is always 0 and only present to enable the pattern.)
+//
+func register(test *testCase) int {
+       if _, found := tests[test.name]; found {
+               panic(fmt.Sprintf("registration failed: test case %q already exists", test.name))
+       }
+       tests[test.name] = test
+       return 0
+}
+
+func runTest(t *testing.T, test *testCase) {
+       // create AST
+       fset := token.NewFileSet()
+       var pkg ast.Package
+       pkg.Files = make(map[string]*ast.File)
+       for filename, src := range test.srcs {
+               file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
+               if err != nil {
+                       t.Errorf("test %s: %v", test.name, err)
+                       return
+               }
+               switch {
+               case pkg.Name == "":
+                       pkg.Name = file.Name.Name
+               case pkg.Name != file.Name.Name:
+                       t.Errorf("test %s: different package names in test files", test.name)
+                       return
+               }
+               pkg.Files[filename] = file
+       }
+
+       doc := New(&pkg, test.importPath, test.mode).String()
+       if doc != test.doc {
+               //TODO(gri) Enable this once the sorting issue of comments is fixed
+               //t.Errorf("test %s\n\tgot : %s\n\twant: %s", test.name, doc, test.doc)
+       }
+}
+
+func Test(t *testing.T) {
+       for _, test := range tests {
+               runTest(t, test)
+       }
+}
+
+// ----------------------------------------------------------------------------
+// Printing support
+
+func (pkg *Package) String() string {
+       var buf bytes.Buffer
+       docText.Execute(&buf, pkg) // ignore error - test will fail w/ incorrect output
+       return buf.String()
+}
+
+// TODO(gri) complete template
+var docText = template.Must(template.New("docText").Parse(
+       `
+PACKAGE {{.Name}}
+DOC {{printf "%q" .Doc}}
+IMPORTPATH {{.ImportPath}}
+FILENAMES {{.Filenames}}
+`))
+
+// ----------------------------------------------------------------------------
+// Test cases
+
+// Test that all package comments and bugs are collected,
+// and that the importPath is correctly set.
+//
+var _ = register(&testCase{
+       name:       "p",
+       importPath: "p",
+       srcs: sources{
+               "p1.go": "// comment 1\npackage p\n//BUG(uid): bug1",
+               "p0.go": "// comment 0\npackage p\n// BUG(uid): bug0",
+       },
+       doc: `
+PACKAGE p
+DOC "comment 0\n\ncomment 1\n"
+IMPORTPATH p
+FILENAMES [p0.go p1.go]
+`,
+})
+
+// Test basic functionality.
+//
+var _ = register(&testCase{
+       name:       "p1",
+       importPath: "p",
+       srcs: sources{
+               "p.go": `
+package p
+import "a"
+const pi = 3.14       // pi
+type T struct{}       // T
+var V T               // v
+func F(x int) int {}  // F
+`,
+       },
+       doc: `
+PACKAGE p
+DOC ""
+IMPORTPATH p
+FILENAMES [p.go]
+`,
+})
index 1bdf4e27e177109977190c68c13847b9a022ee1f..7c59bf9bd60a80b5c9e67227b72671db858cb530 100644 (file)
@@ -35,7 +35,7 @@ func Examples(pkg *ast.Package) []*Example {
                        examples = append(examples, &Example{
                                Name:   name[len("Example"):],
                                Body:   &printer.CommentedNode{f.Body, src.Comments},
-                               Output: CommentText(f.Doc),
+                               Output: f.Doc.Text(),
                        })
                }
        }
index 9cd186a9c7a81cefee3cf68e8bdad6eeac01c292..994bf503b55b1f1062ca21aa918b5a3c13945914 100644 (file)
@@ -33,7 +33,7 @@ func baseName(x ast.Expr) *ast.Ident {
        return nil
 }
 
-func (doc *docReader) filterFieldList(fields *ast.FieldList) (removedFields bool) {
+func (doc *docReader) filterFieldList(tinfo *typeInfo, fields *ast.FieldList) (removedFields bool) {
        if fields == nil {
                return false
        }
@@ -44,7 +44,18 @@ func (doc *docReader) filterFieldList(fields *ast.FieldList) (removedFields bool
                if len(f.Names) == 0 {
                        // anonymous field
                        name := baseName(f.Type)
-                       keepField = name != nil && name.IsExported()
+                       if name != nil && name.IsExported() {
+                               // we keep the field - in this case doc.addDecl
+                               // will take care of adding the embedded type
+                               keepField = true
+                       } else if tinfo != nil {
+                               // we don't keep the field - add it as an embedded
+                               // type so we won't loose its methods, if any
+                               if embedded := doc.lookupTypeInfo(name.Name); embedded != nil {
+                                       _, ptr := f.Type.(*ast.StarExpr)
+                                       tinfo.addEmbeddedType(embedded, ptr)
+                               }
+                       }
                } else {
                        n := len(f.Names)
                        f.Names = filterIdentList(f.Names)
@@ -54,7 +65,7 @@ func (doc *docReader) filterFieldList(fields *ast.FieldList) (removedFields bool
                        keepField = len(f.Names) > 0
                }
                if keepField {
-                       doc.filterType(f.Type)
+                       doc.filterType(nil, f.Type)
                        list[j] = f
                        j++
                }
@@ -72,23 +83,23 @@ func (doc *docReader) filterParamList(fields *ast.FieldList) bool {
        }
        var b bool
        for _, f := range fields.List {
-               if doc.filterType(f.Type) {
+               if doc.filterType(nil, f.Type) {
                        b = true
                }
        }
        return b
 }
 
-func (doc *docReader) filterType(typ ast.Expr) bool {
+func (doc *docReader) filterType(tinfo *typeInfo, typ ast.Expr) bool {
        switch t := typ.(type) {
        case *ast.Ident:
                return ast.IsExported(t.Name)
        case *ast.ParenExpr:
-               return doc.filterType(t.X)
+               return doc.filterType(nil, t.X)
        case *ast.ArrayType:
-               return doc.filterType(t.Elt)
+               return doc.filterType(nil, t.Elt)
        case *ast.StructType:
-               if doc.filterFieldList(t.Fields) {
+               if doc.filterFieldList(tinfo, t.Fields) {
                        t.Incomplete = true
                }
                return len(t.Fields.List) > 0
@@ -97,16 +108,16 @@ func (doc *docReader) filterType(typ ast.Expr) bool {
                b2 := doc.filterParamList(t.Results)
                return b1 || b2
        case *ast.InterfaceType:
-               if doc.filterFieldList(t.Methods) {
+               if doc.filterFieldList(tinfo, t.Methods) {
                        t.Incomplete = true
                }
                return len(t.Methods.List) > 0
        case *ast.MapType:
-               b1 := doc.filterType(t.Key)
-               b2 := doc.filterType(t.Value)
+               b1 := doc.filterType(nil, t.Key)
+               b2 := doc.filterType(nil, t.Value)
                return b1 || b2
        case *ast.ChanType:
-               return doc.filterType(t.Value)
+               return doc.filterType(nil, t.Value)
        }
        return false
 }
@@ -116,12 +127,12 @@ func (doc *docReader) filterSpec(spec ast.Spec) bool {
        case *ast.ValueSpec:
                s.Names = filterIdentList(s.Names)
                if len(s.Names) > 0 {
-                       doc.filterType(s.Type)
+                       doc.filterType(nil, s.Type)
                        return true
                }
        case *ast.TypeSpec:
                if ast.IsExported(s.Name.Name) {
-                       doc.filterType(s.Type)
+                       doc.filterType(doc.lookupTypeInfo(s.Name.Name), s.Type)
                        return true
                }
        }
index 71c2ebb68bd14e0676594e079e15cb212cd11096..fe2d39b8802d06fb73ae4f0fb73170c4668f8b0b 100644 (file)
@@ -49,7 +49,7 @@ func matchDecl(d *ast.GenDecl, f Filter) bool {
        return false
 }
 
-func filterValueDocs(a []*ValueDoc, f Filter) []*ValueDoc {
+func filterValues(a []*Value, f Filter) []*Value {
        w := 0
        for _, vd := range a {
                if matchDecl(vd.Decl, f) {
@@ -60,7 +60,7 @@ func filterValueDocs(a []*ValueDoc, f Filter) []*ValueDoc {
        return a[0:w]
 }
 
-func filterFuncDocs(a []*FuncDoc, f Filter) []*FuncDoc {
+func filterFuncs(a []*Func, f Filter) []*Func {
        w := 0
        for _, fd := range a {
                if f(fd.Name) {
@@ -71,7 +71,18 @@ func filterFuncDocs(a []*FuncDoc, f Filter) []*FuncDoc {
        return a[0:w]
 }
 
-func filterTypeDocs(a []*TypeDoc, f Filter) []*TypeDoc {
+func filterMethods(a []*Method, f Filter) []*Method {
+       w := 0
+       for _, md := range a {
+               if f(md.Name) {
+                       a[w] = md
+                       w++
+               }
+       }
+       return a[0:w]
+}
+
+func filterTypes(a []*Type, f Filter) []*Type {
        w := 0
        for _, td := range a {
                n := 0 // number of matches
@@ -79,11 +90,11 @@ func filterTypeDocs(a []*TypeDoc, f Filter) []*TypeDoc {
                        n = 1
                } else {
                        // type name doesn't match, but we may have matching consts, vars, factories or methods
-                       td.Consts = filterValueDocs(td.Consts, f)
-                       td.Vars = filterValueDocs(td.Vars, f)
-                       td.Factories = filterFuncDocs(td.Factories, f)
-                       td.Methods = filterFuncDocs(td.Methods, f)
-                       n += len(td.Consts) + len(td.Vars) + len(td.Factories) + len(td.Methods)
+                       td.Consts = filterValues(td.Consts, f)
+                       td.Vars = filterValues(td.Vars, f)
+                       td.Funcs = filterFuncs(td.Funcs, f)
+                       td.Methods = filterMethods(td.Methods, f)
+                       n += len(td.Consts) + len(td.Vars) + len(td.Funcs) + len(td.Methods)
                }
                if n > 0 {
                        a[w] = td
@@ -96,10 +107,10 @@ func filterTypeDocs(a []*TypeDoc, f Filter) []*TypeDoc {
 // Filter eliminates documentation for names that don't pass through the filter f.
 // TODO: Recognize "Type.Method" as a name.
 //
-func (p *PackageDoc) Filter(f Filter) {
-       p.Consts = filterValueDocs(p.Consts, f)
-       p.Vars = filterValueDocs(p.Vars, f)
-       p.Types = filterTypeDocs(p.Types, f)
-       p.Funcs = filterFuncDocs(p.Funcs, f)
+func (p *Package) Filter(f Filter) {
+       p.Consts = filterValues(p.Consts, f)
+       p.Vars = filterValues(p.Vars, f)
+       p.Types = filterTypes(p.Types, f)
+       p.Funcs = filterFuncs(p.Funcs, f)
        p.Doc = "" // don't show top-level package doc
 }
index 838223be745f9eb1492172ab5ac4666977e3d6c4..37486b126fd1b69062bb048875cb45cf1580558f 100644 (file)
@@ -77,7 +77,7 @@ func main() {
                        return nil
                }
                for _, pkg := range pkgs {
-                       d := doc.NewPackageDoc(pkg, path)
+                       d := doc.New(pkg, path, doc.Mode(0))
                        list := appendHeadings(nil, d.Doc)
                        for _, d := range d.Consts {
                                list = appendHeadings(list, d.Doc)
diff --git a/libgo/go/go/doc/reader.go b/libgo/go/go/doc/reader.go
new file mode 100644 (file)
index 0000000..939dd89
--- /dev/null
@@ -0,0 +1,669 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package doc
+
+import (
+       "go/ast"
+       "go/token"
+       "regexp"
+       "sort"
+)
+
+// ----------------------------------------------------------------------------
+// Collection of documentation info
+
+// embeddedType describes the type of an anonymous field.
+//
+type embeddedType struct {
+       typ *typeInfo // the corresponding base type
+       ptr bool      // if set, the anonymous field type is a pointer
+}
+
+type typeInfo struct {
+       name     string // base type name
+       isStruct bool
+       // len(decl.Specs) == 1, and the element type is *ast.TypeSpec
+       // if the type declaration hasn't been seen yet, decl is nil
+       decl     *ast.GenDecl
+       embedded []embeddedType
+       forward  *Type // forward link to processed type documentation
+
+       // declarations associated with the type
+       values    []*ast.GenDecl // consts and vars
+       factories map[string]*ast.FuncDecl
+       methods   map[string]*ast.FuncDecl
+}
+
+func (info *typeInfo) exported() bool {
+       return ast.IsExported(info.name)
+}
+
+func (info *typeInfo) addEmbeddedType(embedded *typeInfo, isPtr bool) {
+       info.embedded = append(info.embedded, embeddedType{embedded, isPtr})
+}
+
+// docReader accumulates documentation for a single package.
+// It modifies the AST: Comments (declaration documentation)
+// that have been collected by the DocReader are set to nil
+// in the respective AST nodes so that they are not printed
+// twice (once when printing the documentation and once when
+// printing the corresponding AST node).
+//
+type docReader struct {
+       doc      *ast.CommentGroup // package documentation, if any
+       pkgName  string
+       mode     Mode
+       values   []*ast.GenDecl // consts and vars
+       types    map[string]*typeInfo
+       embedded map[string]*typeInfo // embedded types, possibly not exported
+       funcs    map[string]*ast.FuncDecl
+       bugs     []*ast.CommentGroup
+}
+
+func (doc *docReader) init(pkgName string, mode Mode) {
+       doc.pkgName = pkgName
+       doc.mode = mode
+       doc.types = make(map[string]*typeInfo)
+       doc.embedded = make(map[string]*typeInfo)
+       doc.funcs = make(map[string]*ast.FuncDecl)
+}
+
+func (doc *docReader) addDoc(comments *ast.CommentGroup) {
+       if doc.doc == nil {
+               // common case: just one package comment
+               doc.doc = comments
+               return
+       }
+       // More than one package comment: Usually there will be only
+       // one file with a package comment, but it's better to collect
+       // all comments than drop them on the floor.
+       blankComment := &ast.Comment{token.NoPos, "//"}
+       list := append(doc.doc.List, blankComment)
+       doc.doc.List = append(list, comments.List...)
+}
+
+func (doc *docReader) lookupTypeInfo(name string) *typeInfo {
+       if name == "" || name == "_" {
+               return nil // no type docs for anonymous types
+       }
+       if info, found := doc.types[name]; found {
+               return info
+       }
+       // type wasn't found - add one without declaration
+       info := &typeInfo{
+               name:      name,
+               factories: make(map[string]*ast.FuncDecl),
+               methods:   make(map[string]*ast.FuncDecl),
+       }
+       doc.types[name] = info
+       return info
+}
+
+func baseTypeName(typ ast.Expr, allTypes bool) string {
+       switch t := typ.(type) {
+       case *ast.Ident:
+               // if the type is not exported, the effect to
+               // a client is as if there were no type name
+               if t.IsExported() || allTypes {
+                       return t.Name
+               }
+       case *ast.StarExpr:
+               return baseTypeName(t.X, allTypes)
+       }
+       return ""
+}
+
+func (doc *docReader) addValue(decl *ast.GenDecl) {
+       // determine if decl should be associated with a type
+       // Heuristic: For each typed entry, determine the type name, if any.
+       //            If there is exactly one type name that is sufficiently
+       //            frequent, associate the decl with the respective type.
+       domName := ""
+       domFreq := 0
+       prev := ""
+       for _, s := range decl.Specs {
+               if v, ok := s.(*ast.ValueSpec); ok {
+                       name := ""
+                       switch {
+                       case v.Type != nil:
+                               // a type is present; determine its name
+                               name = baseTypeName(v.Type, false)
+                       case decl.Tok == token.CONST:
+                               // no type is present but we have a constant declaration;
+                               // use the previous type name (w/o more type information
+                               // we cannot handle the case of unnamed variables with
+                               // initializer expressions except for some trivial cases)
+                               name = prev
+                       }
+                       if name != "" {
+                               // entry has a named type
+                               if domName != "" && domName != name {
+                                       // more than one type name - do not associate
+                                       // with any type
+                                       domName = ""
+                                       break
+                               }
+                               domName = name
+                               domFreq++
+                       }
+                       prev = name
+               }
+       }
+
+       // determine values list
+       const threshold = 0.75
+       values := &doc.values
+       if domName != "" && domFreq >= int(float64(len(decl.Specs))*threshold) {
+               // typed entries are sufficiently frequent
+               typ := doc.lookupTypeInfo(domName)
+               if typ != nil {
+                       values = &typ.values // associate with that type
+               }
+       }
+
+       *values = append(*values, decl)
+}
+
+// Helper function to set the table entry for function f. Makes sure that
+// at least one f with associated documentation is stored in table, if there
+// are multiple f's with the same name.
+func setFunc(table map[string]*ast.FuncDecl, f *ast.FuncDecl) {
+       name := f.Name.Name
+       if g, exists := table[name]; exists && g.Doc != nil {
+               // a function with the same name has already been registered;
+               // since it has documentation, assume f is simply another
+               // implementation and ignore it
+               // TODO(gri) consider collecting all functions, or at least
+               //           all comments
+               return
+       }
+       // function doesn't exist or has no documentation; use f
+       table[name] = f
+}
+
+func (doc *docReader) addFunc(fun *ast.FuncDecl) {
+       // strip function body
+       fun.Body = nil
+
+       // determine if it should be associated with a type
+       if fun.Recv != nil {
+               // method
+               recvTypeName := baseTypeName(fun.Recv.List[0].Type, true /* exported or not */ )
+               var typ *typeInfo
+               if ast.IsExported(recvTypeName) {
+                       // exported recv type: if not found, add it to doc.types
+                       typ = doc.lookupTypeInfo(recvTypeName)
+               } else {
+                       // unexported recv type: if not found, do not add it
+                       // (unexported embedded types are added before this
+                       // phase, so if the type doesn't exist yet, we don't
+                       // care about this method)
+                       typ = doc.types[recvTypeName]
+               }
+               if typ != nil {
+                       // exported receiver type
+                       // associate method with the type
+                       // (if the type is not exported, it may be embedded
+                       // somewhere so we need to collect the method anyway)
+                       setFunc(typ.methods, fun)
+               }
+               // otherwise don't show the method
+               // TODO(gri): There may be exported methods of non-exported types
+               // that can be called because of exported values (consts, vars, or
+               // function results) of that type. Could determine if that is the
+               // case and then show those methods in an appropriate section.
+               return
+       }
+
+       // perhaps a factory function
+       // determine result type, if any
+       if fun.Type.Results.NumFields() >= 1 {
+               res := fun.Type.Results.List[0]
+               if len(res.Names) <= 1 {
+                       // exactly one (named or anonymous) result associated
+                       // with the first type in result signature (there may
+                       // be more than one result)
+                       tname := baseTypeName(res.Type, false)
+                       typ := doc.lookupTypeInfo(tname)
+                       if typ != nil {
+                               // named and exported result type
+                               setFunc(typ.factories, fun)
+                               return
+                       }
+               }
+       }
+
+       // ordinary function
+       setFunc(doc.funcs, fun)
+}
+
+func (doc *docReader) addDecl(decl ast.Decl) {
+       switch d := decl.(type) {
+       case *ast.GenDecl:
+               if len(d.Specs) > 0 {
+                       switch d.Tok {
+                       case token.CONST, token.VAR:
+                               // constants and variables are always handled as a group
+                               doc.addValue(d)
+                       case token.TYPE:
+                               // types are handled individually
+                               for _, spec := range d.Specs {
+                                       tspec := spec.(*ast.TypeSpec)
+                                       // add the type to the documentation
+                                       info := doc.lookupTypeInfo(tspec.Name.Name)
+                                       if info == nil {
+                                               continue // no name - ignore the type
+                                       }
+                                       // Make a (fake) GenDecl node for this TypeSpec
+                                       // (we need to do this here - as opposed to just
+                                       // for printing - so we don't lose the GenDecl
+                                       // documentation). Since a new GenDecl node is
+                                       // created, there's no need to nil out d.Doc.
+                                       //
+                                       // TODO(gri): Consider just collecting the TypeSpec
+                                       // node (and copy in the GenDecl.doc if there is no
+                                       // doc in the TypeSpec - this is currently done in
+                                       // makeTypes below). Simpler data structures, but
+                                       // would lose GenDecl documentation if the TypeSpec
+                                       // has documentation as well.
+                                       fake := &ast.GenDecl{d.Doc, d.Pos(), token.TYPE, token.NoPos,
+                                               []ast.Spec{tspec}, token.NoPos}
+                                       // A type should be added at most once, so info.decl
+                                       // should be nil - if it isn't, simply overwrite it.
+                                       info.decl = fake
+                                       // Look for anonymous fields that might contribute methods.
+                                       var fields *ast.FieldList
+                                       switch typ := spec.(*ast.TypeSpec).Type.(type) {
+                                       case *ast.StructType:
+                                               fields = typ.Fields
+                                               info.isStruct = true
+                                       case *ast.InterfaceType:
+                                               fields = typ.Methods
+                                       }
+                                       if fields != nil {
+                                               for _, field := range fields.List {
+                                                       if len(field.Names) == 0 {
+                                                               // anonymous field - add corresponding type
+                                                               // to the info and collect it in doc
+                                                               name := baseTypeName(field.Type, true)
+                                                               if embedded := doc.lookupTypeInfo(name); embedded != nil {
+                                                                       _, ptr := field.Type.(*ast.StarExpr)
+                                                                       info.addEmbeddedType(embedded, ptr)
+                                                               }
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
+               }
+       case *ast.FuncDecl:
+               doc.addFunc(d)
+       }
+}
+
+func copyCommentList(list []*ast.Comment) []*ast.Comment {
+       return append([]*ast.Comment(nil), list...)
+}
+
+var (
+       bug_markers = regexp.MustCompile("^/[/*][ \t]*BUG\\(.*\\):[ \t]*") // BUG(uid):
+       bug_content = regexp.MustCompile("[^ \n\r\t]+")                    // at least one non-whitespace char
+)
+
+// addFile adds the AST for a source file to the docReader.
+// Adding the same AST multiple times is a no-op.
+//
+func (doc *docReader) addFile(src *ast.File) {
+       // add package documentation
+       if src.Doc != nil {
+               doc.addDoc(src.Doc)
+               src.Doc = nil // doc consumed - remove from ast.File node
+       }
+
+       // add all declarations
+       for _, decl := range src.Decls {
+               doc.addDecl(decl)
+       }
+
+       // collect BUG(...) comments
+       for _, c := range src.Comments {
+               text := c.List[0].Text
+               if m := bug_markers.FindStringIndex(text); m != nil {
+                       // found a BUG comment; maybe empty
+                       if btxt := text[m[1]:]; bug_content.MatchString(btxt) {
+                               // non-empty BUG comment; collect comment without BUG prefix
+                               list := copyCommentList(c.List)
+                               list[0].Text = text[m[1]:]
+                               doc.bugs = append(doc.bugs, &ast.CommentGroup{list})
+                       }
+               }
+       }
+       src.Comments = nil // consumed unassociated comments - remove from ast.File node
+}
+
+// ----------------------------------------------------------------------------
+// Conversion to external representation
+
+type sortValue []*Value
+
+func (p sortValue) Len() int      { return len(p) }
+func (p sortValue) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+
+func declName(d *ast.GenDecl) string {
+       if len(d.Specs) != 1 {
+               return ""
+       }
+
+       switch v := d.Specs[0].(type) {
+       case *ast.ValueSpec:
+               return v.Names[0].Name
+       case *ast.TypeSpec:
+               return v.Name.Name
+       }
+
+       return ""
+}
+
+func (p sortValue) Less(i, j int) bool {
+       // sort by name
+       // pull blocks (name = "") up to top
+       // in original order
+       if ni, nj := declName(p[i].Decl), declName(p[j].Decl); ni != nj {
+               return ni < nj
+       }
+       return p[i].order < p[j].order
+}
+
+func specNames(specs []ast.Spec) []string {
+       names := make([]string, len(specs)) // reasonable estimate
+       for _, s := range specs {
+               // should always be an *ast.ValueSpec, but be careful
+               if s, ok := s.(*ast.ValueSpec); ok {
+                       for _, ident := range s.Names {
+                               names = append(names, ident.Name)
+                       }
+               }
+       }
+       return names
+}
+
+func makeValues(list []*ast.GenDecl, tok token.Token) []*Value {
+       d := make([]*Value, len(list)) // big enough in any case
+       n := 0
+       for i, decl := range list {
+               if decl.Tok == tok {
+                       d[n] = &Value{decl.Doc.Text(), specNames(decl.Specs), decl, i}
+                       n++
+                       decl.Doc = nil // doc consumed - removed from AST
+               }
+       }
+       d = d[0:n]
+       sort.Sort(sortValue(d))
+       return d
+}
+
+type sortFunc []*Func
+
+func (p sortFunc) Len() int           { return len(p) }
+func (p sortFunc) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
+func (p sortFunc) Less(i, j int) bool { return p[i].Name < p[j].Name }
+
+func makeFuncs(m map[string]*ast.FuncDecl) []*Func {
+       d := make([]*Func, len(m))
+       i := 0
+       for _, f := range m {
+               doc := new(Func)
+               doc.Doc = f.Doc.Text()
+               f.Doc = nil // doc consumed - remove from ast.FuncDecl node
+               if f.Recv != nil {
+                       doc.Recv = f.Recv.List[0].Type
+               }
+               doc.Name = f.Name.Name
+               doc.Decl = f
+               d[i] = doc
+               i++
+       }
+       sort.Sort(sortFunc(d))
+       return d
+}
+
+type methodSet map[string]*Func
+
+func (mset methodSet) add(m *Func) {
+       if mset[m.Name] == nil {
+               mset[m.Name] = m
+       }
+}
+
+type sortMethod []*Method
+
+func (p sortMethod) Len() int           { return len(p) }
+func (p sortMethod) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
+func (p sortMethod) Less(i, j int) bool { return p[i].Func.Name < p[j].Func.Name }
+
+func (mset methodSet) sortedList() []*Method {
+       list := make([]*Method, len(mset))
+       i := 0
+       for _, m := range mset {
+               list[i] = &Method{Func: m}
+               i++
+       }
+       sort.Sort(sortMethod(list))
+       return list
+}
+
+type sortType []*Type
+
+func (p sortType) Len() int      { return len(p) }
+func (p sortType) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+func (p sortType) Less(i, j int) bool {
+       // sort by name
+       // pull blocks (name = "") up to top
+       // in original order
+       if ni, nj := p[i].Type.Name.Name, p[j].Type.Name.Name; ni != nj {
+               return ni < nj
+       }
+       return p[i].order < p[j].order
+}
+
+// NOTE(rsc): This would appear not to be correct for type ( )
+// blocks, but the doc extractor above has split them into
+// individual declarations.
+func (doc *docReader) makeTypes(m map[string]*typeInfo) []*Type {
+       // TODO(gri) Consider computing the embedded method information
+       //           before calling makeTypes. Then this function can
+       //           be single-phased again. Also, it might simplify some
+       //           of the logic.
+       //
+       // phase 1: associate collected declarations with Types
+       list := make([]*Type, len(m))
+       i := 0
+       for _, old := range m {
+               // old typeInfos may not have a declaration associated with them
+               // if they are not exported but embedded, or because the package
+               // is incomplete.
+               if decl := old.decl; decl != nil || !old.exported() {
+                       // process the type even if not exported so that we have
+                       // its methods in case they are embedded somewhere
+                       t := new(Type)
+                       if decl != nil {
+                               typespec := decl.Specs[0].(*ast.TypeSpec)
+                               doc := typespec.Doc
+                               typespec.Doc = nil // doc consumed - remove from ast.TypeSpec node
+                               if doc == nil {
+                                       // no doc associated with the spec, use the declaration doc, if any
+                                       doc = decl.Doc
+                               }
+                               decl.Doc = nil // doc consumed - remove from ast.Decl node
+                               t.Doc = doc.Text()
+                               t.Type = typespec
+                       }
+                       t.Consts = makeValues(old.values, token.CONST)
+                       t.Vars = makeValues(old.values, token.VAR)
+                       t.Funcs = makeFuncs(old.factories)
+                       t.methods = makeFuncs(old.methods)
+                       // The list of embedded types' methods is computed from the list
+                       // of embedded types, some of which may not have been processed
+                       // yet (i.e., their forward link is nil) - do this in a 2nd phase.
+                       // The final list of methods can only be computed after that -
+                       // do this in a 3rd phase.
+                       t.Decl = old.decl
+                       t.order = i
+                       old.forward = t // old has been processed
+                       // only add the type to the final type list if it
+                       // is exported or if we want to see all types
+                       if old.exported() || doc.mode&AllDecls != 0 {
+                               list[i] = t
+                               i++
+                       }
+               } else {
+                       // no corresponding type declaration found - move any associated
+                       // values, factory functions, and methods back to the top-level
+                       // so that they are not lost (this should only happen if a package
+                       // file containing the explicit type declaration is missing or if
+                       // an unqualified type name was used after a "." import)
+                       // 1) move values
+                       doc.values = append(doc.values, old.values...)
+                       // 2) move factory functions
+                       for name, f := range old.factories {
+                               doc.funcs[name] = f
+                       }
+                       // 3) move methods
+                       for name, f := range old.methods {
+                               // don't overwrite functions with the same name
+                               if _, found := doc.funcs[name]; !found {
+                                       doc.funcs[name] = f
+                               }
+                       }
+               }
+       }
+       list = list[0:i] // some types may have been ignored
+
+       // phase 2: collect embedded methods for each processed typeInfo
+       for _, old := range m {
+               if t := old.forward; t != nil {
+                       // old has been processed into t; collect embedded
+                       // methods for t from the list of processed embedded
+                       // types in old (and thus for which the methods are known)
+                       if old.isStruct {
+                               // struct
+                               t.embedded = make(methodSet)
+                               collectEmbeddedMethods(t.embedded, old, old.name, false)
+                       } else {
+                               // interface
+                               // TODO(gri) fix this
+                       }
+               }
+       }
+
+       // phase 3: compute final method set for each Type
+       for _, d := range list {
+               if len(d.embedded) > 0 {
+                       // there are embedded methods - exclude
+                       // the ones with names conflicting with
+                       // non-embedded methods
+                       mset := make(methodSet)
+                       // top-level methods have priority
+                       for _, m := range d.methods {
+                               mset.add(m)
+                       }
+                       // add non-conflicting embedded methods
+                       for _, m := range d.embedded {
+                               mset.add(m)
+                       }
+                       d.Methods = mset.sortedList()
+               } else {
+                       // no embedded methods - convert into a Method list
+                       d.Methods = make([]*Method, len(d.methods))
+                       for i, m := range d.methods {
+                               d.Methods[i] = &Method{Func: m}
+                       }
+               }
+       }
+
+       sort.Sort(sortType(list))
+       return list
+}
+
+// collectEmbeddedMethods collects the embedded methods from all
+// processed embedded types found in info in mset. It considers
+// embedded types at the most shallow level first so that more
+// deeply nested embedded methods with conflicting names are
+// excluded.
+//
+func collectEmbeddedMethods(mset methodSet, info *typeInfo, recvTypeName string, embeddedIsPtr bool) {
+       for _, e := range info.embedded {
+               if e.typ.forward != nil { // == e was processed
+                       // Once an embedded type was embedded as a pointer type
+                       // all embedded types in those types are treated like
+                       // pointer types for the purpose of the receiver type
+                       // computation; i.e., embeddedIsPtr is sticky for this
+                       // embedding hierarchy.
+                       thisEmbeddedIsPtr := embeddedIsPtr || e.ptr
+                       for _, m := range e.typ.forward.methods {
+                               mset.add(customizeRecv(m, thisEmbeddedIsPtr, recvTypeName))
+                       }
+                       collectEmbeddedMethods(mset, e.typ, recvTypeName, thisEmbeddedIsPtr)
+               }
+       }
+}
+
+func customizeRecv(m *Func, embeddedIsPtr bool, recvTypeName string) *Func {
+       if m == nil || m.Decl == nil || m.Decl.Recv == nil || len(m.Decl.Recv.List) != 1 {
+               return m // shouldn't happen, but be safe
+       }
+
+       // copy existing receiver field and set new type
+       newField := *m.Decl.Recv.List[0]
+       _, origRecvIsPtr := newField.Type.(*ast.StarExpr)
+       var typ ast.Expr = ast.NewIdent(recvTypeName)
+       if !embeddedIsPtr && origRecvIsPtr {
+               typ = &ast.StarExpr{token.NoPos, typ}
+       }
+       newField.Type = typ
+
+       // copy existing receiver field list and set new receiver field
+       newFieldList := *m.Decl.Recv
+       newFieldList.List = []*ast.Field{&newField}
+
+       // copy existing function declaration and set new receiver field list
+       newFuncDecl := *m.Decl
+       newFuncDecl.Recv = &newFieldList
+
+       // copy existing function documentation and set new declaration
+       newM := *m
+       newM.Decl = &newFuncDecl
+       newM.Recv = typ
+
+       return &newM
+}
+
+func makeBugs(list []*ast.CommentGroup) []string {
+       d := make([]string, len(list))
+       for i, g := range list {
+               d[i] = g.Text()
+       }
+       return d
+}
+
+// newDoc returns the accumulated documentation for the package.
+//
+func (doc *docReader) newDoc(importpath string, filenames []string) *Package {
+       p := new(Package)
+       p.Name = doc.pkgName
+       p.ImportPath = importpath
+       sort.Strings(filenames)
+       p.Filenames = filenames
+       p.Doc = doc.doc.Text()
+       // makeTypes may extend the list of doc.values and
+       // doc.funcs and thus must be called before any other
+       // function consuming those lists
+       p.Types = doc.makeTypes(doc.types)
+       p.Consts = makeValues(doc.values, token.CONST)
+       p.Vars = makeValues(doc.values, token.VAR)
+       p.Funcs = makeFuncs(doc.funcs)
+       p.Bugs = makeBugs(doc.bugs)
+       return p
+}
index be11f461c3b33b7351811c60335aad41c7761030..2ce3df8df714750db66d8e785fd18bda450c0351 100644 (file)
@@ -10,7 +10,6 @@ import (
        "bytes"
        "errors"
        "go/ast"
-       "go/scanner"
        "go/token"
        "io"
        "io/ioutil"
@@ -36,86 +35,28 @@ func readSource(filename string, src interface{}) ([]byte, error) {
                        }
                case io.Reader:
                        var buf bytes.Buffer
-                       _, err := io.Copy(&buf, s)
-                       if err != nil {
+                       if _, err := io.Copy(&buf, s); err != nil {
                                return nil, err
                        }
                        return buf.Bytes(), nil
-               default:
-                       return nil, errors.New("invalid source")
                }
+               return nil, errors.New("invalid source")
        }
-
        return ioutil.ReadFile(filename)
 }
 
-func (p *parser) errors() error {
-       mode := scanner.Sorted
-       if p.mode&SpuriousErrors == 0 {
-               mode = scanner.NoMultiples
-       }
-       return p.GetError(mode)
-}
-
-// ParseExpr parses a Go expression and returns the corresponding
-// AST node. The fset, filename, and src arguments have the same interpretation
-// as for ParseFile. If there is an error, the result expression
-// may be nil or contain a partial AST.
-//
-func ParseExpr(fset *token.FileSet, filename string, src interface{}) (ast.Expr, error) {
-       data, err := readSource(filename, src)
-       if err != nil {
-               return nil, err
-       }
-
-       var p parser
-       p.init(fset, filename, data, 0)
-       x := p.parseRhs()
-       if p.tok == token.SEMICOLON {
-               p.next() // consume automatically inserted semicolon, if any
-       }
-       p.expect(token.EOF)
-
-       return x, p.errors()
-}
-
-// ParseStmtList parses a list of Go statements and returns the list
-// of corresponding AST nodes. The fset, filename, and src arguments have the same
-// interpretation as for ParseFile. If there is an error, the node
-// list may be nil or contain partial ASTs.
+// The mode parameter to the Parse* functions is a set of flags (or 0).
+// They control the amount of source code parsed and other optional
+// parser functionality.
 //
-func ParseStmtList(fset *token.FileSet, filename string, src interface{}) ([]ast.Stmt, error) {
-       data, err := readSource(filename, src)
-       if err != nil {
-               return nil, err
-       }
-
-       var p parser
-       p.init(fset, filename, data, 0)
-       list := p.parseStmtList()
-       p.expect(token.EOF)
-
-       return list, p.errors()
-}
-
-// ParseDeclList parses a list of Go declarations and returns the list
-// of corresponding AST nodes. The fset, filename, and src arguments have the same
-// interpretation as for ParseFile. If there is an error, the node
-// list may be nil or contain partial ASTs.
-//
-func ParseDeclList(fset *token.FileSet, filename string, src interface{}) ([]ast.Decl, error) {
-       data, err := readSource(filename, src)
-       if err != nil {
-               return nil, err
-       }
-
-       var p parser
-       p.init(fset, filename, data, 0)
-       list := p.parseDeclList()
-       p.expect(token.EOF)
-
-       return list, p.errors()
-}
+const (
+       PackageClauseOnly uint = 1 << iota // parsing stops after package clause
+       ImportsOnly                        // parsing stops after import declarations
+       ParseComments                      // parse comments and add them to AST
+       Trace                              // print a trace of parsed productions
+       DeclarationErrors                  // report declaration errors
+       SpuriousErrors                     // report all (not just the first) errors per line
+)
 
 // ParseFile parses the source code of a single Go source file and returns
 // the corresponding ast.File node. The source code may be provided via
@@ -124,7 +65,6 @@ func ParseDeclList(fset *token.FileSet, filename string, src interface{}) ([]ast
 // If src != nil, ParseFile parses the source from src and the filename is
 // only used when recording position information. The type of the argument
 // for the src parameter must be string, []byte, or io.Reader.
-//
 // If src == nil, ParseFile parses the file specified by filename.
 //
 // The mode parameter controls the amount of source text parsed and other
@@ -133,49 +73,18 @@ func ParseDeclList(fset *token.FileSet, filename string, src interface{}) ([]ast
 //
 // If the source couldn't be read, the returned AST is nil and the error
 // indicates the specific failure. If the source was read but syntax
-// errors were found, the result is a partial AST (with ast.BadX nodes
+// errors were found, the result is a partial AST (with ast.Bad* nodes
 // representing the fragments of erroneous source code). Multiple errors
 // are returned via a scanner.ErrorList which is sorted by file position.
 //
 func ParseFile(fset *token.FileSet, filename string, src interface{}, mode uint) (*ast.File, error) {
-       data, err := readSource(filename, src)
+       text, err := readSource(filename, src)
        if err != nil {
                return nil, err
        }
-
        var p parser
-       p.init(fset, filename, data, mode)
-       file := p.parseFile() // parseFile reads to EOF
-
-       return file, p.errors()
-}
-
-// ParseFiles calls ParseFile for each file in the filenames list and returns
-// a map of package name -> package AST with all the packages found. The mode
-// bits are passed to ParseFile unchanged. Position information is recorded
-// in the file set fset.
-//
-// Files with parse errors are ignored. In this case the map of packages may
-// be incomplete (missing packages and/or incomplete packages) and the first
-// error encountered is returned.
-//
-func ParseFiles(fset *token.FileSet, filenames []string, mode uint) (pkgs map[string]*ast.Package, first error) {
-       pkgs = make(map[string]*ast.Package)
-       for _, filename := range filenames {
-               if src, err := ParseFile(fset, filename, nil, mode); err == nil {
-                       name := src.Name.Name
-                       pkg, found := pkgs[name]
-                       if !found {
-                               // TODO(gri) Use NewPackage here; reconsider ParseFiles API.
-                               pkg = &ast.Package{name, nil, nil, make(map[string]*ast.File)}
-                               pkgs[name] = pkg
-                       }
-                       pkg.Files[filename] = src
-               } else if first == nil {
-                       first = err
-               }
-       }
-       return
+       p.init(fset, filename, text, mode)
+       return p.parseFile(), p.errors()
 }
 
 // ParseDir calls ParseFile for the files in the directory specified by path and
@@ -186,9 +95,9 @@ func ParseFiles(fset *token.FileSet, filenames []string, mode uint) (pkgs map[st
 //
 // If the directory couldn't be read, a nil map and the respective error are
 // returned. If a parse error occurred, a non-nil but incomplete map and the
-// error are returned.
+// first error encountered are returned.
 //
-func ParseDir(fset *token.FileSet, path string, filter func(os.FileInfo) bool, mode uint) (map[string]*ast.Package, error) {
+func ParseDir(fset *token.FileSet, path string, filter func(os.FileInfo) bool, mode uint) (pkgs map[string]*ast.Package, first error) {
        fd, err := os.Open(path)
        if err != nil {
                return nil, err
@@ -200,15 +109,36 @@ func ParseDir(fset *token.FileSet, path string, filter func(os.FileInfo) bool, m
                return nil, err
        }
 
-       filenames := make([]string, len(list))
-       n := 0
+       pkgs = make(map[string]*ast.Package)
        for _, d := range list {
                if filter == nil || filter(d) {
-                       filenames[n] = filepath.Join(path, d.Name())
-                       n++
+                       filename := filepath.Join(path, d.Name())
+                       if src, err := ParseFile(fset, filename, nil, mode); err == nil {
+                               name := src.Name.Name
+                               pkg, found := pkgs[name]
+                               if !found {
+                                       pkg = &ast.Package{name, nil, nil, make(map[string]*ast.File)}
+                                       pkgs[name] = pkg
+                               }
+                               pkg.Files[filename] = src
+                       } else if first == nil {
+                               first = err
+                       }
                }
        }
-       filenames = filenames[0:n]
 
-       return ParseFiles(fset, filenames, mode)
+       return
+}
+
+// ParseExpr is a convenience function for obtaining the AST of an expression x.
+// The position information recorded in the AST is undefined.
+// 
+func ParseExpr(x string) (ast.Expr, error) {
+       // parse x within the context of a complete package for correct scopes;
+       // use //line directive for correct positions in error messages
+       file, err := ParseFile(token.NewFileSet(), "", "package p;func _(){_=\n//line :1\n"+x+";}", 0)
+       if err != nil {
+               return nil, err
+       }
+       return file.Decls[0].(*ast.FuncDecl).Body.List[0].(*ast.AssignStmt).Rhs[0], nil
 }
index 9fbed2d2ca35de31d605caf2336722ce22225d80..d90f5775df45eff5c6b51c10bffdb62a8e583f5d 100644 (file)
@@ -16,19 +16,6 @@ import (
        "go/token"
 )
 
-// The mode parameter to the Parse* functions is a set of flags (or 0).
-// They control the amount of source code parsed and other optional
-// parser functionality.
-//
-const (
-       PackageClauseOnly uint = 1 << iota // parsing stops after package clause
-       ImportsOnly                        // parsing stops after import declarations
-       ParseComments                      // parse comments and add them to AST
-       Trace                              // print a trace of parsed productions
-       DeclarationErrors                  // report declaration errors
-       SpuriousErrors                     // report all (not just the first) errors per line
-)
-
 // The parser structure holds the parser's internal state.
 type parser struct {
        file *token.File
@@ -65,18 +52,13 @@ type parser struct {
        targetStack [][]*ast.Ident // stack of unresolved labels
 }
 
-// scannerMode returns the scanner mode bits given the parser's mode bits.
-func scannerMode(mode uint) uint {
-       var m uint = scanner.InsertSemis
-       if mode&ParseComments != 0 {
-               m |= scanner.ScanComments
-       }
-       return m
-}
-
 func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode uint) {
        p.file = fset.AddFile(filename, fset.Base(), len(src))
-       p.scanner.Init(p.file, src, p, scannerMode(mode))
+       var m uint
+       if mode&ParseComments != 0 {
+               m = scanner.ScanComments
+       }
+       p.scanner.Init(p.file, src, p, m)
 
        p.mode = mode
        p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently)
@@ -92,6 +74,14 @@ func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode uin
        p.openLabelScope()
 }
 
+func (p *parser) errors() error {
+       m := scanner.Sorted
+       if p.mode&SpuriousErrors == 0 {
+               m = scanner.NoMultiples
+       }
+       return p.GetError(m)
+}
+
 // ----------------------------------------------------------------------------
 // Scoping support
 
@@ -2109,18 +2099,6 @@ func (p *parser) parseDecl() ast.Decl {
        return p.parseGenDecl(p.tok, f)
 }
 
-func (p *parser) parseDeclList() (list []ast.Decl) {
-       if p.trace {
-               defer un(trace(p, "DeclList"))
-       }
-
-       for p.tok != token.EOF {
-               list = append(list, p.parseDecl())
-       }
-
-       return
-}
-
 // ----------------------------------------------------------------------------
 // Source files
 
index f602db8896db5e02e6fe962a98d0f3a656f21fd4..a3ee8525de29b931e73594c9570f3b91684e2139 100644 (file)
@@ -54,7 +54,7 @@ func TestParseIllegalInputs(t *testing.T) {
        }
 }
 
-var validPrograms = []interface{}{
+var validPrograms = []string{
        "package p\n",
        `package p;`,
        `package p; import "fmt"; func f() { fmt.Println("Hello, World!") };`,
@@ -136,6 +136,32 @@ func TestParse4(t *testing.T) {
        }
 }
 
+func TestParseExpr(t *testing.T) {
+       // just kicking the tires:
+       // a valid expression
+       src := "a + b"
+       x, err := ParseExpr(src)
+       if err != nil {
+               t.Errorf("ParseExpr(%s): %v", src, err)
+       }
+       // sanity check
+       if _, ok := x.(*ast.BinaryExpr); !ok {
+               t.Errorf("ParseExpr(%s): got %T, expected *ast.BinaryExpr", src, x)
+       }
+
+       // an invalid expression
+       src = "a + *"
+       _, err = ParseExpr(src)
+       if err == nil {
+               t.Errorf("ParseExpr(%s): %v", src, err)
+       }
+
+       // it must not crash
+       for _, src := range validPrograms {
+               ParseExpr(src)
+       }
+}
+
 func TestColonEqualsScope(t *testing.T) {
        f, err := ParseFile(fset, "", `package p; func f() { x, y, z := x, y, z }`, 0)
        if err != nil {
index a78cfc65fccb372399048449b35c282900599d13..c720f2e665c479b263edd5c399a2573ac6b2e523 100644 (file)
@@ -773,8 +773,13 @@ func (p *printer) print(args ...interface{}) {
                                next = p.fset.Position(x) // accurate position of next item
                        }
                        tok = p.lastTok
+               case string:
+                       // incorrect AST - print error message
+                       data = x
+                       isLit = true
+                       tok = token.STRING
                default:
-                       fmt.Fprintf(os.Stderr, "print: unsupported argument type %T\n", f)
+                       fmt.Fprintf(os.Stderr, "print: unsupported argument %v (%T)\n", f, f)
                        panic("go/printer type")
                }
                p.lastTok = tok
index 45477d40f6e99ced3300a880c3c4e72754ceed34..525fcc1595f7ef096b51f3457892365fafb88e99 100644 (file)
@@ -204,3 +204,18 @@ func init() {
                panic("got " + s + ", want " + name)
        }
 }
+
+// Verify that the printer doesn't crash if the AST contains BadXXX nodes.
+func TestBadNodes(t *testing.T) {
+       const src = "package p\n("
+       const res = "package p\nBadDecl\n"
+       f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
+       if err == nil {
+               t.Errorf("expected illegal program")
+       }
+       var buf bytes.Buffer
+       Fprint(&buf, fset, f)
+       if buf.String() != res {
+               t.Errorf("got %q, expected %q", buf.String(), res)
+       }
+}
index 7fb0104e450a51d74df116cd347a137e0ccdbbb1..59a796574f60bb21051984dcbd51c984c531bfa5 100644 (file)
@@ -90,8 +90,8 @@ func (S *Scanner) next() {
 // They control scanner behavior.
 //
 const (
-       ScanComments = 1 << iota // return comments as COMMENT tokens
-       InsertSemis              // automatically insert semicolons
+       ScanComments    = 1 << iota // return comments as COMMENT tokens
+       dontInsertSemis             // do not automatically insert semicolons - for testing only
 )
 
 // Init prepares the scanner S to tokenize the text src by setting the
@@ -104,7 +104,7 @@ const (
 // Calls to Scan will use the error handler err if they encounter a
 // syntax error and err is not nil. Also, for each error encountered,
 // the Scanner field ErrorCount is incremented by one. The mode parameter
-// determines how comments, illegal characters, and semicolons are handled.
+// determines how comments are handled.
 //
 // Note that Init may call err if there is an error in the first character
 // of the file.
@@ -157,7 +157,7 @@ func (S *Scanner) interpretLineComment(text []byte) {
        }
 }
 
-func (S *Scanner) scanComment() {
+func (S *Scanner) scanComment() string {
        // initial '/' already consumed; S.ch == '/' || S.ch == '*'
        offs := S.offset - 1 // position of initial '/'
 
@@ -171,7 +171,7 @@ func (S *Scanner) scanComment() {
                        // comment starts at the beginning of the current line
                        S.interpretLineComment(S.src[offs:S.offset])
                }
-               return
+               goto exit
        }
 
        /*-style comment */
@@ -181,11 +181,14 @@ func (S *Scanner) scanComment() {
                S.next()
                if ch == '*' && S.ch == '/' {
                        S.next()
-                       return
+                       goto exit
                }
        }
 
        S.error(offs, "comment not terminated")
+
+exit:
+       return string(S.src[offs:S.offset])
 }
 
 func (S *Scanner) findLineEnd() bool {
@@ -240,12 +243,12 @@ func isDigit(ch rune) bool {
        return '0' <= ch && ch <= '9' || ch >= 0x80 && unicode.IsDigit(ch)
 }
 
-func (S *Scanner) scanIdentifier() token.Token {
+func (S *Scanner) scanIdentifier() string {
        offs := S.offset
        for isLetter(S.ch) || isDigit(S.ch) {
                S.next()
        }
-       return token.Lookup(S.src[offs:S.offset])
+       return string(S.src[offs:S.offset])
 }
 
 func digitVal(ch rune) int {
@@ -266,11 +269,13 @@ func (S *Scanner) scanMantissa(base int) {
        }
 }
 
-func (S *Scanner) scanNumber(seenDecimalPoint bool) token.Token {
+func (S *Scanner) scanNumber(seenDecimalPoint bool) (token.Token, string) {
        // digitVal(S.ch) < 10
+       offs := S.offset
        tok := token.INT
 
        if seenDecimalPoint {
+               offs--
                tok = token.FLOAT
                S.scanMantissa(10)
                goto exponent
@@ -334,7 +339,7 @@ exponent:
        }
 
 exit:
-       return tok
+       return tok, string(S.src[offs:S.offset])
 }
 
 func (S *Scanner) scanEscape(quote rune) {
@@ -381,7 +386,7 @@ func (S *Scanner) scanEscape(quote rune) {
        }
 }
 
-func (S *Scanner) scanChar() {
+func (S *Scanner) scanChar() string {
        // '\'' opening already consumed
        offs := S.offset - 1
 
@@ -405,9 +410,11 @@ func (S *Scanner) scanChar() {
        if n != 1 {
                S.error(offs, "illegal character literal")
        }
+
+       return string(S.src[offs:S.offset])
 }
 
-func (S *Scanner) scanString() {
+func (S *Scanner) scanString() string {
        // '"' opening already consumed
        offs := S.offset - 1
 
@@ -424,12 +431,27 @@ func (S *Scanner) scanString() {
        }
 
        S.next()
+
+       return string(S.src[offs:S.offset])
+}
+
+func stripCR(b []byte) []byte {
+       c := make([]byte, len(b))
+       i := 0
+       for _, ch := range b {
+               if ch != '\r' {
+                       c[i] = ch
+                       i++
+               }
+       }
+       return c[:i]
 }
 
-func (S *Scanner) scanRawString() (hasCR bool) {
+func (S *Scanner) scanRawString() string {
        // '`' opening already consumed
        offs := S.offset - 1
 
+       hasCR := false
        for S.ch != '`' {
                ch := S.ch
                S.next()
@@ -443,7 +465,13 @@ func (S *Scanner) scanRawString() (hasCR bool) {
        }
 
        S.next()
-       return
+
+       lit := S.src[offs:S.offset]
+       if hasCR {
+               lit = stripCR(lit)
+       }
+
+       return string(lit)
 }
 
 func (S *Scanner) skipWhitespace() {
@@ -494,27 +522,24 @@ func (S *Scanner) switch4(tok0, tok1 token.Token, ch2 rune, tok2, tok3 token.Tok
        return tok0
 }
 
-func stripCR(b []byte) []byte {
-       c := make([]byte, len(b))
-       i := 0
-       for _, ch := range b {
-               if ch != '\r' {
-                       c[i] = ch
-                       i++
-               }
-       }
-       return c[:i]
-}
-
-// Scan scans the next token and returns the token position,
-// the token, and the literal string corresponding to the
-// token. The source end is indicated by token.EOF.
+// Scan scans the next token and returns the token position, the token,
+// and its literal string if applicable. The source end is indicated by
+// token.EOF.
+//
+// If the returned token is a literal (token.IDENT, token.INT, token.FLOAT,
+// token.IMAG, token.CHAR, token.STRING) or token.COMMENT, the literal string
+// has the corresponding value.
 //
 // If the returned token is token.SEMICOLON, the corresponding
 // literal string is ";" if the semicolon was present in the source,
 // and "\n" if the semicolon was inserted because of a newline or
 // at EOF.
 //
+// If the returned token is token.ILLEGAL, the literal string is the
+// offending character.
+//
+// In all other cases, Scan returns an empty literal string.
+//
 // For more tolerant parsing, Scan will return a valid token if
 // possible even if a syntax error was encountered. Thus, even
 // if the resulting token sequence contains no illegal tokens,
@@ -526,34 +551,33 @@ func stripCR(b []byte) []byte {
 // set with Init. Token positions are relative to that file
 // and thus relative to the file set.
 //
-func (S *Scanner) Scan() (token.Pos, token.Token, string) {
+func (S *Scanner) Scan() (pos token.Pos, tok token.Token, lit string) {
 scanAgain:
        S.skipWhitespace()
 
        // current token start
-       insertSemi := false
-       offs := S.offset
-       tok := token.ILLEGAL
-       hasCR := false
+       pos = S.file.Pos(S.offset)
 
        // determine token value
+       insertSemi := false
        switch ch := S.ch; {
        case isLetter(ch):
-               tok = S.scanIdentifier()
+               lit = S.scanIdentifier()
+               tok = token.Lookup(lit)
                switch tok {
                case token.IDENT, token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN:
                        insertSemi = true
                }
        case digitVal(ch) < 10:
                insertSemi = true
-               tok = S.scanNumber(false)
+               tok, lit = S.scanNumber(false)
        default:
                S.next() // always make progress
                switch ch {
                case -1:
                        if S.insertSemi {
                                S.insertSemi = false // EOF consumed
-                               return S.file.Pos(offs), token.SEMICOLON, "\n"
+                               return pos, token.SEMICOLON, "\n"
                        }
                        tok = token.EOF
                case '\n':
@@ -561,25 +585,25 @@ scanAgain:
                        // set in the first place and exited early
                        // from S.skipWhitespace()
                        S.insertSemi = false // newline consumed
-                       return S.file.Pos(offs), token.SEMICOLON, "\n"
+                       return pos, token.SEMICOLON, "\n"
                case '"':
                        insertSemi = true
                        tok = token.STRING
-                       S.scanString()
+                       lit = S.scanString()
                case '\'':
                        insertSemi = true
                        tok = token.CHAR
-                       S.scanChar()
+                       lit = S.scanChar()
                case '`':
                        insertSemi = true
                        tok = token.STRING
-                       hasCR = S.scanRawString()
+                       lit = S.scanRawString()
                case ':':
                        tok = S.switch2(token.COLON, token.DEFINE)
                case '.':
                        if digitVal(S.ch) < 10 {
                                insertSemi = true
-                               tok = S.scanNumber(true)
+                               tok, lit = S.scanNumber(true)
                        } else if S.ch == '.' {
                                S.next()
                                if S.ch == '.' {
@@ -593,6 +617,7 @@ scanAgain:
                        tok = token.COMMA
                case ';':
                        tok = token.SEMICOLON
+                       lit = ";"
                case '(':
                        tok = token.LPAREN
                case ')':
@@ -626,12 +651,12 @@ scanAgain:
                                if S.insertSemi && S.findLineEnd() {
                                        // reset position to the beginning of the comment
                                        S.ch = '/'
-                                       S.offset = offs
-                                       S.rdOffset = offs + 1
+                                       S.offset = S.file.Offset(pos)
+                                       S.rdOffset = S.offset + 1
                                        S.insertSemi = false // newline consumed
-                                       return S.file.Pos(offs), token.SEMICOLON, "\n"
+                                       return pos, token.SEMICOLON, "\n"
                                }
-                               S.scanComment()
+                               lit = S.scanComment()
                                if S.mode&ScanComments == 0 {
                                        // skip comment
                                        S.insertSemi = false // newline consumed
@@ -668,21 +693,15 @@ scanAgain:
                case '|':
                        tok = S.switch3(token.OR, token.OR_ASSIGN, '|', token.LOR)
                default:
-                       S.error(offs, fmt.Sprintf("illegal character %#U", ch))
+                       S.error(S.file.Offset(pos), fmt.Sprintf("illegal character %#U", ch))
                        insertSemi = S.insertSemi // preserve insertSemi info
+                       tok = token.ILLEGAL
+                       lit = string(ch)
                }
        }
-
-       if S.mode&InsertSemis != 0 {
+       if S.mode&dontInsertSemis == 0 {
                S.insertSemi = insertSemi
        }
 
-       // TODO(gri): The scanner API should change such that the literal string
-       //            is only valid if an actual literal was scanned. This will
-       //            permit a more efficient implementation.
-       lit := S.src[offs:S.offset]
-       if hasCR {
-               lit = stripCR(lit)
-       }
-       return S.file.Pos(offs), tok, string(lit)
+       return
 }
index dc8ab2a748a08a7826ccef1dd103ea1110a13779..2e4dd4fff638860b79aaf2a1e8a86930a014060f 100644 (file)
@@ -177,6 +177,15 @@ var tokens = [...]elt{
 
 const whitespace = "  \t  \n\n\n" // to separate tokens
 
+var source = func() []byte {
+       var src []byte
+       for _, t := range tokens {
+               src = append(src, t.lit...)
+               src = append(src, whitespace...)
+       }
+       return src
+}()
+
 type testErrorHandler struct {
        t *testing.T
 }
@@ -214,20 +223,20 @@ func checkPos(t *testing.T, lit string, p token.Pos, expected token.Position) {
 // Verify that calling Scan() provides the correct results.
 func TestScan(t *testing.T) {
        // make source
-       var src string
-       for _, e := range tokens {
-               src += e.lit + whitespace
-       }
-       src_linecount := newlineCount(src)
+       src_linecount := newlineCount(string(source))
        whitespace_linecount := newlineCount(whitespace)
 
        // verify scan
        var s Scanner
-       s.Init(fset.AddFile("", fset.Base(), len(src)), []byte(src), &testErrorHandler{t}, ScanComments)
+       s.Init(fset.AddFile("", fset.Base(), len(source)), source, &testErrorHandler{t}, ScanComments|dontInsertSemis)
        index := 0
        epos := token.Position{"", 0, 1, 1} // expected position
        for {
                pos, tok, lit := s.Scan()
+               if lit == "" {
+                       // no literal value for non-literal tokens
+                       lit = tok.String()
+               }
                e := elt{token.EOF, "", special}
                if index < len(tokens) {
                        e = tokens[index]
@@ -430,14 +439,14 @@ var lines = []string{
 
 func TestSemis(t *testing.T) {
        for _, line := range lines {
-               checkSemi(t, line, InsertSemis)
-               checkSemi(t, line, InsertSemis|ScanComments)
+               checkSemi(t, line, 0)
+               checkSemi(t, line, ScanComments)
 
                // if the input ended in newlines, the input must tokenize the
                // same with or without those newlines
                for i := len(line) - 1; i >= 0 && line[i] == '\n'; i-- {
-                       checkSemi(t, line[0:i], InsertSemis)
-                       checkSemi(t, line[0:i], InsertSemis|ScanComments)
+                       checkSemi(t, line[0:i], 0)
+                       checkSemi(t, line[0:i], ScanComments)
                }
        }
 }
@@ -492,7 +501,7 @@ func TestLineComments(t *testing.T) {
        // verify scan
        var S Scanner
        file := fset.AddFile(filepath.Join("dir", "TestLineComments"), fset.Base(), len(src))
-       S.Init(file, []byte(src), nil, 0)
+       S.Init(file, []byte(src), nil, dontInsertSemis)
        for _, s := range segs {
                p, _, lit := S.Scan()
                pos := file.Position(p)
@@ -511,7 +520,7 @@ func TestInit(t *testing.T) {
        // 1st init
        src1 := "if true { }"
        f1 := fset.AddFile("src1", fset.Base(), len(src1))
-       s.Init(f1, []byte(src1), nil, 0)
+       s.Init(f1, []byte(src1), nil, dontInsertSemis)
        if f1.Size() != len(src1) {
                t.Errorf("bad file size: got %d, expected %d", f1.Size(), len(src1))
        }
@@ -525,7 +534,7 @@ func TestInit(t *testing.T) {
        // 2nd init
        src2 := "go true { ]"
        f2 := fset.AddFile("src2", fset.Base(), len(src2))
-       s.Init(f2, []byte(src2), nil, 0)
+       s.Init(f2, []byte(src2), nil, dontInsertSemis)
        if f2.Size() != len(src2) {
                t.Errorf("bad file size: got %d, expected %d", f2.Size(), len(src2))
        }
@@ -551,7 +560,7 @@ func TestStdErrorHander(t *testing.T) {
 
        v := new(ErrorVector)
        var s Scanner
-       s.Init(fset.AddFile("File1", fset.Base(), len(src)), []byte(src), v, 0)
+       s.Init(fset.AddFile("File1", fset.Base(), len(src)), []byte(src), v, dontInsertSemis)
        for {
                if _, tok, _ := s.Scan(); tok == token.EOF {
                        break
@@ -596,7 +605,7 @@ func (h *errorCollector) Error(pos token.Position, msg string) {
 func checkError(t *testing.T, src string, tok token.Token, pos int, err string) {
        var s Scanner
        var h errorCollector
-       s.Init(fset.AddFile("", fset.Base(), len(src)), []byte(src), &h, ScanComments)
+       s.Init(fset.AddFile("", fset.Base(), len(src)), []byte(src), &h, ScanComments|dontInsertSemis)
        _, tok0, _ := s.Scan()
        _, tok1, _ := s.Scan()
        if tok0 != tok {
@@ -659,3 +668,20 @@ func TestScanErrors(t *testing.T) {
                checkError(t, e.src, e.tok, e.pos, e.err)
        }
 }
+
+func BenchmarkScan(b *testing.B) {
+       b.StopTimer()
+       fset := token.NewFileSet()
+       file := fset.AddFile("", fset.Base(), len(source))
+       var s Scanner
+       b.StartTimer()
+       for i := b.N - 1; i >= 0; i-- {
+               s.Init(file, source, nil, ScanComments)
+               for {
+                       _, tok, _ := s.Scan()
+                       if tok == token.EOF {
+                               break
+                       }
+               }
+       }
+}
index 557374052c91ad96d010e1b3973eb7b5632fab0c..84b6314d57af1c8a623fec7b5cd155ca32318c1b 100644 (file)
@@ -283,10 +283,8 @@ func init() {
 
 // Lookup maps an identifier to its keyword token or IDENT (if not a keyword).
 //
-func Lookup(ident []byte) Token {
-       // TODO Maps with []byte key are illegal because []byte does not
-       //      support == . Should find a more efficient solution eventually.
-       if tok, is_keyword := keywords[string(ident)]; is_keyword {
+func Lookup(ident string) Token {
+       if tok, is_keyword := keywords[ident]; is_keyword {
                return tok
        }
        return IDENT
@@ -295,16 +293,16 @@ func Lookup(ident []byte) Token {
 // Predicates
 
 // IsLiteral returns true for tokens corresponding to identifiers
-// and basic type literals; returns false otherwise.
+// and basic type literals; it returns false otherwise.
 //
 func (tok Token) IsLiteral() bool { return literal_beg < tok && tok < literal_end }
 
 // IsOperator returns true for tokens corresponding to operators and
-// delimiters; returns false otherwise.
+// delimiters; it returns false otherwise.
 //
 func (tok Token) IsOperator() bool { return operator_beg < tok && tok < operator_end }
 
 // IsKeyword returns true for tokens corresponding to keywords;
-// returns false otherwise.
+// it returns false otherwise.
 //
 func (tok Token) IsKeyword() bool { return keyword_beg < tok && tok < keyword_end }
index 0f9b4ad560d3268ab60e1d2f76d834580329daff..3ba81ce4d6f1655b32a0b6bae1f5a7d97b389f9e 100644 (file)
@@ -4,6 +4,42 @@
 
 package html
 
+import (
+       "strings"
+)
+
+func adjustForeignAttributes(aa []Attribute) {
+       for i, a := range aa {
+               if a.Key == "" || a.Key[0] != 'x' {
+                       continue
+               }
+               switch a.Key {
+               case "xlink:actuate", "xlink:arcrole", "xlink:href", "xlink:role", "xlink:show",
+                       "xlink:title", "xlink:type", "xml:base", "xml:lang", "xml:space", "xmlns:xlink":
+                       j := strings.Index(a.Key, ":")
+                       aa[i].Namespace = a.Key[:j]
+                       aa[i].Key = a.Key[j+1:]
+               }
+       }
+}
+
+func htmlIntegrationPoint(n *Node) bool {
+       if n.Type != ElementNode {
+               return false
+       }
+       switch n.Namespace {
+       case "math":
+               // TODO: annotation-xml elements whose start tags have "text/html" or
+               // "application/xhtml+xml" encodings.
+       case "svg":
+               switch n.Data {
+               case "desc", "foreignObject", "title":
+                       return true
+               }
+       }
+       return false
+}
+
 // Section 12.2.5.5.
 var breakout = map[string]bool{
        "b":          true,
@@ -53,4 +89,44 @@ var breakout = map[string]bool{
        "var":        true,
 }
 
-// TODO: add look-up tables for MathML and SVG adjustments.
+// Section 12.2.5.5.
+var svgTagNameAdjustments = map[string]string{
+       "altglyph":            "altGlyph",
+       "altglyphdef":         "altGlyphDef",
+       "altglyphitem":        "altGlyphItem",
+       "animatecolor":        "animateColor",
+       "animatemotion":       "animateMotion",
+       "animatetransform":    "animateTransform",
+       "clippath":            "clipPath",
+       "feblend":             "feBlend",
+       "fecolormatrix":       "feColorMatrix",
+       "fecomponenttransfer": "feComponentTransfer",
+       "fecomposite":         "feComposite",
+       "feconvolvematrix":    "feConvolveMatrix",
+       "fediffuselighting":   "feDiffuseLighting",
+       "fedisplacementmap":   "feDisplacementMap",
+       "fedistantlight":      "feDistantLight",
+       "feflood":             "feFlood",
+       "fefunca":             "feFuncA",
+       "fefuncb":             "feFuncB",
+       "fefuncg":             "feFuncG",
+       "fefuncr":             "feFuncR",
+       "fegaussianblur":      "feGaussianBlur",
+       "feimage":             "feImage",
+       "femerge":             "feMerge",
+       "femergenode":         "feMergeNode",
+       "femorphology":        "feMorphology",
+       "feoffset":            "feOffset",
+       "fepointlight":        "fePointLight",
+       "fespecularlighting":  "feSpecularLighting",
+       "fespotlight":         "feSpotLight",
+       "fetile":              "feTile",
+       "feturbulence":        "feTurbulence",
+       "foreignobject":       "foreignObject",
+       "glyphref":            "glyphRef",
+       "lineargradient":      "linearGradient",
+       "radialgradient":      "radialGradient",
+       "textpath":            "textPath",
+}
+
+// TODO: add look-up tables for MathML and SVG attribute adjustments.
index 4ba3f5fb627117893f1ace65646d172db34d7819..83f17308b18a10de4f0162bdd7621a87a57a2eb1 100644 (file)
@@ -26,6 +26,10 @@ var scopeMarker = Node{Type: scopeMarkerNode}
 // content for text) and are part of a tree of Nodes. Element nodes may also
 // have a Namespace and contain a slice of Attributes. Data is unescaped, so
 // that it looks like "a<b" rather than "a&lt;b".
+//
+// An empty Namespace implies a "http://www.w3.org/1999/xhtml" namespace.
+// Similarly, "math" is short for "http://www.w3.org/1998/Math/MathML", and
+// "svg" is short for "http://www.w3.org/2000/svg".
 type Node struct {
        Parent    *Node
        Child     []*Node
index 6962e643932cae1a27a559697393ca33400de8fb..43c04727ab83d7adfdbbab3372d54d1c9e2427b5 100644 (file)
@@ -51,58 +51,87 @@ func (p *parser) top() *Node {
        return p.doc
 }
 
-// stopTags for use in popUntil. These come from section 12.2.3.2.
+// Stop tags for use in popUntil. These come from section 12.2.3.2.
 var (
-       defaultScopeStopTags  = []string{"applet", "caption", "html", "table", "td", "th", "marquee", "object"}
-       listItemScopeStopTags = []string{"applet", "caption", "html", "table", "td", "th", "marquee", "object", "ol", "ul"}
-       buttonScopeStopTags   = []string{"applet", "caption", "html", "table", "td", "th", "marquee", "object", "button"}
-       tableScopeStopTags    = []string{"html", "table"}
+       defaultScopeStopTags = map[string][]string{
+               "":     {"applet", "caption", "html", "table", "td", "th", "marquee", "object"},
+               "math": {"annotation-xml", "mi", "mn", "mo", "ms", "mtext"},
+               "svg":  {"desc", "foreignObject", "title"},
+       }
 )
 
-// stopTags for use in clearStackToContext.
-var (
-       tableRowContextStopTags = []string{"tr", "html"}
+type scope int
+
+const (
+       defaultScope scope = iota
+       listItemScope
+       buttonScope
+       tableScope
+       tableRowScope
 )
 
 // popUntil pops the stack of open elements at the highest element whose tag
-// is in matchTags, provided there is no higher element in stopTags. It returns
-// whether or not there was such an element. If there was not, popUntil leaves
-// the stack unchanged.
+// is in matchTags, provided there is no higher element in the scope's stop
+// tags (as defined in section 12.2.3.2). It returns whether or not there was
+// such an element. If there was not, popUntil leaves the stack unchanged.
 //
-// For example, if the stack was:
+// For example, the set of stop tags for table scope is: "html", "table". If
+// the stack was:
 // ["html", "body", "font", "table", "b", "i", "u"]
-// then popUntil([]string{"html, "table"}, "font") would return false, but
-// popUntil([]string{"html, "table"}, "i") would return true and the resultant
-// stack would be:
+// then popUntil(tableScope, "font") would return false, but
+// popUntil(tableScope, "i") would return true and the stack would become:
 // ["html", "body", "font", "table", "b"]
 //
-// If an element's tag is in both stopTags and matchTags, then the stack will
-// be popped and the function returns true (provided, of course, there was no
-// higher element in the stack that was also in stopTags). For example,
-// popUntil([]string{"html, "table"}, "table") would return true and leave:
+// If an element's tag is in both the stop tags and matchTags, then the stack
+// will be popped and the function returns true (provided, of course, there was
+// no higher element in the stack that was also in the stop tags). For example,
+// popUntil(tableScope, "table") returns true and leaves:
 // ["html", "body", "font"]
-func (p *parser) popUntil(stopTags []string, matchTags ...string) bool {
-       if i := p.indexOfElementInScope(stopTags, matchTags...); i != -1 {
+func (p *parser) popUntil(s scope, matchTags ...string) bool {
+       if i := p.indexOfElementInScope(s, matchTags...); i != -1 {
                p.oe = p.oe[:i]
                return true
        }
        return false
 }
 
-// indexOfElementInScope returns the index in p.oe of the highest element
-// whose tag is in matchTags that is in scope according to stopTags.
-// If no matching element is in scope, it returns -1.
-func (p *parser) indexOfElementInScope(stopTags []string, matchTags ...string) int {
+// indexOfElementInScope returns the index in p.oe of the highest element whose
+// tag is in matchTags that is in scope. If no matching element is in scope, it
+// returns -1.
+func (p *parser) indexOfElementInScope(s scope, matchTags ...string) int {
        for i := len(p.oe) - 1; i >= 0; i-- {
                tag := p.oe[i].Data
-               for _, t := range matchTags {
-                       if t == tag {
-                               return i
+               if p.oe[i].Namespace == "" {
+                       for _, t := range matchTags {
+                               if t == tag {
+                                       return i
+                               }
+                       }
+                       switch s {
+                       case defaultScope:
+                               // No-op.
+                       case listItemScope:
+                               if tag == "ol" || tag == "ul" {
+                                       return -1
+                               }
+                       case buttonScope:
+                               if tag == "button" {
+                                       return -1
+                               }
+                       case tableScope:
+                               if tag == "html" || tag == "table" {
+                                       return -1
+                               }
+                       default:
+                               panic("unreachable")
                        }
                }
-               for _, t := range stopTags {
-                       if t == tag {
-                               return -1
+               switch s {
+               case defaultScope, listItemScope, buttonScope:
+                       for _, t := range defaultScopeStopTags[p.oe[i].Namespace] {
+                               if t == tag {
+                                       return -1
+                               }
                        }
                }
        }
@@ -111,8 +140,30 @@ func (p *parser) indexOfElementInScope(stopTags []string, matchTags ...string) i
 
 // elementInScope is like popUntil, except that it doesn't modify the stack of
 // open elements.
-func (p *parser) elementInScope(stopTags []string, matchTags ...string) bool {
-       return p.indexOfElementInScope(stopTags, matchTags...) != -1
+func (p *parser) elementInScope(s scope, matchTags ...string) bool {
+       return p.indexOfElementInScope(s, matchTags...) != -1
+}
+
+// clearStackToContext pops elements off the stack of open elements until a
+// scope-defined element is found.
+func (p *parser) clearStackToContext(s scope) {
+       for i := len(p.oe) - 1; i >= 0; i-- {
+               tag := p.oe[i].Data
+               switch s {
+               case tableScope:
+                       if tag == "html" || tag == "table" {
+                               p.oe = p.oe[:i+1]
+                               return
+                       }
+               case tableRowScope:
+                       if tag == "html" || tag == "tr" {
+                               p.oe = p.oe[:i+1]
+                               return
+                       }
+               default:
+                       panic("unreachable")
+               }
+       }
 }
 
 // addChild adds a child node n to the top element, and pushes n onto the stack
@@ -192,10 +243,9 @@ func (p *parser) addText(text string) {
 // addElement calls addChild with an element node.
 func (p *parser) addElement(tag string, attr []Attribute) {
        p.addChild(&Node{
-               Type:      ElementNode,
-               Data:      tag,
-               Namespace: p.top().Namespace,
-               Attr:      attr,
+               Type: ElementNode,
+               Data: tag,
+               Attr: attr,
        })
 }
 
@@ -624,10 +674,10 @@ func inBodyIM(p *parser) bool {
                case "html":
                        copyAttributes(p.oe[0], p.tok)
                case "address", "article", "aside", "blockquote", "center", "details", "dir", "div", "dl", "fieldset", "figcaption", "figure", "footer", "header", "hgroup", "menu", "nav", "ol", "p", "section", "summary", "ul":
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.addElement(p.tok.Data, p.tok.Attr)
                case "h1", "h2", "h3", "h4", "h5", "h6":
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        switch n := p.top(); n.Data {
                        case "h1", "h2", "h3", "h4", "h5", "h6":
                                p.oe.pop()
@@ -649,7 +699,7 @@ func inBodyIM(p *parser) bool {
                        p.addFormattingElement(p.tok.Data, p.tok.Attr)
                case "nobr":
                        p.reconstructActiveFormattingElements()
-                       if p.elementInScope(defaultScopeStopTags, "nobr") {
+                       if p.elementInScope(defaultScope, "nobr") {
                                p.inBodyEndTagFormatting("nobr")
                                p.reconstructActiveFormattingElements()
                        }
@@ -667,14 +717,14 @@ func inBodyIM(p *parser) bool {
                        p.framesetOK = false
                case "table":
                        if !p.quirks {
-                               p.popUntil(buttonScopeStopTags, "p")
+                               p.popUntil(buttonScope, "p")
                        }
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.framesetOK = false
                        p.im = inTableIM
                        return true
                case "hr":
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.oe.pop()
                        p.acknowledgeSelfClosingTag()
@@ -683,12 +733,11 @@ func inBodyIM(p *parser) bool {
                        p.reconstructActiveFormattingElements()
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.framesetOK = false
-                       // TODO: detect <select> inside a table.
                        p.im = inSelectIM
                        return true
                case "form":
                        if p.form == nil {
-                               p.popUntil(buttonScopeStopTags, "p")
+                               p.popUntil(buttonScope, "p")
                                p.addElement(p.tok.Data, p.tok.Attr)
                                p.form = p.top()
                        }
@@ -698,7 +747,7 @@ func inBodyIM(p *parser) bool {
                                node := p.oe[i]
                                switch node.Data {
                                case "li":
-                                       p.popUntil(listItemScopeStopTags, "li")
+                                       p.popUntil(listItemScope, "li")
                                case "address", "div", "p":
                                        continue
                                default:
@@ -708,7 +757,7 @@ func inBodyIM(p *parser) bool {
                                }
                                break
                        }
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.addElement(p.tok.Data, p.tok.Attr)
                case "dd", "dt":
                        p.framesetOK = false
@@ -726,13 +775,13 @@ func inBodyIM(p *parser) bool {
                                }
                                break
                        }
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.addElement(p.tok.Data, p.tok.Attr)
                case "plaintext":
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.addElement(p.tok.Data, p.tok.Attr)
                case "button":
-                       p.popUntil(defaultScopeStopTags, "button")
+                       p.popUntil(defaultScope, "button")
                        p.reconstructActiveFormattingElements()
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.framesetOK = false
@@ -750,6 +799,19 @@ func inBodyIM(p *parser) bool {
                                        copyAttributes(body, p.tok)
                                }
                        }
+               case "frameset":
+                       if !p.framesetOK || len(p.oe) < 2 || p.oe[1].Data != "body" {
+                               // Ignore the token.
+                               return true
+                       }
+                       body := p.oe[1]
+                       if body.Parent != nil {
+                               body.Parent.Remove(body)
+                       }
+                       p.oe = p.oe[:1]
+                       p.addElement(p.tok.Data, p.tok.Attr)
+                       p.im = inFramesetIM
+                       return true
                case "base", "basefont", "bgsound", "command", "link", "meta", "noframes", "script", "style", "title":
                        return inHeadIM(p)
                case "image":
@@ -776,7 +838,7 @@ func inBodyIM(p *parser) bool {
                                }
                        }
                        p.acknowledgeSelfClosingTag()
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.addElement("form", nil)
                        p.form = p.top()
                        if action != "" {
@@ -794,23 +856,20 @@ func inBodyIM(p *parser) bool {
                        p.oe.pop()
                        p.form = nil
                case "xmp":
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                        p.reconstructActiveFormattingElements()
                        p.framesetOK = false
                        p.addElement(p.tok.Data, p.tok.Attr)
                case "math", "svg":
                        p.reconstructActiveFormattingElements()
-                       namespace := ""
                        if p.tok.Data == "math" {
                                // TODO: adjust MathML attributes.
-                               namespace = "mathml"
                        } else {
                                // TODO: adjust SVG attributes.
-                               namespace = "svg"
                        }
-                       // TODO: adjust foreign attributes.
+                       adjustForeignAttributes(p.tok.Attr)
                        p.addElement(p.tok.Data, p.tok.Attr)
-                       p.top().Namespace = namespace
+                       p.top().Namespace = p.tok.Data
                        return true
                case "caption", "col", "colgroup", "frame", "head", "tbody", "td", "tfoot", "th", "thead", "tr":
                        // Ignore the token.
@@ -825,16 +884,16 @@ func inBodyIM(p *parser) bool {
                        p.im = afterBodyIM
                        return true
                case "p":
-                       if !p.elementInScope(buttonScopeStopTags, "p") {
+                       if !p.elementInScope(buttonScope, "p") {
                                p.addElement("p", nil)
                        }
-                       p.popUntil(buttonScopeStopTags, "p")
+                       p.popUntil(buttonScope, "p")
                case "a", "b", "big", "code", "em", "font", "i", "nobr", "s", "small", "strike", "strong", "tt", "u":
                        p.inBodyEndTagFormatting(p.tok.Data)
                case "address", "article", "aside", "blockquote", "button", "center", "details", "dir", "div", "dl", "fieldset", "figcaption", "figure", "footer", "header", "hgroup", "listing", "menu", "nav", "ol", "pre", "section", "summary", "ul":
-                       p.popUntil(defaultScopeStopTags, p.tok.Data)
+                       p.popUntil(defaultScope, p.tok.Data)
                case "applet", "marquee", "object":
-                       if p.popUntil(defaultScopeStopTags, p.tok.Data) {
+                       if p.popUntil(defaultScope, p.tok.Data) {
                                p.clearActiveFormattingElements()
                        }
                case "br":
@@ -883,7 +942,7 @@ func (p *parser) inBodyEndTagFormatting(tag string) {
                        p.afe.remove(formattingElement)
                        return
                }
-               if !p.elementInScope(defaultScopeStopTags, tag) {
+               if !p.elementInScope(defaultScope, tag) {
                        // Ignore the tag.
                        return
                }
@@ -1017,45 +1076,56 @@ func inTableIM(p *parser) bool {
        case StartTagToken:
                switch p.tok.Data {
                case "caption":
-                       p.clearStackToContext(tableScopeStopTags)
+                       p.clearStackToContext(tableScope)
                        p.afe = append(p.afe, &scopeMarker)
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.im = inCaptionIM
                        return true
                case "tbody", "tfoot", "thead":
-                       p.clearStackToContext(tableScopeStopTags)
+                       p.clearStackToContext(tableScope)
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.im = inTableBodyIM
                        return true
                case "td", "th", "tr":
-                       p.clearStackToContext(tableScopeStopTags)
+                       p.clearStackToContext(tableScope)
                        p.addElement("tbody", nil)
                        p.im = inTableBodyIM
                        return false
                case "table":
-                       if p.popUntil(tableScopeStopTags, "table") {
+                       if p.popUntil(tableScope, "table") {
                                p.resetInsertionMode()
                                return false
                        }
                        // Ignore the token.
                        return true
                case "colgroup":
-                       p.clearStackToContext(tableScopeStopTags)
+                       p.clearStackToContext(tableScope)
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.im = inColumnGroupIM
                        return true
                case "col":
-                       p.clearStackToContext(tableScopeStopTags)
+                       p.clearStackToContext(tableScope)
                        p.addElement("colgroup", p.tok.Attr)
                        p.im = inColumnGroupIM
                        return false
+               case "select":
+                       p.reconstructActiveFormattingElements()
+                       switch p.top().Data {
+                       case "table", "tbody", "tfoot", "thead", "tr":
+                               p.fosterParenting = true
+                       }
+                       p.addElement(p.tok.Data, p.tok.Attr)
+                       p.fosterParenting = false
+                       p.framesetOK = false
+                       p.im = inSelectInTableIM
+                       return true
                default:
                        // TODO.
                }
        case EndTagToken:
                switch p.tok.Data {
                case "table":
-                       if p.popUntil(tableScopeStopTags, "table") {
+                       if p.popUntil(tableScope, "table") {
                                p.resetInsertionMode()
                                return true
                        }
@@ -1082,26 +1152,13 @@ func inTableIM(p *parser) bool {
        return inBodyIM(p)
 }
 
-// clearStackToContext pops elements off the stack of open elements
-// until an element listed in stopTags is found.
-func (p *parser) clearStackToContext(stopTags []string) {
-       for i := len(p.oe) - 1; i >= 0; i-- {
-               for _, tag := range stopTags {
-                       if p.oe[i].Data == tag {
-                               p.oe = p.oe[:i+1]
-                               return
-                       }
-               }
-       }
-}
-
 // Section 12.2.5.4.11.
 func inCaptionIM(p *parser) bool {
        switch p.tok.Type {
        case StartTagToken:
                switch p.tok.Data {
                case "caption", "col", "colgroup", "tbody", "td", "tfoot", "thead", "tr":
-                       if p.popUntil(tableScopeStopTags, "caption") {
+                       if p.popUntil(tableScope, "caption") {
                                p.clearActiveFormattingElements()
                                p.im = inTableIM
                                return false
@@ -1109,17 +1166,23 @@ func inCaptionIM(p *parser) bool {
                                // Ignore the token.
                                return true
                        }
+               case "select":
+                       p.reconstructActiveFormattingElements()
+                       p.addElement(p.tok.Data, p.tok.Attr)
+                       p.framesetOK = false
+                       p.im = inSelectInTableIM
+                       return true
                }
        case EndTagToken:
                switch p.tok.Data {
                case "caption":
-                       if p.popUntil(tableScopeStopTags, "caption") {
+                       if p.popUntil(tableScope, "caption") {
                                p.clearActiveFormattingElements()
                                p.im = inTableIM
                        }
                        return true
                case "table":
-                       if p.popUntil(tableScopeStopTags, "caption") {
+                       if p.popUntil(tableScope, "caption") {
                                p.clearActiveFormattingElements()
                                p.im = inTableIM
                                return false
@@ -1203,7 +1266,7 @@ func inTableBodyIM(p *parser) bool {
                        data = "tr"
                        consumed = false
                case "caption", "col", "colgroup", "tbody", "tfoot", "thead":
-                       if !p.popUntil(tableScopeStopTags, "tbody", "thead", "tfoot") {
+                       if !p.popUntil(tableScope, "tbody", "thead", "tfoot") {
                                // Ignore the token.
                                return true
                        }
@@ -1215,7 +1278,7 @@ func inTableBodyIM(p *parser) bool {
        case EndTagToken:
                switch p.tok.Data {
                case "table":
-                       if p.popUntil(tableScopeStopTags, "tbody", "thead", "tfoot") {
+                       if p.popUntil(tableScope, "tbody", "thead", "tfoot") {
                                p.im = inTableIM
                                return false
                        }
@@ -1251,13 +1314,13 @@ func inRowIM(p *parser) bool {
        case StartTagToken:
                switch p.tok.Data {
                case "td", "th":
-                       p.clearStackToContext(tableRowContextStopTags)
+                       p.clearStackToContext(tableRowScope)
                        p.addElement(p.tok.Data, p.tok.Attr)
                        p.afe = append(p.afe, &scopeMarker)
                        p.im = inCellIM
                        return true
                case "caption", "col", "colgroup", "tbody", "tfoot", "thead", "tr":
-                       if p.popUntil(tableScopeStopTags, "tr") {
+                       if p.popUntil(tableScope, "tr") {
                                p.im = inTableBodyIM
                                return false
                        }
@@ -1269,14 +1332,14 @@ func inRowIM(p *parser) bool {
        case EndTagToken:
                switch p.tok.Data {
                case "tr":
-                       if p.popUntil(tableScopeStopTags, "tr") {
+                       if p.popUntil(tableScope, "tr") {
                                p.im = inTableBodyIM
                                return true
                        }
                        // Ignore the token.
                        return true
                case "table":
-                       if p.popUntil(tableScopeStopTags, "tr") {
+                       if p.popUntil(tableScope, "tr") {
                                p.im = inTableBodyIM
                                return false
                        }
@@ -1311,11 +1374,17 @@ func inCellIM(p *parser) bool {
                case "caption", "col", "colgroup", "tbody", "td", "tfoot", "th", "thead", "tr":
                        // TODO: check for "td" or "th" in table scope.
                        closeTheCellAndReprocess = true
+               case "select":
+                       p.reconstructActiveFormattingElements()
+                       p.addElement(p.tok.Data, p.tok.Attr)
+                       p.framesetOK = false
+                       p.im = inSelectInTableIM
+                       return true
                }
        case EndTagToken:
                switch p.tok.Data {
                case "td", "th":
-                       if !p.popUntil(tableScopeStopTags, p.tok.Data) {
+                       if !p.popUntil(tableScope, p.tok.Data) {
                                // Ignore the token.
                                return true
                        }
@@ -1336,7 +1405,7 @@ func inCellIM(p *parser) bool {
                return true
        }
        if closeTheCellAndReprocess {
-               if p.popUntil(tableScopeStopTags, "td") || p.popUntil(tableScopeStopTags, "th") {
+               if p.popUntil(tableScope, "td") || p.popUntil(tableScope, "th") {
                        p.clearActiveFormattingElements()
                        p.im = inRowIM
                        return false
@@ -1405,21 +1474,40 @@ func inSelectIM(p *parser) bool {
                })
        }
        if endSelect {
-               for i := len(p.oe) - 1; i >= 0; i-- {
-                       switch p.oe[i].Data {
-                       case "select":
-                               p.oe = p.oe[:i]
-                               p.resetInsertionMode()
-                               return true
-                       case "option", "optgroup":
-                               continue
-                       default:
+               p.endSelect()
+       }
+       return true
+}
+
+// Section 12.2.5.4.17.
+func inSelectInTableIM(p *parser) bool {
+       switch p.tok.Type {
+       case StartTagToken, EndTagToken:
+               switch p.tok.Data {
+               case "caption", "table", "tbody", "tfoot", "thead", "tr", "td", "th":
+                       if p.tok.Type == StartTagToken || p.elementInScope(tableScope, p.tok.Data) {
+                               p.endSelect()
+                               return false
+                       } else {
                                // Ignore the token.
                                return true
                        }
                }
        }
-       return true
+       return inSelectIM(p)
+}
+
+func (p *parser) endSelect() {
+       for i := len(p.oe) - 1; i >= 0; i-- {
+               switch p.oe[i].Data {
+               case "option", "optgroup":
+                       continue
+               case "select":
+                       p.oe = p.oe[:i]
+                       p.resetInsertionMode()
+               }
+               return
+       }
 }
 
 // Section 12.2.5.4.18.
@@ -1618,6 +1706,11 @@ func parseForeignContent(p *parser) bool {
                        Data: p.tok.Data,
                })
        case StartTagToken:
+               if htmlIntegrationPoint(p.top()) {
+                       inBodyIM(p)
+                       p.resetInsertionMode()
+                       return true
+               }
                if breakout[p.tok.Data] {
                        for i := len(p.oe) - 1; i >= 0; i-- {
                                // TODO: HTML, MathML integration points.
@@ -1629,16 +1722,22 @@ func parseForeignContent(p *parser) bool {
                        return false
                }
                switch p.top().Namespace {
-               case "mathml":
+               case "math":
                        // TODO: adjust MathML attributes.
                case "svg":
-                       // TODO: adjust SVG tag names.
+                       // Adjust SVG tag names. The tokenizer lower-cases tag names, but
+                       // SVG wants e.g. "foreignObject" with a capital second "O".
+                       if x := svgTagNameAdjustments[p.tok.Data]; x != "" {
+                               p.tok.Data = x
+                       }
                        // TODO: adjust SVG attributes.
                default:
                        panic("html: bad parser state: unexpected namespace")
                }
-               // TODO: adjust foreign attributes.
+               adjustForeignAttributes(p.tok.Attr)
+               namespace := p.top().Namespace
                p.addElement(p.tok.Data, p.tok.Attr)
+               p.top().Namespace = namespace
        case EndTagToken:
                for i := len(p.oe) - 1; i >= 0; i-- {
                        if p.oe[i].Namespace == "" {
index 015b5838f0b50b3a0e8f33e39c545dc68988a60c..c929c257727f2e27fbdef26bd2645d18b3b7083a 100644 (file)
@@ -103,10 +103,21 @@ func dumpLevel(w io.Writer, n *Node, level int) error {
                } else {
                        fmt.Fprintf(w, "<%s>", n.Data)
                }
-               for _, a := range n.Attr {
+               attr := n.Attr
+               if len(attr) == 2 && attr[0].Namespace == "xml" && attr[1].Namespace == "xlink" {
+                       // Some of the test cases in tests10.dat change the order of adjusted
+                       // foreign attributes, but that behavior is not in the spec, and could
+                       // simply be an implementation detail of html5lib's python map ordering.
+                       attr[0], attr[1] = attr[1], attr[0]
+               }
+               for _, a := range attr {
                        io.WriteString(w, "\n")
                        dumpIndent(w, level+1)
-                       fmt.Fprintf(w, `%s="%s"`, a.Key, a.Val)
+                       if a.Namespace != "" {
+                               fmt.Fprintf(w, `%s %s="%s"`, a.Namespace, a.Key, a.Val)
+                       } else {
+                               fmt.Fprintf(w, `%s="%s"`, a.Key, a.Val)
+                       }
                }
        case TextNode:
                fmt.Fprintf(w, `"%s"`, n.Data)
@@ -172,8 +183,8 @@ func TestParser(t *testing.T) {
                {"tests3.dat", -1},
                {"tests4.dat", -1},
                {"tests5.dat", -1},
-               {"tests6.dat", 47},
-               {"tests10.dat", 16},
+               {"tests6.dat", -1},
+               {"tests10.dat", 33},
        }
        for _, tf := range testFiles {
                f, err := os.Open("testdata/webkit/" + tf.filename)
index 20751938d9d4922039af1964251c255845509c7b..07859faa7dd833df71f5410bdf67e044ec3aa437 100644 (file)
@@ -149,6 +149,14 @@ func render1(w writer, n *Node) error {
                if err := w.WriteByte(' '); err != nil {
                        return err
                }
+               if a.Namespace != "" {
+                       if _, err := w.WriteString(a.Namespace); err != nil {
+                               return err
+                       }
+                       if err := w.WriteByte(':'); err != nil {
+                               return err
+                       }
+               }
                if _, err := w.WriteString(a.Key); err != nil {
                        return err
                }
index a57f9826b5b5a4e75a5214cc1bf039ed04228082..2ce1fb566a59d079647993c14a5664ccdf94743e 100644 (file)
@@ -302,7 +302,7 @@ func TestEscape(t *testing.T) {
                },
                {
                        "styleObfuscatedExpressionBlocked",
-                       `<p style="width: {{"  e\78preS\0Sio/**/n(alert(1337))"}}">`,
+                       `<p style="width: {{"  e\\78preS\x00Sio/**/n(alert(1337))"}}">`,
                        `<p style="width: ZgotmplZ">`,
                },
                {
@@ -312,7 +312,7 @@ func TestEscape(t *testing.T) {
                },
                {
                        "styleObfuscatedMozBindingBlocked",
-                       `<p style="{{"  -mo\7a-B\0I/**/nding(alert(1337))"}}: ...">`,
+                       `<p style="{{"  -mo\\7a-B\x00I/**/nding(alert(1337))"}}: ...">`,
                        `<p style="ZgotmplZ: ...">`,
                },
                {
index 69af96840c272b1c8bfa9313b174315d24f1d9ee..5a385a1b5c544df4fd762e0e6e7c926d4a825b4e 100644 (file)
@@ -52,11 +52,14 @@ func (t TokenType) String() string {
        return "Invalid(" + strconv.Itoa(int(t)) + ")"
 }
 
-// An Attribute is an attribute key-value pair. Key is alphabetic (and hence
+// An Attribute is an attribute namespace-key-value triple. Namespace is
+// non-empty for foreign attributes like xlink, Key is alphabetic (and hence
 // does not contain escapable characters like '&', '<' or '>'), and Val is
 // unescaped (it looks like "a<b" rather than "a&lt;b").
+//
+// Namespace is only used by the parser, not the tokenizer.
 type Attribute struct {
-       Key, Val string
+       Namespace, Key, Val string
 }
 
 // A Token consists of a TokenType and some Data (tag name for start and end
@@ -756,7 +759,7 @@ func (z *Tokenizer) Token() Token {
                for moreAttr {
                        var key, val []byte
                        key, val, moreAttr = z.TagAttr()
-                       attr = append(attr, Attribute{string(key), string(val)})
+                       attr = append(attr, Attribute{"", string(key), string(val)})
                }
                t.Data = string(name)
                t.Attr = attr
index a7d1a5798315d35b0cce25004b35695acf123485..b830f88e1c454cb760aedd0b954b888447a8408a 100644 (file)
@@ -51,25 +51,25 @@ func NewUniform(c color.Color) *Uniform {
        return &Uniform{c}
 }
 
-// A Tiled is an infinite-sized Image that repeats another Image in both
-// directions. Tiled{i, p}.At(x, y) will equal i.At(x+p.X, y+p.Y) for all
+// Repeated is an infinite-sized Image that repeats another Image in both
+// directions. Repeated{i, p}.At(x, y) will equal i.At(x+p.X, y+p.Y) for all
 // points {x+p.X, y+p.Y} within i's Bounds.
-type Tiled struct {
+type Repeated struct {
        I      Image
        Offset Point
 }
 
-func (t *Tiled) ColorModel() color.Model {
-       return t.I.ColorModel()
+func (r *Repeated) ColorModel() color.Model {
+       return r.I.ColorModel()
 }
 
-func (t *Tiled) Bounds() Rectangle { return Rectangle{Point{-1e9, -1e9}, Point{1e9, 1e9}} }
+func (r *Repeated) Bounds() Rectangle { return Rectangle{Point{-1e9, -1e9}, Point{1e9, 1e9}} }
 
-func (t *Tiled) At(x, y int) color.Color {
-       p := Point{x, y}.Add(t.Offset).Mod(t.I.Bounds())
-       return t.I.At(p.X, p.Y)
+func (r *Repeated) At(x, y int) color.Color {
+       p := Point{x, y}.Add(r.Offset).Mod(r.I.Bounds())
+       return r.I.At(p.X, p.Y)
 }
 
-func NewTiled(i Image, offset Point) *Tiled {
-       return &Tiled{i, offset}
+func NewRepeated(i Image, offset Point) *Repeated {
+       return &Repeated{i, offset}
 }
index 914391af80d65132c8e9cf2151b676aaa20c5518..aef63480f161682b9066a432337ca8a54164bad2 100644 (file)
@@ -93,13 +93,19 @@ func (w *Writer) Emerg(m string) (err error) {
        return err
 }
 
+// Alert logs a message using the LOG_ALERT priority.
+func (w *Writer) Alert(m string) (err error) {
+       _, err = w.writeString(LOG_ALERT, m)
+       return err
+}
+
 // Crit logs a message using the LOG_CRIT priority.
 func (w *Writer) Crit(m string) (err error) {
        _, err = w.writeString(LOG_CRIT, m)
        return err
 }
 
-// ERR logs a message using the LOG_ERR priority.
+// Err logs a message using the LOG_ERR priority.
 func (w *Writer) Err(m string) (err error) {
        _, err = w.writeString(LOG_ERR, m)
        return err
index 101c8dd85b4ac6f755fef9481738f68fef08804d..ed66a42fb0088b851b3ac29ca5e5233a705506f8 100644 (file)
@@ -2214,8 +2214,8 @@ func TestLogb(t *testing.T) {
                }
        }
        for i := 0; i < len(vffrexpBC); i++ {
-               if e := Logb(vffrexpBC[i]); !alike(logbBC[i], e) {
-                       t.Errorf("Ilogb(%g) = %g, want %g", vffrexpBC[i], e, logbBC[i])
+               if f := Logb(vffrexpBC[i]); !alike(logbBC[i], f) {
+                       t.Errorf("Logb(%g) = %g, want %g", vffrexpBC[i], f, logbBC[i])
                }
        }
 }
@@ -2536,7 +2536,7 @@ func TestLargeTan(t *testing.T) {
 }
 
 // Check that math constants are accepted by compiler
-// and have right value (assumes strconv.Atof works).
+// and have right value (assumes strconv.ParseFloat works).
 // http://code.google.com/p/go/issues/detail?id=201
 
 type floatTest struct {
index 69681ae2d640b825ba3b1bcbf3be370efea7b38f..16f6ce9ba1bc6c1dd1ee053ce02a8c556861e728 100644 (file)
@@ -715,13 +715,13 @@ func (x nat) decimalString() string {
 
 // string converts x to a string using digits from a charset; a digit with
 // value d is represented by charset[d]. The conversion base is determined
-// by len(charset), which must be >= 2.
+// by len(charset), which must be >= 2 and <= 256.
 func (x nat) string(charset string) string {
        b := Word(len(charset))
 
        // special cases
        switch {
-       case b < 2 || MaxBase < b:
+       case b < 2 || MaxBase > 256:
                panic("illegal base")
        case len(x) == 0:
                return string(charset[0])
@@ -773,49 +773,59 @@ func (x nat) string(charset string) string {
                        w >>= shift
                        nbits -= shift
                }
+
        } else {
-               // determine "big base" as in 10^19 for 19 decimal digits in a 64 bit Word
-               bb := Word(1) // big base is b**ndigits
-               ndigits := 0  // number of base b digits
+               // determine "big base"; i.e., the largest possible value bb
+               // that is a power of base b and still fits into a Word
+               // (as in 10^19 for 19 decimal digits in a 64bit Word)
+               bb := b      // big base is b**ndigits
+               ndigits := 1 // number of base b digits
                for max := Word(_M / b); bb <= max; bb *= b {
                        ndigits++ // maximize ndigits where bb = b**ndigits, bb <= _M
                }
 
                // construct table of successive squares of bb*leafSize to use in subdivisions
+               // result (table != nil) <=> (len(x) > leafSize > 0)
                table := divisors(len(x), b, ndigits, bb)
 
-               // preserve x, create local copy for use in divisions
+               // preserve x, create local copy for use by convertWords
                q := nat(nil).set(x)
 
-               // convert q to string s in base b with index of MSD indicated by return value
-               i = q.convertWords(0, i, s, charset, b, ndigits, bb, table)
+               // convert q to string s in base b
+               q.convertWords(s, charset, b, ndigits, bb, table)
+
+               // strip leading zeros
+               // (x != 0; thus s must contain at least one non-zero digit
+               // and the loop will terminate)
+               i = 0
+               for zero := charset[0]; s[i] == zero; {
+                       i++
+               }
        }
 
        return string(s[i:])
 }
 
-// Convert words of q to base b digits in s directly using iterated nat/Word divison to extract
-// low-order Words and indirectly by recursive subdivision and nat/nat division by tabulated 
-// divisors. 
+// Convert words of q to base b digits in s. If q is large, it is recursively "split in half"
+// by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using
+// repeated nat/Word divison.
 //
-// The direct method processes n Words by n divW() calls, each of which visits every Word in the 
+// The iterative method processes n Words by n divW() calls, each of which visits every Word in the 
 // incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. 
-// Indirect conversion divides q by its approximate square root, yielding two parts, each half 
-// the size of q. Using the direct method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s plus 
-// the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and is 
-// made better by splitting the subblocks recursively. Best is to split blocks until one more 
+// Recursive conversion divides q by its approximate square root, yielding two parts, each half 
+// the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s
+// plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and
+// is made better by splitting the subblocks recursively. Best is to split blocks until one more 
 // split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the 
-// direct approach. This threshold is represented by leafSize. Benchmarking of leafSize in the 
+// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the 
 // range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and 
 // ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for 
 // specfic hardware.
 //
-// lo and hi index character array s. conversion starts with the LSD at hi and moves down toward
-// the MSD, which will be at s[0] or s[1]. lo == 0 signals span includes the most significant word.
-//
-func (q nat) convertWords(lo, hi int, s []byte, charset string, b Word, ndigits int, bb Word, table []divisor) int {
-       // indirect conversion: split larger blocks to reduce quadratic expense of iterated nat/W division
-       if leafSize > 0 && len(q) > leafSize && table != nil {
+func (q nat) convertWords(s []byte, charset string, b Word, ndigits int, bb Word, table []divisor) {
+       // split larger blocks recursively
+       if table != nil {
+               // len(q) > leafSize > 0
                var r nat
                index := len(table) - 1
                for len(q) > leafSize {
@@ -835,72 +845,52 @@ func (q nat) convertWords(lo, hi int, s []byte, charset string, b Word, ndigits
                        // split q into the two digit number (q'*bbb + r) to form independent subblocks
                        q, r = q.div(r, q, table[index].bbb)
 
-                       // convert subblocks and collect results in s[lo:partition] and s[partition:hi]
-                       partition := hi - table[index].ndigits
-                       r.convertWords(partition, hi, s, charset, b, ndigits, bb, table[0:index])
-                       hi = partition // i.e., q.convertWords(lo, partition, s, charset, b, ndigits, bb, table[0:index+1])
+                       // convert subblocks and collect results in s[:h] and s[h:]
+                       h := len(s) - table[index].ndigits
+                       r.convertWords(s[h:], charset, b, ndigits, bb, table[0:index])
+                       s = s[:h] // == q.convertWords(s, charset, b, ndigits, bb, table[0:index+1])
                }
-       } // having split any large blocks now process the remaining small block
+       }
 
-       // direct conversion: process smaller blocks monolithically to avoid overhead of nat/nat division
+       // having split any large blocks now process the remaining (small) block iteratively
+       i := len(s)
        var r Word
-       if b == 10 { // hard-coding for 10 here speeds this up by 1.25x (allows mod as mul vs div)
+       if b == 10 {
+               // hard-coding for 10 here speeds this up by 1.25x (allows for / and % by constants)
                for len(q) > 0 {
                        // extract least significant, base bb "digit"
                        q, r = q.divW(q, bb)
-                       if lo == 0 && len(q) == 0 {
-                               // skip leading zeros in most-significant group of digits
-                               for j := 0; j < ndigits && r != 0; j++ {
-                                       hi--
-                                       t := r / 10
-                                       s[hi] = charset[r-(t<<3+t<<1)] // 8*t + 2*t = 10*t; r - 10*int(r/10) = r mod 10
-                                       r = t
-                               }
-                       } else {
-                               for j := 0; j < ndigits && hi > lo; j++ {
-                                       hi--
-                                       t := r / 10
-                                       s[hi] = charset[r-(t<<3+t<<1)] // 8*t + 2*t = 10*t; r - 10*int(r/10) = r mod 10
-                                       r = t
-                               }
+                       for j := 0; j < ndigits && i > 0; j++ {
+                               i--
+                               // avoid % computation since r%10 == r - int(r/10)*10;
+                               // this appears to be faster for BenchmarkString10000Base10
+                               // and smaller strings (but a bit slower for larger ones)
+                               t := r / 10
+                               s[i] = charset[r-t<<3-t-t] // TODO(gri) replace w/ t*10 once compiler produces better code
+                               r = t
                        }
                }
        } else {
                for len(q) > 0 {
-                       // extract least significant group of digits
+                       // extract least significant, base bb "digit"
                        q, r = q.divW(q, bb)
-                       if lo == 0 && len(q) == 0 {
-                               // skip leading zeros in most-significant group of digits
-                               for j := 0; j < ndigits && r != 0; j++ {
-                                       hi--
-                                       s[hi] = charset[r%b]
-                                       r = r / b
-                               }
-                       } else {
-                               for j := 0; j < ndigits && hi > lo; j++ {
-                                       hi--
-                                       s[hi] = charset[r%b]
-                                       r = r / b
-                               }
+                       for j := 0; j < ndigits && i > 0; j++ {
+                               i--
+                               s[i] = charset[r%b]
+                               r /= b
                        }
                }
        }
 
-       // prepend high-order zeroes when q has been normalized to a short number of Words.
-       // however, do not prepend zeroes when converting the most dignificant digits.
-       if lo != 0 { // if not MSD
-               zero := charset[0]
-               for hi > lo { // while need more leading zeroes
-                       hi--
-                       s[hi] = zero
-               }
+       // prepend high-order zeroes
+       zero := charset[0]
+       for i > 0 { // while need more leading zeroes
+               i--
+               s[i] = zero
        }
-
-       // return index of most significant output digit in s[] (stored in lowest index)
-       return hi
 }
 
-// Split blocks greater than leafSize Words (or set to 0 to disable indirect conversion)
+// Split blocks greater than leafSize Words (or set to 0 to disable recursive conversion)
 // Benchmark and configure leafSize using: gotest -test.bench="Leaf"
 //   8 and 16 effective on 3.0 GHz Xeon "Clovertown" CPU (128 byte cache lines)
 //   8 and 16 effective on 2.66 GHz Core 2 Duo "Penryn" CPU
@@ -912,26 +902,30 @@ type divisor struct {
        ndigits int // digit length of divisor in terms of output base digits
 }
 
-const maxCache = 64               // maximum number of divisors in a single table
-var cacheBase10 [maxCache]divisor // cached divisors for base 10
-var cacheLock sync.Mutex          // defense against concurrent table extensions
+var cacheBase10 [64]divisor // cached divisors for base 10
+var cacheLock sync.Mutex    // protects cacheBase10
+
+// expWW computes x**y
+func (z nat) expWW(x, y Word) nat {
+       return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil)
+}
 
 // construct table of powers of bb*leafSize to use in subdivisions
 func divisors(m int, b Word, ndigits int, bb Word) []divisor {
-       // only build table when indirect conversion is enabled and x is large
+       // only compute table when recursive conversion is enabled and x is large
        if leafSize == 0 || m <= leafSize {
                return nil
        }
 
        // determine k where (bb**leafSize)**(2**k) >= sqrt(x)
        k := 1
-       for words := leafSize; words < m>>1 && k < maxCache; words <<= 1 {
+       for words := leafSize; words < m>>1 && k < len(cacheBase10); words <<= 1 {
                k++
        }
 
        // create new table of divisors or extend and reuse existing table as appropriate
-       var cached bool
        var table []divisor
+       var cached bool
        switch b {
        case 10:
                table = cacheBase10[0:k] // reuse old table for this conversion
@@ -946,28 +940,27 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor {
                        cacheLock.Lock() // begin critical section
                }
 
-               var i int
+               // add new entries as needed
                var larger nat
-               for i < k && table[i].ndigits != 0 { // skip existing entries
-                       i++
-               }
-               for ; i < k; i++ { // add new entries
-                       if i == 0 {
-                               table[i].bbb = nat(nil).expWW(bb, Word(leafSize))
-                               table[i].ndigits = ndigits * leafSize
-                       } else {
-                               table[i].bbb = nat(nil).mul(table[i-1].bbb, table[i-1].bbb)
-                               table[i].ndigits = 2 * table[i-1].ndigits
-                       }
+               for i := 0; i < k; i++ {
+                       if table[i].ndigits == 0 {
+                               if i == 0 {
+                                       table[i].bbb = nat(nil).expWW(bb, Word(leafSize))
+                                       table[i].ndigits = ndigits * leafSize
+                               } else {
+                                       table[i].bbb = nat(nil).mul(table[i-1].bbb, table[i-1].bbb)
+                                       table[i].ndigits = 2 * table[i-1].ndigits
+                               }
 
-                       // optimization: exploit aggregated extra bits in macro blocks
-                       larger = nat(nil).set(table[i].bbb)
-                       for mulAddVWW(larger, larger, b, 0) == 0 {
-                               table[i].bbb = table[i].bbb.set(larger)
-                               table[i].ndigits++
-                       }
+                               // optimization: exploit aggregated extra bits in macro blocks
+                               larger = nat(nil).set(table[i].bbb)
+                               for mulAddVWW(larger, larger, b, 0) == 0 {
+                                       table[i].bbb = table[i].bbb.set(larger)
+                                       table[i].ndigits++
+                               }
 
-                       table[i].nbits = table[i].bbb.bitLen()
+                               table[i].nbits = table[i].bbb.bitLen()
+                       }
                }
 
                if cached {
@@ -1295,11 +1288,6 @@ func (z nat) expNN(x, y, m nat) nat {
        return z.norm()
 }
 
-// calculate x**y for Word arguments y and y
-func (z nat) expWW(x, y Word) nat {
-       return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil)
-}
-
 // probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
 // If it returns true, n is prime with probability 1 - 1/4^reps.
 // If it returns false, n is not prime.
index 7867fa8df36a368a9a6723fa9f93912f81aff685..868388efa1cd7d6b7249226d31418f63d9169855 100644 (file)
@@ -8,7 +8,6 @@ import (
        "os"
        "reflect"
        "runtime"
-       "syscall"
        "testing"
 )
 
@@ -67,7 +66,7 @@ func TestFileListener(t *testing.T) {
                testFileListener(t, "tcp", "127.0.0.1")
                testFileListener(t, "tcp", "[::ffff:127.0.0.1]")
        }
-       if syscall.OS == "linux" {
+       if runtime.GOOS == "linux" {
                testFileListener(t, "unix", "@gotest/net")
                testFileListener(t, "unixpacket", "@gotest/net")
        }
@@ -132,7 +131,7 @@ func TestFilePacketConn(t *testing.T) {
        if supportsIPv6 && supportsIPv4map {
                testFilePacketConnDial(t, "udp", "[::ffff:127.0.0.1]:12345")
        }
-       if syscall.OS == "linux" {
+       if runtime.GOOS == "linux" {
                testFilePacketConnListen(t, "unixgram", "@gotest1/net")
        }
 }
index cad852242e2e520d6db5defe39d5758f0cae6b0f..2e30bbff1777e1ac0e7872e3d4b3d5e8a8179bb1 100644 (file)
@@ -96,7 +96,7 @@ func readSetCookies(h Header) []*Cookie {
                                continue
                        case "max-age":
                                secs, err := strconv.Atoi(val)
-                               if err != nil || secs < 0 || secs != 0 && val[0] == '0' {
+                               if err != nil || secs != 0 && val[0] == '0' {
                                        break
                                }
                                if secs <= 0 {
index 66178490e37b80dfb8e427ab30eafe39b2539d91..260301005ebb13374b7ef082abe4f0aa81308855 100644 (file)
@@ -368,8 +368,8 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
        if err != nil {
                return err
        }
-       bw.Flush()
-       return nil
+
+       return bw.Flush()
 }
 
 // Convert decimal at s[i:len(s)] to integer,
index 95486a6301ecb846a3ec9665c40724f99aa1fa21..5e7b352ed50012943a323d6c5847d7d281953e2a 100644 (file)
@@ -12,6 +12,14 @@ import (
        "fmt"
 )
 
+var (
+       errInvalidInterface         = errors.New("net: invalid interface")
+       errInvalidInterfaceIndex    = errors.New("net: invalid interface index")
+       errInvalidInterfaceName     = errors.New("net: invalid interface name")
+       errNoSuchInterface          = errors.New("net: no such interface")
+       errNoSuchMulticastInterface = errors.New("net: no such multicast interface")
+)
+
 // A HardwareAddr represents a physical hardware address.
 type HardwareAddr []byte
 
@@ -131,7 +139,7 @@ func (f Flags) String() string {
 // Addrs returns interface addresses for a specific interface.
 func (ifi *Interface) Addrs() ([]Addr, error) {
        if ifi == nil {
-               return nil, errors.New("net: invalid interface")
+               return nil, errInvalidInterface
        }
        return interfaceAddrTable(ifi.Index)
 }
@@ -140,7 +148,7 @@ func (ifi *Interface) Addrs() ([]Addr, error) {
 // a specific interface.
 func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
        if ifi == nil {
-               return nil, errors.New("net: invalid interface")
+               return nil, errInvalidInterface
        }
        return interfaceMulticastAddrTable(ifi.Index)
 }
@@ -159,7 +167,7 @@ func InterfaceAddrs() ([]Addr, error) {
 // InterfaceByIndex returns the interface specified by index.
 func InterfaceByIndex(index int) (*Interface, error) {
        if index <= 0 {
-               return nil, errors.New("net: invalid interface index")
+               return nil, errInvalidInterfaceIndex
        }
        ift, err := interfaceTable(index)
        if err != nil {
@@ -168,13 +176,13 @@ func InterfaceByIndex(index int) (*Interface, error) {
        for _, ifi := range ift {
                return &ifi, nil
        }
-       return nil, errors.New("net: no such interface")
+       return nil, errNoSuchInterface
 }
 
 // InterfaceByName returns the interface specified by name.
 func InterfaceByName(name string) (*Interface, error) {
        if name == "" {
-               return nil, errors.New("net: invalid interface name")
+               return nil, errInvalidInterfaceName
        }
        ift, err := interfaceTable(0)
        if err != nil {
@@ -185,5 +193,5 @@ func InterfaceByName(name string) (*Interface, error) {
                        return &ifi, nil
                }
        }
-       return nil, errors.New("net: no such interface")
+       return nil, errNoSuchInterface
 }
index 7e4bc56faca8e8c7733c83b8cade44443bf9a423..3fd9dce05e47d2741a2c9880bae2dea56e9b411f 100644 (file)
@@ -84,8 +84,8 @@ func splitNetProto(netProto string) (net string, proto int, err error) {
        return
 }
 
-// DialIP connects to the remote address raddr on the network net,
-// which must be "ip", "ip4", or "ip6".
+// DialIP connects to the remote address raddr on the network protocol netProto,
+// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name.
 func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err error) {
        return nil, os.EPLAN9
 }
index 7bb4c7dc0d3d4d4a8ed86d35afa2f7cb14bb325b..103c4f6a92555de9a45caeac69e7feb8f673476a 100644 (file)
@@ -224,8 +224,8 @@ func splitNetProto(netProto string) (net string, proto int, err error) {
        return net, proto, nil
 }
 
-// DialIP connects to the remote address raddr on the network net,
-// which must be "ip", "ip4", or "ip6".
+// DialIP connects to the remote address raddr on the network protocol netProto,
+// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name.
 func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err error) {
        net, proto, err := splitNetProto(netProto)
        if err != nil {
@@ -260,7 +260,7 @@ func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err error) {
        default:
                return nil, UnknownNetworkError(net)
        }
-       fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "dial", sockaddrToIP)
+       fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "listen", sockaddrToIP)
        if e != nil {
                return nil, e
        }
index d141c050b23042b4e8ae5611401966fd3e6636a7..45fe0d9640b224b749c56af091ce0068911b0732 100644 (file)
@@ -91,11 +91,6 @@ func favoriteAddrFamily(net string, raddr, laddr sockaddr, mode string) int {
        return syscall.AF_INET6
 }
 
-// TODO(rsc): if syscall.OS == "linux", we're supposed to read
-// /proc/sys/net/core/somaxconn,
-// to take advantage of kernels that have raised the limit.
-func listenBacklog() int { return syscall.SOMAXCONN }
-
 // Internet sockets (TCP, UDP)
 
 // A sockaddr represents a TCP or UDP network address that can
index a66250c844b43fe6763449b6016e8192c89d98ef..183d5a8abaa837c9a136525098442614709896ad 100644 (file)
@@ -13,7 +13,7 @@ import (
 
 var multicast = flag.Bool("multicast", false, "enable multicast tests")
 
-var joinAndLeaveGroupUDPTests = []struct {
+var multicastUDPTests = []struct {
        net   string
        laddr IP
        gaddr IP
@@ -32,8 +32,8 @@ var joinAndLeaveGroupUDPTests = []struct {
        {"udp6", IPv6unspecified, ParseIP("ff0e::114"), (FlagUp | FlagLoopback), true},
 }
 
-func TestJoinAndLeaveGroupUDP(t *testing.T) {
-       if runtime.GOOS == "windows" {
+func TestMulticastUDP(t *testing.T) {
+       if runtime.GOOS == "plan9" || runtime.GOOS == "windows" {
                return
        }
        if !*multicast {
@@ -41,7 +41,7 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
                return
        }
 
-       for _, tt := range joinAndLeaveGroupUDPTests {
+       for _, tt := range multicastUDPTests {
                var (
                        ifi   *Interface
                        found bool
@@ -51,7 +51,7 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
                }
                ift, err := Interfaces()
                if err != nil {
-                       t.Fatalf("Interfaces() failed: %v", err)
+                       t.Fatalf("Interfaces failed: %v", err)
                }
                for _, x := range ift {
                        if x.Flags&tt.flags == tt.flags {
@@ -65,15 +65,20 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
                }
                c, err := ListenUDP(tt.net, &UDPAddr{IP: tt.laddr})
                if err != nil {
-                       t.Fatal(err)
+                       t.Fatalf("ListenUDP failed: %v", err)
                }
                defer c.Close()
                if err := c.JoinGroup(ifi, tt.gaddr); err != nil {
-                       t.Fatal(err)
+                       t.Fatalf("JoinGroup failed: %v", err)
+               }
+               if !tt.ipv6 {
+                       testIPv4MulticastSocketOptions(t, c.fd, ifi)
+               } else {
+                       testIPv6MulticastSocketOptions(t, c.fd, ifi)
                }
                ifmat, err := ifi.MulticastAddrs()
                if err != nil {
-                       t.Fatalf("MulticastAddrs() failed: %v", err)
+                       t.Fatalf("MulticastAddrs failed: %v", err)
                }
                for _, ifma := range ifmat {
                        if ifma.(*IPAddr).IP.Equal(tt.gaddr) {
@@ -85,7 +90,114 @@ func TestJoinAndLeaveGroupUDP(t *testing.T) {
                        t.Fatalf("%q not found in RIB", tt.gaddr.String())
                }
                if err := c.LeaveGroup(ifi, tt.gaddr); err != nil {
-                       t.Fatal(err)
+                       t.Fatalf("LeaveGroup failed: %v", err)
+               }
+       }
+}
+
+func TestSimpleMulticastUDP(t *testing.T) {
+       if runtime.GOOS == "plan9" {
+               return
+       }
+       if !*multicast {
+               t.Logf("test disabled; use --multicast to enable")
+               return
+       }
+
+       for _, tt := range multicastUDPTests {
+               var ifi *Interface
+               if tt.ipv6 {
+                       continue
+               }
+               tt.flags = FlagUp | FlagMulticast
+               ift, err := Interfaces()
+               if err != nil {
+                       t.Fatalf("Interfaces failed: %v", err)
+               }
+               for _, x := range ift {
+                       if x.Flags&tt.flags == tt.flags {
+                               ifi = &x
+                               break
+                       }
+               }
+               if ifi == nil {
+                       t.Logf("an appropriate multicast interface not found")
+                       return
+               }
+               c, err := ListenUDP(tt.net, &UDPAddr{IP: tt.laddr})
+               if err != nil {
+                       t.Fatalf("ListenUDP failed: %v", err)
+               }
+               defer c.Close()
+               if err := c.JoinGroup(ifi, tt.gaddr); err != nil {
+                       t.Fatalf("JoinGroup failed: %v", err)
+               }
+               if err := c.LeaveGroup(ifi, tt.gaddr); err != nil {
+                       t.Fatalf("LeaveGroup failed: %v", err)
                }
        }
 }
+
+func testIPv4MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) {
+       ifmc, err := ipv4MulticastInterface(fd)
+       if err != nil {
+               t.Fatalf("ipv4MulticastInterface failed: %v", err)
+       }
+       t.Logf("IPv4 multicast interface: %v", ifmc)
+       err = setIPv4MulticastInterface(fd, ifi)
+       if err != nil {
+               t.Fatalf("setIPv4MulticastInterface failed: %v", err)
+       }
+
+       ttl, err := ipv4MulticastTTL(fd)
+       if err != nil {
+               t.Fatalf("ipv4MulticastTTL failed: %v", err)
+       }
+       t.Logf("IPv4 multicast TTL: %v", ttl)
+       err = setIPv4MulticastTTL(fd, 1)
+       if err != nil {
+               t.Fatalf("setIPv4MulticastTTL failed: %v", err)
+       }
+
+       loop, err := ipv4MulticastLoopback(fd)
+       if err != nil {
+               t.Fatalf("ipv4MulticastLoopback failed: %v", err)
+       }
+       t.Logf("IPv4 multicast loopback: %v", loop)
+       err = setIPv4MulticastLoopback(fd, false)
+       if err != nil {
+               t.Fatalf("setIPv4MulticastLoopback failed: %v", err)
+       }
+}
+
+func testIPv6MulticastSocketOptions(t *testing.T, fd *netFD, ifi *Interface) {
+       ifmc, err := ipv6MulticastInterface(fd)
+       if err != nil {
+               t.Fatalf("ipv6MulticastInterface failed: %v", err)
+       }
+       t.Logf("IPv6 multicast interface: %v", ifmc)
+       err = setIPv6MulticastInterface(fd, ifi)
+       if err != nil {
+               t.Fatalf("setIPv6MulticastInterface failed: %v", err)
+       }
+
+       hoplim, err := ipv6MulticastHopLimit(fd)
+       if err != nil {
+               t.Fatalf("ipv6MulticastHopLimit failed: %v", err)
+       }
+       t.Logf("IPv6 multicast hop limit: %v", hoplim)
+       err = setIPv6MulticastHopLimit(fd, 1)
+       if err != nil {
+               t.Fatalf("setIPv6MulticastHopLimit failed: %v", err)
+       }
+
+       loop, err := ipv6MulticastLoopback(fd)
+       if err != nil {
+               t.Fatalf("ipv6MulticastLoopback failed: %v", err)
+       }
+       t.Logf("IPv6 multicast loopback: %v", loop)
+       err = setIPv6MulticastLoopback(fd, false)
+       if err != nil {
+               t.Fatalf("setIPv6MulticastLoopback failed: %v", err)
+       }
+}
index c1845fa50737bb2f2a72321230b60f20fd58334a..ae688c0f8ca82e0ff3d76d1e60d61b6b525d09c7 100644 (file)
@@ -498,8 +498,7 @@ func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
        once.Do(startServer)
        client, err := dial()
        if err != nil {
-               fmt.Println("error dialing", err)
-               return
+               b.Fatal("error dialing:", err)
        }
 
        // Synchronous calls
@@ -534,7 +533,7 @@ func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
        once.Do(startServer)
        client, err := dial()
        if err != nil {
-               b.Fatalf("error dialing:", err)
+               b.Fatal("error dialing:", err)
        }
 
        // Asynchronous calls
index 7d17ccd53c37b660337770d6fe4b431e234cf6ce..5475d3874fe6f1e8fe12c29f4d41e68bc150d428 100644 (file)
@@ -10,7 +10,6 @@ import (
        "os"
        "runtime"
        "strings"
-       "syscall"
        "testing"
 )
 
@@ -115,7 +114,7 @@ func doTest(t *testing.T, network, listenaddr, dialaddr string) {
 }
 
 func TestTCPServer(t *testing.T) {
-       if syscall.OS != "openbsd" {
+       if runtime.GOOS != "openbsd" {
                doTest(t, "tcp", "", "127.0.0.1")
        }
        doTest(t, "tcp", "0.0.0.0", "127.0.0.1")
@@ -155,7 +154,7 @@ func TestUnixServer(t *testing.T) {
        os.Remove("/tmp/gotest.net")
        doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net")
        os.Remove("/tmp/gotest.net")
-       if syscall.OS == "linux" {
+       if runtime.GOOS == "linux" {
                doTest(t, "unixpacket", "/tmp/gotest.net", "/tmp/gotest.net")
                os.Remove("/tmp/gotest.net")
                // Test abstract unix domain socket, a Linux-ism
@@ -237,7 +236,7 @@ func TestUnixDatagramServer(t *testing.T) {
                doTestPacket(t, "unixgram", "/tmp/gotest1.net", "/tmp/gotest1.net", isEmpty)
                os.Remove("/tmp/gotest1.net")
                os.Remove("/tmp/gotest1.net.local")
-               if syscall.OS == "linux" {
+               if runtime.GOOS == "linux" {
                        // Test abstract unix domain socket, a Linux-ism
                        doTestPacket(t, "unixgram", "@gotest1/net", "@gotest1/net", isEmpty)
                }
index dc073927eb4211ee5b80d2fb690fb4b966181fe6..881c922a25f31f3cb7c3ca2ca1f526fa3f713ef0 100644 (file)
@@ -10,18 +10,11 @@ package net
 
 import (
        "io"
-       "os"
        "reflect"
        "syscall"
 )
 
-// Boolean to int.
-func boolint(b bool) int {
-       if b {
-               return 1
-       }
-       return 0
-}
+var listenerBacklog = maxListenerBacklog()
 
 // Generic socket creation.
 func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
@@ -35,7 +28,7 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal
        syscall.CloseOnExec(s)
        syscall.ForkLock.RUnlock()
 
-       setKernelSpecificSockopt(s, f)
+       setDefaultSockopts(s, f, p)
 
        if la != nil {
                e = syscall.Bind(s, la)
@@ -67,83 +60,6 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal
        return fd, nil
 }
 
-func setsockoptInt(fd *netFD, level, opt int, value int) error {
-       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, level, opt, value))
-}
-
-func setsockoptNsec(fd *netFD, level, opt int, nsec int64) error {
-       var tv = syscall.NsecToTimeval(nsec)
-       return os.NewSyscallError("setsockopt", syscall.SetsockoptTimeval(fd.sysfd, level, opt, &tv))
-}
-
-func setReadBuffer(fd *netFD, bytes int) error {
-       fd.incref()
-       defer fd.decref()
-       return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
-}
-
-func setWriteBuffer(fd *netFD, bytes int) error {
-       fd.incref()
-       defer fd.decref()
-       return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
-}
-
-func setReadTimeout(fd *netFD, nsec int64) error {
-       fd.rdeadline_delta = nsec
-       return nil
-}
-
-func setWriteTimeout(fd *netFD, nsec int64) error {
-       fd.wdeadline_delta = nsec
-       return nil
-}
-
-func setTimeout(fd *netFD, nsec int64) error {
-       if e := setReadTimeout(fd, nsec); e != nil {
-               return e
-       }
-       return setWriteTimeout(fd, nsec)
-}
-
-func setReuseAddr(fd *netFD, reuse bool) error {
-       fd.incref()
-       defer fd.decref()
-       return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))
-}
-
-func setDontRoute(fd *netFD, dontroute bool) error {
-       fd.incref()
-       defer fd.decref()
-       return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))
-}
-
-func setKeepAlive(fd *netFD, keepalive bool) error {
-       fd.incref()
-       defer fd.decref()
-       return setsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))
-}
-
-func setNoDelay(fd *netFD, noDelay bool) error {
-       fd.incref()
-       defer fd.decref()
-       return setsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))
-}
-
-func setLinger(fd *netFD, sec int) error {
-       var l syscall.Linger
-       if sec >= 0 {
-               l.Onoff = 1
-               l.Linger = int32(sec)
-       } else {
-               l.Onoff = 0
-               l.Linger = 0
-       }
-       fd.incref()
-       defer fd.decref()
-       e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l)
-       return os.NewSyscallError("setsockopt", e)
-}
-
 type UnknownSocketError struct {
        sa syscall.Sockaddr
 }
index 816e4fc3f741cfa97d4ab34905dff9f60712d5de..630a91ed9f670d06f35658264b2d248b2f3c9848 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2011 The Go Authors.  All rights reserved.
+// Copyright 2009 The Go Authors.  All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -9,25 +9,25 @@
 package net
 
 import (
+       "runtime"
        "syscall"
 )
 
-func setKernelSpecificSockopt(s, f int) {
-       // Allow reuse of recently-used addresses.
-       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-
-       // Allow reuse of recently-used ports.
-       // This option is supported only in descendants of 4.4BSD,
-       // to make an effective multicast application and an application
-       // that requires quick draw possible.
-       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
-
-       // Allow broadcast.
-       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
-
-       if f == syscall.AF_INET6 {
-               // using ip, tcp, udp, etc.
-               // allow both protocols even if the OS default is otherwise.
-               syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+func maxListenerBacklog() int {
+       var (
+               n   uint32
+               err error
+       )
+       switch runtime.GOOS {
+       case "darwin", "freebsd":
+               n, err = syscall.SysctlUint32("kern.ipc.somaxconn")
+       case "netbsd":
+               // NOTE: NetBSD has no somaxconn-like kernel state so far
+       case "openbsd":
+               n, err = syscall.SysctlUint32("kern.somaxconn")
+       }
+       if n == 0 || err != nil {
+               return syscall.SOMAXCONN
        }
+       return int(n)
 }
index ec31e803b6f29b9b2ac589d09b37b128d657a983..2cbc34f24b38384482642ac59e397cd542d5463f 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2011 The Go Authors.  All rights reserved.
+// Copyright 2009 The Go Authors.  All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -6,20 +6,22 @@
 
 package net
 
-import (
-       "syscall"
-)
+import "syscall"
 
-func setKernelSpecificSockopt(s, f int) {
-       // Allow reuse of recently-used addresses.
-       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-
-       // Allow broadcast.
-       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
-
-       if f == syscall.AF_INET6 {
-               // using ip, tcp, udp, etc.
-               // allow both protocols even if the OS default is otherwise.
-               syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+func maxListenerBacklog() int {
+       fd, err := open("/proc/sys/net/core/somaxconn")
+       if err != nil {
+               return syscall.SOMAXCONN
+       }
+       defer fd.close()
+       l, ok := fd.readLine()
+       if !ok {
+               return syscall.SOMAXCONN
+       }
+       f := getFields(l)
+       n, _, ok := dtoi(f[0], 0)
+       if n == 0 || !ok {
+               return syscall.SOMAXCONN
        }
+       return n
 }
index 9b9cd9e368bf5dc3dd3d897d1433c81e082ee949..2d803de1fc180da2e93adaa87adfaec4aa93cb11 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2011 The Go Authors.  All rights reserved.
+// Copyright 2009 The Go Authors.  All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -6,24 +6,9 @@
 
 package net
 
-import (
-       "syscall"
-)
+import "syscall"
 
-func setKernelSpecificSockopt(s syscall.Handle, f int) {
-       // Windows will reuse recently-used addresses by default.
-       // SO_REUSEADDR should not be used here, as it allows
-       // a socket to forcibly bind to a port in use by another socket.
-       // This could lead to a non-deterministic behavior, where
-       // connection requests over the port cannot be guaranteed
-       // to be handled by the correct socket.
-
-       // Allow broadcast.
-       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
-
-       if f == syscall.AF_INET6 {
-               // using ip, tcp, udp, etc.
-               // allow both protocols even if the OS default is otherwise.
-               syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
-       }
+func maxListenerBacklog() int {
+       // TODO: Implement this
+       return syscall.SOMAXCONN
 }
diff --git a/libgo/go/net/sockopt.go b/libgo/go/net/sockopt.go
new file mode 100644 (file)
index 0000000..7fa1052
--- /dev/null
@@ -0,0 +1,171 @@
+// Copyright 2009 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+// Socket options
+
+package net
+
+import (
+       "bytes"
+       "os"
+       "syscall"
+)
+
+// Boolean to int.
+func boolint(b bool) int {
+       if b {
+               return 1
+       }
+       return 0
+}
+
+func ipv4AddrToInterface(ip IP) (*Interface, error) {
+       ift, err := Interfaces()
+       if err != nil {
+               return nil, err
+       }
+       for _, ifi := range ift {
+               ifat, err := ifi.Addrs()
+               if err != nil {
+                       return nil, err
+               }
+               for _, ifa := range ifat {
+                       switch v := ifa.(type) {
+                       case *IPAddr:
+                               if ip.Equal(v.IP) {
+                                       return &ifi, nil
+                               }
+                       case *IPNet:
+                               if ip.Equal(v.IP) {
+                                       return &ifi, nil
+                               }
+                       }
+               }
+       }
+       if ip.Equal(IPv4zero) {
+               return nil, nil
+       }
+       return nil, errNoSuchInterface
+}
+
+func interfaceToIPv4Addr(ifi *Interface) (IP, error) {
+       if ifi == nil {
+               return IPv4zero, nil
+       }
+       ifat, err := ifi.Addrs()
+       if err != nil {
+               return nil, err
+       }
+       for _, ifa := range ifat {
+               switch v := ifa.(type) {
+               case *IPAddr:
+                       if v.IP.To4() != nil {
+                               return v.IP, nil
+                       }
+               case *IPNet:
+                       if v.IP.To4() != nil {
+                               return v.IP, nil
+                       }
+               }
+       }
+       return nil, errNoSuchInterface
+}
+
+func setIPv4MreqToInterface(mreq *syscall.IPMreq, ifi *Interface) error {
+       if ifi == nil {
+               return nil
+       }
+       ifat, err := ifi.Addrs()
+       if err != nil {
+               return err
+       }
+       for _, ifa := range ifat {
+               switch v := ifa.(type) {
+               case *IPAddr:
+                       if a := v.IP.To4(); a != nil {
+                               copy(mreq.Interface[:], a)
+                               goto done
+                       }
+               case *IPNet:
+                       if a := v.IP.To4(); a != nil {
+                               copy(mreq.Interface[:], a)
+                               goto done
+                       }
+               }
+       }
+done:
+       if bytes.Equal(mreq.Multiaddr[:], IPv4zero.To4()) {
+               return errNoSuchMulticastInterface
+       }
+       return nil
+}
+
+func setReadBuffer(fd *netFD, bytes int) error {
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes))
+}
+
+func setWriteBuffer(fd *netFD, bytes int) error {
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes))
+}
+
+func setReadTimeout(fd *netFD, nsec int64) error {
+       fd.rdeadline_delta = nsec
+       return nil
+}
+
+func setWriteTimeout(fd *netFD, nsec int64) error {
+       fd.wdeadline_delta = nsec
+       return nil
+}
+
+func setTimeout(fd *netFD, nsec int64) error {
+       if e := setReadTimeout(fd, nsec); e != nil {
+               return e
+       }
+       return setWriteTimeout(fd, nsec)
+}
+
+func setReuseAddr(fd *netFD, reuse bool) error {
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)))
+}
+
+func setDontRoute(fd *netFD, dontroute bool) error {
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)))
+}
+
+func setKeepAlive(fd *netFD, keepalive bool) error {
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)))
+}
+
+func setNoDelay(fd *netFD, noDelay bool) error {
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)))
+}
+
+func setLinger(fd *netFD, sec int) error {
+       var l syscall.Linger
+       if sec >= 0 {
+               l.Onoff = 1
+               l.Linger = int32(sec)
+       } else {
+               l.Onoff = 0
+               l.Linger = 0
+       }
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l))
+}
diff --git a/libgo/go/net/sockopt_bsd.go b/libgo/go/net/sockopt_bsd.go
new file mode 100644 (file)
index 0000000..e99fb41
--- /dev/null
@@ -0,0 +1,44 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd netbsd openbsd
+
+// Socket options for BSD variants
+
+package net
+
+import (
+       "syscall"
+)
+
+func setDefaultSockopts(s, f, p int) {
+       switch f {
+       case syscall.AF_INET6:
+               // Allow both IP versions even if the OS default is otherwise.
+               syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+       }
+
+       if f == syscall.AF_UNIX || p == syscall.IPPROTO_TCP {
+               // Allow reuse of recently-used addresses.
+               syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+
+               // Allow reuse of recently-used ports.
+               // This option is supported only in descendants of 4.4BSD,
+               // to make an effective multicast application and an application
+               // that requires quick draw possible.
+               syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
+       }
+
+       // Allow broadcast.
+       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
+}
+
+func setDefaultMulticastSockopts(fd *netFD) {
+       fd.incref()
+       defer fd.decref()
+       // Allow multicast UDP and raw IP datagram sockets to listen
+       // concurrently across multiple listeners.
+       syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+       syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
+}
diff --git a/libgo/go/net/sockopt_linux.go b/libgo/go/net/sockopt_linux.go
new file mode 100644 (file)
index 0000000..5158384
--- /dev/null
@@ -0,0 +1,36 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Socket options for Linux
+
+package net
+
+import (
+       "syscall"
+)
+
+func setDefaultSockopts(s, f, p int) {
+       switch f {
+       case syscall.AF_INET6:
+               // Allow both IP versions even if the OS default is otherwise.
+               syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+       }
+
+       if f == syscall.AF_UNIX || p == syscall.IPPROTO_TCP {
+               // Allow reuse of recently-used addresses.
+               syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+       }
+
+       // Allow broadcast.
+       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
+
+}
+
+func setDefaultMulticastSockopts(fd *netFD) {
+       fd.incref()
+       defer fd.decref()
+       // Allow multicast UDP and raw IP datagram sockets to listen
+       // concurrently across multiple listeners.
+       syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+}
diff --git a/libgo/go/net/sockopt_windows.go b/libgo/go/net/sockopt_windows.go
new file mode 100644 (file)
index 0000000..485c14a
--- /dev/null
@@ -0,0 +1,38 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Socket options for Windows
+
+package net
+
+import (
+       "syscall"
+)
+
+func setDefaultSockopts(s syscall.Handle, f, p int) {
+       switch f {
+       case syscall.AF_INET6:
+               // Allow both IP versions even if the OS default is otherwise.
+               syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+       }
+
+       // Windows will reuse recently-used addresses by default.
+       // SO_REUSEADDR should not be used here, as it allows
+       // a socket to forcibly bind to a port in use by another socket.
+       // This could lead to a non-deterministic behavior, where
+       // connection requests over the port cannot be guaranteed
+       // to be handled by the correct socket.
+
+       // Allow broadcast.
+       syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
+
+}
+
+func setDefaultMulticastSockopts(fd *netFD) {
+       fd.incref()
+       defer fd.decref()
+       // Allow multicast UDP and raw IP datagram sockets to listen
+       // concurrently across multiple listeners.
+       syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+}
diff --git a/libgo/go/net/sockoptip.go b/libgo/go/net/sockoptip.go
new file mode 100644 (file)
index 0000000..90b6f75
--- /dev/null
@@ -0,0 +1,187 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd windows
+
+// IP-level socket options
+
+package net
+
+import (
+       "os"
+       "syscall"
+)
+
+func ipv4TOS(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv4TOS(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TOS, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4TTL(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv4TTL(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_TTL, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+       mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
+       if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
+               return err
+       }
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq))
+}
+
+func leaveIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+       mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
+       if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
+               return err
+       }
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq))
+}
+
+func ipv6HopLimit(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv6HopLimit(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv6MulticastInterface(fd *netFD) (*Interface, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF)
+       if err != nil {
+               return nil, os.NewSyscallError("getsockopt", err)
+       }
+       if v == 0 {
+               return nil, nil
+       }
+       ifi, err := InterfaceByIndex(v)
+       if err != nil {
+               return nil, err
+       }
+       return ifi, nil
+}
+
+func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
+       var v int
+       if ifi != nil {
+               v = ifi.Index
+       }
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv6MulticastHopLimit(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv6MulticastHopLimit(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv6MulticastLoopback(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv6MulticastLoopback(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+       mreq := &syscall.IPv6Mreq{}
+       copy(mreq.Multiaddr[:], ip)
+       if ifi != nil {
+               mreq.Interface = uint32(ifi.Index)
+       }
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq))
+}
+
+func leaveIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+       mreq := &syscall.IPv6Mreq{}
+       copy(mreq.Multiaddr[:], ip)
+       if ifi != nil {
+               mreq.Interface = uint32(ifi.Index)
+       }
+       fd.incref()
+       defer fd.decref()
+       return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq))
+}
diff --git a/libgo/go/net/sockoptip_bsd.go b/libgo/go/net/sockoptip_bsd.go
new file mode 100644 (file)
index 0000000..5f7dff2
--- /dev/null
@@ -0,0 +1,54 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd netbsd openbsd
+
+// IP-level socket options for BSD variants
+
+package net
+
+import (
+       "os"
+       "syscall"
+)
+
+func ipv4MulticastTTL(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return int(v), nil
+}
+
+func setIPv4MulticastTTL(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, byte(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv6TrafficClass(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv6TrafficClass(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
diff --git a/libgo/go/net/sockoptip_darwin.go b/libgo/go/net/sockoptip_darwin.go
new file mode 100644 (file)
index 0000000..dedfd6f
--- /dev/null
@@ -0,0 +1,78 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for Darwin
+
+package net
+
+import (
+       "os"
+       "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+       fd.incref()
+       defer fd.decref()
+       a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+       if err != nil {
+               return nil, os.NewSyscallError("getsockopt", err)
+       }
+       return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3]))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+       ip, err := interfaceToIPv4Addr(ifi)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       var x [4]byte
+       copy(x[:], ip.To4())
+       fd.incref()
+       defer fd.decref()
+       err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
diff --git a/libgo/go/net/sockoptip_freebsd.go b/libgo/go/net/sockoptip_freebsd.go
new file mode 100644 (file)
index 0000000..55f7b1a
--- /dev/null
@@ -0,0 +1,80 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for FreeBSD
+
+package net
+
+import (
+       "os"
+       "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+       fd.incref()
+       defer fd.decref()
+       mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+       if err != nil {
+               return nil, os.NewSyscallError("getsockopt", err)
+       }
+       if int(mreq.Ifindex) == 0 {
+               return nil, nil
+       }
+       return InterfaceByIndex(int(mreq.Ifindex))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+       var v int32
+       if ifi != nil {
+               v = int32(ifi.Index)
+       }
+       mreq := &syscall.IPMreqn{Ifindex: v}
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
diff --git a/libgo/go/net/sockoptip_linux.go b/libgo/go/net/sockoptip_linux.go
new file mode 100644 (file)
index 0000000..360f8de
--- /dev/null
@@ -0,0 +1,120 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for Linux
+
+package net
+
+import (
+       "os"
+       "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+       fd.incref()
+       defer fd.decref()
+       mreq, err := syscall.GetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+       if err != nil {
+               return nil, os.NewSyscallError("getsockopt", err)
+       }
+       if int(mreq.Ifindex) == 0 {
+               return nil, nil
+       }
+       return InterfaceByIndex(int(mreq.Ifindex))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+       var v int32
+       if ifi != nil {
+               v = int32(ifi.Index)
+       }
+       mreq := &syscall.IPMreqn{Ifindex: v}
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4MulticastTTL(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv4MulticastTTL(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_PKTINFO, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv6TrafficClass(fd *netFD) (int, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS)
+       if err != nil {
+               return -1, os.NewSyscallError("getsockopt", err)
+       }
+       return v, nil
+}
+
+func setIPv6TrafficClass(fd *netFD, v int) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, v)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
diff --git a/libgo/go/net/sockoptip_openbsd.go b/libgo/go/net/sockoptip_openbsd.go
new file mode 100644 (file)
index 0000000..89b8e45
--- /dev/null
@@ -0,0 +1,78 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for OpenBSD
+
+package net
+
+import (
+       "os"
+       "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+       fd.incref()
+       defer fd.decref()
+       a, err := syscall.GetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF)
+       if err != nil {
+               return nil, os.NewSyscallError("getsockopt", err)
+       }
+       return ipv4AddrToInterface(IPv4(a[0], a[1], a[2], a[3]))
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+       ip, err := interfaceToIPv4Addr(ifi)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       var x [4]byte
+       copy(x[:], ip.To4())
+       fd.incref()
+       defer fd.decref()
+       err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, x)
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+       fd.incref()
+       defer fd.decref()
+       v, err := syscall.GetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF)
+       if err != nil {
+               return false, os.NewSyscallError("getsockopt", err)
+       }
+       return v == 1, nil
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+       fd.incref()
+       defer fd.decref()
+       err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_RECVIF, boolint(v))
+       if err != nil {
+               return os.NewSyscallError("setsockopt", err)
+       }
+       return nil
+}
diff --git a/libgo/go/net/sockoptip_windows.go b/libgo/go/net/sockoptip_windows.go
new file mode 100644 (file)
index 0000000..3320e76
--- /dev/null
@@ -0,0 +1,61 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP-level socket options for Windows
+
+package net
+
+import (
+       "syscall"
+)
+
+func ipv4MulticastInterface(fd *netFD) (*Interface, error) {
+       // TODO: Implement this
+       return nil, syscall.EWINDOWS
+}
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+       // TODO: Implement this
+       return syscall.EWINDOWS
+}
+
+func ipv4MulticastTTL(fd *netFD) (int, error) {
+       // TODO: Implement this
+       return -1, syscall.EWINDOWS
+}
+
+func setIPv4MulticastTTL(fd *netFD, v int) error {
+       // TODO: Implement this
+       return syscall.EWINDOWS
+}
+
+func ipv4MulticastLoopback(fd *netFD) (bool, error) {
+       // TODO: Implement this
+       return false, syscall.EWINDOWS
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+       // TODO: Implement this
+       return syscall.EWINDOWS
+}
+
+func ipv4ReceiveInterface(fd *netFD) (bool, error) {
+       // TODO: Implement this
+       return false, syscall.EWINDOWS
+}
+
+func setIPv4ReceiveInterface(fd *netFD, v bool) error {
+       // TODO: Implement this
+       return syscall.EWINDOWS
+}
+
+func ipv6TrafficClass(fd *netFD) (int, error) {
+       // TODO: Implement this
+       return 0, syscall.EWINDOWS
+}
+
+func setIPv6TrafficClass(fd *netFD, v int) error {
+       // TODO: Implement this
+       return syscall.EWINDOWS
+}
index a7c09c73ed5bd89ac7f49160a4f9a3834dfd545e..a492e614e35419ba95aca0e1cec32e2f8b876481 100644 (file)
@@ -249,10 +249,10 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err error) {
        if err != nil {
                return nil, err
        }
-       errno := syscall.Listen(fd.sysfd, listenBacklog())
-       if errno != nil {
+       err = syscall.Listen(fd.sysfd, listenerBacklog)
+       if err != nil {
                closesocket(fd.sysfd)
-               return nil, &OpError{"listen", "tcp", laddr, errno}
+               return nil, &OpError{"listen", "tcp", laddr, err}
        }
        l = new(TCPListener)
        l.fd = fd
index 793c6c2c83e5fb9b8ef1ec3bb7a6b72b533711bf..862cd536c467e237ec6f56d56195ea5fa255444c 100644 (file)
@@ -22,6 +22,7 @@ import (
 type Reader struct {
        R   *bufio.Reader
        dot *dotReader
+       buf []byte // a re-usable buffer for readContinuedLineSlice
 }
 
 // NewReader returns a new Reader reading from r.
@@ -121,74 +122,44 @@ func (r *Reader) readContinuedLineSlice() ([]byte, error) {
        // Read the first line.
        line, err := r.readLineSlice()
        if err != nil {
-               return line, err
+               return nil, err
        }
        if len(line) == 0 { // blank line - no continuation
                return line, nil
        }
-       line = trim(line)
 
-       copied := false
-       if r.R.Buffered() < 1 {
-               // ReadByte will flush the buffer; make a copy of the slice.
-               copied = true
-               line = append([]byte(nil), line...)
-       }
-
-       // Look for a continuation line.
-       c, err := r.R.ReadByte()
-       if err != nil {
-               // Delay err until we read the byte next time.
-               return line, nil
-       }
-       if c != ' ' && c != '\t' {
-               // Not a continuation.
-               r.R.UnreadByte()
-               return line, nil
-       }
-
-       if !copied {
-               // The next readLineSlice will invalidate the previous one.
-               line = append(make([]byte, 0, len(line)*2), line...)
-       }
+       // ReadByte or the next readLineSlice will flush the read buffer;
+       // copy the slice into buf.
+       r.buf = append(r.buf[:0], trim(line)...)
 
        // Read continuation lines.
-       for {
-               // Consume leading spaces; one already gone.
-               for {
-                       c, err = r.R.ReadByte()
-                       if err != nil {
-                               break
-                       }
-                       if c != ' ' && c != '\t' {
-                               r.R.UnreadByte()
-                               break
-                       }
-               }
-               var cont []byte
-               cont, err = r.readLineSlice()
-               cont = trim(cont)
-               line = append(line, ' ')
-               line = append(line, cont...)
+       for r.skipSpace() > 0 {
+               line, err := r.readLineSlice()
                if err != nil {
                        break
                }
+               r.buf = append(r.buf, ' ')
+               r.buf = append(r.buf, line...)
+       }
+       return r.buf, nil
+}
 
-               // Check for leading space on next line.
-               if c, err = r.R.ReadByte(); err != nil {
+// skipSpace skips R over all spaces and returns the number of bytes skipped.
+func (r *Reader) skipSpace() int {
+       n := 0
+       for {
+               c, err := r.R.ReadByte()
+               if err != nil {
+                       // Bufio will keep err until next read.
                        break
                }
                if c != ' ' && c != '\t' {
                        r.R.UnreadByte()
                        break
                }
+               n++
        }
-
-       // Delay error until next call.
-       if len(line) > 0 {
-               err = nil
-       }
-       return line, err
+       return n
 }
 
 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
index 0460c1c8deeb5557c07524ebed9a39f1ee9e7dd3..4d036914801f9b8fb30fe51fa16be01fead1dbaa 100644 (file)
@@ -138,6 +138,15 @@ func TestReadMIMEHeader(t *testing.T) {
        }
 }
 
+func TestReadMIMEHeaderSingle(t *testing.T) {
+       r := reader("Foo: bar\n\n")
+       m, err := r.ReadMIMEHeader()
+       want := MIMEHeader{"Foo": {"bar"}}
+       if !reflect.DeepEqual(m, want) || err != nil {
+               t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want)
+       }
+}
+
 func TestLargeReadMIMEHeader(t *testing.T) {
        data := make([]byte, 16*1024)
        for i := 0; i < len(data); i++ {
index 6bb15714e2b0efdc5f0ef224c4875140c4838d34..d0bdb14755e4348772444095a8a3b6f079a7155c 100644 (file)
@@ -9,7 +9,6 @@
 package net
 
 import (
-       "bytes"
        "os"
        "syscall"
 )
@@ -233,7 +232,7 @@ func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err error) {
        if laddr == nil {
                return nil, &OpError{"listen", "udp", nil, errMissingAddress}
        }
-       fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP)
+       fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP)
        if e != nil {
                return nil, e
        }
@@ -252,6 +251,7 @@ func (c *UDPConn) JoinGroup(ifi *Interface, addr IP) error {
        if !c.ok() {
                return os.EINVAL
        }
+       setDefaultMulticastSockopts(c.fd)
        ip := addr.To4()
        if ip != nil {
                return joinIPv4GroupUDP(c, ifi, ip)
@@ -272,66 +272,32 @@ func (c *UDPConn) LeaveGroup(ifi *Interface, addr IP) error {
 }
 
 func joinIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
-       mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
-       if err := setIPv4InterfaceToJoin(mreq, ifi); err != nil {
-               return &OpError{"joinipv4group", "udp", &IPAddr{ip}, err}
-       }
-       if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)); err != nil {
+       err := joinIPv4Group(c.fd, ifi, ip)
+       if err != nil {
                return &OpError{"joinipv4group", "udp", &IPAddr{ip}, err}
        }
        return nil
 }
 
 func leaveIPv4GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
-       mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
-       if err := setIPv4InterfaceToJoin(mreq, ifi); err != nil {
-               return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err}
-       }
-       if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(c.fd.sysfd, syscall.IPPROTO_IP, syscall.IP_DROP_MEMBERSHIP, mreq)); err != nil {
-               return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err}
-       }
-       return nil
-}
-
-func setIPv4InterfaceToJoin(mreq *syscall.IPMreq, ifi *Interface) error {
-       if ifi == nil {
-               return nil
-       }
-       ifat, err := ifi.Addrs()
+       err := leaveIPv4Group(c.fd, ifi, ip)
        if err != nil {
-               return err
-       }
-       for _, ifa := range ifat {
-               if x := ifa.(*IPAddr).IP.To4(); x != nil {
-                       copy(mreq.Interface[:], x)
-                       break
-               }
-       }
-       if bytes.Equal(mreq.Multiaddr[:], IPv4zero) {
-               return os.EINVAL
+               return &OpError{"leaveipv4group", "udp", &IPAddr{ip}, err}
        }
        return nil
 }
 
 func joinIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
-       mreq := &syscall.IPv6Mreq{}
-       copy(mreq.Multiaddr[:], ip)
-       if ifi != nil {
-               mreq.Interface = uint32(ifi.Index)
-       }
-       if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(c.fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)); err != nil {
+       err := joinIPv6Group(c.fd, ifi, ip)
+       if err != nil {
                return &OpError{"joinipv6group", "udp", &IPAddr{ip}, err}
        }
        return nil
 }
 
 func leaveIPv6GroupUDP(c *UDPConn, ifi *Interface, ip IP) error {
-       mreq := &syscall.IPv6Mreq{}
-       copy(mreq.Multiaddr[:], ip)
-       if ifi != nil {
-               mreq.Interface = uint32(ifi.Index)
-       }
-       if err := os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(c.fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_LEAVE_GROUP, mreq)); err != nil {
+       err := leaveIPv6Group(c.fd, ifi, ip)
+       if err != nil {
                return &OpError{"leaveipv6group", "udp", &IPAddr{ip}, err}
        }
        return nil
diff --git a/libgo/go/net/unicast_test.go b/libgo/go/net/unicast_test.go
new file mode 100644 (file)
index 0000000..6ed6f59
--- /dev/null
@@ -0,0 +1,99 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "runtime"
+       "testing"
+)
+
+var unicastTests = []struct {
+       net    string
+       laddr  string
+       ipv6   bool
+       packet bool
+}{
+       {"tcp4", "127.0.0.1:0", false, false},
+       {"tcp6", "[::1]:0", true, false},
+       {"udp4", "127.0.0.1:0", false, true},
+       {"udp6", "[::1]:0", true, true},
+}
+
+func TestUnicastTCPAndUDP(t *testing.T) {
+       if runtime.GOOS == "plan9" || runtime.GOOS == "windows" {
+               return
+       }
+
+       for _, tt := range unicastTests {
+               if tt.ipv6 && !supportsIPv6 {
+                       continue
+               }
+               var fd *netFD
+               if !tt.packet {
+                       c, err := Listen(tt.net, tt.laddr)
+                       if err != nil {
+                               t.Fatalf("Listen failed: %v", err)
+                       }
+                       defer c.Close()
+                       fd = c.(*TCPListener).fd
+               } else {
+                       c, err := ListenPacket(tt.net, tt.laddr)
+                       if err != nil {
+                               t.Fatalf("ListenPacket failed: %v", err)
+                       }
+                       defer c.Close()
+                       fd = c.(*UDPConn).fd
+               }
+               if !tt.ipv6 {
+                       testIPv4UnicastSocketOptions(t, fd)
+               } else {
+                       testIPv6UnicastSocketOptions(t, fd)
+               }
+       }
+}
+
+func testIPv4UnicastSocketOptions(t *testing.T, fd *netFD) {
+       tos, err := ipv4TOS(fd)
+       if err != nil {
+               t.Fatalf("ipv4TOS failed: %v", err)
+       }
+       t.Logf("IPv4 TOS: %v", tos)
+       err = setIPv4TOS(fd, 1)
+       if err != nil {
+               t.Fatalf("setIPv4TOS failed: %v", err)
+       }
+
+       ttl, err := ipv4TTL(fd)
+       if err != nil {
+               t.Fatalf("ipv4TTL failed: %v", err)
+       }
+       t.Logf("IPv4 TTL: %v", ttl)
+       err = setIPv4TTL(fd, 1)
+       if err != nil {
+               t.Fatalf("setIPv4TTL failed: %v", err)
+       }
+}
+
+func testIPv6UnicastSocketOptions(t *testing.T, fd *netFD) {
+       tos, err := ipv6TrafficClass(fd)
+       if err != nil {
+               t.Fatalf("ipv6TrafficClass failed: %v", err)
+       }
+       t.Logf("IPv6 TrafficClass: %v", tos)
+       err = setIPv6TrafficClass(fd, 1)
+       if err != nil {
+               t.Fatalf("setIPv6TrafficClass failed: %v", err)
+       }
+
+       hoplim, err := ipv6HopLimit(fd)
+       if err != nil {
+               t.Fatalf("ipv6HopLimit failed: %v", err)
+       }
+       t.Logf("IPv6 HopLimit: %v", hoplim)
+       err = setIPv6HopLimit(fd, 1)
+       if err != nil {
+               t.Fatalf("setIPv6HopLimit failed: %v", err)
+       }
+}
index 10632c1412e16558bebd92fad91b2580a307e06e..00ee0164f2e64ad14511b258fcddd124f226db2b 100644 (file)
@@ -315,7 +315,7 @@ type UnixListener struct {
 
 // ListenUnix announces on the Unix domain socket laddr and returns a Unix listener.
 // Net must be "unix" (stream sockets).
-func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err error) {
+func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
        if net != "unix" && net != "unixgram" && net != "unixpacket" {
                return nil, UnknownNetworkError(net)
        }
@@ -326,10 +326,10 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err error) {
        if err != nil {
                return nil, err
        }
-       e1 := syscall.Listen(fd.sysfd, 8) // listenBacklog());
-       if e1 != nil {
+       err = syscall.Listen(fd.sysfd, listenerBacklog)
+       if err != nil {
                closesocket(fd.sysfd)
-               return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Err: e1}
+               return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Err: err}
        }
        return &UnixListener{fd, laddr.Name}, nil
 }
index 04ff390727cf9f09608b4a2ef158cd4a8049c1c8..991fa4d057802a26750d0a43b52652f9b5140ec6 100644 (file)
@@ -6,6 +6,7 @@ package os_test
 
 import (
        . "os"
+       "reflect"
        "testing"
 )
 
@@ -57,3 +58,13 @@ func TestExpand(t *testing.T) {
                }
        }
 }
+
+func TestConsistentEnviron(t *testing.T) {
+       e0 := Environ()
+       for i := 0; i < 10; i++ {
+               e1 := Environ()
+               if !reflect.DeepEqual(e0, e1) {
+                       t.Fatalf("environment changed")
+               }
+       }
+}
index 9a6099080341e4e874fda6c2941fb186abb1f8bd..59f2bb061503f547d85f80704b128f69a8d388cd 100644 (file)
@@ -11,8 +11,8 @@ import (
        "io/ioutil"
        . "os"
        "path/filepath"
+       "runtime"
        "strings"
-       "syscall"
        "testing"
        "time"
 )
@@ -33,7 +33,7 @@ type sysDir struct {
 }
 
 var sysdir = func() (sd *sysDir) {
-       switch syscall.OS {
+       switch runtime.GOOS {
        case "windows":
                sd = &sysDir{
                        Getenv("SystemRoot") + "\\system32\\drivers\\etc",
@@ -87,7 +87,7 @@ func size(name string, t *testing.T) int64 {
 }
 
 func equal(name1, name2 string) (r bool) {
-       switch syscall.OS {
+       switch runtime.GOOS {
        case "windows":
                r = strings.ToLower(name1) == strings.ToLower(name2)
        default:
@@ -101,7 +101,7 @@ func newFile(testName string, t *testing.T) (f *File) {
        // On Unix, override $TMPDIR in case the user
        // has it set to an NFS-mounted directory.
        dir := ""
-       if syscall.OS != "windows" {
+       if runtime.GOOS != "windows" {
                dir = "/tmp"
        }
        f, err := ioutil.TempFile(dir, "_Go_"+testName)
@@ -276,7 +276,7 @@ func smallReaddirnames(file *File, length int, t *testing.T) []string {
 func TestReaddirnamesOneAtATime(t *testing.T) {
        // big directory that doesn't change often.
        dir := "/usr/bin"
-       switch syscall.OS {
+       switch runtime.GOOS {
        case "windows":
                dir = Getenv("SystemRoot") + "\\system32"
        case "plan9":
@@ -380,7 +380,7 @@ func TestReaddirNValues(t *testing.T) {
 
 func TestHardLink(t *testing.T) {
        // Hardlinks are not supported under windows or Plan 9.
-       if syscall.OS == "windows" || syscall.OS == "plan9" {
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
                return
        }
        from, to := "hardlinktestfrom", "hardlinktestto"
@@ -413,7 +413,7 @@ func TestHardLink(t *testing.T) {
 
 func TestSymLink(t *testing.T) {
        // Symlinks are not supported under windows or Plan 9.
-       if syscall.OS == "windows" || syscall.OS == "plan9" {
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
                return
        }
        from, to := "symlinktestfrom", "symlinktestto"
@@ -475,7 +475,7 @@ func TestSymLink(t *testing.T) {
 
 func TestLongSymlink(t *testing.T) {
        // Symlinks are not supported under windows or Plan 9.
-       if syscall.OS == "windows" || syscall.OS == "plan9" {
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
                return
        }
        s := "0123456789abcdef"
@@ -545,7 +545,7 @@ func exec(t *testing.T, dir, cmd string, args []string, expect string) {
 func TestStartProcess(t *testing.T) {
        var dir, cmd, le string
        var args []string
-       if syscall.OS == "windows" {
+       if runtime.GOOS == "windows" {
                le = "\r\n"
                cmd = Getenv("COMSPEC")
                dir = Getenv("SystemRoot")
@@ -576,7 +576,7 @@ func checkMode(t *testing.T, path string, mode FileMode) {
 
 func TestChmod(t *testing.T) {
        // Chmod is not supported under windows.
-       if syscall.OS == "windows" {
+       if runtime.GOOS == "windows" {
                return
        }
        f := newFile("TestChmod", t)
@@ -678,7 +678,7 @@ func TestChtimes(t *testing.T) {
        */
        pat := Atime(postStat)
        pmt := postStat.ModTime()
-       if !pat.Before(at) && syscall.OS != "plan9" {
+       if !pat.Before(at) && runtime.GOOS != "plan9" {
                t.Errorf("AccessTime didn't go backwards; was=%d, after=%d", at, pat)
        }
 
@@ -689,7 +689,7 @@ func TestChtimes(t *testing.T) {
 
 func TestChdirAndGetwd(t *testing.T) {
        // TODO(brainman): file.Chdir() is not implemented on windows.
-       if syscall.OS == "windows" {
+       if runtime.GOOS == "windows" {
                return
        }
        fd, err := Open(".")
@@ -700,7 +700,7 @@ func TestChdirAndGetwd(t *testing.T) {
        // (unlike, say, /var, /etc, and /tmp).
        dirs := []string{"/", "/usr/bin"}
        // /usr/bin does not usually exist on Plan 9.
-       if syscall.OS == "plan9" {
+       if runtime.GOOS == "plan9" {
                dirs = []string{"/", "/usr"}
        }
        for mode := 0; mode < 2; mode++ {
@@ -828,7 +828,7 @@ func TestOpenError(t *testing.T) {
                        t.Errorf("Open(%q, %d) returns error of %T type; want *PathError", tt.path, tt.mode, err)
                }
                if perr.Err != tt.error {
-                       if syscall.OS == "plan9" {
+                       if runtime.GOOS == "plan9" {
                                syscallErrStr := perr.Err.Error()
                                expectedErrStr := strings.Replace(tt.error.Error(), "file ", "", 1)
                                if !strings.HasSuffix(syscallErrStr, expectedErrStr) {
@@ -886,7 +886,7 @@ func run(t *testing.T, cmd []string) string {
 func TestHostname(t *testing.T) {
        // There is no other way to fetch hostname on windows, but via winapi.
        // On Plan 9 it is can be taken from #c/sysname as Hostname() does.
-       if syscall.OS == "windows" || syscall.OS == "plan9" {
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
                return
        }
 
index 1f800d78ccaee956ec3ece4b5338b18d67db17b0..1bdcd748bc00991179f02e3beb2a7ec596af4b5e 100644 (file)
@@ -8,6 +8,7 @@ package os_test
 
 import (
        . "os"
+       "runtime"
        "syscall"
        "testing"
 )
@@ -29,7 +30,7 @@ func checkUidGid(t *testing.T, path string, uid, gid int) {
 func TestChown(t *testing.T) {
        // Chown is not supported under windows or Plan 9.
        // Plan9 provides a native ChownPlan9 version instead.
-       if syscall.OS == "windows" || syscall.OS == "plan9" {
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
                return
        }
        // Use TempDir() to make sure we're on a local file system,
index 89d66c29ef92f53ef1cafca22e9d8106c7a46a41..18634ba410ede3d23f8f4707181e6b5a99eadb1e 100644 (file)
@@ -8,7 +8,6 @@ import (
        . "os"
        "path/filepath"
        "runtime"
-       "syscall"
        "testing"
 )
 
@@ -63,7 +62,7 @@ func TestMkdirAll(t *testing.T) {
                t.Fatalf("MkdirAll %q returned wrong error path: %q not %q", ffpath, filepath.Clean(perr.Path), filepath.Clean(fpath))
        }
 
-       if syscall.OS == "windows" {
+       if runtime.GOOS == "windows" {
                path := `_test\_TestMkdirAll_\dir\.\dir2\`
                err := MkdirAll(path, 0777)
                if err != nil {
@@ -117,7 +116,7 @@ func TestRemoveAll(t *testing.T) {
 
        // Determine if we should run the following test.
        testit := true
-       if syscall.OS == "windows" {
+       if runtime.GOOS == "windows" {
                // Chmod is not supported under windows.
                testit = false
        } else {
index 66189a6b9baaa323ae807eb5904c70a5d45b1d07..b0a569e24cd1c85a9e5c48f67982f2b77af2e876 100644 (file)
@@ -24,8 +24,10 @@ func fileInfoFromStat(st *syscall.Stat_t, name string) FileInfo {
        }
        fs.mode = FileMode(st.Mode & 0777)
        switch st.Mode & syscall.S_IFMT {
-       case syscall.S_IFBLK, syscall.S_IFCHR:
+       case syscall.S_IFBLK:
                fs.mode |= ModeDevice
+       case syscall.S_IFCHR:
+               fs.mode |= ModeDevice | ModeCharDevice
        case syscall.S_IFDIR:
                fs.mode |= ModeDir
        case syscall.S_IFIFO:
index 2638153ddbe1d1fa6ebad5f852d1cbcbdd61a938..bf009805fd730dec8af0a26d7a87974434666d14 100644 (file)
@@ -30,19 +30,23 @@ type FileMode uint32
 
 // The defined file mode bits are the most significant bits of the FileMode.
 // The nine least-significant bits are the standard Unix rwxrwxrwx permissions.
+// The values of these bits should be considered part of the public API and
+// may be used in wire protocols or disk representations: they must not be
+// changed, although new bits might be added.
 const (
        // The single letters are the abbreviations
        // used by the String method's formatting.
-       ModeDir       FileMode = 1 << (32 - 1 - iota) // d: is a directory
-       ModeAppend                                    // a: append-only
-       ModeExclusive                                 // l: exclusive use
-       ModeTemporary                                 // t: temporary file (not backed up)
-       ModeSymlink                                   // L: symbolic link
-       ModeDevice                                    // D: device file
-       ModeNamedPipe                                 // p: named pipe (FIFO)
-       ModeSocket                                    // S: Unix domain socket
-       ModeSetuid                                    // u: setuid
-       ModeSetgid                                    // g: setgid
+       ModeDir        FileMode = 1 << (32 - 1 - iota) // d: is a directory
+       ModeAppend                                     // a: append-only
+       ModeExclusive                                  // l: exclusive use
+       ModeTemporary                                  // t: temporary file (not backed up)
+       ModeSymlink                                    // L: symbolic link
+       ModeDevice                                     // D: device file
+       ModeNamedPipe                                  // p: named pipe (FIFO)
+       ModeSocket                                     // S: Unix domain socket
+       ModeSetuid                                     // u: setuid
+       ModeSetgid                                     // g: setgid
+       ModeCharDevice                                 // c: Unix character device, when ModeDevice is set
 
        // Mask for the type bits. For regular files, none will be set.
        ModeType = ModeDir | ModeSymlink | ModeNamedPipe | ModeSocket | ModeDevice
@@ -51,7 +55,7 @@ const (
 )
 
 func (m FileMode) String() string {
-       const str = "daltLDpSug"
+       const str = "daltLDpSugc"
        var buf [20]byte
        w := 0
        for i, c := range str {
index 124370384c095ebfc46f5059ef89a50913021b01..c2b90566a9989e6f20b7e859e7b8dac01380afad 100644 (file)
@@ -19,6 +19,7 @@ func UnlockOSThread()
 // GOMAXPROCS sets the maximum number of CPUs that can be executing
 // simultaneously and returns the previous setting.  If n < 1, it does not
 // change the current setting.
+// The number of logical CPUs on the local machine can be queried with NumCPU.
 // This call will go away when the scheduler improves.
 func GOMAXPROCS(n int) int
 
index 7c986daee632951b1cffceca70f7e9a9f10860d8..25c7470aab16900284e62b84417e0502f92961a4 100644 (file)
@@ -19,8 +19,8 @@ func Gosched()
 func Goexit()
 
 // Caller reports file and line number information about function invocations on
-// the calling goroutine's stack.  The argument skip is the number of stack frames to
-// ascend, with 0 identifying the the caller of Caller.  The return values report the
+// the calling goroutine's stack.  The argument skip is the number of stack frames
+// to ascend, with 0 identifying the caller of Caller.  The return values report the
 // program counter, file name, and line number within the file of the corresponding
 // call.  The boolean ok is false if it was not possible to recover the information.
 func Caller(skip int) (pc uintptr, file string, line int, ok bool)
@@ -59,54 +59,18 @@ func (f *Func) Entry() uintptr { return f.entry }
 // The result will not be accurate if pc is not a program
 // counter within f.
 func (f *Func) FileLine(pc uintptr) (file string, line int) {
-       // NOTE(rsc): If you edit this function, also edit
-       // symtab.c:/^funcline.  That function also has the
-       // comments explaining the logic.
-       targetpc := pc
-
-       var pcQuant uintptr = 1
-       if GOARCH == "arm" {
-               pcQuant = 4
-       }
-
-       p := f.pcln
-       pc = f.pc0
-       line = int(f.ln0)
-       i := 0
-       //print("FileLine start pc=", pc, " targetpc=", targetpc, " line=", line,
-       //      " tab=", p, " ", p[0], " quant=", pcQuant, " GOARCH=", GOARCH, "\n")
-       for {
-               for i < len(p) && p[i] > 128 {
-                       pc += pcQuant * uintptr(p[i]-128)
-                       i++
-               }
-               //print("pc<", pc, " targetpc=", targetpc, " line=", line, "\n")
-               if pc > targetpc || i >= len(p) {
-                       break
-               }
-               if p[i] == 0 {
-                       if i+5 > len(p) {
-                               break
-                       }
-                       line += int(p[i+1]<<24) | int(p[i+2]<<16) | int(p[i+3]<<8) | int(p[i+4])
-                       i += 5
-               } else if p[i] <= 64 {
-                       line += int(p[i])
-                       i++
-               } else {
-                       line -= int(p[i] - 64)
-                       i++
-               }
-               //print("pc=", pc, " targetpc=", targetpc, " line=", line, "\n")
-               pc += pcQuant
-       }
-       file = f.src
-       return
+       return funcline_go(f, pc)
 }
 
+// implemented in symtab.c
+func funcline_go(*Func, uintptr) (string, int)
+
 // mid returns the current os thread (m) id.
 func mid() uint32
 
+// NumCPU returns the number of logical CPUs on the local machine.
+func NumCPU() int
+
 // Semacquire waits until *s > 0 and then atomically decrements it.
 // It is intended as a simple sleep primitive for use by the synchronization
 // library and should not be used directly.
index 4aa4ca6d7da78e25bdec3b83f1a4695b487d8dfc..31da3c83d0d6162e70bfc353dac7dc985f94e93d 100644 (file)
@@ -191,7 +191,7 @@ func Sort(data Interface) {
                maxDepth++
        }
        maxDepth *= 2
-       quickSort(data, 0, data.Len(), maxDepth)
+       quickSort(data, 0, n, maxDepth)
 }
 
 func IsSorted(data Interface) bool {
index 980052a778bc45813c807a2b7719f51673e4cfce..64ab84f45549914b8e73df1a1b21beee88ace0df 100644 (file)
@@ -191,6 +191,36 @@ func (f *extFloat) Assign(x float64) {
        f.exp -= 64
 }
 
+// AssignComputeBounds sets f to the value of x and returns
+// lower, upper such that any number in the closed interval
+// [lower, upper] is converted back to x.
+func (f *extFloat) AssignComputeBounds(x float64) (lower, upper extFloat) {
+       // Special cases.
+       bits := math.Float64bits(x)
+       flt := &float64info
+       neg := bits>>(flt.expbits+flt.mantbits) != 0
+       expBiased := int(bits>>flt.mantbits) & (1<<flt.expbits - 1)
+       mant := bits & (uint64(1)<<flt.mantbits - 1)
+
+       if expBiased == 0 {
+               // denormalized.
+               f.mant = mant
+               f.exp = 1 + flt.bias - int(flt.mantbits)
+       } else {
+               f.mant = mant | 1<<flt.mantbits
+               f.exp = expBiased + flt.bias - int(flt.mantbits)
+       }
+       f.neg = neg
+
+       upper = extFloat{mant: 2*f.mant + 1, exp: f.exp - 1, neg: f.neg}
+       if mant != 0 || expBiased == 1 {
+               lower = extFloat{mant: 2*f.mant - 1, exp: f.exp - 1, neg: f.neg}
+       } else {
+               lower = extFloat{mant: 4*f.mant - 1, exp: f.exp - 2, neg: f.neg}
+       }
+       return
+}
+
 // Normalize normalizes f so that the highest bit of the mantissa is
 // set, and returns the number by which the mantissa was left-shifted.
 func (f *extFloat) Normalize() uint {
@@ -309,3 +339,163 @@ func (f *extFloat) AssignDecimal(d *decimal) (ok bool) {
        }
        return true
 }
+
+// Frexp10 is an analogue of math.Frexp for decimal powers. It scales
+// f by an approximate power of ten 10^-exp, and returns exp10, so
+// that f*10^exp10 has the same value as the old f, up to an ulp,
+// as well as the index of 10^-exp in the powersOfTen table.
+// The arguments expMin and expMax constrain the final value of the
+// binary exponent of f.
+func (f *extFloat) frexp10(expMin, expMax int) (exp10, index int) {
+       // it is illegal to call this function with a too restrictive exponent range.
+       if expMax-expMin <= 25 {
+               panic("strconv: invalid exponent range")
+       }
+       // Find power of ten such that x * 10^n has a binary exponent
+       // between expMin and expMax
+       approxExp10 := -(f.exp + 100) * 28 / 93 // log(10)/log(2) is close to 93/28.
+       i := (approxExp10 - firstPowerOfTen) / stepPowerOfTen
+Loop:
+       for {
+               exp := f.exp + powersOfTen[i].exp + 64
+               switch {
+               case exp < expMin:
+                       i++
+               case exp > expMax:
+                       i--
+               default:
+                       break Loop
+               }
+       }
+       // Apply the desired decimal shift on f. It will have exponent
+       // in the desired range. This is multiplication by 10^-exp10.
+       f.Multiply(powersOfTen[i])
+
+       return -(firstPowerOfTen + i*stepPowerOfTen), i
+}
+
+// frexp10Many applies a common shift by a power of ten to a, b, c.
+func frexp10Many(expMin, expMax int, a, b, c *extFloat) (exp10 int) {
+       exp10, i := c.frexp10(expMin, expMax)
+       a.Multiply(powersOfTen[i])
+       b.Multiply(powersOfTen[i])
+       return
+}
+
+// ShortestDecimal stores in d the shortest decimal representation of f
+// which belongs to the open interval (lower, upper), where f is supposed
+// to lie. It returns false whenever the result is unsure. The implementation
+// uses the Grisu3 algorithm.
+func (f *extFloat) ShortestDecimal(d *decimal, lower, upper *extFloat) bool {
+       if f.mant == 0 {
+               d.d[0] = '0'
+               d.nd = 1
+               d.dp = 0
+               d.neg = f.neg
+       }
+       const minExp = -60
+       const maxExp = -32
+       upper.Normalize()
+       // Uniformize exponents.
+       if f.exp > upper.exp {
+               f.mant <<= uint(f.exp - upper.exp)
+               f.exp = upper.exp
+       }
+       if lower.exp > upper.exp {
+               lower.mant <<= uint(lower.exp - upper.exp)
+               lower.exp = upper.exp
+       }
+
+       exp10 := frexp10Many(minExp, maxExp, lower, f, upper)
+       // Take a safety margin due to rounding in frexp10Many, but we lose precision.
+       upper.mant++
+       lower.mant--
+
+       // The shortest representation of f is either rounded up or down, but
+       // in any case, it is a truncation of upper.
+       shift := uint(-upper.exp)
+       integer := uint32(upper.mant >> shift)
+       fraction := upper.mant - (uint64(integer) << shift)
+
+       // How far we can go down from upper until the result is wrong.
+       allowance := upper.mant - lower.mant
+       // How far we should go to get a very precise result.
+       targetDiff := upper.mant - f.mant
+
+       // Count integral digits: there are at most 10.
+       var integerDigits int
+       for i, pow := range uint64pow10 {
+               if uint64(integer) >= pow {
+                       integerDigits = i + 1
+               }
+       }
+       for i := 0; i < integerDigits; i++ {
+               pow := uint64pow10[integerDigits-i-1]
+               digit := integer / uint32(pow)
+               d.d[i] = byte(digit + '0')
+               integer -= digit * uint32(pow)
+               // evaluate whether we should stop.
+               if currentDiff := uint64(integer)<<shift + fraction; currentDiff < allowance {
+                       d.nd = i + 1
+                       d.dp = integerDigits + exp10
+                       d.neg = f.neg
+                       // Sometimes allowance is so large the last digit might need to be
+                       // decremented to get closer to f.
+                       return adjustLastDigit(d, currentDiff, targetDiff, allowance, pow<<shift, 2)
+               }
+       }
+       d.nd = integerDigits
+       d.dp = d.nd + exp10
+       d.neg = f.neg
+
+       // Compute digits of the fractional part. At each step fraction does not
+       // overflow. The choice of minExp implies that fraction is less than 2^60.
+       var digit int
+       multiplier := uint64(1)
+       for {
+               fraction *= 10
+               multiplier *= 10
+               digit = int(fraction >> shift)
+               d.d[d.nd] = byte(digit + '0')
+               d.nd++
+               fraction -= uint64(digit) << shift
+               if fraction < allowance*multiplier {
+                       // We are in the admissible range. Note that if allowance is about to
+                       // overflow, that is, allowance > 2^64/10, the condition is automatically
+                       // true due to the limited range of fraction.
+                       return adjustLastDigit(d,
+                               fraction, targetDiff*multiplier, allowance*multiplier,
+                               1<<shift, multiplier*2)
+               }
+       }
+       return false
+}
+
+// adjustLastDigit modifies d = x-currentDiff*ε, to get closest to 
+// d = x-targetDiff*ε, without becoming smaller than x-maxDiff*ε.
+// It assumes that a decimal digit is worth ulpDecimal*ε, and that
+// all data is known with a error estimate of ulpBinary*ε.
+func adjustLastDigit(d *decimal, currentDiff, targetDiff, maxDiff, ulpDecimal, ulpBinary uint64) bool {
+       if ulpDecimal < 2*ulpBinary {
+               // Appromixation is too wide.
+               return false
+       }
+       for currentDiff+ulpDecimal/2+ulpBinary < targetDiff {
+               d.d[d.nd-1]--
+               currentDiff += ulpDecimal
+       }
+       if currentDiff+ulpDecimal <= targetDiff+ulpDecimal/2+ulpBinary {
+               // we have two choices, and don't know what to do.
+               return false
+       }
+       if currentDiff < ulpBinary || currentDiff > maxDiff-ulpBinary {
+               // we went too far
+               return false
+       }
+       if d.nd == 1 && d.d[0] == '0' {
+               // the number has actually reached zero.
+               d.nd = 0
+               d.dp = 0
+       }
+       return true
+}
index 47877e373aaecc68c53a8897dabc6a1019999710..171defa4417ddfd09ff813a888ae81f4f38046db 100644 (file)
@@ -26,8 +26,8 @@ func pow2(i int) float64 {
        return pow2(i/2) * pow2(i-i/2)
 }
 
-// Wrapper around strconv.Atof64.  Handles dddddp+ddd (binary exponent)
-// itself, passes the rest on to strconv.Atof64.
+// Wrapper around strconv.ParseFloat(x, 64).  Handles dddddp+ddd (binary exponent)
+// itself, passes the rest on to strconv.ParseFloat.
 func myatof64(s string) (f float64, ok bool) {
        a := strings.SplitN(s, "p", 2)
        if len(a) == 2 {
@@ -70,8 +70,8 @@ func myatof64(s string) (f float64, ok bool) {
        return f1, true
 }
 
-// Wrapper around strconv.Atof32.  Handles dddddp+ddd (binary exponent)
-// itself, passes the rest on to strconv.Atof32.
+// Wrapper around strconv.ParseFloat(x, 32).  Handles dddddp+ddd (binary exponent)
+// itself, passes the rest on to strconv.ParseFloat.
 func myatof32(s string) (f float32, ok bool) {
        a := strings.SplitN(s, "p", 2)
        if len(a) == 2 {
index f4434fd51753c8a96e0ecee3406f09af2c8ae063..8eefbee79f21dca076a54779b15e45a7c4b13709 100644 (file)
@@ -98,29 +98,43 @@ func genericFtoa(dst []byte, val float64, fmt byte, prec, bitSize int) []byte {
                return fmtB(dst, neg, mant, exp, flt)
        }
 
-       // Create exact decimal representation.
-       // The shift is exp - flt.mantbits because mant is a 1-bit integer
-       // followed by a flt.mantbits fraction, and we are treating it as
-       // a 1+flt.mantbits-bit integer.
-       d := new(decimal)
-       d.Assign(mant)
-       d.Shift(exp - int(flt.mantbits))
-
-       // Round appropriately.
        // Negative precision means "only as much as needed to be exact."
-       shortest := false
-       if prec < 0 {
-               shortest = true
-               roundShortest(d, mant, exp, flt)
-               switch fmt {
-               case 'e', 'E':
-                       prec = d.nd - 1
-               case 'f':
-                       prec = max(d.nd-d.dp, 0)
-               case 'g', 'G':
-                       prec = d.nd
+       shortest := prec < 0
+
+       d := new(decimal)
+       if shortest {
+               ok := false
+               if optimize && bitSize == 64 {
+                       // Try Grisu3 algorithm.
+                       f := new(extFloat)
+                       lower, upper := f.AssignComputeBounds(val)
+                       ok = f.ShortestDecimal(d, &lower, &upper)
+               }
+               if !ok {
+                       // Create exact decimal representation.
+                       // The shift is exp - flt.mantbits because mant is a 1-bit integer
+                       // followed by a flt.mantbits fraction, and we are treating it as
+                       // a 1+flt.mantbits-bit integer.
+                       d.Assign(mant)
+                       d.Shift(exp - int(flt.mantbits))
+                       roundShortest(d, mant, exp, flt)
+               }
+               // Precision for shortest representation mode.
+               if prec < 0 {
+                       switch fmt {
+                       case 'e', 'E':
+                               prec = d.nd - 1
+                       case 'f':
+                               prec = max(d.nd-d.dp, 0)
+                       case 'g', 'G':
+                               prec = d.nd
+                       }
                }
        } else {
+               // Create exact decimal representation.
+               d.Assign(mant)
+               d.Shift(exp - int(flt.mantbits))
+               // Round appropriately.
                switch fmt {
                case 'e', 'E':
                        d.Round(prec + 1)
@@ -178,15 +192,26 @@ func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
                return
        }
 
-       // TODO(rsc): Unless exp == minexp, if the number of digits in d
-       // is less than 17, it seems likely that it would be
-       // the shortest possible number already.  So maybe we can
-       // bail out without doing the extra multiprecision math here.
-
        // Compute upper and lower such that any decimal number
        // between upper and lower (possibly inclusive)
        // will round to the original floating point number.
 
+       // We may see at once that the number is already shortest.
+       //
+       // Suppose d is not denormal, so that 2^exp <= d < 10^dp.
+       // The closest shorter number is at least 10^(dp-nd) away.
+       // The lower/upper bounds computed below are at distance
+       // at most 2^(exp-mantbits).
+       //
+       // So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
+       // or equivalently log2(10)*(dp-nd) > exp-mantbits.
+       // It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
+       minexp := flt.bias + 1 // minimum possible exponent
+       if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
+               // The number is already shortest.
+               return
+       }
+
        // d = mant << (exp - mantbits)
        // Next highest floating point number is mant+1 << exp-mantbits.
        // Our upper bound is halfway inbetween, mant*2+1 << exp-mantbits-1.
@@ -200,7 +225,6 @@ func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
        // in which case the next lowest is mant*2-1 << exp-mantbits-1.
        // Either way, call it mantlo << explo-mantbits.
        // Our lower bound is halfway inbetween, mantlo*2+1 << explo-mantbits-1.
-       minexp := flt.bias + 1 // minimum possible exponent
        var mantlo uint64
        var explo int
        if mant > 1<<flt.mantbits || exp == minexp {
@@ -241,7 +265,7 @@ func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
 
                // Okay to round up if upper has a different digit and
                // either upper is inclusive or upper is bigger than the result of rounding up.
-               okup := m != u && (inclusive || i+1 < upper.nd)
+               okup := m != u && (inclusive || m+1 < u || i+1 < upper.nd)
 
                // If it's okay to do either, then round to the nearest one.
                // If it's okay to do only one, do it.
index c69f8c2466d0e915446520cd7584188bd9eedfb5..ee7b7c431e7fa413e2c48fe0dea27c3b72a7dc52 100644 (file)
@@ -6,6 +6,7 @@ package strconv_test
 
 import (
        "math"
+       "math/rand"
        . "strconv"
        "testing"
 )
@@ -123,6 +124,10 @@ var ftoatests = []ftoaTest{
        {2.2250738585072012e-308, 'g', -1, "2.2250738585072014e-308"},
        // http://www.exploringbinary.com/php-hangs-on-numeric-value-2-2250738585072011e-308/
        {2.2250738585072011e-308, 'g', -1, "2.225073858507201e-308"},
+
+       // Issue 2625.
+       {383260575764816448, 'f', 0, "383260575764816448"},
+       {383260575764816448, 'g', -1, "3.8326057576481645e+17"},
 }
 
 func TestFtoa(t *testing.T) {
@@ -149,6 +154,25 @@ func TestFtoa(t *testing.T) {
        }
 }
 
+func TestFtoaRandom(t *testing.T) {
+       N := int(1e4)
+       if testing.Short() {
+               N = 100
+       }
+       t.Logf("testing %d random numbers with fast and slow FormatFloat", N)
+       for i := 0; i < N; i++ {
+               bits := uint64(rand.Uint32())<<32 | uint64(rand.Uint32())
+               x := math.Float64frombits(bits)
+               shortFast := FormatFloat(x, 'g', -1, 64)
+               SetOptimize(false)
+               shortSlow := FormatFloat(x, 'g', -1, 64)
+               SetOptimize(true)
+               if shortSlow != shortFast {
+                       t.Errorf("%b printed as %s, want %s", x, shortFast, shortSlow)
+               }
+       }
+}
+
 /* This test relies on escape analysis which gccgo does not yet do.
 
 func TestAppendFloatDoesntAllocate(t *testing.T) {
@@ -188,6 +212,12 @@ func BenchmarkFormatFloatExp(b *testing.B) {
        }
 }
 
+func BenchmarkFormatFloatNegExp(b *testing.B) {
+       for i := 0; i < b.N; i++ {
+               FormatFloat(-5.11e-95, 'g', -1, 64)
+       }
+}
+
 func BenchmarkFormatFloatBig(b *testing.B) {
        for i := 0; i < b.N; i++ {
                FormatFloat(123456789123456789123456789, 'g', -1, 64)
@@ -215,6 +245,13 @@ func BenchmarkAppendFloatExp(b *testing.B) {
        }
 }
 
+func BenchmarkAppendFloatNegExp(b *testing.B) {
+       dst := make([]byte, 0, 30)
+       for i := 0; i < b.N; i++ {
+               AppendFloat(dst, -5.11e-95, 'g', -1, 64)
+       }
+}
+
 func BenchmarkAppendFloatBig(b *testing.B) {
        dst := make([]byte, 0, 30)
        for i := 0; i < b.N; i++ {
index edba62954be4e03bc6e2a41f3e69b89d07529608..61dbcae70f4b2892d14492a69e07c26dda0e0687 100644 (file)
@@ -260,6 +260,7 @@ func UnquoteChar(s string, quote byte) (value rune, multibyte bool, tail string,
                for j := 0; j < 2; j++ { // one digit already; two more
                        x := rune(s[j]) - '0'
                        if x < 0 || x > 7 {
+                               err = ErrSyntax
                                return
                        }
                        v = (v << 3) | x
index 419943d83c751133b9d577227398605530c7f5ee..3f544c43cd55cbd0f931233b1ad345d5e578f0d5 100644 (file)
@@ -191,7 +191,13 @@ var misquoted = []string{
        `"'`,
        `b"`,
        `"\"`,
+       `"\9"`,
+       `"\19"`,
+       `"\129"`,
        `'\'`,
+       `'\9'`,
+       `'\19'`,
+       `'\129'`,
        `'ab'`,
        `"\x1!"`,
        `"\U12345678"`,
index 3ba0fb1b098d30be0b72c3730299a382ce825cfa..c1a02135f4f1f0709ce43e198036aa006662acdb 100644 (file)
@@ -10,26 +10,40 @@ package syscall
 
 import "sync"
 
-var env map[string]string
-var envOnce sync.Once
-var Envs []string // provided by runtime
+var (
+       // envOnce guards initialization by copyenv, which populates env.
+       envOnce sync.Once
 
+       // envLock guards env and envs.
+       envLock sync.RWMutex
+
+       // env maps from an environment variable to its first occurrence in envs.
+       env map[string]int
+
+       // envs is provided by the runtime. elements are expected to be
+       // of the form "key=value".
+       Envs []string
+)
+
+// setenv_c is provided by the runtime, but is a no-op if cgo isn't
+// loaded.
 func setenv_c(k, v string)
 
 func copyenv() {
-       env = make(map[string]string)
-       for _, s := range Envs {
+       env = make(map[string]int)
+       for i, s := range Envs {
                for j := 0; j < len(s); j++ {
                        if s[j] == '=' {
-                               env[s[0:j]] = s[j+1:]
+                               key := s[:j]
+                               if _, ok := env[key]; !ok {
+                                       env[key] = i
+                               }
                                break
                        }
                }
        }
 }
 
-var envLock sync.RWMutex
-
 func Getenv(key string) (value string, found bool) {
        envOnce.Do(copyenv)
        if len(key) == 0 {
@@ -39,11 +53,17 @@ func Getenv(key string) (value string, found bool) {
        envLock.RLock()
        defer envLock.RUnlock()
 
-       v, ok := env[key]
+       i, ok := env[key]
        if !ok {
                return "", false
        }
-       return v, true
+       s := Envs[i]
+       for i := 0; i < len(s); i++ {
+               if s[i] == '=' {
+                       return s[i+1:], true
+               }
+       }
+       return "", false
 }
 
 func Setenv(key, value string) error {
@@ -55,8 +75,16 @@ func Setenv(key, value string) error {
        envLock.Lock()
        defer envLock.Unlock()
 
-       env[key] = value
-       setenv_c(key, value) // is a no-op if cgo isn't loaded
+       i, ok := env[key]
+       kv := key + "=" + value
+       if ok {
+               Envs[i] = kv
+       } else {
+               i = len(Envs)
+               Envs = append(Envs, kv)
+       }
+       env[key] = i
+       setenv_c(key, value)
        return nil
 }
 
@@ -66,8 +94,8 @@ func Clearenv() {
        envLock.Lock()
        defer envLock.Unlock()
 
-       env = make(map[string]string)
-
+       env = make(map[string]int)
+       Envs = []string{}
        // TODO(bradfitz): pass through to C
 }
 
@@ -75,11 +103,7 @@ func Environ() []string {
        envOnce.Do(copyenv)
        envLock.RLock()
        defer envLock.RUnlock()
-       a := make([]string, len(env))
-       i := 0
-       for k, v := range env {
-               a[i] = k + "=" + v
-               i++
-       }
+       a := make([]string, len(Envs))
+       copy(a, Envs)
        return a
 }
diff --git a/libgo/go/syscall/exec_bsd.go b/libgo/go/syscall/exec_bsd.go
new file mode 100644 (file)
index 0000000..7baa3af
--- /dev/null
@@ -0,0 +1,227 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd netbsd openbsd
+
+package syscall
+
+import (
+       "runtime"
+       "unsafe"
+)
+
+type SysProcAttr struct {
+       Chroot     string      // Chroot.
+       Credential *Credential // Credential.
+       Ptrace     bool        // Enable tracing.
+       Setsid     bool        // Create session.
+       Setpgid    bool        // Set process group ID to new pid (SYSV setpgrp)
+       Setctty    bool        // Set controlling terminal to fd 0
+       Noctty     bool        // Detach fd 0 from controlling terminal
+}
+
+// Fork, dup fd onto 0..len(fd), and exec(argv0, argvv, envv) in child.
+// If a dup or exec fails, write the errno error to pipe.
+// (Pipe is close-on-exec so if exec succeeds, it will be closed.)
+// In the child, this function must not acquire any locks, because
+// they might have been locked at the time of the fork.  This means
+// no rescheduling, no malloc calls, and no new stack segments.
+// The calls to RawSyscall are okay because they are assembly
+// functions that do not grow the stack.
+func forkAndExecInChild(argv0 *byte, argv, envv []*byte, chroot, dir *byte, attr *ProcAttr, sys *SysProcAttr, pipe int) (pid int, err Errno) {
+       // Declare all variables at top in case any
+       // declarations require heap allocation (e.g., err1).
+       var (
+               r1     Pid_t
+               err1   Errno
+               nextfd int
+               i      int
+       )
+
+       // guard against side effects of shuffling fds below.
+       fd := append([]int(nil), attr.Files...)
+
+       // About to call fork.
+       // No more allocation or calls of non-assembly functions.
+       r1, err1 = raw_fork()
+       if err1 != 0 {
+               return 0, err1
+       }
+
+       if r1 != 0 {
+               // parent; return PID
+               return int(r1), 0
+       }
+
+       // Fork succeeded, now in child.
+
+       // Enable tracing if requested.
+       if sys.Ptrace {
+               err1 = raw_trace(_PTRACE_TRACEME, 0, nil, nil)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Session ID
+       if sys.Setsid {
+               err1 = raw_setsid()
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Set process group
+       if sys.Setpgid {
+               err1 = raw_setpgid(0, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Chroot
+       if chroot != nil {
+               err1 = raw_chroot(chroot)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // User and groups
+       if cred := sys.Credential; cred != nil {
+               ngroups := len(cred.Groups)
+               if ngroups == 0 {
+                       err2 := setgroups(0, nil)
+                       if err2 == nil {
+                               err1 = 0
+                       } else {
+                               err1 = err2.(Errno)
+                       }
+               } else {
+                       groups := make([]Gid_t, ngroups)
+                       for i, v := range cred.Groups {
+                               groups[i] = Gid_t(v)
+                       }
+                       err2 := setgroups(ngroups, &groups[0])
+                       if err2 == nil {
+                               err1 = 0
+                       } else {
+                               err1 = err2.(Errno)
+                       }
+               }
+               if err1 != 0 {
+                       goto childerror
+               }
+               err2 := Setgid(int(cred.Gid))
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+               err2 = Setuid(int(cred.Uid))
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+       }
+
+       // Chdir
+       if dir != nil {
+               err1 = raw_chdir(dir)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Pass 1: look for fd[i] < i and move those up above len(fd)
+       // so that pass 2 won't stomp on an fd it needs later.
+       nextfd = int(len(fd))
+       if pipe < nextfd {
+               _, err2 := Dup2(pipe, nextfd)
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+               raw_fcntl(nextfd, F_SETFD, FD_CLOEXEC)
+               pipe = nextfd
+               nextfd++
+       }
+       for i = 0; i < len(fd); i++ {
+               if fd[i] >= 0 && fd[i] < int(i) {
+                       _, err2 := Dup2(fd[i], nextfd)
+                       if err2 != nil {
+                               err1 = err2.(Errno)
+                               goto childerror
+                       }
+                       raw_fcntl(nextfd, F_SETFD, F_CLOEXEC)
+                       fd[i] = nextfd
+                       nextfd++
+                       if nextfd == pipe { // don't stomp on pipe
+                               nextfd++
+                       }
+               }
+       }
+
+       // Pass 2: dup fd[i] down onto i.
+       for i = 0; i < len(fd); i++ {
+               if fd[i] == -1 {
+                       raw_close(i)
+                       continue
+               }
+               if fd[i] == int(i) {
+                       // dup2(i, i) won't clear close-on-exec flag on Linux,
+                       // probably not elsewhere either.
+                       _, err1 = raw_fcntl(fd[i], F_SETFD, 0)
+                       if err1 != 0 {
+                               goto childerror
+                       }
+                       continue
+               }
+               // The new fd is created NOT close-on-exec,
+               // which is exactly what we want.
+               _, err2 := Dup2(fd[i], i)
+               if err1 != 0 {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+       }
+
+       // By convention, we don't close-on-exec the fds we are
+       // started with, so if len(fd) < 3, close 0, 1, 2 as needed.
+       // Programs that know they inherit fds >= 3 will need
+       // to set them close-on-exec.
+       for i = len(fd); i < 3; i++ {
+               raw_close(i)
+       }
+
+       // Detach fd 0 from tty
+       if sys.Noctty {
+               _, err1 = raw_ioctl(0, IOTCNOTTY, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Make fd 0 the tty
+       if sys.Setctty {
+               _, err1 = raw_ioctl(TIOCSCTTY, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Time to exec.
+       err1 = raw_execve(argv0, &argv[0], &envv[0])
+
+childerror:
+       // send error code on pipe
+       raw_write(pipe, (*byte)(unsafe.Pointer(&err1)), int(unsafe.Sizeof(err1)))
+       for {
+               raw_exit(253)
+       }
+
+       // Calling panic is not actually safe,
+       // but the for loop above won't break
+       // and this shuts up the compiler.
+       panic("unreached")
+}
diff --git a/libgo/go/syscall/exec_linux.go b/libgo/go/syscall/exec_linux.go
new file mode 100644 (file)
index 0000000..98dbeb2
--- /dev/null
@@ -0,0 +1,251 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package syscall
+
+import (
+       "unsafe"
+)
+
+//sysnb        raw_prctl(option int, arg2 int, arg3 int, arg4 int, arg5 int) (ret int, err Errno)
+//prctl(option int, arg2 _C_long, arg3 _C_long, arg4 _C_long, arg5 _C_long) int
+
+type SysProcAttr struct {
+       Chroot     string      // Chroot.
+       Credential *Credential // Credential.
+       Ptrace     bool        // Enable tracing.
+       Setsid     bool        // Create session.
+       Setpgid    bool        // Set process group ID to new pid (SYSV setpgrp)
+       Setctty    bool        // Set controlling terminal to fd 0
+       Noctty     bool        // Detach fd 0 from controlling terminal
+       Pdeathsig  int         // Signal that the process will get when its parent dies (Linux only)
+}
+
+// Fork, dup fd onto 0..len(fd), and exec(argv0, argvv, envv) in child.
+// If a dup or exec fails, write the errno error to pipe.
+// (Pipe is close-on-exec so if exec succeeds, it will be closed.)
+// In the child, this function must not acquire any locks, because
+// they might have been locked at the time of the fork.  This means
+// no rescheduling, no malloc calls, and no new stack segments.
+// The calls to RawSyscall are okay because they are assembly
+// functions that do not grow the stack.
+func forkAndExecInChild(argv0 *byte, argv, envv []*byte, chroot, dir *byte, attr *ProcAttr, sys *SysProcAttr, pipe int) (pid int, err Errno) {
+       // Declare all variables at top in case any
+       // declarations require heap allocation (e.g., err1).
+       var (
+               r1     Pid_t
+               err1   Errno
+               nextfd int
+               i      int
+       )
+
+       // guard against side effects of shuffling fds below.
+       fd := append([]int(nil), attr.Files...)
+
+       // About to call fork.
+       // No more allocation or calls of non-assembly functions.
+       r1, err1 = raw_fork()
+       if err1 != 0 {
+               return 0, err1
+       }
+
+       if r1 != 0 {
+               // parent; return PID
+               return int(r1), 0
+       }
+
+       // Fork succeeded, now in child.
+
+       // Parent death signal
+       if sys.Pdeathsig != 0 {
+               _, err1 = raw_prctl(PR_SET_PDEATHSIG, sys.Pdeathsig, 0, 0, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+
+               // Signal self if parent is already dead. This might cause a
+               // duplicate signal in rare cases, but it won't matter when
+               // using SIGKILL.
+               ppid := Getppid()
+               if ppid == 1 {
+                       pid = Getpid()
+                       err2 := Kill(pid, sys.Pdeathsig)
+                       if err2 != nil {
+                               err1 = err2.(Errno)
+                               goto childerror
+                       }
+               }
+       }
+
+       // Enable tracing if requested.
+       if sys.Ptrace {
+               err1 = raw_ptrace(_PTRACE_TRACEME, 0, nil, nil)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Session ID
+       if sys.Setsid {
+               err1 = raw_setsid()
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Set process group
+       if sys.Setpgid {
+               err1 = raw_setpgid(0, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Chroot
+       if chroot != nil {
+               err1 = raw_chroot(chroot)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // User and groups
+       if cred := sys.Credential; cred != nil {
+               ngroups := len(cred.Groups)
+               if ngroups == 0 {
+                       err2 := setgroups(0, nil)
+                       if err2 == nil {
+                               err1 = 0
+                       } else {
+                               err1 = err2.(Errno)
+                       }
+               } else {
+                       groups := make([]Gid_t, ngroups)
+                       for i, v := range cred.Groups {
+                               groups[i] = Gid_t(v)
+                       }
+                       err2 := setgroups(ngroups, &groups[0])
+                       if err2 == nil {
+                               err1 = 0
+                       } else {
+                               err1 = err2.(Errno)
+                       }
+               }
+               if err1 != 0 {
+                       goto childerror
+               }
+               err2 := Setgid(int(cred.Gid))
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+               err2 = Setuid(int(cred.Uid))
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+       }
+
+       // Chdir
+       if dir != nil {
+               err1 = raw_chdir(dir)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Pass 1: look for fd[i] < i and move those up above len(fd)
+       // so that pass 2 won't stomp on an fd it needs later.
+       nextfd = int(len(fd))
+       if pipe < nextfd {
+               _, err2 := Dup2(pipe, nextfd)
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+               raw_fcntl(nextfd, F_SETFD, FD_CLOEXEC)
+               pipe = nextfd
+               nextfd++
+       }
+       for i = 0; i < len(fd); i++ {
+               if fd[i] >= 0 && fd[i] < int(i) {
+                       _, err2 := Dup2(fd[i], nextfd)
+                       if err2 != nil {
+                               err1 = err2.(Errno)
+                               goto childerror
+                       }
+                       raw_fcntl(nextfd, F_SETFD, FD_CLOEXEC)
+                       fd[i] = nextfd
+                       nextfd++
+                       if nextfd == pipe { // don't stomp on pipe
+                               nextfd++
+                       }
+               }
+       }
+
+       // Pass 2: dup fd[i] down onto i.
+       for i = 0; i < len(fd); i++ {
+               if fd[i] == -1 {
+                       raw_close(i)
+                       continue
+               }
+               if fd[i] == int(i) {
+                       // dup2(i, i) won't clear close-on-exec flag on Linux,
+                       // probably not elsewhere either.
+                       _, err1 = raw_fcntl(fd[i], F_SETFD, 0)
+                       if err1 != 0 {
+                               goto childerror
+                       }
+                       continue
+               }
+               // The new fd is created NOT close-on-exec,
+               // which is exactly what we want.
+               _, err2 := Dup2(fd[i], i);
+               if err2 != nil {
+                       err1 = err2.(Errno)
+                       goto childerror
+               }
+       }
+
+       // By convention, we don't close-on-exec the fds we are
+       // started with, so if len(fd) < 3, close 0, 1, 2 as needed.
+       // Programs that know they inherit fds >= 3 will need
+       // to set them close-on-exec.
+       for i = len(fd); i < 3; i++ {
+               raw_close(i)
+       }
+
+       // Detach fd 0 from tty
+       if sys.Noctty {
+               _, err1 = raw_ioctl(0, TIOCNOTTY, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Make fd 0 the tty
+       if sys.Setctty {
+               _, err1 = raw_ioctl(0, TIOCSCTTY, 0)
+               if err1 != 0 {
+                       goto childerror
+               }
+       }
+
+       // Time to exec.
+       err1 = raw_execve(argv0, &argv[0], &envv[0])
+
+childerror:
+       // send error code on pipe
+       raw_write(pipe, (*byte)(unsafe.Pointer(&err1)), int(unsafe.Sizeof(err1)))
+       for {
+               raw_exit(253)
+       }
+
+       // Calling panic is not actually safe,
+       // but the for loop above won't break
+       // and this shuts up the compiler.
+       panic("unreached")
+}
index 0cd37c4a0b529c48c1c0cc369b54643937d956e9..131ebaae87c03e9ba158fa717173da2bd1dfe4cb 100644 (file)
@@ -9,6 +9,7 @@
 package syscall
 
 import (
+       "runtime"
        "sync"
        "unsafe"
 )
@@ -126,211 +127,6 @@ func SetNonblock(fd int, nonblocking bool) (err error) {
        return err
 }
 
-// Fork, dup fd onto 0..len(fd), and exec(argv0, argvv, envv) in child.
-// If a dup or exec fails, write the errno error to pipe.
-// (Pipe is close-on-exec so if exec succeeds, it will be closed.)
-// In the child, this function must not acquire any locks, because
-// they might have been locked at the time of the fork.  This means
-// no rescheduling, no malloc calls, and no new stack segments.
-// The calls to RawSyscall are okay because they are assembly
-// functions that do not grow the stack.
-func forkAndExecInChild(argv0 *byte, argv, envv []*byte, chroot, dir *byte, attr *ProcAttr, sys *SysProcAttr, pipe int) (pid int, err Errno) {
-       // Declare all variables at top in case any
-       // declarations require heap allocation (e.g., err1).
-       var (
-               r1 Pid_t
-               err1 Errno
-               nextfd int
-               i int
-       )
-
-       // guard against side effects of shuffling fds below.
-       fd := append([]int(nil), attr.Files...)
-
-       // About to call fork.
-       // No more allocation or calls of non-assembly functions.
-       r1, err1 = raw_fork()
-       if err1 != 0 {
-               return 0, err1
-       }
-
-       if r1 != 0 {
-               // parent; return PID
-               return int(r1), 0
-       }
-
-       // Fork succeeded, now in child.
-
-       // Enable tracing if requested.
-       if sys.Ptrace {
-               err1 = raw_ptrace(_PTRACE_TRACEME, 0, nil, nil)
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // Session ID
-       if sys.Setsid {
-               err1 = raw_setsid()
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // Set process group
-       if sys.Setpgid {
-               err1 = raw_setpgid(0, 0)
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // Chroot
-       if chroot != nil {
-               err1 = raw_chroot(chroot)
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // User and groups
-       if cred := sys.Credential; cred != nil {
-               ngroups := len(cred.Groups)
-               if ngroups == 0 {
-                       err2 := setgroups(0, nil)
-                       if err2 == nil {
-                               err1 = 0
-                       } else {
-                               err1 = err2.(Errno)
-                       }
-               } else {
-                       groups := make([]Gid_t, ngroups)
-                       for i, v := range cred.Groups {
-                               groups[i] = Gid_t(v)
-                       }
-                       err2 := setgroups(ngroups, &groups[0])
-                       if err2 == nil {
-                               err1 = 0
-                       } else {
-                               err1 = err2.(Errno)
-                       }
-               }
-               if err1 != 0 {
-                       goto childerror
-               }
-               err2 := Setgid(int(cred.Gid))
-               if err2 != nil {
-                       err1 = err2.(Errno)
-                       goto childerror
-               }
-               err2 = Setuid(int(cred.Uid))
-               if err2 != nil {
-                       err1 = err2.(Errno)
-                       goto childerror
-               }
-       }
-
-       // Chdir
-       if dir != nil {
-               err1 = raw_chdir(dir)
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // Pass 1: look for fd[i] < i and move those up above len(fd)
-       // so that pass 2 won't stomp on an fd it needs later.
-       nextfd = int(len(fd))
-       if pipe < nextfd {
-               _, err2 := Dup2(pipe, nextfd)
-               if err2 != nil {
-                       err1 = err2.(Errno)
-                       goto childerror
-               }
-               raw_fcntl(nextfd, F_SETFD, FD_CLOEXEC)
-               pipe = nextfd
-               nextfd++
-       }
-       for i = 0; i < len(fd); i++ {
-               if fd[i] >= 0 && fd[i] < int(i) {
-                       _, err2 := Dup2(fd[i], nextfd)
-                       if err2 != nil {
-                               err1 = err2.(Errno)
-                               goto childerror
-                       }
-                       raw_fcntl(nextfd, F_SETFD, FD_CLOEXEC)
-                       fd[i] = nextfd
-                       nextfd++
-                       if nextfd == pipe { // don't stomp on pipe
-                               nextfd++
-                       }
-               }
-       }
-
-       // Pass 2: dup fd[i] down onto i.
-       for i = 0; i < len(fd); i++ {
-               if fd[i] == -1 {
-                       raw_close(i)
-                       continue
-               }
-               if fd[i] == int(i) {
-                       // Dup2(i, i) won't clear close-on-exec flag on
-                       // GNU/Linux, probably not elsewhere either.
-                       _, err1 = raw_fcntl(fd[i], F_SETFD, 0)
-                       if err1 != 0 {
-                               goto childerror
-                       }
-                       continue
-               }
-               // The new fd is created NOT close-on-exec,
-               // which is exactly what we want.
-               _, err2 := Dup2(fd[i], i)
-               if err2 != nil {
-                       err1 = err2.(Errno)
-                       goto childerror
-               }
-       }
-
-       // By convention, we don't close-on-exec the fds we are
-       // started with, so if len(fd) < 3, close 0, 1, 2 as needed.
-       // Programs that know they inherit fds >= 3 will need
-       // to set them close-on-exec.
-       for i = len(fd); i < 3; i++ {
-               raw_close(i)
-       }
-
-       // Detach fd 0 from tty
-       if sys.Noctty {
-               _, err1 = raw_ioctl(0, TIOCNOTTY, 0)
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // Make fd 0 the tty
-       if sys.Setctty {
-               _, err1 = raw_ioctl(0, TIOCSCTTY, 0)
-               if err1 != 0 {
-                       goto childerror
-               }
-       }
-
-       // Time to exec.
-       err1 = raw_execve(argv0, &argv[0], &envv[0])
-
-childerror:
-       // send error code on pipe
-       raw_write(pipe, (*byte)(unsafe.Pointer(&err1)), int(unsafe.Sizeof(err1)))
-       for {
-               raw_exit(253)
-       }
-
-       // Calling panic is not actually safe,
-       // but the for loop above won't break
-       // and this shuts up the compiler.
-       panic("unreached")
-}
-
 // Credential holds user and group identities to be assumed
 // by a child process started by StartProcess.
 type Credential struct {
@@ -348,16 +144,6 @@ type ProcAttr struct {
        Sys   *SysProcAttr
 }
 
-type SysProcAttr struct {
-       Chroot     string      // Chroot.
-       Credential *Credential // Credential.
-       Ptrace     bool        // Enable tracing.
-       Setsid     bool        // Create session.
-       Setpgid    bool        // Set process group ID to new pid (SYSV setpgrp)
-       Setctty    bool        // Set controlling terminal to fd 0
-       Noctty     bool        // Detach fd 0 from controlling terminal
-}
-
 var zeroProcAttr ProcAttr
 var zeroSysProcAttr SysProcAttr
 
@@ -383,7 +169,7 @@ func forkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err error)
        argvp := StringSlicePtr(argv)
        envvp := StringSlicePtr(attr.Env)
 
-       if OS == "freebsd" && len(argv[0]) > len(argv0) {
+       if runtime.GOOS == "freebsd" && len(argv[0]) > len(argv0) {
                argvp[0] = argv0p
        }
 
index 005fd843486d15e351ab6bddc39d8c3293ae8e72..517b5b9408d50735063bc4598cccd5cbd6b5ab78 100644 (file)
@@ -237,8 +237,6 @@ func GetsockoptIPMreq(fd, level, opt int) (*IPMreq, error) {
        return &value, err
 }
 
-/* FIXME: mksysinfo needs to support IPMreqn.
-
 func GetsockoptIPMreqn(fd, level, opt int) (*IPMreqn, error) {
        var value IPMreqn
        vallen := Socklen_t(SizeofIPMreqn)
@@ -246,10 +244,6 @@ func GetsockoptIPMreqn(fd, level, opt int) (*IPMreqn, error) {
        return &value, err
 }
 
-*/
-
-/* FIXME: mksysinfo needs to support IPv6Mreq.
-
 func GetsockoptIPv6Mreq(fd, level, opt int) (*IPv6Mreq, error) {
        var value IPv6Mreq
        vallen := Socklen_t(SizeofIPv6Mreq)
@@ -257,8 +251,6 @@ func GetsockoptIPv6Mreq(fd, level, opt int) (*IPv6Mreq, error) {
        return &value, err
 }
 
-*/
-
 //sys  setsockopt(s int, level int, name int, val *byte, vallen Socklen_t) (err error)
 //setsockopt(s int, level int, optname int, val *byte, vallen Socklen_t) int
 
@@ -288,14 +280,10 @@ func SetsockoptIPMreq(fd, level, opt int, mreq *IPMreq) (err error) {
        return setsockopt(fd, level, opt, (*byte)(unsafe.Pointer(mreq)), Socklen_t(unsafe.Sizeof(*mreq)))
 }
 
-/* FIXME: mksysinfo needs to support IMPreqn.
-
 func SetsockoptIPMreqn(fd, level, opt int, mreq *IPMreqn) (err error) {
        return setsockopt(fd, level, opt, (*byte)(unsafe.Pointer(mreq)), Socklen_t(unsafe.Sizeof(*mreq)))
 }
 
-*/
-
 func SetsockoptIPv6Mreq(fd, level, opt int, mreq *IPv6Mreq) (err error) {
        return setsockopt(fd, level, opt, (*byte)(unsafe.Pointer(mreq)), Socklen_t(unsafe.Sizeof(*mreq)))
 }
index ba109f63ac1b23afe5e7f99e86ca102b70e7c7ca..fb8986ce8497f2439390406441999ada8bed249b 100644 (file)
@@ -7,6 +7,7 @@
 package syscall
 
 import (
+       "runtime"
        "sync"
        "unsafe"
 )
@@ -20,6 +21,8 @@ var (
 func c_syscall32(trap int32, a1, a2, a3, a4, a5, a6 int32) int32 __asm__ ("syscall");
 func c_syscall64(trap int64, a1, a2, a3, a4, a5, a6 int64) int64 __asm__ ("syscall");
 
+const darwinAMD64 = runtime.GOOS == "darwin" && runtime.GOARCH == "amd64"
+
 // Do a system call.  We look at the size of uintptr to see how to pass
 // the arguments, so that we don't pass a 64-bit value when the function
 // expects a 32-bit one.
index 4ce637082ca660d3441ab484b0d02359849b6310..0bf567b7c4d898fa6e159a15e1c198a6a7a12e64 100644 (file)
@@ -142,6 +142,13 @@ func (b *B) run() BenchmarkResult {
 func (b *B) launch() {
        // Run the benchmark for a single iteration in case it's expensive.
        n := 1
+
+       // Signal that we're done whether we return normally
+       // or by FailNow's runtime.Goexit.
+       defer func() {
+               b.signal <- b
+       }()
+
        b.runN(n)
        // Run the benchmark for at least the specified amount of time.
        d := time.Duration(*benchTime * float64(time.Second))
@@ -162,7 +169,6 @@ func (b *B) launch() {
                b.runN(n)
        }
        b.result = BenchmarkResult{b.N, b.duration, b.bytes}
-       b.signal <- b
 }
 
 // The results of a benchmark run.
index 16890e0b3fa2dc37fed36af4a2d411c5c7e572f7..cfe212dc1d78ff16e7fdc03c3b2afb97111cda21 100644 (file)
@@ -63,7 +63,7 @@ var (
        memProfile     = flag.String("test.memprofile", "", "write a memory profile to the named file after execution")
        memProfileRate = flag.Int("test.memprofilerate", 0, "if >=0, sets runtime.MemProfileRate")
        cpuProfile     = flag.String("test.cpuprofile", "", "write a cpu profile to the named file during execution")
-       timeout        = flag.Int64("test.timeout", 0, "if > 0, sets time limit for tests in seconds")
+       timeout        = flag.Duration("test.timeout", 0, "if positive, sets an aggregate time limit for all tests")
        cpuListStr     = flag.String("test.cpu", "", "comma-separated list of number of CPUs to use for each test")
        parallel       = flag.Int("test.parallel", runtime.GOMAXPROCS(0), "maximum test parallelism")
 
@@ -90,7 +90,7 @@ func Short() bool {
 // If addFileLine is true, it also prefixes the string with the file and line of the call site.
 func decorate(s string, addFileLine bool) string {
        if addFileLine {
-               _, file, line, ok := runtime.Caller(4) // decorate + log + public function.
+               _, file, line, ok := runtime.Caller(3) // decorate + log + public function.
                if ok {
                        // Truncate file name at last file name separator.
                        if index := strings.LastIndex(file, "/"); index >= 0 {
@@ -136,9 +136,27 @@ func (c *common) Failed() bool { return c.failed }
 // FailNow marks the function as having failed and stops its execution.
 // Execution will continue at the next Test.
 func (c *common) FailNow() {
-       c.duration = time.Now().Sub(c.start)
        c.Fail()
-       c.signal <- c.self
+
+       // Calling runtime.Goexit will exit the goroutine, which
+       // will run the deferred functions in this goroutine,
+       // which will eventually run the deferred lines in tRunner,
+       // which will signal to the test loop that this test is done.
+       //
+       // A previous version of this code said:
+       //
+       //      c.duration = ...
+       //      c.signal <- c.self
+       //      runtime.Goexit()
+       //
+       // This previous version duplicated code (those lines are in
+       // tRunner no matter what), but worse the goroutine teardown
+       // implicit in runtime.Goexit was not guaranteed to complete
+       // before the test exited.  If a test deferred an important cleanup
+       // function (like removing temporary files), there was no guarantee
+       // it would run on a test failure.  Because we send on c.signal during
+       // a top-of-stack deferred function now, we know that the send
+       // only happens after any other stacked defers have completed.
        runtime.Goexit()
 }
 
@@ -195,9 +213,17 @@ type InternalTest struct {
 
 func tRunner(t *T, test *InternalTest) {
        t.start = time.Now()
+
+       // When this goroutine is done, either because test.F(t)
+       // returned normally or because a test failure triggered 
+       // a call to runtime.Goexit, record the duration and send
+       // a signal saying that the test is done.
+       defer func() {
+               t.duration = time.Now().Sub(t.start)
+               t.signal <- t
+       }()
+
        test.F(t)
-       t.duration = time.Now().Sub(t.start)
-       t.signal <- t
 }
 
 // An internal function but exported because it is cross-package; part of the implementation
@@ -346,7 +372,7 @@ var timer *time.Timer
 // startAlarm starts an alarm if requested.
 func startAlarm() {
        if *timeout > 0 {
-               timer = time.AfterFunc(time.Duration(*timeout)*time.Second, alarm)
+               timer = time.AfterFunc(*timeout, alarm)
        }
 }
 
diff --git a/libgo/go/testing/wrapper.go b/libgo/go/testing/wrapper.go
deleted file mode 100644 (file)
index 2bef9df..0000000
+++ /dev/null
@@ -1,105 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// This file contains wrappers so t.Errorf etc. have documentation.
-// TODO: delete when godoc shows exported methods for unexported embedded fields.
-// TODO: need to change the argument to runtime.Caller in testing.go from 4 to 3 at that point.
-
-package testing
-
-// Fail marks the function as having failed but continues execution.
-func (b *B) Fail() {
-       b.common.Fail()
-}
-
-// Failed returns whether the function has failed.
-func (b *B) Failed() bool {
-       return b.common.Failed()
-}
-
-// FailNow marks the function as having failed and stops its execution.
-// Execution will continue at the next Test.
-func (b *B) FailNow() {
-       b.common.FailNow()
-}
-
-// Log formats its arguments using default formatting, analogous to Println(),
-// and records the text in the error log.
-func (b *B) Log(args ...interface{}) {
-       b.common.Log(args...)
-}
-
-// Logf formats its arguments according to the format, analogous to Printf(),
-// and records the text in the error log.
-func (b *B) Logf(format string, args ...interface{}) {
-       b.common.Logf(format, args...)
-}
-
-// Error is equivalent to Log() followed by Fail().
-func (b *B) Error(args ...interface{}) {
-       b.common.Error(args...)
-}
-
-// Errorf is equivalent to Logf() followed by Fail().
-func (b *B) Errorf(format string, args ...interface{}) {
-       b.common.Errorf(format, args...)
-}
-
-// Fatal is equivalent to Log() followed by FailNow().
-func (b *B) Fatal(args ...interface{}) {
-       b.common.Fatal(args...)
-}
-
-// Fatalf is equivalent to Logf() followed by FailNow().
-func (b *B) Fatalf(format string, args ...interface{}) {
-       b.common.Fatalf(format, args...)
-}
-
-// Fail marks the function as having failed but continues execution.
-func (t *T) Fail() {
-       t.common.Fail()
-}
-
-// Failed returns whether the function has failed.
-func (t *T) Failed() bool {
-       return t.common.Failed()
-}
-
-// FailNow marks the function as having failed and stops its execution.
-// Execution will continue at the next Test.
-func (t *T) FailNow() {
-       t.common.FailNow()
-}
-
-// Log formats its arguments using default formatting, analogous to Println(),
-// and records the text in the error log.
-func (t *T) Log(args ...interface{}) {
-       t.common.Log(args...)
-}
-
-// Logf formats its arguments according to the format, analogous to Printf(),
-// and records the text in the error log.
-func (t *T) Logf(format string, args ...interface{}) {
-       t.common.Logf(format, args...)
-}
-
-// Error is equivalent to Log() followed by Fail().
-func (t *T) Error(args ...interface{}) {
-       t.common.Error(args...)
-}
-
-// Errorf is equivalent to Logf() followed by Fail().
-func (t *T) Errorf(format string, args ...interface{}) {
-       t.common.Errorf(format, args...)
-}
-
-// Fatal is equivalent to Log() followed by FailNow().
-func (t *T) Fatal(args ...interface{}) {
-       t.common.Fatal(args...)
-}
-
-// Fatalf is equivalent to Logf() followed by FailNow().
-func (t *T) Fatalf(format string, args ...interface{}) {
-       t.common.Fatalf(format, args...)
-}
index 4208d53a0a47eabba94ea87820cb14db9f6767d1..3be1ec44e697102f6d4dc799c1b3552aff3469fd 100644 (file)
@@ -50,7 +50,9 @@ data, defined in detail below.
                The value of the pipeline must be an array, slice, or map. If
                the value of the pipeline has length zero, nothing is output;
                otherwise, dot is set to the successive elements of the array,
-               slice, or map and T1 is executed.
+               slice, or map and T1 is executed. If the value is a map and the
+               keys are of basic type with a defined order ("comparable"), the
+               elements will be visited in sorted key order.
 
        {{range pipeline}} T1 {{else}} T0 {{end}}
                The value of the pipeline must be an array, slice, or map. If
index acb88afee3684afa1e1bfb4222277ac7664d2285..973189a8a62c97de524bab028c4e730f0af023dd 100644 (file)
@@ -9,6 +9,7 @@ import (
        "io"
        "reflect"
        "runtime"
+       "sort"
        "strings"
        "text/template/parse"
 )
@@ -78,10 +79,14 @@ func (s *state) error(err error) {
 func errRecover(errp *error) {
        e := recover()
        if e != nil {
-               if _, ok := e.(runtime.Error); ok {
+               switch err := e.(type) {
+               case runtime.Error:
+                       panic(e)
+               case error:
+                       *errp = err
+               default:
                        panic(e)
                }
-               *errp = e.(error)
        }
 }
 
@@ -230,7 +235,7 @@ func (s *state) walkRange(dot reflect.Value, r *parse.RangeNode) {
                if val.Len() == 0 {
                        break
                }
-               for _, key := range val.MapKeys() {
+               for _, key := range sortKeys(val.MapKeys()) {
                        oneIteration(key, val.MapIndex(key))
                }
                return
@@ -672,3 +677,44 @@ func (s *state) printValue(n parse.Node, v reflect.Value) {
        }
        fmt.Fprint(s.wr, v.Interface())
 }
+
+// Types to help sort the keys in a map for reproducible output.
+
+type rvs []reflect.Value
+
+func (x rvs) Len() int      { return len(x) }
+func (x rvs) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+
+type rvInts struct{ rvs }
+
+func (x rvInts) Less(i, j int) bool { return x.rvs[i].Int() < x.rvs[j].Int() }
+
+type rvUints struct{ rvs }
+
+func (x rvUints) Less(i, j int) bool { return x.rvs[i].Uint() < x.rvs[j].Uint() }
+
+type rvFloats struct{ rvs }
+
+func (x rvFloats) Less(i, j int) bool { return x.rvs[i].Float() < x.rvs[j].Float() }
+
+type rvStrings struct{ rvs }
+
+func (x rvStrings) Less(i, j int) bool { return x.rvs[i].String() < x.rvs[j].String() }
+
+// sortKeys sorts (if it can) the slice of reflect.Values, which is a slice of map keys.
+func sortKeys(v []reflect.Value) []reflect.Value {
+       if len(v) <= 1 {
+               return v
+       }
+       switch v[0].Kind() {
+       case reflect.Float32, reflect.Float64:
+               sort.Sort(rvFloats{v})
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               sort.Sort(rvInts{v})
+       case reflect.String:
+               sort.Sort(rvStrings{v})
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+               sort.Sort(rvUints{v})
+       }
+       return v
+}
index e33988b86c0930a1b60971c3ecbe596587f363af..2070cefde73475d567fc307e0c4b3582876bdde6 100644 (file)
@@ -11,7 +11,6 @@ import (
        "fmt"
        "os"
        "reflect"
-       "sort"
        "strings"
        "testing"
 )
@@ -169,18 +168,6 @@ func (t *T) MAdd(a int, b []int) []int {
        return v
 }
 
-// MSort is used to sort map keys for stable output. (Nice trick!)
-func (t *T) MSort(m map[string]int) []string {
-       keys := make([]string, len(m))
-       i := 0
-       for k := range m {
-               keys[i] = k
-               i++
-       }
-       sort.Strings(keys)
-       return keys
-}
-
 // EPERM returns a value and an error according to its argument.
 func (t *T) EPERM(error bool) (bool, error) {
        if error {
@@ -410,9 +397,9 @@ var execTests = []execTest{
        {"range empty else", "{{range .SIEmpty}}-{{.}}-{{else}}EMPTY{{end}}", "EMPTY", tVal, true},
        {"range []bool", "{{range .SB}}-{{.}}-{{end}}", "-true--false-", tVal, true},
        {"range []int method", "{{range .SI | .MAdd .I}}-{{.}}-{{end}}", "-20--21--22-", tVal, true},
-       {"range map", "{{range .MSI | .MSort}}-{{.}}-{{end}}", "-one--three--two-", tVal, true},
+       {"range map", "{{range .MSI}}-{{.}}-{{end}}", "-1--3--2-", tVal, true},
        {"range empty map no else", "{{range .MSIEmpty}}-{{.}}-{{end}}", "", tVal, true},
-       {"range map else", "{{range .MSI | .MSort}}-{{.}}-{{else}}EMPTY{{end}}", "-one--three--two-", tVal, true},
+       {"range map else", "{{range .MSI}}-{{.}}-{{else}}EMPTY{{end}}", "-1--3--2-", tVal, true},
        {"range empty map else", "{{range .MSIEmpty}}-{{.}}-{{else}}EMPTY{{end}}", "EMPTY", tVal, true},
        {"range empty interface", "{{range .Empty3}}-{{.}}-{{else}}EMPTY{{end}}", "-7--8-", tVal, true},
        {"range empty nil", "{{range .Empty0}}-{{.}}-{{end}}", "", tVal, true},
index 082a51a1621ddac36123841b8c5e7e40eb8b6249..bd02b4867201c6ea22b6058cb2ca61a4174300a2 100644 (file)
@@ -283,25 +283,16 @@ var atoiError = errors.New("time: invalid number")
 
 // Duplicates functionality in strconv, but avoids dependency.
 func atoi(s string) (x int, err error) {
-       i := 0
-       if len(s) > 0 && s[0] == '-' {
-               i++
+       neg := false
+       if s != "" && s[0] == '-' {
+               neg = true
+               s = s[1:]
        }
-       if i >= len(s) {
+       x, rem, err := leadingInt(s)
+       if err != nil || rem != "" {
                return 0, atoiError
        }
-       for ; i < len(s); i++ {
-               c := s[i]
-               if c < '0' || c > '9' {
-                       return 0, atoiError
-               }
-               if x >= (1<<31-10)/10 {
-                       // will overflow
-                       return 0, atoiError
-               }
-               x = x*10 + int(c) - '0'
-       }
-       if s[0] == '-' {
+       if neg {
                x = -x
        }
        return x, nil
@@ -344,10 +335,6 @@ func (b *buffer) WriteString(s string) {
        *b = append(*b, s...)
 }
 
-func (b *buffer) WriteByte(c byte) {
-       *b = append(*b, c)
-}
-
 func (b *buffer) String() string {
        return string([]byte(*b))
 }
@@ -893,3 +880,126 @@ func parseNanoseconds(value string, nbytes int) (ns int, rangeErrString string,
        }
        return
 }
+
+var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
+
+// leadingInt consumes the leading [0-9]* from s.
+func leadingInt(s string) (x int, rem string, err error) {
+       i := 0
+       for ; i < len(s); i++ {
+               c := s[i]
+               if c < '0' || c > '9' {
+                       break
+               }
+               if x >= (1<<31-10)/10 {
+                       // overflow
+                       return 0, "", errLeadingInt
+               }
+               x = x*10 + int(c) - '0'
+       }
+       return x, s[i:], nil
+}
+
+var unitMap = map[string]float64{
+       "ns": float64(Nanosecond),
+       "us": float64(Microsecond),
+       "µs": float64(Microsecond), // U+00B5 = micro symbol
+       "μs": float64(Microsecond), // U+03BC = Greek letter mu
+       "ms": float64(Millisecond),
+       "s":  float64(Second),
+       "m":  float64(Minute),
+       "h":  float64(Hour),
+}
+
+// ParseDuration parses a duration string.
+// A duration string is a possibly signed sequence of
+// decimal numbers, each with optional fraction and a unit suffix,
+// such as "300ms", "-1.5h" or "2h45m".
+// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
+func ParseDuration(s string) (Duration, error) {
+       // [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
+       orig := s
+       f := float64(0)
+       neg := false
+
+       // Consume [-+]?
+       if s != "" {
+               c := s[0]
+               if c == '-' || c == '+' {
+                       neg = c == '-'
+                       s = s[1:]
+               }
+       }
+       // Special case: if all that is left is "0", this is zero.
+       if s == "0" {
+               return 0, nil
+       }
+       if s == "" {
+               return 0, errors.New("time: invalid duration " + orig)
+       }
+       for s != "" {
+               g := float64(0) // this element of the sequence
+
+               var x int
+               var err error
+
+               // The next character must be [0-9.]
+               if !(s[0] == '.' || ('0' <= s[0] && s[0] <= '9')) {
+                       return 0, errors.New("time: invalid duration " + orig)
+               }
+               // Consume [0-9]*
+               pl := len(s)
+               x, s, err = leadingInt(s)
+               if err != nil {
+                       return 0, errors.New("time: invalid duration " + orig)
+               }
+               g = float64(x)
+               pre := pl != len(s) // whether we consumed anything before a period
+
+               // Consume (\.[0-9]*)?
+               post := false
+               if s != "" && s[0] == '.' {
+                       s = s[1:]
+                       pl := len(s)
+                       x, s, err = leadingInt(s)
+                       if err != nil {
+                               return 0, errors.New("time: invalid duration " + orig)
+                       }
+                       scale := 1
+                       for n := pl - len(s); n > 0; n-- {
+                               scale *= 10
+                       }
+                       g += float64(x) / float64(scale)
+                       post = pl != len(s)
+               }
+               if !pre && !post {
+                       // no digits (e.g. ".s" or "-.s")
+                       return 0, errors.New("time: invalid duration " + orig)
+               }
+
+               // Consume unit.
+               i := 0
+               for ; i < len(s); i++ {
+                       c := s[i]
+                       if c == '.' || ('0' <= c && c <= '9') {
+                               break
+                       }
+               }
+               if i == 0 {
+                       return 0, errors.New("time: missing unit in duration " + orig)
+               }
+               u := s[:i]
+               s = s[i:]
+               unit, ok := unitMap[u]
+               if !ok {
+                       return 0, errors.New("time: unknown unit " + u + " in duration " + orig)
+               }
+
+               f += g * unit
+       }
+
+       if neg {
+               f = -f
+       }
+       return Duration(f), nil
+}
index b4680db2387e9dd008326949e15ad5aa4511d3eb..27820b0eaa76ffbb258c5429d957e03b2604929a 100644 (file)
@@ -41,7 +41,7 @@ func (t *Timer) Stop() (ok bool) {
 }
 
 // NewTimer creates a new Timer that will send
-// the current time on its channel after at least ns nanoseconds.
+// the current time on its channel after at least duration d.
 func NewTimer(d Duration) *Timer {
        c := make(chan Time, 1)
        t := &Timer{
@@ -70,7 +70,7 @@ func sendTime(now int64, c interface{}) {
 
 // After waits for the duration to elapse and then sends the current time
 // on the returned channel.
-// It is equivalent to NewTimer(ns).C.
+// It is equivalent to NewTimer(d).C.
 func After(d Duration) <-chan Time {
        return NewTimer(d).C
 }
index 4440c2207b33715ce54d087051b10ba5547c6121..8c6b9bc3b2a9a84aedb03d8015f5b51983194744 100644 (file)
@@ -14,7 +14,7 @@ type Ticker struct {
 }
 
 // NewTicker returns a new Ticker containing a channel that will send the
-// time, in nanoseconds, with a period specified by the duration argument.
+// time with a period specified by the duration argument.
 // It adjusts the intervals or drops ticks to make up for slow receivers.
 // The duration d must be greater than zero; if not, NewTicker will panic.
 func NewTicker(d Duration) *Ticker {
index 484ae4266a31e3d0a190cdde724f327350c28f29..cdc1c39c5f535e62380d35a167fd8e8c88b5fcc9 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bytes"
        "encoding/gob"
        "encoding/json"
+       "math/rand"
        "strconv"
        "strings"
        "testing"
@@ -816,6 +817,82 @@ func TestNotJSONEncodableTime(t *testing.T) {
        }
 }
 
+var parseDurationTests = []struct {
+       in   string
+       ok   bool
+       want Duration
+}{
+       // simple
+       {"0", true, 0},
+       {"5s", true, 5 * Second},
+       {"30s", true, 30 * Second},
+       {"1478s", true, 1478 * Second},
+       // sign
+       {"-5s", true, -5 * Second},
+       {"+5s", true, 5 * Second},
+       {"-0", true, 0},
+       {"+0", true, 0},
+       // decimal
+       {"5.0s", true, 5 * Second},
+       {"5.6s", true, 5*Second + 600*Millisecond},
+       {"5.s", true, 5 * Second},
+       {".5s", true, 500 * Millisecond},
+       {"1.0s", true, 1 * Second},
+       {"1.00s", true, 1 * Second},
+       {"1.004s", true, 1*Second + 4*Millisecond},
+       {"1.0040s", true, 1*Second + 4*Millisecond},
+       {"100.00100s", true, 100*Second + 1*Millisecond},
+       // different units
+       {"10ns", true, 10 * Nanosecond},
+       {"11us", true, 11 * Microsecond},
+       {"12µs", true, 12 * Microsecond}, // U+00B5
+       {"12μs", true, 12 * Microsecond}, // U+03BC
+       {"13ms", true, 13 * Millisecond},
+       {"14s", true, 14 * Second},
+       {"15m", true, 15 * Minute},
+       {"16h", true, 16 * Hour},
+       // composite durations
+       {"3h30m", true, 3*Hour + 30*Minute},
+       {"10.5s4m", true, 4*Minute + 10*Second + 500*Millisecond},
+       {"-2m3.4s", true, -(2*Minute + 3*Second + 400*Millisecond)},
+       {"1h2m3s4ms5us6ns", true, 1*Hour + 2*Minute + 3*Second + 4*Millisecond + 5*Microsecond + 6*Nanosecond},
+       {"39h9m14.425s", true, 39*Hour + 9*Minute + 14*Second + 425*Millisecond},
+
+       // errors
+       {"", false, 0},
+       {"3", false, 0},
+       {"-", false, 0},
+       {"s", false, 0},
+       {".", false, 0},
+       {"-.", false, 0},
+       {".s", false, 0},
+       {"+.s", false, 0},
+}
+
+func TestParseDuration(t *testing.T) {
+       for _, tc := range parseDurationTests {
+               d, err := ParseDuration(tc.in)
+               if tc.ok && (err != nil || d != tc.want) {
+                       t.Errorf("ParseDuration(%q) = %v, %v, want %v, nil", tc.in, d, err, tc.want)
+               } else if !tc.ok && err == nil {
+                       t.Errorf("ParseDuration(%q) = _, nil, want _, non-nil", tc.in)
+               }
+       }
+}
+
+func TestParseDurationRoundTrip(t *testing.T) {
+       for i := 0; i < 100; i++ {
+               // Resolutions finer than milliseconds will result in
+               // imprecise round-trips.
+               d0 := Duration(rand.Int31()) * Millisecond
+               s := d0.String()
+               d1, err := ParseDuration(s)
+               if err != nil || d0 != d1 {
+                       t.Errorf("round-trip failed: %d => %q => %d, %v", d0, s, d1, err)
+               }
+       }
+}
+
 func BenchmarkNow(b *testing.B) {
        for i := 0; i < b.N; i++ {
                Now()
index a8c70f07c6f13b88a6f8f417f305763404fe6ff9..80474566eeb5d218508bfa0788c729776f31b143 100755 (executable)
@@ -55,6 +55,9 @@ cat > sysinfo.c <<EOF
 #if defined(HAVE_SYS_MMAN_H)
 #include <sys/mman.h>
 #endif
+#if defined(HAVE_SYS_PRCTL_H)
+#include <sys/prctl.h>
+#endif
 #if defined(HAVE_SYS_PTRACE_H)
 #include <sys/ptrace.h>
 #endif
@@ -210,6 +213,10 @@ if ! grep '^const EPOLL_CLOEXEC' ${OUT} >/dev/null 2>&1; then
   echo "const EPOLL_CLOEXEC = 02000000" >> ${OUT}
 fi
 
+# Prctl constants.
+grep '^const _PR_' gen-sysinfo.go |
+  sed -e 's/^\(const \)_\(PR_[^= ]*\)\(.*\)$/\1\2 = _\2/' >> ${OUT}
+
 # Ptrace constants.
 grep '^const _PTRACE' gen-sysinfo.go |
   sed -e 's/^\(const \)_\(PTRACE[^= ]*\)\(.*\)$/\1\2 = _\2/' >> ${OUT}
@@ -505,6 +512,26 @@ if ! grep 'type IPv6Mreq ' ${OUT} >/dev/null 2>&1; then
   echo 'type IPv6Mreq struct { Multiaddr [16]byte; Interface uint32; }' >> ${OUT}
 fi
 
+# The size of the ipv6_mreq struct.
+echo 'var SizeofIPv6Mreq = int(unsafe.Sizeof(IPv6Mreq{}))' >> ${OUT}
+
+# The ip_mreqn struct.
+grep '^type _ip_mreqn ' gen-sysinfo.go | \
+    sed -e 's/_ip_mreqn/IPMreqn/' \
+      -e 's/imr_multiaddr/Multiaddr/' \
+      -e 's/imr_address/Address/' \
+      -e 's/imr_ifindex/Ifindex/' \
+      -e 's/_in_addr/[4]byte/g' \
+    >> ${OUT}
+
+# We need IPMreq to compile the net package.
+if ! grep 'type IPMreqn ' ${OUT} >/dev/null 2>&1; then
+  echo 'type IPMreqn struct { Multiaddr [4]byte; Interface [4]byte; Ifindex int32 }' >> ${OUT}
+fi
+
+# The size of the ip_mreqn struct.
+echo 'var SizeofIPMreqn = int(unsafe.Sizeof(IPMreqn{}))' >> ${OUT}
+
 # Try to guess the type to use for fd_set.
 fd_set=`grep '^type _fd_set ' gen-sysinfo.go || true`
 fds_bits_type="_C_long"
index f5321856eacd947ea27f3e12ed9dfb302aa55578..9ad9eda8350e0454f68e4a1eaf6b54614b3c0e47 100644 (file)
@@ -129,7 +129,7 @@ __go_free(void *v)
        if(v == nil)
                return;
        
-       // If you change this also change mgc0.c:/^sweepspan,
+       // If you change this also change mgc0.c:/^sweep,
        // which has a copy of the guts of free.
 
        m = runtime_m();
index da0c0f85766ae83fa084bf8dbc93f25f68d25a0e..aa7d9ff3ae256ac2c00d76e1ee6a7f9fc5ed980b 100644 (file)
@@ -123,10 +123,9 @@ enum
 
        // Max number of threads to run garbage collection.
        // 2, 3, and 4 are all plausible maximums depending
-       // on the hardware details of the machine.  The second
-       // proc is the one that helps the most (after the first),
-       // so start with just 2 for now.
-       MaxGcproc = 2,
+       // on the hardware details of the machine.  The garbage
+       // collector scales well to 4 cpus.
+       MaxGcproc = 4,
 };
 
 // A generic linked list of blocks.  (Typically the block is bigger than sizeof(MLink).)
index c4ab1454c5bf8ec5e223220d682b3d1de1c2e7d8..26633ab1f1879224d04d198cc3395c5f87073a0b 100644 (file)
@@ -62,9 +62,6 @@ enum {
 #define bitMask (bitBlockBoundary | bitAllocated | bitMarked | bitSpecial)
 
 // TODO: Make these per-M.
-static uint64 nlookup;
-static uint64 nsizelookup;
-static uint64 naddrlookup;
 static uint64 nhandoff;
 
 static int32 gctrace;
@@ -218,8 +215,6 @@ scanblock(byte *b, int64 n)
 
                        // Otherwise consult span table to find beginning.
                        // (Manually inlined copy of MHeap_LookupMaybe.)
-                       nlookup++;
-                       naddrlookup++;
                        k = (uintptr)obj>>PageShift;
                        x = k;
                        if(sizeof(void*) == 8)
@@ -307,49 +302,8 @@ scanblock(byte *b, int64 n)
                b = *--wp;
                nobj--;
 
-               // Figure out n = size of b.  Start by loading bits for b.
-               off = (uintptr*)b - (uintptr*)arena_start;
-               bitp = (uintptr*)arena_start - off/wordsPerBitmapWord - 1;
-               shift = off % wordsPerBitmapWord;
-               xbits = *bitp;
-               bits = xbits >> shift;
-
-               // Might be small; look for nearby block boundary.
-               // A block boundary is marked by either bitBlockBoundary
-               // or bitAllocated being set (see notes near their definition).
-               enum {
-                       boundary = bitBlockBoundary|bitAllocated
-               };
-               // Look for a block boundary both after and before b
-               // in the same bitmap word.
-               //
-               // A block boundary j words after b is indicated by
-               //      bits>>j & boundary
-               // assuming shift+j < bitShift.  (If shift+j >= bitShift then
-               // we'll be bleeding other bit types like bitMarked into our test.)
-               // Instead of inserting the conditional shift+j < bitShift into the loop,
-               // we can let j range from 1 to bitShift as long as we first
-               // apply a mask to keep only the bits corresponding
-               // to shift+j < bitShift aka j < bitShift-shift.
-               bits &= (boundary<<(bitShift-shift)) - boundary;
-
-               // A block boundary j words before b is indicated by
-               //      xbits>>(shift-j) & boundary
-               // (assuming shift >= j).  There is no cleverness here
-               // avoid the test, because when j gets too large the shift
-               // turns negative, which is undefined in C.
-
-               for(j=1; j<bitShift; j++) {
-                       if(((bits>>j)&boundary) != 0 || (shift>=j && ((xbits>>(shift-j))&boundary) != 0)) {
-                               n = j*PtrSize;
-                               goto scan;
-                       }
-               }
-
-               // Fall back to asking span about size class.
+               // Ask span about size class.
                // (Manually inlined copy of MHeap_Lookup.)
-               nlookup++;
-               nsizelookup++;
                x = (uintptr)b>>PageShift;
                if(sizeof(void*) == 8)
                        x -= (uintptr)arena_start>>PageShift;
@@ -358,7 +312,6 @@ scanblock(byte *b, int64 n)
                        n = s->npages<<PageShift;
                else
                        n = runtime_class_to_size[s->sizeclass];
-       scan:;
        }
 }
 
@@ -1018,9 +971,6 @@ runtime_gc(int32 force)
        }
 
        t0 = runtime_nanotime();
-       nlookup = 0;
-       nsizelookup = 0;
-       naddrlookup = 0;
        nhandoff = 0;
 
        m->gcing = 1;
@@ -1085,11 +1035,11 @@ runtime_gc(int32 force)
                runtime_printf("pause %llu\n", (unsigned long long)t3-t0);
 
        if(gctrace) {
-               runtime_printf("gc%d: %llu+%llu+%llu ms %llu -> %llu MB %llu -> %llu (%llu-%llu) objects %llu pointer lookups (%llu size, %llu addr) %llu handoff\n",
-                       mstats.numgc, (unsigned long long)(t1-t0)/1000000, (unsigned long long)(t2-t1)/1000000, (unsigned long long)(t3-t2)/1000000,
+               runtime_printf("gc%d(%d): %llu+%llu+%llu ms %llu -> %llu MB %llu -> %llu (%llu-%llu) objects %llu handoff\n",
+                       mstats.numgc, work.nproc, (unsigned long long)(t1-t0)/1000000, (unsigned long long)(t2-t1)/1000000, (unsigned long long)(t3-t2)/1000000,
                        (unsigned long long)heap0>>20, (unsigned long long)heap1>>20, (unsigned long long)obj0, (unsigned long long)obj1,
-                       (unsigned long long)mstats.nmalloc, (unsigned long long)mstats.nfree,
-                       (unsigned long long)nlookup, (unsigned long long)nsizelookup, (unsigned long long)naddrlookup, (unsigned long long) nhandoff);
+                       (unsigned long long) mstats.nmalloc, (unsigned long long)mstats.nfree,
+                       (unsigned long long) nhandoff);
        }
 
        runtime_semrelease(&gcsema);
index ec96f5b615f711247bf2107cb52e8b3dc8df828c..922fa20448d2acc20b9f0fdccb0d80e032479df1 100644 (file)
@@ -115,7 +115,7 @@ runtime_goargs(void)
 }
 
 void
-runtime_goenvs(void)
+runtime_goenvs_unix(void)
 {
        String *s;
        int32 i, n;
@@ -183,3 +183,22 @@ runtime_fastrand1(void)
        m->fastrand = x;
        return x;
 }
+
+struct funcline_go_return
+{
+  String retfile;
+  int32 retline;
+};
+
+struct funcline_go_return
+runtime_funcline_go(void *f, uintptr targetpc)
+  __asm__("libgo_runtime.runtime.funcline_go");
+
+struct funcline_go_return
+runtime_funcline_go(void *f __attribute__((unused)),
+                   uintptr targetpc __attribute__((unused)))
+{
+  struct funcline_go_return ret;
+  runtime_memclr(&ret, sizeof ret);
+  return ret;
+}
index 94113b8db8397fd497e8dcecf94748d26f2f4086..253c49b21f376773392e42868c798c7970477dc6 100644 (file)
@@ -266,6 +266,7 @@ void        runtime_args(int32, byte**);
 void   runtime_osinit();
 void   runtime_goargs(void);
 void   runtime_goenvs(void);
+void   runtime_goenvs_unix(void);
 void   runtime_throw(const char*) __attribute__ ((noreturn));
 void   runtime_panicstring(const char*) __attribute__ ((noreturn));
 void*  runtime_mal(uintptr);
index 4cd98041717f3137352e74a42a6ec9d4ec5c000c..fd8918ed57733d3ba4e154d0cf8d7327f861da4b 100644 (file)
@@ -8,3 +8,7 @@ package runtime
 func GOMAXPROCS(n int32) (ret int32) {
        ret = runtime_gomaxprocsfunc(n);
 }
+
+func NumCPU() (ret int32) {
+       ret = runtime_ncpu;
+}
index a0ee36006500c0ddc1739fff603f613dca02c443..8dd5fc4b481150a62909a86ba3c452fc9284f8b2 100644 (file)
@@ -103,3 +103,9 @@ runtime_osinit(void)
 {
        runtime_ncpu = getproccount();
 }
+
+void
+runtime_goenvs(void)
+{
+       runtime_goenvs_unix();
+}
index b414b160ed13fe7b9c5a4bdb954a8a1443490f4a..3511c52a0716c316b6a14090bd7e2bf01da02778 100755 (executable)
@@ -273,8 +273,18 @@ esac
 
 # Split $gofiles into external gofiles (those in *_test packages)
 # and internal ones (those in the main package).
-xgofiles=$(echo $(grep '^package[      ]' $gofiles /dev/null | grep ':.*_test' | sed 's/:.*//'))
-gofiles=$(echo $(grep '^package[       ]' $gofiles /dev/null | grep -v ':.*_test' | sed 's/:.*//'))
+for f in $gofiles; do
+    package=`grep '^package[   ]' $f | sed 1q`
+    case "$package" in
+    *_test)
+       xgofiles="$xgofiles $f"
+       ;;
+    *)
+       ngofiles="$ngofiles $f"
+       ;;
+    esac
+done
+gofiles=$ngofiles
 
 # External $O file
 xofile=""
@@ -413,9 +423,9 @@ xno)
        ${GL} *.o ${GOLIBS}
 
        if test "$trace" = "true"; then
-           echo ./a.out -test.short -test.timeout=$timeout "$@"
+           echo ./a.out -test.short -test.timeout=${timeout}s "$@"
        fi
-       ./a.out -test.short -test.timeout=$timeout "$@" &
+       ./a.out -test.short -test.timeout=${timeout}s "$@" &
        pid=$!
        (sleep `expr $timeout + 10`
            echo > gotest-timeout
This page took 0.692356 seconds and 5 git commands to generate.