| package server |
| |
| import ( |
| "encoding/base64" |
| "encoding/json" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "reflect" |
| "strconv" |
| "testing" |
| ) |
| |
| func composeQuery(path string, code int, headers http.Header, body []byte) (string, error) { |
| u, err := url.Parse(path) |
| if err != nil { |
| return "", err |
| } |
| q := u.Query() |
| if code > 0 { |
| q.Set("respStatus", strconv.Itoa(code)) |
| } |
| if headers != nil { |
| h, err := json.Marshal(headers) |
| if err != nil { |
| return "", err |
| } |
| q.Set("respHeader", base64.URLEncoding.EncodeToString(h)) |
| } |
| if len(body) > 0 { |
| q.Set("respBody", base64.URLEncoding.EncodeToString(body)) |
| } |
| u.RawQuery = q.Encode() |
| return u.String(), nil |
| } |
| |
| func TestResponseOverride(t *testing.T) { |
| tests := []struct { |
| name string |
| code int |
| headers http.Header |
| body []byte |
| }{ |
| {name: "code", code: 204}, |
| {name: "body", body: []byte("new body")}, |
| { |
| name: "headers", |
| headers: http.Header{ |
| "Via": []string{"Via1", "Via2"}, |
| "Content-Type": []string{"random content"}, |
| }, |
| }, |
| { |
| name: "everything", |
| code: 204, |
| body: []byte("new body"), |
| headers: http.Header{ |
| "Via": []string{"Via1", "Via2"}, |
| "Content-Type": []string{"random content"}, |
| }, |
| }, |
| } |
| |
| for _, test := range tests { |
| u, err := composeQuery("http://test.com/override", test.code, test.headers, test.body) |
| if err != nil { |
| t.Errorf("%s: composeQuery: %v", test.name, err) |
| return |
| } |
| req, err := http.NewRequest("GET", u, nil) |
| if err != nil { |
| t.Errorf("%s: http.NewRequest: %v", test.name, err) |
| return |
| } |
| w := httptest.NewRecorder() |
| defaultResponse(w, req) |
| if test.code > 0 { |
| if got, want := w.Code, test.code; got != want { |
| t.Errorf("%s: response code: got %d want %d", test.name, got, want) |
| return |
| } |
| } |
| if test.headers != nil { |
| for k, want := range test.headers { |
| got, ok := w.HeaderMap[k] |
| if !ok || !reflect.DeepEqual(got, want) { |
| t.Errorf("%s: header %s: code: got %v want %v", test.name, k, got, want) |
| return |
| } |
| } |
| } |
| if test.body != nil { |
| if got, want := string(w.Body.Bytes()), string(test.body); got != want { |
| t.Errorf("%s: body: got %s want %s", test.name, got, want) |
| return |
| } |
| } |
| } |
| } |