Source file src/net/http/httptest/server.go
1
2
3
4
5
6
7 package httptest
8
9 import (
10 "crypto/tls"
11 "crypto/x509"
12 "flag"
13 "fmt"
14 "log"
15 "net"
16 "net/http"
17 "net/http/internal"
18 "os"
19 "strings"
20 "sync"
21 "time"
22 )
23
24
25
26 type Server struct {
27 URL string
28 Listener net.Listener
29
30
31
32
33 TLS *tls.Config
34
35
36
37 Config *http.Server
38
39
40 certificate *x509.Certificate
41
42
43
44 wg sync.WaitGroup
45
46 mu sync.Mutex
47 closed bool
48 conns map[net.Conn]http.ConnState
49
50
51
52 client *http.Client
53 }
54
55 func newLocalListener() net.Listener {
56 if serveFlag != "" {
57 l, err := net.Listen("tcp", serveFlag)
58 if err != nil {
59 panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
60 }
61 return l
62 }
63 l, err := net.Listen("tcp", "127.0.0.1:0")
64 if err != nil {
65 if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
66 panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
67 }
68 }
69 return l
70 }
71
72
73
74
75
76
77
78
79 var serveFlag string
80
81 func init() {
82 if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
83 flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
84 }
85 }
86
87 func strSliceContainsPrefix(v []string, pre string) bool {
88 for _, s := range v {
89 if strings.HasPrefix(s, pre) {
90 return true
91 }
92 }
93 return false
94 }
95
96
97
98 func NewServer(handler http.Handler) *Server {
99 ts := NewUnstartedServer(handler)
100 ts.Start()
101 return ts
102 }
103
104
105
106
107
108
109
110 func NewUnstartedServer(handler http.Handler) *Server {
111 return &Server{
112 Listener: newLocalListener(),
113 Config: &http.Server{Handler: handler},
114 }
115 }
116
117
118 func (s *Server) Start() {
119 if s.URL != "" {
120 panic("Server already started")
121 }
122 if s.client == nil {
123 s.client = &http.Client{Transport: &http.Transport{}}
124 }
125 s.URL = "http://" + s.Listener.Addr().String()
126 s.wrap()
127 s.goServe()
128 if serveFlag != "" {
129 fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
130 select {}
131 }
132 }
133
134
135 func (s *Server) StartTLS() {
136 if s.URL != "" {
137 panic("Server already started")
138 }
139 if s.client == nil {
140 s.client = &http.Client{Transport: &http.Transport{}}
141 }
142 cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
143 if err != nil {
144 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
145 }
146
147 existingConfig := s.TLS
148 if existingConfig != nil {
149 s.TLS = existingConfig.Clone()
150 } else {
151 s.TLS = new(tls.Config)
152 }
153 if s.TLS.NextProtos == nil {
154 s.TLS.NextProtos = []string{"http/1.1"}
155 }
156 if len(s.TLS.Certificates) == 0 {
157 s.TLS.Certificates = []tls.Certificate{cert}
158 }
159 s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
160 if err != nil {
161 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
162 }
163 certpool := x509.NewCertPool()
164 certpool.AddCert(s.certificate)
165 s.client.Transport = &http.Transport{
166 TLSClientConfig: &tls.Config{
167 RootCAs: certpool,
168 },
169 }
170 s.Listener = tls.NewListener(s.Listener, s.TLS)
171 s.URL = "https://" + s.Listener.Addr().String()
172 s.wrap()
173 s.goServe()
174 }
175
176
177
178 func NewTLSServer(handler http.Handler) *Server {
179 ts := NewUnstartedServer(handler)
180 ts.StartTLS()
181 return ts
182 }
183
184 type closeIdleTransport interface {
185 CloseIdleConnections()
186 }
187
188
189
190 func (s *Server) Close() {
191 s.mu.Lock()
192 if !s.closed {
193 s.closed = true
194 s.Listener.Close()
195 s.Config.SetKeepAlivesEnabled(false)
196 for c, st := range s.conns {
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215 if st == http.StateIdle || st == http.StateNew {
216 s.closeConn(c)
217 }
218 }
219
220 t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
221 defer t.Stop()
222 }
223 s.mu.Unlock()
224
225
226
227
228 if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
229 t.CloseIdleConnections()
230 }
231
232
233 if s.client != nil {
234 if t, ok := s.client.Transport.(closeIdleTransport); ok {
235 t.CloseIdleConnections()
236 }
237 }
238
239 s.wg.Wait()
240 }
241
242 func (s *Server) logCloseHangDebugInfo() {
243 s.mu.Lock()
244 defer s.mu.Unlock()
245 var buf strings.Builder
246 buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
247 for c, st := range s.conns {
248 fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
249 }
250 log.Print(buf.String())
251 }
252
253
254 func (s *Server) CloseClientConnections() {
255 s.mu.Lock()
256 nconn := len(s.conns)
257 ch := make(chan struct{}, nconn)
258 for c := range s.conns {
259 go s.closeConnChan(c, ch)
260 }
261 s.mu.Unlock()
262
263
264
265
266
267
268
269 timer := time.NewTimer(5 * time.Second)
270 defer timer.Stop()
271 for i := 0; i < nconn; i++ {
272 select {
273 case <-ch:
274 case <-timer.C:
275
276 return
277 }
278 }
279 }
280
281
282
283 func (s *Server) Certificate() *x509.Certificate {
284 return s.certificate
285 }
286
287
288
289
290 func (s *Server) Client() *http.Client {
291 return s.client
292 }
293
294 func (s *Server) goServe() {
295 s.wg.Add(1)
296 go func() {
297 defer s.wg.Done()
298 s.Config.Serve(s.Listener)
299 }()
300 }
301
302
303
304 func (s *Server) wrap() {
305 oldHook := s.Config.ConnState
306 s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
307 s.mu.Lock()
308 defer s.mu.Unlock()
309 switch cs {
310 case http.StateNew:
311 s.wg.Add(1)
312 if _, exists := s.conns[c]; exists {
313 panic("invalid state transition")
314 }
315 if s.conns == nil {
316 s.conns = make(map[net.Conn]http.ConnState)
317 }
318 s.conns[c] = cs
319 if s.closed {
320
321
322
323
324 s.closeConn(c)
325 }
326 case http.StateActive:
327 if oldState, ok := s.conns[c]; ok {
328 if oldState != http.StateNew && oldState != http.StateIdle {
329 panic("invalid state transition")
330 }
331 s.conns[c] = cs
332 }
333 case http.StateIdle:
334 if oldState, ok := s.conns[c]; ok {
335 if oldState != http.StateActive {
336 panic("invalid state transition")
337 }
338 s.conns[c] = cs
339 }
340 if s.closed {
341 s.closeConn(c)
342 }
343 case http.StateHijacked, http.StateClosed:
344 s.forgetConn(c)
345 }
346 if oldHook != nil {
347 oldHook(c, cs)
348 }
349 }
350 }
351
352
353
354 func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
355
356
357
358 func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
359 c.Close()
360 if done != nil {
361 done <- struct{}{}
362 }
363 }
364
365
366
367
368 func (s *Server) forgetConn(c net.Conn) {
369 if _, ok := s.conns[c]; ok {
370 delete(s.conns, c)
371 s.wg.Done()
372 }
373 }
374
View as plain text