diff --git a/cmd/newrelic-infra/newrelic-infra.go b/cmd/newrelic-infra/newrelic-infra.go index ec47366ed..2cc57d081 100644 --- a/cmd/newrelic-infra/newrelic-infra.go +++ b/cmd/newrelic-infra/newrelic-infra.go @@ -418,7 +418,7 @@ func initializeAgentAndRun(c *config.Config, logFwCfg config.LogForward) error { // This should never happen, as the correct format is checked during NormalizeConfig. aslog.WithError(err).Error("invalid startup_connection_timeout value, cannot run status server") } else { - rep := status.NewReporter(agt.Context.Ctx, rlog, c.StatusEndpoints, timeoutD, transport, agt.Context.AgentIdnOrEmpty, agt.Context.EntityKey, c.License, userAgent) + rep := status.NewReporter(agt.Context.Ctx, rlog, c.StatusEndpoints, c.HealthEndpoint, timeoutD, transport, agt.Context.AgentIdnOrEmpty, agt.Context.EntityKey, c.License, userAgent) apiSrv, err := httpapi.NewServer(rep, integrationEmitter) if c.HTTPServerEnabled { diff --git a/internal/agent/status/status.go b/internal/agent/status/status.go index 2e2f0c305..03c5484b3 100644 --- a/internal/agent/status/status.go +++ b/internal/agent/status/status.go @@ -21,6 +21,7 @@ const ( // Report agent status report. It contains: // - checks: // - backend endpoints reachability statuses +// - backend communication healthiness // // - configuration // fields will be empty when ReportErrors() report no errors. @@ -31,6 +32,7 @@ type Report struct { type ChecksReport struct { Endpoints []EndpointReport `json:"endpoints,omitempty"` + Health HealthReport `json:"health,omitempty"` } // ConfigReport configuration used for status report. @@ -45,6 +47,12 @@ type EndpointReport struct { Error string `json:"error,omitempty"` } +// HealthReport represents the backend communication healthiness status. +type HealthReport struct { + Healthy bool `json:"healthy"` + Error string `json:"error,omitempty"` +} + // ReportEntity agent entity report. type ReportEntity struct { GUID string `json:"guid"` @@ -59,12 +67,15 @@ type Reporter interface { ReportErrors() (Report, error) // ReportEntity agent entity report. ReportEntity() (ReportEntity, error) + // ReportHealth agent healthy report. + ReportHealth() HealthReport } type nrReporter struct { ctx context.Context log log.Entry endpoints []string // NR backend URLs + healthEndpoint string // NR command backend URL to check communication healthiness license string userAgent string idProvide id.Provide @@ -119,8 +130,19 @@ func (r *nrReporter) report(onlyErrors bool) (report Report, err error) { }(ep) } + hReportC := make(chan HealthReport, 1) + + wg.Add(1) + + go func() { + hReportC <- r.getHealth(agentID) + + wg.Done() + }() + wg.Wait() close(eReportsC) + close(hReportC) var errored bool var eReports []EndpointReport @@ -132,16 +154,17 @@ func (r *nrReporter) report(onlyErrors bool) (report Report, err error) { errored = true } } + hreport := <-hReportC if !onlyErrors || errored { if report.Checks == nil { report.Checks = &ChecksReport{} } report.Checks.Endpoints = eReports + report.Checks.Health = hreport report.Config = &ConfigReport{ ReachabilityTimeout: r.timeout.String(), } - } return @@ -154,11 +177,41 @@ func (r *nrReporter) ReportEntity() (re ReportEntity, err error) { }, nil } +func (r *nrReporter) ReportHealth() HealthReport { + agentID := r.idProvide().ID.String() + + return r.getHealth(agentID) +} + +// Make a http req to the command api to validate the ingest key is valid and connectivity is ok. +func (r *nrReporter) getHealth(agentID string) HealthReport { + health, err := backendhttp.CheckEndpointHealthiness( + r.ctx, + r.healthEndpoint, + r.license, + r.userAgent, + agentID, + r.timeout, + r.transport, + ) + + healthReport := HealthReport{ + Healthy: health, + Error: "", + } + if err != nil { + healthReport.Error = err.Error() + } + + return healthReport +} + // NewReporter creates a new status reporter. func NewReporter( ctx context.Context, l log.Entry, backendEndpoints []string, + healthEndpoint string, timeout time.Duration, transport http.RoundTripper, agentIDProvide id.Provide, @@ -166,11 +219,11 @@ func NewReporter( license, userAgent string, ) Reporter { - return &nrReporter{ ctx: ctx, log: l, endpoints: backendEndpoints, + healthEndpoint: healthEndpoint, license: license, userAgent: userAgent, idProvide: agentIDProvide, diff --git a/internal/agent/status/status_test.go b/internal/agent/status/status_test.go index 67af89cab..c6fea6fcf 100644 --- a/internal/agent/status/status_test.go +++ b/internal/agent/status/status_test.go @@ -1,5 +1,7 @@ // Copyright 2021 New Relic Corporation. All rights reserved. // SPDX-License-Identifier: Apache-2.0 + +//nolint:exhaustruct,noctx package status import ( @@ -9,6 +11,8 @@ import ( "testing" "time" + http2 "github.com/newrelic/infrastructure-agent/pkg/backend/http" + "github.com/newrelic/infrastructure-agent/pkg/entity" "github.com/newrelic/infrastructure-agent/pkg/log" "github.com/stretchr/testify/assert" @@ -17,45 +21,81 @@ import ( func TestNewReporter_Report(t *testing.T) { serverOk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) })) defer serverOk.Close() + serverTimeout := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(10 * time.Second) })) defer serverTimeout.Close() + serverUnauthorized := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer serverUnauthorized.Close() + assert.Eventually(t, func() bool { res, err := serverOk.Client().Get(serverOk.URL) - return err == nil && res.StatusCode == 200 + + return err == nil && res.StatusCode == http.StatusOK }, time.Second, 10*time.Millisecond) endpointsOk := []string{serverOk.URL} + healthEndpointOK := serverOk.URL endpointsTimeout := []string{serverTimeout.URL} + healthEndpointTimeout := serverTimeout.URL endpointsMixed := []string{serverOk.URL, serverTimeout.URL} + healthEndpointUnauthorized := serverUnauthorized.URL - expectReportOk := Report{Checks: &ChecksReport{Endpoints: []EndpointReport{{ - URL: serverOk.URL, - Reachable: true, - }}}} - expectReportTimeout := Report{Checks: &ChecksReport{Endpoints: []EndpointReport{{ - URL: serverTimeout.URL, - Reachable: false, - Error: endpointTimeoutMsg, // substring is enough, it'll assert via "string contains" - }}}} - expectReportMixed := Report{Checks: &ChecksReport{Endpoints: []EndpointReport{ - { - URL: serverOk.URL, - Reachable: true, + expectReportOk := Report{Checks: &ChecksReport{ + Endpoints: []EndpointReport{ + { + URL: serverOk.URL, + Reachable: true, + Error: "", + }, }, - { - URL: serverTimeout.URL, - Reachable: false, - Error: endpointTimeoutMsg, + Health: HealthReport{ + Healthy: true, + Error: "", }, - }}} + }, Config: nil} + + expectReportTimeout := Report{Checks: &ChecksReport{ + Endpoints: []EndpointReport{ + { + URL: serverTimeout.URL, + Reachable: false, + Error: endpointTimeoutMsg, // substring is enough, it'll assert via "string contains" + }, + }, + Health: HealthReport{ + Healthy: false, + Error: "context deadline exceeded", + }, + }, Config: nil} + + expectReportMixed := Report{Checks: &ChecksReport{ + Endpoints: []EndpointReport{ + { + URL: serverOk.URL, + Reachable: true, + Error: "", + }, + { + URL: serverTimeout.URL, + Reachable: false, + Error: endpointTimeoutMsg, + }, + }, + Health: HealthReport{ + Healthy: false, + Error: http2.ErrUnexepectedResponseCode.Error(), + }, + }, Config: nil} timeout := 10 * time.Millisecond transport := &http.Transport{} @@ -66,19 +106,20 @@ func TestNewReporter_Report(t *testing.T) { return "" } tests := []struct { - name string - endpoints []string - want Report - wantErr bool + name string + endpoints []string + healthEndpoint string + want Report + wantErr bool }{ - {"connectivity ok", endpointsOk, expectReportOk, false}, - {"connectivity timedout", endpointsTimeout, expectReportTimeout, false}, - {"connectivities ok and timeout", endpointsMixed, expectReportMixed, false}, + {"connectivity ok", endpointsOk, healthEndpointOK, expectReportOk, false}, + {"connectivity timedout", endpointsTimeout, healthEndpointTimeout, expectReportTimeout, false}, + {"connectivities ok and timeout and unhealthy", endpointsMixed, healthEndpointUnauthorized, expectReportMixed, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { l := log.WithComponent(tt.name) - r := NewReporter(context.Background(), l, tt.endpoints, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + r := NewReporter(context.Background(), l, tt.endpoints, tt.healthEndpoint, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") got, err := r.Report() @@ -103,21 +144,25 @@ func TestNewReporter_Report(t *testing.T) { assert.Equal(t, expectedEndpoint.Reachable, gotEndpoint.Reachable) assert.Contains(t, gotEndpoint.Error, expectedEndpoint.Error) } + assert.Equal(t, tt.want.Checks.Health.Healthy, got.Checks.Health.Healthy) + assert.Contains(t, got.Checks.Health.Error, tt.want.Checks.Health.Error) }) } } func TestNewReporter_ReportErrors(t *testing.T) { serverOk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) })) defer serverOk.Close() + serverTimeout := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(10 * time.Second) })) defer serverTimeout.Close() endpointsOk := []string{serverOk.URL} + healthEndpointOK := serverOk.URL endpointsTimeout := []string{serverTimeout.URL} endpointsMixed := []string{serverOk.URL, serverTimeout.URL} @@ -156,7 +201,7 @@ func TestNewReporter_ReportErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { l := log.WithComponent(tt.name) - r := NewReporter(context.Background(), l, tt.endpoints, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + r := NewReporter(context.Background(), l, tt.endpoints, healthEndpointOK, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") got, err := r.ReportErrors() @@ -205,6 +250,7 @@ func TestNewReporter_ReportEntity(t *testing.T) { {"foo guid", "foo", "", ReportEntity{GUID: "foo"}, false}, {"foo guid bar key", "foo", "bar", ReportEntity{GUID: "foo", Key: "bar"}, false}, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { idProvide := func() entity.Identity { @@ -216,7 +262,7 @@ func TestNewReporter_ReportEntity(t *testing.T) { entityKeyProvider := func() string { return tt.entityKey } - r := NewReporter(context.Background(), l, []string{}, timeout, transport, idProvide, entityKeyProvider, "user-agent", "agent-key") + r := NewReporter(context.Background(), l, []string{}, "", timeout, transport, idProvide, entityKeyProvider, "user-agent", "agent-key") got, err := r.ReportEntity() @@ -230,3 +276,81 @@ func TestNewReporter_ReportEntity(t *testing.T) { }) } } + +//nolint:paralleltest +func TestNewReporter_ReportHealth(t *testing.T) { + serverOk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer serverOk.Close() + + serverTimeout := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + })) + defer serverTimeout.Close() + + serverUnauthorized := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer serverUnauthorized.Close() + + assert.Eventually(t, + func() bool { + res, err := serverOk.Client().Get(serverOk.URL) + defer func() { + _ = res.Body.Close() + }() + + return err == nil && res.StatusCode == 200 + }, + time.Second, 10*time.Millisecond) + + healthEndpointOK := serverOk.URL + healthEndpointTimeout := serverTimeout.URL + healthEndpointUnauthorized := serverUnauthorized.URL + + expectReportOk := HealthReport{ + Healthy: true, + Error: "", + } + + expectReportTimeout := HealthReport{ + Healthy: false, + Error: "context deadline exceeded", + } + + expectReportUnauthorized := HealthReport{ + Healthy: false, + Error: http2.ErrUnexepectedResponseCode.Error(), + } + + timeout := 10 * time.Millisecond + transport := &http.Transport{} + emptyIDProvide := func() entity.Identity { + return entity.EmptyIdentity + } + emptyEntityKeyProvider := func() string { + return "" + } + tests := []struct { + name string + healthEndpoint string + want HealthReport + }{ + {"connectivity ok", healthEndpointOK, expectReportOk}, + {"connectivity timedout", healthEndpointTimeout, expectReportTimeout}, + {"unhealthy", healthEndpointUnauthorized, expectReportUnauthorized}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + l := log.WithComponent(testCase.name) + r := NewReporter(context.Background(), l, nil, testCase.healthEndpoint, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + + got := r.ReportHealth() + + assert.Equal(t, testCase.want.Healthy, got.Healthy) + assert.Contains(t, got.Error, testCase.want.Error) + }) + } +} diff --git a/internal/httpapi/httpapi.go b/internal/httpapi/httpapi.go index af0922afd..2ccf20c90 100644 --- a/internal/httpapi/httpapi.go +++ b/internal/httpapi/httpapi.go @@ -30,6 +30,7 @@ const ( statusOnlyErrorsAPIPath = "/v1/status/errors" statusEntityAPIPath = "/v1/status/entity" statusAPIPathReady = "/v1/status/ready" + statusHealthAPIPath = "/v1/status/health" ingestAPIPath = "/v1/data" ingestAPIPathReady = "/v1/data/ready" readinessProbeRetryBackoff = 100 * time.Millisecond @@ -174,6 +175,7 @@ func (s *Server) serveStatus(_ context.Context) error { router.GET(statusEntityAPIPath, s.handleEntity) router.GET(statusAPIPath, s.handle(false)) router.GET(statusOnlyErrorsAPIPath, s.handle(true)) + router.GET(statusHealthAPIPath, s.handleHealth) // local only API err := http.ListenAndServe(s.Status.address, router) statusServerErr <- err @@ -344,10 +346,34 @@ func (s *Server) handle(onlyErrors bool) func(http.ResponseWriter, *http.Request } } -func (s *Server) handleReady(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { w.WriteHeader(http.StatusOK) } +func (s *Server) handleHealth(writer http.ResponseWriter, _ *http.Request, _ httprouter.Params) { + health := s.reporter.ReportHealth() + + body, err := json.Marshal(health) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + s.logger.WithError(err).Warn("couldn't encode Status report") + + return + } + + if !health.Healthy { + writer.WriteHeader(http.StatusInternalServerError) + } + + _, err = writer.Write(body) + if err != nil { + s.logger.Warn("cannot write entity response, error: " + err.Error()) + writer.WriteHeader(http.StatusInternalServerError) + + return + } +} + func (s *Server) handleEntity(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { re, err := s.reporter.ReportEntity() if err != nil { @@ -374,7 +400,6 @@ func (s *Server) handleEntity(w http.ResponseWriter, r *http.Request, ps httprou w.WriteHeader(http.StatusInternalServerError) return } - w.WriteHeader(http.StatusOK) } func (s *Server) handleIngest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { diff --git a/internal/httpapi/httpapi_test.go b/internal/httpapi/httpapi_test.go index e90012602..22f45a6fe 100644 --- a/internal/httpapi/httpapi_test.go +++ b/internal/httpapi/httpapi_test.go @@ -1,5 +1,7 @@ // Copyright 2021 NewServer Relic Corporation. All rights reserved. // SPDX-License-Identifier: Apache-2.0 + +//nolint:exhaustruct,noctx package httpapi import ( @@ -43,15 +45,16 @@ func TestHTTPAPITestSuite(t *testing.T) { func (suite *HTTPAPITestSuite) TestServe_Status() { // Given a running HTTP endpoint port, err := networkHelpers.TCPPort() - require.NoError(suite.T(), err) + suite.Require().NoError(err) serverOk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) })) defer serverOk.Close() // And a status reporter monitoring it endpoints := []string{serverOk.URL} + healthEndpoint := serverOk.URL logger := log.WithComponent(suite.T().Name()) timeout := 100 * time.Millisecond transport := &http.Transport{} @@ -63,7 +66,7 @@ func (suite *HTTPAPITestSuite) TestServe_Status() { } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - r := status.NewReporter(ctx, logger, endpoints, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + r := status.NewReporter(ctx, logger, endpoints, healthEndpoint, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") // When agent status API server is ready em := &testemit.RecordEmitter{} @@ -95,6 +98,9 @@ func (suite *HTTPAPITestSuite) TestServe_Status() { assert.Empty(suite.T(), e.Error) assert.True(suite.T(), e.Reachable) assert.Equal(suite.T(), serverOk.URL, e.URL) + h := gotReport.Checks.Health + suite.Require().True(h.Healthy) + suite.Require().Empty(h.Error) } func (suite *HTTPAPITestSuite) TestServe_OnlyErrors() { @@ -103,7 +109,7 @@ func (suite *HTTPAPITestSuite) TestServe_OnlyErrors() { require.NoError(suite.T(), err) serverOk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) })) defer serverOk.Close() serverTimeout := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -113,6 +119,7 @@ func (suite *HTTPAPITestSuite) TestServe_OnlyErrors() { // And a status reporter monitoring these endpoints endpoints := []string{serverOk.URL, serverTimeout.URL} + healthEndpoint := serverOk.URL logger := log.WithComponent(suite.T().Name()) timeout := 100 * time.Millisecond transport := &http.Transport{} @@ -124,7 +131,7 @@ func (suite *HTTPAPITestSuite) TestServe_OnlyErrors() { } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - r := status.NewReporter(ctx, logger, endpoints, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + r := status.NewReporter(ctx, logger, endpoints, healthEndpoint, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") // When agent status API server is ready em := &testemit.RecordEmitter{} @@ -191,7 +198,7 @@ func (suite *HTTPAPITestSuite) TestServe_Entity() { port, err := networkHelpers.TCPPort() require.NoError(t, err) - r := status.NewReporter(ctx, logger, []string{}, timeout, transport, tt.idProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + r := status.NewReporter(ctx, logger, []string{}, "", timeout, transport, tt.idProvide, emptyEntityKeyProvider, "user-agent", "agent-key") // When agent status API server is ready em := &testemit.RecordEmitter{} s, err := NewServer(r, em) @@ -224,6 +231,75 @@ func (suite *HTTPAPITestSuite) TestServe_Entity() { } } +func (suite *HTTPAPITestSuite) TestServe_Health() { + // Given a running HTTP endpoint + port, err := networkHelpers.TCPPort() + suite.Require().NoError(err) + var requestsDone int + + serverOk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if requestsDone > 0 { + w.WriteHeader(http.StatusUnauthorized) + } + w.WriteHeader(http.StatusOK) + requestsDone++ + })) + defer serverOk.Close() + + // And a status reporter monitoring it + logger := log.WithComponent(suite.T().Name()) + timeout := 100 * time.Millisecond + transport := &http.Transport{} + emptyIDProvide := func() entity.Identity { + return entity.EmptyIdentity + } + emptyEntityKeyProvider := func() string { + return "" + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := status.NewReporter(ctx, logger, []string{}, serverOk.URL, timeout, transport, emptyIDProvide, emptyEntityKeyProvider, "user-agent", "agent-key") + + // When agent status API server is ready + em := &testemit.RecordEmitter{} + server, err := NewServer(r, em) + suite.Require().NoError(err) + server.Status.Enable("localhost", port) + + go server.Serve(ctx) + + server.waitUntilReady() + + tests := []struct { + name string + healthy bool + statusCode int + }{ + {"healthy", true, http.StatusOK}, + {"unhealthy", false, http.StatusInternalServerError}, + } + for _, testCase := range tests { + suite.T().Run(testCase.name, func(t *testing.T) { + // And a request to the status API is sent + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d%s", port, statusHealthAPIPath), nil) + suite.Require().NoError(err) + client := http.Client{} + + res, err := client.Do(req) + suite.Require().NoError(err) + defer res.Body.Close() + + suite.Require().Equal(testCase.statusCode, res.StatusCode) + + var gotReport status.HealthReport + _ = json.NewDecoder(res.Body).Decode(&gotReport) + suite.Require().Equal(testCase.healthy, gotReport.Healthy) + }) + } +} + func (suite *HTTPAPITestSuite) TestServe_IngestData() { port, err := networkHelpers.TCPPort() require.NoError(suite.T(), err) @@ -463,3 +539,7 @@ func (r *noopReporter) ReportErrors() (status.Report, error) { func (r *noopReporter) ReportEntity() (re status.ReportEntity, err error) { return status.ReportEntity{}, nil } + +func (r *noopReporter) ReportHealth() status.HealthReport { + return status.HealthReport{} +} diff --git a/pkg/backend/http/http_client.go b/pkg/backend/http/http_client.go index 06193079c..067aa179a 100644 --- a/pkg/backend/http/http_client.go +++ b/pkg/backend/http/http_client.go @@ -1,10 +1,13 @@ // Copyright 2020 New Relic Corporation. All rights reserved. // SPDX-License-Identifier: Apache-2.0 + +//nolint:wrapcheck package http import ( "context" "crypto/x509" + "errors" "fmt" "io/ioutil" "net" @@ -19,6 +22,8 @@ import ( "github.com/sirupsen/logrus" ) +var ErrUnexepectedResponseCode = errors.New("endpoint returned and unexpected response code") + func GetHttpClient( httpTimeout time.Duration, transport http.RoundTripper, @@ -85,27 +90,33 @@ var NullHttpClient = func(req *http.Request) (res *http.Response, err error) { return } -func CheckEndpointReachability(ctx context.Context, l log.Entry, endpointURL, license, userAgent, agentID string, timeout time.Duration, transport http.RoundTripper) (timedOut bool, err error) { - var request *http.Request - if request, err = http.NewRequest("HEAD", endpointURL, nil); err != nil { - return false, fmt.Errorf("unable to prepare availability request: %v, error: %s", request, err) - } +func CheckEndpointReachability( + ctx context.Context, + logger log.Entry, + endpointURL string, + license string, + userAgent string, + agentID string, + timeout time.Duration, + transport http.RoundTripper, +) (bool, error) { + var timedOut bool - request = request.WithContext(ctx) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", userAgent) - request.Header.Set(LicenseHeader, license) - request.Header.Set(EntityKeyHeader, agentID) + request, err := buildRequest(ctx, endpointURL, "HEAD", userAgent, license, agentID) + if err != nil { + return false, err + } client := GetHttpClient(timeout, transport) // all status codes are acceptable as request has been replied by the endpoint - if _, err = client.Do(request); err != nil { + resp, err := client.Do(request) + if err != nil { if e2, ok := err.(net.Error); ok && (e2.Timeout() || e2.Temporary()) { timedOut = true } if _, ok := err.(*url.Error); ok { - l.WithError(err). + logger.WithError(err). WithField("userAgent", userAgent). WithField("timeout", timeout). WithField("url", endpointURL). @@ -114,5 +125,56 @@ func CheckEndpointReachability(ctx context.Context, l log.Entry, endpointURL, li } } - return + if resp != nil { + _ = resp.Body.Close() + } + + return timedOut, err +} + +func CheckEndpointHealthiness( + ctx context.Context, + endpointURL string, + license string, + userAgent string, + agentID string, + timeout time.Duration, + transport http.RoundTripper, +) (bool, error) { + request, err := buildRequest(ctx, endpointURL, "GET", userAgent, license, agentID) + if err != nil { + return false, err + } + + client := GetHttpClient(timeout, transport) + + resp, err := client.Do(request) + if err != nil { + return false, err + } + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusNoContent { + return false, fmt.Errorf("%w, status_code: %d", ErrUnexepectedResponseCode, resp.StatusCode) + } + + return true, nil +} + +func buildRequest(ctx context.Context, endpointURL, method, userAgent, license, agentID string) (*http.Request, error) { + request, err := http.NewRequest(method, endpointURL, nil) + if err != nil { + return nil, fmt.Errorf("unable to prepare availability request: %v, error: %w", request, err) + } + + request = request.WithContext(ctx) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", userAgent) + request.Header.Set(LicenseHeader, license) + request.Header.Set(EntityKeyHeader, agentID) + + return request, nil } diff --git a/pkg/config/config.go b/pkg/config/config.go index a322022c1..902e4566e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -979,11 +979,16 @@ type Config struct { // Public: Yes StatusServerPort int `yaml:"status_server_port" envconfig:"status_server_port"` - // StatusServerPort Set the port for status server. + // StatusEndpoints Status endpoints to check reachability. // Default: IdentityURL, CommandChannelURL, MetricsIngestURL, InventoryIngestURL // Public: Yes StatusEndpoints []string `yaml:"status_endpoints" envconfig:"status_endpoints"` + // HealthEndpoint to check backend connection healthiness. + // Default: CommandChannelURL + // Public: Yes + HealthEndpoint string `envconfig:"health_endpoint" yaml:"health_endpoint"` + // AppDataDir This option is only for Windows. It defines the path to store data in a different path than the // program files directory. // - %AppDir%/data: used for storing the delta data. @@ -2166,6 +2171,10 @@ func NormalizeConfig(cfg *Config, cfgMetadata config_loader.YAMLMetadata) (err e } } + if cfg.HealthEndpoint == "" { + cfg.HealthEndpoint = cfg.CommandChannelURL + cfg.CommandChannelEndpoint + } + // MetricsIngestEndpoint default value defined in NewConfig nlog.WithField("MetricsIngestEndpoint", cfg.MetricsIngestEndpoint). Debug("Metrics ingest endpoint.")