...

Source file src/net/http/httptest/server.go

     1	// Copyright 2011 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// Implementation of Server
     6	
     7	package httptest
     8	
     9	import (
    10		"crypto/tls"
    11		"crypto/x509"
    12		"flag"
    13		"fmt"
    14		"log"
    15		"net"
    16		"net/http"
    17		"net/http/internal"
    18		"os"
    19		"strings"
    20		"sync"
    21		"time"
    22	)
    23	
    24	// A Server is an HTTP server listening on a system-chosen port on the
    25	// local loopback interface, for use in end-to-end HTTP tests.
    26	type Server struct {
    27		URL      string // base URL of form http://ipaddr:port with no trailing slash
    28		Listener net.Listener
    29	
    30		// TLS is the optional TLS configuration, populated with a new config
    31		// after TLS is started. If set on an unstarted server before StartTLS
    32		// is called, existing fields are copied into the new config.
    33		TLS *tls.Config
    34	
    35		// Config may be changed after calling NewUnstartedServer and
    36		// before Start or StartTLS.
    37		Config *http.Server
    38	
    39		// certificate is a parsed version of the TLS config certificate, if present.
    40		certificate *x509.Certificate
    41	
    42		// wg counts the number of outstanding HTTP requests on this server.
    43		// Close blocks until all requests are finished.
    44		wg sync.WaitGroup
    45	
    46		mu     sync.Mutex // guards closed and conns
    47		closed bool
    48		conns  map[net.Conn]http.ConnState // except terminal states
    49	
    50		// client is configured for use with the server.
    51		// Its transport is automatically closed when Close is called.
    52		client *http.Client
    53	}
    54	
    55	func newLocalListener() net.Listener {
    56		if serveFlag != "" {
    57			l, err := net.Listen("tcp", serveFlag)
    58			if err != nil {
    59				panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
    60			}
    61			return l
    62		}
    63		l, err := net.Listen("tcp", "127.0.0.1:0")
    64		if err != nil {
    65			if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
    66				panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
    67			}
    68		}
    69		return l
    70	}
    71	
    72	// When debugging a particular http server-based test,
    73	// this flag lets you run
    74	//	go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
    75	// to start the broken server so you can interact with it manually.
    76	// We only register this flag if it looks like the caller knows about it
    77	// and is trying to use it as we don't want to pollute flags and this
    78	// isn't really part of our API. Don't depend on this.
    79	var serveFlag string
    80	
    81	func init() {
    82		if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
    83			flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
    84		}
    85	}
    86	
    87	func strSliceContainsPrefix(v []string, pre string) bool {
    88		for _, s := range v {
    89			if strings.HasPrefix(s, pre) {
    90				return true
    91			}
    92		}
    93		return false
    94	}
    95	
    96	// NewServer starts and returns a new Server.
    97	// The caller should call Close when finished, to shut it down.
    98	func NewServer(handler http.Handler) *Server {
    99		ts := NewUnstartedServer(handler)
   100		ts.Start()
   101		return ts
   102	}
   103	
   104	// NewUnstartedServer returns a new Server but doesn't start it.
   105	//
   106	// After changing its configuration, the caller should call Start or
   107	// StartTLS.
   108	//
   109	// The caller should call Close when finished, to shut it down.
   110	func NewUnstartedServer(handler http.Handler) *Server {
   111		return &Server{
   112			Listener: newLocalListener(),
   113			Config:   &http.Server{Handler: handler},
   114		}
   115	}
   116	
   117	// Start starts a server from NewUnstartedServer.
   118	func (s *Server) Start() {
   119		if s.URL != "" {
   120			panic("Server already started")
   121		}
   122		if s.client == nil {
   123			s.client = &http.Client{Transport: &http.Transport{}}
   124		}
   125		s.URL = "http://" + s.Listener.Addr().String()
   126		s.wrap()
   127		s.goServe()
   128		if serveFlag != "" {
   129			fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
   130			select {}
   131		}
   132	}
   133	
   134	// StartTLS starts TLS on a server from NewUnstartedServer.
   135	func (s *Server) StartTLS() {
   136		if s.URL != "" {
   137			panic("Server already started")
   138		}
   139		if s.client == nil {
   140			s.client = &http.Client{Transport: &http.Transport{}}
   141		}
   142		cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
   143		if err != nil {
   144			panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
   145		}
   146	
   147		existingConfig := s.TLS
   148		if existingConfig != nil {
   149			s.TLS = existingConfig.Clone()
   150		} else {
   151			s.TLS = new(tls.Config)
   152		}
   153		if s.TLS.NextProtos == nil {
   154			s.TLS.NextProtos = []string{"http/1.1"}
   155		}
   156		if len(s.TLS.Certificates) == 0 {
   157			s.TLS.Certificates = []tls.Certificate{cert}
   158		}
   159		s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
   160		if err != nil {
   161			panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
   162		}
   163		certpool := x509.NewCertPool()
   164		certpool.AddCert(s.certificate)
   165		s.client.Transport = &http.Transport{
   166			TLSClientConfig: &tls.Config{
   167				RootCAs: certpool,
   168			},
   169		}
   170		s.Listener = tls.NewListener(s.Listener, s.TLS)
   171		s.URL = "https://" + s.Listener.Addr().String()
   172		s.wrap()
   173		s.goServe()
   174	}
   175	
   176	// NewTLSServer starts and returns a new Server using TLS.
   177	// The caller should call Close when finished, to shut it down.
   178	func NewTLSServer(handler http.Handler) *Server {
   179		ts := NewUnstartedServer(handler)
   180		ts.StartTLS()
   181		return ts
   182	}
   183	
   184	type closeIdleTransport interface {
   185		CloseIdleConnections()
   186	}
   187	
   188	// Close shuts down the server and blocks until all outstanding
   189	// requests on this server have completed.
   190	func (s *Server) Close() {
   191		s.mu.Lock()
   192		if !s.closed {
   193			s.closed = true
   194			s.Listener.Close()
   195			s.Config.SetKeepAlivesEnabled(false)
   196			for c, st := range s.conns {
   197				// Force-close any idle connections (those between
   198				// requests) and new connections (those which connected
   199				// but never sent a request). StateNew connections are
   200				// super rare and have only been seen (in
   201				// previously-flaky tests) in the case of
   202				// socket-late-binding races from the http Client
   203				// dialing this server and then getting an idle
   204				// connection before the dial completed. There is thus
   205				// a connected connection in StateNew with no
   206				// associated Request. We only close StateIdle and
   207				// StateNew because they're not doing anything. It's
   208				// possible StateNew is about to do something in a few
   209				// milliseconds, but a previous CL to check again in a
   210				// few milliseconds wasn't liked (early versions of
   211				// https://golang.org/cl/15151) so now we just
   212				// forcefully close StateNew. The docs for Server.Close say
   213				// we wait for "outstanding requests", so we don't close things
   214				// in StateActive.
   215				if st == http.StateIdle || st == http.StateNew {
   216					s.closeConn(c)
   217				}
   218			}
   219			// If this server doesn't shut down in 5 seconds, tell the user why.
   220			t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
   221			defer t.Stop()
   222		}
   223		s.mu.Unlock()
   224	
   225		// Not part of httptest.Server's correctness, but assume most
   226		// users of httptest.Server will be using the standard
   227		// transport, so help them out and close any idle connections for them.
   228		if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
   229			t.CloseIdleConnections()
   230		}
   231	
   232		// Also close the client idle connections.
   233		if s.client != nil {
   234			if t, ok := s.client.Transport.(closeIdleTransport); ok {
   235				t.CloseIdleConnections()
   236			}
   237		}
   238	
   239		s.wg.Wait()
   240	}
   241	
   242	func (s *Server) logCloseHangDebugInfo() {
   243		s.mu.Lock()
   244		defer s.mu.Unlock()
   245		var buf strings.Builder
   246		buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
   247		for c, st := range s.conns {
   248			fmt.Fprintf(&buf, "  %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
   249		}
   250		log.Print(buf.String())
   251	}
   252	
   253	// CloseClientConnections closes any open HTTP connections to the test Server.
   254	func (s *Server) CloseClientConnections() {
   255		s.mu.Lock()
   256		nconn := len(s.conns)
   257		ch := make(chan struct{}, nconn)
   258		for c := range s.conns {
   259			go s.closeConnChan(c, ch)
   260		}
   261		s.mu.Unlock()
   262	
   263		// Wait for outstanding closes to finish.
   264		//
   265		// Out of paranoia for making a late change in Go 1.6, we
   266		// bound how long this can wait, since golang.org/issue/14291
   267		// isn't fully understood yet. At least this should only be used
   268		// in tests.
   269		timer := time.NewTimer(5 * time.Second)
   270		defer timer.Stop()
   271		for i := 0; i < nconn; i++ {
   272			select {
   273			case <-ch:
   274			case <-timer.C:
   275				// Too slow. Give up.
   276				return
   277			}
   278		}
   279	}
   280	
   281	// Certificate returns the certificate used by the server, or nil if
   282	// the server doesn't use TLS.
   283	func (s *Server) Certificate() *x509.Certificate {
   284		return s.certificate
   285	}
   286	
   287	// Client returns an HTTP client configured for making requests to the server.
   288	// It is configured to trust the server's TLS test certificate and will
   289	// close its idle connections on Server.Close.
   290	func (s *Server) Client() *http.Client {
   291		return s.client
   292	}
   293	
   294	func (s *Server) goServe() {
   295		s.wg.Add(1)
   296		go func() {
   297			defer s.wg.Done()
   298			s.Config.Serve(s.Listener)
   299		}()
   300	}
   301	
   302	// wrap installs the connection state-tracking hook to know which
   303	// connections are idle.
   304	func (s *Server) wrap() {
   305		oldHook := s.Config.ConnState
   306		s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
   307			s.mu.Lock()
   308			defer s.mu.Unlock()
   309			switch cs {
   310			case http.StateNew:
   311				s.wg.Add(1)
   312				if _, exists := s.conns[c]; exists {
   313					panic("invalid state transition")
   314				}
   315				if s.conns == nil {
   316					s.conns = make(map[net.Conn]http.ConnState)
   317				}
   318				s.conns[c] = cs
   319				if s.closed {
   320					// Probably just a socket-late-binding dial from
   321					// the default transport that lost the race (and
   322					// thus this connection is now idle and will
   323					// never be used).
   324					s.closeConn(c)
   325				}
   326			case http.StateActive:
   327				if oldState, ok := s.conns[c]; ok {
   328					if oldState != http.StateNew && oldState != http.StateIdle {
   329						panic("invalid state transition")
   330					}
   331					s.conns[c] = cs
   332				}
   333			case http.StateIdle:
   334				if oldState, ok := s.conns[c]; ok {
   335					if oldState != http.StateActive {
   336						panic("invalid state transition")
   337					}
   338					s.conns[c] = cs
   339				}
   340				if s.closed {
   341					s.closeConn(c)
   342				}
   343			case http.StateHijacked, http.StateClosed:
   344				s.forgetConn(c)
   345			}
   346			if oldHook != nil {
   347				oldHook(c, cs)
   348			}
   349		}
   350	}
   351	
   352	// closeConn closes c.
   353	// s.mu must be held.
   354	func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
   355	
   356	// closeConnChan is like closeConn, but takes an optional channel to receive a value
   357	// when the goroutine closing c is done.
   358	func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
   359		c.Close()
   360		if done != nil {
   361			done <- struct{}{}
   362		}
   363	}
   364	
   365	// forgetConn removes c from the set of tracked conns and decrements it from the
   366	// waitgroup, unless it was previously removed.
   367	// s.mu must be held.
   368	func (s *Server) forgetConn(c net.Conn) {
   369		if _, ok := s.conns[c]; ok {
   370			delete(s.conns, c)
   371			s.wg.Done()
   372		}
   373	}
   374	

View as plain text