Source file src/net/http/httputil/reverseproxy.go
1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "log"
14 "net"
15 "net/http"
16 "net/url"
17 "strings"
18 "sync"
19 "time"
20
21 "golang.org/x/net/http/httpguts"
22 )
23
24
25
26
27 type ReverseProxy struct {
28
29
30
31
32
33
34 Director func(*http.Request)
35
36
37
38 Transport http.RoundTripper
39
40
41
42
43
44
45
46
47
48
49
50 FlushInterval time.Duration
51
52
53
54
55 ErrorLog *log.Logger
56
57
58
59
60 BufferPool BufferPool
61
62
63
64
65
66
67
68
69
70
71 ModifyResponse func(*http.Response) error
72
73
74
75
76
77
78 ErrorHandler func(http.ResponseWriter, *http.Request, error)
79 }
80
81
82
83 type BufferPool interface {
84 Get() []byte
85 Put([]byte)
86 }
87
88 func singleJoiningSlash(a, b string) string {
89 aslash := strings.HasSuffix(a, "/")
90 bslash := strings.HasPrefix(b, "/")
91 switch {
92 case aslash && bslash:
93 return a + b[1:]
94 case !aslash && !bslash:
95 return a + "/" + b
96 }
97 return a + b
98 }
99
100
101
102
103
104
105
106
107 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
108 targetQuery := target.RawQuery
109 director := func(req *http.Request) {
110 req.URL.Scheme = target.Scheme
111 req.URL.Host = target.Host
112 req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
113 if targetQuery == "" || req.URL.RawQuery == "" {
114 req.URL.RawQuery = targetQuery + req.URL.RawQuery
115 } else {
116 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
117 }
118 if _, ok := req.Header["User-Agent"]; !ok {
119
120 req.Header.Set("User-Agent", "")
121 }
122 }
123 return &ReverseProxy{Director: director}
124 }
125
126 func copyHeader(dst, src http.Header) {
127 for k, vv := range src {
128 for _, v := range vv {
129 dst.Add(k, v)
130 }
131 }
132 }
133
134
135
136
137
138
139 var hopHeaders = []string{
140 "Connection",
141 "Proxy-Connection",
142 "Keep-Alive",
143 "Proxy-Authenticate",
144 "Proxy-Authorization",
145 "Te",
146 "Trailer",
147 "Transfer-Encoding",
148 "Upgrade",
149 }
150
151 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
152 p.logf("http: proxy error: %v", err)
153 rw.WriteHeader(http.StatusBadGateway)
154 }
155
156 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
157 if p.ErrorHandler != nil {
158 return p.ErrorHandler
159 }
160 return p.defaultErrorHandler
161 }
162
163
164
165 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
166 if p.ModifyResponse == nil {
167 return true
168 }
169 if err := p.ModifyResponse(res); err != nil {
170 res.Body.Close()
171 p.getErrorHandler()(rw, req, err)
172 return false
173 }
174 return true
175 }
176
177 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
178 transport := p.Transport
179 if transport == nil {
180 transport = http.DefaultTransport
181 }
182
183 ctx := req.Context()
184 if cn, ok := rw.(http.CloseNotifier); ok {
185 var cancel context.CancelFunc
186 ctx, cancel = context.WithCancel(ctx)
187 defer cancel()
188 notifyChan := cn.CloseNotify()
189 go func() {
190 select {
191 case <-notifyChan:
192 cancel()
193 case <-ctx.Done():
194 }
195 }()
196 }
197
198 outreq := req.Clone(ctx)
199 if req.ContentLength == 0 {
200 outreq.Body = nil
201 }
202 if outreq.Header == nil {
203 outreq.Header = make(http.Header)
204 }
205
206 p.Director(outreq)
207 outreq.Close = false
208
209 reqUpType := upgradeType(outreq.Header)
210 removeConnectionHeaders(outreq.Header)
211
212
213
214
215 for _, h := range hopHeaders {
216 hv := outreq.Header.Get(h)
217 if hv == "" {
218 continue
219 }
220 if h == "Te" && hv == "trailers" {
221
222
223
224
225
226
227 continue
228 }
229 outreq.Header.Del(h)
230 }
231
232
233
234 if reqUpType != "" {
235 outreq.Header.Set("Connection", "Upgrade")
236 outreq.Header.Set("Upgrade", reqUpType)
237 }
238
239 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
240
241
242
243 if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
244 clientIP = strings.Join(prior, ", ") + ", " + clientIP
245 }
246 outreq.Header.Set("X-Forwarded-For", clientIP)
247 }
248
249 res, err := transport.RoundTrip(outreq)
250 if err != nil {
251 p.getErrorHandler()(rw, outreq, err)
252 return
253 }
254
255
256 if res.StatusCode == http.StatusSwitchingProtocols {
257 if !p.modifyResponse(rw, res, outreq) {
258 return
259 }
260 p.handleUpgradeResponse(rw, outreq, res)
261 return
262 }
263
264 removeConnectionHeaders(res.Header)
265
266 for _, h := range hopHeaders {
267 res.Header.Del(h)
268 }
269
270 if !p.modifyResponse(rw, res, outreq) {
271 return
272 }
273
274 copyHeader(rw.Header(), res.Header)
275
276
277
278 announcedTrailers := len(res.Trailer)
279 if announcedTrailers > 0 {
280 trailerKeys := make([]string, 0, len(res.Trailer))
281 for k := range res.Trailer {
282 trailerKeys = append(trailerKeys, k)
283 }
284 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
285 }
286
287 rw.WriteHeader(res.StatusCode)
288
289 err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
290 if err != nil {
291 defer res.Body.Close()
292
293
294
295 if !shouldPanicOnCopyError(req) {
296 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
297 return
298 }
299 panic(http.ErrAbortHandler)
300 }
301 res.Body.Close()
302
303 if len(res.Trailer) > 0 {
304
305
306
307 if fl, ok := rw.(http.Flusher); ok {
308 fl.Flush()
309 }
310 }
311
312 if len(res.Trailer) == announcedTrailers {
313 copyHeader(rw.Header(), res.Trailer)
314 return
315 }
316
317 for k, vv := range res.Trailer {
318 k = http.TrailerPrefix + k
319 for _, v := range vv {
320 rw.Header().Add(k, v)
321 }
322 }
323 }
324
325 var inOurTests bool
326
327
328
329
330
331
332 func shouldPanicOnCopyError(req *http.Request) bool {
333 if inOurTests {
334
335 return true
336 }
337 if req.Context().Value(http.ServerContextKey) != nil {
338
339
340 return true
341 }
342
343
344 return false
345 }
346
347
348
349 func removeConnectionHeaders(h http.Header) {
350 for _, f := range h["Connection"] {
351 for _, sf := range strings.Split(f, ",") {
352 if sf = strings.TrimSpace(sf); sf != "" {
353 h.Del(sf)
354 }
355 }
356 }
357 }
358
359
360
361 func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
362 resCT := res.Header.Get("Content-Type")
363
364
365
366 if resCT == "text/event-stream" {
367 return -1
368 }
369
370
371 return p.FlushInterval
372 }
373
374 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
375 if flushInterval != 0 {
376 if wf, ok := dst.(writeFlusher); ok {
377 mlw := &maxLatencyWriter{
378 dst: wf,
379 latency: flushInterval,
380 }
381 defer mlw.stop()
382
383
384 mlw.flushPending = true
385 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
386
387 dst = mlw
388 }
389 }
390
391 var buf []byte
392 if p.BufferPool != nil {
393 buf = p.BufferPool.Get()
394 defer p.BufferPool.Put(buf)
395 }
396 _, err := p.copyBuffer(dst, src, buf)
397 return err
398 }
399
400
401
402 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
403 if len(buf) == 0 {
404 buf = make([]byte, 32*1024)
405 }
406 var written int64
407 for {
408 nr, rerr := src.Read(buf)
409 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
410 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
411 }
412 if nr > 0 {
413 nw, werr := dst.Write(buf[:nr])
414 if nw > 0 {
415 written += int64(nw)
416 }
417 if werr != nil {
418 return written, werr
419 }
420 if nr != nw {
421 return written, io.ErrShortWrite
422 }
423 }
424 if rerr != nil {
425 if rerr == io.EOF {
426 rerr = nil
427 }
428 return written, rerr
429 }
430 }
431 }
432
433 func (p *ReverseProxy) logf(format string, args ...interface{}) {
434 if p.ErrorLog != nil {
435 p.ErrorLog.Printf(format, args...)
436 } else {
437 log.Printf(format, args...)
438 }
439 }
440
441 type writeFlusher interface {
442 io.Writer
443 http.Flusher
444 }
445
446 type maxLatencyWriter struct {
447 dst writeFlusher
448 latency time.Duration
449
450 mu sync.Mutex
451 t *time.Timer
452 flushPending bool
453 }
454
455 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
456 m.mu.Lock()
457 defer m.mu.Unlock()
458 n, err = m.dst.Write(p)
459 if m.latency < 0 {
460 m.dst.Flush()
461 return
462 }
463 if m.flushPending {
464 return
465 }
466 if m.t == nil {
467 m.t = time.AfterFunc(m.latency, m.delayedFlush)
468 } else {
469 m.t.Reset(m.latency)
470 }
471 m.flushPending = true
472 return
473 }
474
475 func (m *maxLatencyWriter) delayedFlush() {
476 m.mu.Lock()
477 defer m.mu.Unlock()
478 if !m.flushPending {
479 return
480 }
481 m.dst.Flush()
482 m.flushPending = false
483 }
484
485 func (m *maxLatencyWriter) stop() {
486 m.mu.Lock()
487 defer m.mu.Unlock()
488 m.flushPending = false
489 if m.t != nil {
490 m.t.Stop()
491 }
492 }
493
494 func upgradeType(h http.Header) string {
495 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
496 return ""
497 }
498 return strings.ToLower(h.Get("Upgrade"))
499 }
500
501 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
502 reqUpType := upgradeType(req.Header)
503 resUpType := upgradeType(res.Header)
504 if reqUpType != resUpType {
505 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
506 return
507 }
508
509 copyHeader(res.Header, rw.Header())
510
511 hj, ok := rw.(http.Hijacker)
512 if !ok {
513 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
514 return
515 }
516 backConn, ok := res.Body.(io.ReadWriteCloser)
517 if !ok {
518 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
519 return
520 }
521 defer backConn.Close()
522 conn, brw, err := hj.Hijack()
523 if err != nil {
524 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
525 return
526 }
527 defer conn.Close()
528 res.Body = nil
529 if err := res.Write(brw); err != nil {
530 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
531 return
532 }
533 if err := brw.Flush(); err != nil {
534 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
535 return
536 }
537 errc := make(chan error, 1)
538 spc := switchProtocolCopier{user: conn, backend: backConn}
539 go spc.copyToBackend(errc)
540 go spc.copyFromBackend(errc)
541 <-errc
542 return
543 }
544
545
546
547 type switchProtocolCopier struct {
548 user, backend io.ReadWriter
549 }
550
551 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
552 _, err := io.Copy(c.user, c.backend)
553 errc <- err
554 }
555
556 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
557 _, err := io.Copy(c.backend, c.user)
558 errc <- err
559 }
560
View as plain text