Skip to content

Instantly share code, notes, and snippets.

@seantalts
Forked from dmichael/httpclient.go
Last active August 12, 2022 14:07
Show Gist options
  • Save seantalts/11266762 to your computer and use it in GitHub Desktop.
Save seantalts/11266762 to your computer and use it in GitHub Desktop.
package httptimeout
import (
"net/http"
"time"
"fmt"
)
type TimeoutTransport struct {
http.Transport
RoundTripTimeout time.Duration
}
type respAndErr struct {
resp *http.Response
err error
}
type netTimeoutError struct {
error
}
func (ne netTimeoutError) Timeout() bool { return true }
// If you don't set RoundTrip on TimeoutTransport, this will always timeout at 0
func (t *TimeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
timeout := time.After(t.RoundTripTimeout)
resp := make(chan respAndErr, 1)
go func() {
r, e := t.Transport.RoundTrip(req)
resp <- respAndErr{
resp: r,
err: e,
}
}()
select {
case <-timeout:// A round trip timeout has occurred.
t.Transport.CancelRequest(req)
return nil, netTimeoutError{
error: fmt.Errorf("timed out after %s", t.RoundTripTimeout),
}
case r := <-resp: // Success!
return r.resp, r.err
}
}
package httptimeout
import (
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestHttpTimeout(t *testing.T) {
http.HandleFunc("/normal", func(w http.ResponseWriter, req *http.Request) {
// Empirically, timeouts less than these seem to be flaky
time.Sleep(100 * time.Millisecond)
io.WriteString(w, "ok")
})
http.HandleFunc("/timeout", func(w http.ResponseWriter, req *http.Request) {
time.Sleep(250 * time.Millisecond)
io.WriteString(w, "ok")
})
ts := httptest.NewServer(http.DefaultServeMux)
defer ts.Close()
numDials := 0
client := &http.Client{
Transport: &TimeoutTransport{
Transport: http.Transport{
Dial: func(netw, addr string) (net.Conn, error) {
t.Logf("dial to %s://%s", netw, addr)
numDials++ // For testing only.
return net.Dial(netw, addr) // Regular ass dial.
},
},
RoundTripTimeout: time.Millisecond * 200,
},
}
addr := ts.URL
SendTestRequest(t, client, "1st", addr, "normal")
if numDials != 1 {
t.Fatalf("Should only have 1 dial at this point.")
}
SendTestRequest(t, client, "2st", addr, "normal")
if numDials != 1 {
t.Fatalf("Should only have 1 dial at this point.")
}
SendTestRequest(t, client, "3st", addr, "timeout")
if numDials != 1 {
t.Fatalf("Should only have 1 dial at this point.")
}
SendTestRequest(t, client, "4st", addr, "normal")
if numDials != 2 {
t.Fatalf("Should have our 2nd dial.")
}
time.Sleep(time.Millisecond * 700)
SendTestRequest(t, client, "5st", addr, "normal")
if numDials != 2 {
t.Fatalf("Should still only have 2 dials.")
}
}
func SendTestRequest(t *testing.T, client *http.Client, id, addr, path string) {
req, err := http.NewRequest("GET", addr+"/"+path, nil)
if err != nil {
t.Fatalf("new request failed - %s", err)
}
req.Header.Add("Connection", "keep-alive")
switch path {
case "normal":
if resp, err := client.Do(req); err != nil {
t.Fatalf("%s request failed - %s", id, err)
} else {
result, err2 := ioutil.ReadAll(resp.Body)
if err2 != nil {
t.Fatalf("%s response read failed - %s", id, err2)
}
resp.Body.Close()
t.Logf("%s request - %s", id, result)
}
case "timeout":
if _, err := client.Do(req); err == nil {
t.Fatalf("%s request not timeout", id)
} else {
t.Logf("%s request - %s", id, err)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment