...

Source file src/net/http/httptest/recorder.go

     1	// Copyright 2011 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package httptest
     6	
     7	import (
     8		"bytes"
     9		"fmt"
    10		"io/ioutil"
    11		"net/http"
    12		"strconv"
    13		"strings"
    14	
    15		"golang.org/x/net/http/httpguts"
    16	)
    17	
    18	// ResponseRecorder is an implementation of http.ResponseWriter that
    19	// records its mutations for later inspection in tests.
    20	type ResponseRecorder struct {
    21		// Code is the HTTP response code set by WriteHeader.
    22		//
    23		// Note that if a Handler never calls WriteHeader or Write,
    24		// this might end up being 0, rather than the implicit
    25		// http.StatusOK. To get the implicit value, use the Result
    26		// method.
    27		Code int
    28	
    29		// HeaderMap contains the headers explicitly set by the Handler.
    30		// It is an internal detail.
    31		//
    32		// Deprecated: HeaderMap exists for historical compatibility
    33		// and should not be used. To access the headers returned by a handler,
    34		// use the Response.Header map as returned by the Result method.
    35		HeaderMap http.Header
    36	
    37		// Body is the buffer to which the Handler's Write calls are sent.
    38		// If nil, the Writes are silently discarded.
    39		Body *bytes.Buffer
    40	
    41		// Flushed is whether the Handler called Flush.
    42		Flushed bool
    43	
    44		result      *http.Response // cache of Result's return value
    45		snapHeader  http.Header    // snapshot of HeaderMap at first Write
    46		wroteHeader bool
    47	}
    48	
    49	// NewRecorder returns an initialized ResponseRecorder.
    50	func NewRecorder() *ResponseRecorder {
    51		return &ResponseRecorder{
    52			HeaderMap: make(http.Header),
    53			Body:      new(bytes.Buffer),
    54			Code:      200,
    55		}
    56	}
    57	
    58	// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    59	// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    60	const DefaultRemoteAddr = "1.2.3.4"
    61	
    62	// Header implements http.ResponseWriter. It returns the response
    63	// headers to mutate within a handler. To test the headers that were
    64	// written after a handler completes, use the Result method and see
    65	// the returned Response value's Header.
    66	func (rw *ResponseRecorder) Header() http.Header {
    67		m := rw.HeaderMap
    68		if m == nil {
    69			m = make(http.Header)
    70			rw.HeaderMap = m
    71		}
    72		return m
    73	}
    74	
    75	// writeHeader writes a header if it was not written yet and
    76	// detects Content-Type if needed.
    77	//
    78	// bytes or str are the beginning of the response body.
    79	// We pass both to avoid unnecessarily generate garbage
    80	// in rw.WriteString which was created for performance reasons.
    81	// Non-nil bytes win.
    82	func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    83		if rw.wroteHeader {
    84			return
    85		}
    86		if len(str) > 512 {
    87			str = str[:512]
    88		}
    89	
    90		m := rw.Header()
    91	
    92		_, hasType := m["Content-Type"]
    93		hasTE := m.Get("Transfer-Encoding") != ""
    94		if !hasType && !hasTE {
    95			if b == nil {
    96				b = []byte(str)
    97			}
    98			m.Set("Content-Type", http.DetectContentType(b))
    99		}
   100	
   101		rw.WriteHeader(200)
   102	}
   103	
   104	// Write implements http.ResponseWriter. The data in buf is written to
   105	// rw.Body, if not nil.
   106	func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
   107		rw.writeHeader(buf, "")
   108		if rw.Body != nil {
   109			rw.Body.Write(buf)
   110		}
   111		return len(buf), nil
   112	}
   113	
   114	// WriteString implements io.StringWriter. The data in str is written
   115	// to rw.Body, if not nil.
   116	func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   117		rw.writeHeader(nil, str)
   118		if rw.Body != nil {
   119			rw.Body.WriteString(str)
   120		}
   121		return len(str), nil
   122	}
   123	
   124	// WriteHeader implements http.ResponseWriter.
   125	func (rw *ResponseRecorder) WriteHeader(code int) {
   126		if rw.wroteHeader {
   127			return
   128		}
   129		rw.Code = code
   130		rw.wroteHeader = true
   131		if rw.HeaderMap == nil {
   132			rw.HeaderMap = make(http.Header)
   133		}
   134		rw.snapHeader = rw.HeaderMap.Clone()
   135	}
   136	
   137	// Flush implements http.Flusher. To test whether Flush was
   138	// called, see rw.Flushed.
   139	func (rw *ResponseRecorder) Flush() {
   140		if !rw.wroteHeader {
   141			rw.WriteHeader(200)
   142		}
   143		rw.Flushed = true
   144	}
   145	
   146	// Result returns the response generated by the handler.
   147	//
   148	// The returned Response will have at least its StatusCode,
   149	// Header, Body, and optionally Trailer populated.
   150	// More fields may be populated in the future, so callers should
   151	// not DeepEqual the result in tests.
   152	//
   153	// The Response.Header is a snapshot of the headers at the time of the
   154	// first write call, or at the time of this call, if the handler never
   155	// did a write.
   156	//
   157	// The Response.Body is guaranteed to be non-nil and Body.Read call is
   158	// guaranteed to not return any error other than io.EOF.
   159	//
   160	// Result must only be called after the handler has finished running.
   161	func (rw *ResponseRecorder) Result() *http.Response {
   162		if rw.result != nil {
   163			return rw.result
   164		}
   165		if rw.snapHeader == nil {
   166			rw.snapHeader = rw.HeaderMap.Clone()
   167		}
   168		res := &http.Response{
   169			Proto:      "HTTP/1.1",
   170			ProtoMajor: 1,
   171			ProtoMinor: 1,
   172			StatusCode: rw.Code,
   173			Header:     rw.snapHeader,
   174		}
   175		rw.result = res
   176		if res.StatusCode == 0 {
   177			res.StatusCode = 200
   178		}
   179		res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
   180		if rw.Body != nil {
   181			res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   182		} else {
   183			res.Body = http.NoBody
   184		}
   185		res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   186	
   187		if trailers, ok := rw.snapHeader["Trailer"]; ok {
   188			res.Trailer = make(http.Header, len(trailers))
   189			for _, k := range trailers {
   190				k = http.CanonicalHeaderKey(k)
   191				if !httpguts.ValidTrailerHeader(k) {
   192					// Ignore since forbidden by RFC 7230, section 4.1.2.
   193					continue
   194				}
   195				vv, ok := rw.HeaderMap[k]
   196				if !ok {
   197					continue
   198				}
   199				vv2 := make([]string, len(vv))
   200				copy(vv2, vv)
   201				res.Trailer[k] = vv2
   202			}
   203		}
   204		for k, vv := range rw.HeaderMap {
   205			if !strings.HasPrefix(k, http.TrailerPrefix) {
   206				continue
   207			}
   208			if res.Trailer == nil {
   209				res.Trailer = make(http.Header)
   210			}
   211			for _, v := range vv {
   212				res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   213			}
   214		}
   215		return res
   216	}
   217	
   218	// parseContentLength trims whitespace from s and returns -1 if no value
   219	// is set, or the value if it's >= 0.
   220	//
   221	// This a modified version of same function found in net/http/transfer.go. This
   222	// one just ignores an invalid header.
   223	func parseContentLength(cl string) int64 {
   224		cl = strings.TrimSpace(cl)
   225		if cl == "" {
   226			return -1
   227		}
   228		n, err := strconv.ParseInt(cl, 10, 64)
   229		if err != nil {
   230			return -1
   231		}
   232		return n
   233	}
   234	

View as plain text