diff --git a/http.go b/http.go index 7dd0401..a7acbec 100644 --- a/http.go +++ b/http.go @@ -93,7 +93,7 @@ func (d *Dialer) connect(c net.Conn, network, address string) error { return errors.New("network not implemented") } req := &http.Request{ - Method: "CONNECT", + Method: http.MethodConnect, URL: &url.URL{Opaque: address}, Host: address, Header: make(http.Header), diff --git a/service/autoproxy.go b/service/autoproxy.go index 08086de..57d494b 100644 --- a/service/autoproxy.go +++ b/service/autoproxy.go @@ -31,17 +31,30 @@ var errNoUpdateAvailable = errors.New("no update available") func parseAutoproxy(c *Client) (*proxy.PerHost, error) { accessLogger.Print("check autoproxy") - resp, err := http.Get(autoproxyURL) - if err != nil { + ch := make(chan *http.Response, 1) + go func() { + if resp, err := http.Get(autoproxyURL); err == nil { + ch <- resp + } else { + errorLogger.Debug("failed to check autoproxy without using proxy", "error", err) + } + }() + go func() { if t, ok := http.DefaultTransport.(*http.Transport); ok { + t = t.Clone() t.Proxy = http.ProxyURL(c.u) - if resp, err = http.Get(autoproxyURL); err != nil { - errorLogger.Print(err) - return nil, err + if resp, err := (&http.Client{Transport: t}).Get(autoproxyURL); err == nil { + ch <- resp + } else { + errorLogger.Debug("failed to check autoproxy using proxy", "error", err) } - } else { - return nil, err } + }() + var resp *http.Response + select { + case <-time.After(time.Minute): + return nil, errors.New("failed to check autoproxy") + case resp = <-ch: } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) @@ -68,6 +81,7 @@ func parseAutoproxy(c *Client) (*proxy.PerHost, error) { } } last = string(b) + accessLogger.Print("autoproxy loaded") return perHost, nil } @@ -90,7 +104,6 @@ func initAutoproxy(c *Client) *proxy.PerHost { } return } - accessLogger.Print("autoproxy updated") c.autoproxy.PerHost = p }) return p diff --git a/service/httpproxy_test.go b/service/httpproxy_test.go index 567e81a..937bf49 100644 --- a/service/httpproxy_test.go +++ b/service/httpproxy_test.go @@ -31,9 +31,11 @@ import ( ) var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header := w.Header() m := make(map[string]string) - for k := range r.Header { + for k, v := range r.Header { m[k] = r.Header.Get(k) + header.Add(k, v[0]) } b, _ := json.Marshal(m) w.Write(b) @@ -58,7 +60,7 @@ func newRequest(url string, m map[string]string) *http.Request { return req } -func do(proxy proxy.Dialer, url string, req *http.Request) (m map[string]string, err error) { +func do(proxy proxy.Dialer, url string, req *http.Request) (resp *http.Response, m map[string]string, err error) { c, err := proxy.Dial("tcp", strings.TrimPrefix(url, "http://")) if err != nil { return @@ -66,7 +68,7 @@ func do(proxy proxy.Dialer, url string, req *http.Request) (m map[string]string, if err = req.WriteProxy(c); err != nil { return } - resp, err := http.ReadResponse(bufio.NewReader(c), nil) + resp, err = http.ReadResponse(bufio.NewReader(c), nil) if err != nil { return } @@ -124,25 +126,35 @@ func createCert() (string, string, error) { func testProxy(t *testing.T, proxyPort, testURL string, m map[string]string) { req := newRequest(testURL, m) d, _ := httpproxy.NewDialer(":"+proxyPort, nil, nil, nil) - res, err := do(d, testURL, req) + resp, res, err := do(d, testURL, req) if err != nil { t.Fatal(err) } if !maps.Equal(m, res) { t.Errorf("expect %v; got %v", m, res) } + for k, v := range m { + if vv := resp.Header.Get(k); vv != v { + t.Errorf("expect %s %s; got %s", k, v, vv) + } + } u, err := url.Parse("http://localhost:" + proxyPort) if err != nil { t.Fatal(err) } d, _ = httpproxy.FromURL(u, nil) - res, err = do(d, testURL, req) + resp, res, err = do(d, testURL, req) if err != nil { t.Fatal(err) } if !maps.Equal(m, res) { t.Errorf("expect %v; got %v", m, res) } + for k, v := range m { + if vv := resp.Header.Get(k); vv != v { + t.Errorf("expect %s %s; got %s", k, v, vv) + } + } } func TestProxy(t *testing.T) { @@ -265,7 +277,7 @@ func TestAuth(t *testing.T) { if testcase.proxyAuth != nil { d.(*httpproxy.Dialer).Auth = testcase.proxyAuth } - res, err := do(d, ts.URL, req) + resp, res, err := do(d, ts.URL, req) if testcase.err != "" { if err == nil || !strings.Contains(err.Error(), testcase.err) { @@ -274,8 +286,15 @@ func TestAuth(t *testing.T) { } else { if err != nil { t.Error(i, err) - } else if !maps.Equal(m, res) { - t.Errorf("%d expect %v; got %v", i, m, res) + } else { + if !maps.Equal(m, res) { + t.Errorf("%d expect %v; got %v", i, m, res) + } + for k, v := range m { + if vv := resp.Header.Get(k); vv != v { + t.Errorf("expect %s %s; got %s", k, v, vv) + } + } } } }