diff --git a/agent/agent.go b/agent/agent.go index 2eaf0b5..5f07387 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -2,6 +2,7 @@ package agent import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -131,7 +132,7 @@ func (c *Agent) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } if r.Method != http.MethodGet && r.Method != http.MethodHead { - errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + utils.ServeError(rw, r, errcode.ErrorCodeUnsupported, 0) return } @@ -157,16 +158,16 @@ func (c *Agent) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if c.authenticator != nil { t, err = c.authenticator.Authorization(r) if err != nil { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(err.Error())) + utils.ServeError(rw, r, errcode.ErrorCodeDenied.WithMessage(err.Error()), 0) return } } if t.Block { if t.BlockMessage != "" { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage)) + utils.ServeError(rw, r, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage), 0) } else { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0) } return } @@ -187,7 +188,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t value, ok := c.blobCache.Get(info.Blobs) if ok { if value.Error != nil { - errcode.ServeJSON(rw, value.Error) + utils.ServeError(rw, r, value.Error, 0) return ctx, true } c.serveCachedBlob(rw, r, info.Blobs, info, t, value.Size, start) @@ -204,9 +205,10 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t if ctx.Err() != nil { return nil } - size, err := c.cacheBlob(info) + size, sc, err := c.cacheBlob(info) if err != nil { - return err + utils.ServeError(rw, r, err, sc) + return nil } c.serveCachedBlob(rw, r, info.Blobs, info, t, size, start) return nil @@ -215,7 +217,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t if err != nil { c.logger.Warn("error response", "remoteAddr", r.RemoteAddr, "error", err) c.blobCache.PutError(info.Blobs, err) - errcode.ServeJSON(rw, err) + utils.ServeError(rw, r, err, 0) return } } @@ -237,7 +239,7 @@ func sleepDuration(ctx context.Context, size, limit float64, start time.Time) er return nil } -func (c *Agent) cacheBlob(info *BlobInfo) (int64, error) { +func (c *Agent) cacheBlob(info *BlobInfo) (int64, int, error) { ctx := context.Background() u := &url.URL{ Scheme: "https", @@ -247,13 +249,15 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, error) { forwardReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { c.logger.Warn("failed to new request", "url", u.String(), "error", err) - return 0, err + return 0, 0, err } + forwardReq.Header.Set("Accept", "*/*") + resp, err := c.httpClient.Do(forwardReq) if err != nil { c.logger.Warn("failed to request", "url", u.String(), "error", err) - return 0, errcode.ErrorCodeUnknown + return 0, 0, errcode.ErrorCodeUnknown } defer func() { resp.Body.Close() @@ -261,14 +265,44 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, error) { switch resp.StatusCode { case http.StatusUnauthorized, http.StatusForbidden: - return 0, errcode.ErrorCodeDenied + return 0, 0, errcode.ErrorCodeDenied + } + + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + c.logger.Error("upstream denied", "statusCode", resp.StatusCode, "url", u.String()) + return 0, 0, errcode.ErrorCodeDenied + } + if resp.StatusCode < http.StatusOK || + (resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest) { + c.logger.Error("upstream unkown code", "statusCode", resp.StatusCode, "url", u.String()) + return 0, 0, errcode.ErrorCodeUnknown } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return 0, errcode.ErrorCodeUnknown.WithMessage(fmt.Sprintf("source response code %d: %s", resp.StatusCode, u.String())) + if resp.StatusCode >= http.StatusBadRequest { + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) + if err != nil { + c.logger.Error("failed to get body", "statusCode", resp.StatusCode, "url", u.String(), "error", err) + return 0, 0, errcode.ErrorCodeUnknown + } + if !json.Valid(body) { + c.logger.Error("invalid body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body)) + return 0, 0, errcode.ErrorCodeDenied + } + var retErrs errcode.Errors + err = retErrs.UnmarshalJSON(body) + if err != nil { + c.logger.Error("failed to unmarshal body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body)) + return 0, 0, errcode.ErrorCodeUnknown + } + return 0, resp.StatusCode, retErrs } - return c.cache.PutBlob(ctx, info.Blobs, resp.Body) + size, err := c.cache.PutBlob(ctx, info.Blobs, resp.Body) + if err != nil { + return 0, 0, err + } + return size, 0, nil } func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64, start time.Time) { @@ -290,7 +324,7 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st if err != nil { c.logger.Info("failed to get blob", "digest", blob, "error", err) c.blobCache.Remove(info.Blobs) - errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) return } defer data.Close() @@ -312,7 +346,7 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st if err != nil { c.logger.Info("failed to redirect blob", "digest", blob, "error", err) c.blobCache.Remove(info.Blobs) - errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) return } diff --git a/gateway/gateway.go b/gateway/gateway.go index 8a9e750..92f7784 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -173,12 +173,12 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } if r.Method != http.MethodGet && r.Method != http.MethodHead { - errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + utils.ServeError(rw, r, errcode.ErrorCodeUnsupported, 0) return } if oriPath == catalog { - errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + utils.ServeError(rw, r, errcode.ErrorCodeUnsupported, 0) return } @@ -209,9 +209,9 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } if t.Block { if t.BlockMessage != "" { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage)) + utils.ServeError(rw, r, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage), 0) } else { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0) } return } @@ -219,7 +219,7 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) { info, ok := parseOriginPathInfo(oriPath) if !ok { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0) return } @@ -238,14 +238,14 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } if info.Host == "" { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0) return } if r.URL.RawQuery != "" { q := r.URL.Query() if ns := q.Get("ns"); ns != "" && ns != info.Host { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0) return } } @@ -286,7 +286,7 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf path, err := info.Path() if err != nil { c.logger.Warn("failed to get path", "error", err) - errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) return } u := &url.URL{ @@ -297,14 +297,14 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf forwardReq, err := http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil) if err != nil { c.logger.Warn("failed to new request", "error", err) - errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) return } resp, err := c.httpClient.Do(forwardReq) if err != nil { c.logger.Warn("failed to request", "host", info.Host, "image", info.Image, "error", err) - errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) return } defer func() { @@ -314,7 +314,7 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf switch resp.StatusCode { case http.StatusUnauthorized, http.StatusForbidden: c.logger.Warn("origin direct response 40x", "host", info.Host, "image", info.Image, "response", dumpResponse(resp)) - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) + utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0) return } diff --git a/gateway/manifest.go b/gateway/manifest.go index 8b1bb35..8ef7b69 100644 --- a/gateway/manifest.go +++ b/gateway/manifest.go @@ -9,6 +9,7 @@ import ( "net/url" "strconv" + "github.com/daocloud/crproxy/internal/utils" "github.com/docker/distribution/registry/api/errcode" ) @@ -39,9 +40,11 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, if ctx.Err() != nil { return errcode.ErrorCodeUnknown } - err := c.cacheManifest(info) + sc, err := c.cacheManifest(info) if err != nil { - return err + c.manifestCache.PutError(info, err, sc) + utils.ServeError(rw, r, err, sc) + return nil } if ctx.Err() != nil { return errcode.ErrorCodeUnknown @@ -57,12 +60,13 @@ func (c *Gateway) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, return } - c.serveError(rw, r, info, err) + c.manifestCache.PutError(info, err, 0) + utils.ServeError(rw, r, err, 0) return } } -func (c *Gateway) cacheManifest(info *PathInfo) error { +func (c *Gateway) cacheManifest(info *PathInfo) (int, error) { ctx := context.Background() u := &url.URL{ Scheme: "https", @@ -73,30 +77,30 @@ func (c *Gateway) cacheManifest(info *PathInfo) error { if !info.IsDigestManifests { forwardReq, err := http.NewRequestWithContext(ctx, http.MethodHead, u.String(), nil) if err != nil { - return err + return 0, err } // Never trust a client's Accept !!! forwardReq.Header.Set("Accept", c.acceptsStr) resp, err := c.httpClient.Do(forwardReq) if err != nil { - return err + return 0, err } if resp.Body != nil { resp.Body.Close() } switch resp.StatusCode { case http.StatusUnauthorized, http.StatusForbidden: - return errcode.ErrorCodeDenied + return 0, errcode.ErrorCodeDenied } if resp.StatusCode < http.StatusOK || (resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest) { - return errcode.ErrorCodeUnknown + return 0, errcode.ErrorCodeUnknown } digest := resp.Header.Get("Docker-Content-Digest") if digest == "" { - return errcode.ErrorCodeDenied + return 0, errcode.ErrorCodeDenied } err = c.cache.RelinkManifest(ctx, info.Host, info.Image, info.Manifests, digest) @@ -104,14 +108,14 @@ func (c *Gateway) cacheManifest(info *PathInfo) error { c.logger.Warn("failed relink manifest", "url", u.String(), "error", err) } else { c.logger.Info("relink manifest", "url", u.String()) - return nil + return 0, nil } u.Path = fmt.Sprintf("/v2/%s/manifests/%s", info.Image, digest) } forwardReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { - return err + return 0, err } // Never trust a client's Accept !!! forwardReq.Header.Set("Accept", c.acceptsStr) @@ -119,7 +123,7 @@ func (c *Gateway) cacheManifest(info *PathInfo) error { resp, err := c.httpClient.Do(forwardReq) if err != nil { c.logger.Warn("failed to request", "url", u.String(), "error", err) - return errcode.ErrorCodeUnknown + return 0, errcode.ErrorCodeUnknown } defer func() { resp.Body.Close() @@ -128,22 +132,22 @@ func (c *Gateway) cacheManifest(info *PathInfo) error { switch resp.StatusCode { case http.StatusUnauthorized, http.StatusForbidden: c.logger.Error("upstream denied", "statusCode", resp.StatusCode, "url", u.String(), "response", dumpResponse(resp)) - return errcode.ErrorCodeDenied + return 0, errcode.ErrorCodeDenied } if resp.StatusCode < http.StatusOK || (resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest) { c.logger.Error("upstream unkown code", "statusCode", resp.StatusCode, "url", u.String(), "response", dumpResponse(resp)) - return errcode.ErrorCodeUnknown + return 0, errcode.ErrorCodeUnknown } body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) if err != nil { c.logger.Error("failed to get body", "statusCode", resp.StatusCode, "url", u.String(), "error", err) - return errcode.ErrorCodeUnknown + return 0, errcode.ErrorCodeUnknown } if !json.Valid(body) { c.logger.Error("invalid body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body)) - return errcode.ErrorCodeDenied + return 0, errcode.ErrorCodeDenied } if resp.StatusCode >= http.StatusBadRequest { @@ -151,17 +155,16 @@ func (c *Gateway) cacheManifest(info *PathInfo) error { err = retErrs.UnmarshalJSON(body) if err != nil { c.logger.Error("failed to unmarshal body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body)) - return errcode.ErrorCodeUnknown + return 0, errcode.ErrorCodeUnknown } - err = append(errcode.Errors{errcode.ErrorCode(resp.StatusCode)}, retErrs...) - return err + return resp.StatusCode, retErrs } _, _, err = c.cache.PutManifestContent(ctx, info.Host, info.Image, info.Manifests, body) if err != nil { - return err + return 0, err } - return nil + return 0, nil } func (c *Gateway) tryFirstServeCachedManifest(rw http.ResponseWriter, r *http.Request, info *PathInfo) (done bool, fallback bool) { @@ -173,7 +176,7 @@ func (c *Gateway) tryFirstServeCachedManifest(rw http.ResponseWriter, r *http.Re return false, true } if val.Error != nil { - errcode.ServeJSON(rw, val.Error) + utils.ServeError(rw, r, val.Error, val.StatusCode) return true, false } @@ -221,36 +224,3 @@ func (c *Gateway) serveCachedManifest(rw http.ResponseWriter, r *http.Request, i return true } - -func (c *Gateway) serveError(rw http.ResponseWriter, r *http.Request, info *PathInfo, err error) error { - rw.Header().Set("Content-Type", "application/json; charset=utf-8") - var sc int - - switch errs := err.(type) { - case errcode.Errors: - if len(errs) < 1 { - break - } - - if err, ok := errs[0].(errcode.ErrorCoder); ok { - sc = err.ErrorCode().Descriptor().HTTPStatusCode - } - case errcode.ErrorCoder: - sc = errs.ErrorCode().Descriptor().HTTPStatusCode - err = errcode.Errors{err} // create an envelope. - default: - err = errcode.Errors{err} - } - - if sc == 0 { - sc = http.StatusInternalServerError - } - - rw.WriteHeader(sc) - - c.manifestCache.PutError(info, err) - - c.logger.Warn("error response", "remoteAddr", r.RemoteAddr, "error", err.Error()) - - return json.NewEncoder(rw).Encode(err) -} diff --git a/gateway/manifest_cache.go b/gateway/manifest_cache.go index ae86959..3c52b20 100644 --- a/gateway/manifest_cache.go +++ b/gateway/manifest_cache.go @@ -51,7 +51,8 @@ func (m *manifestCache) Get(info *PathInfo) (cacheValue, bool) { } if val.Error != nil { return cacheValue{ - Error: val.Error, + Error: val.Error, + StatusCode: val.StatusCode, }, true } key.Tag = val.Digest @@ -62,7 +63,8 @@ func (m *manifestCache) Get(info *PathInfo) (cacheValue, bool) { } if val.Error != nil { return cacheValue{ - Error: val.Error, + Error: val.Error, + StatusCode: val.StatusCode, }, true } @@ -73,15 +75,17 @@ func (m *manifestCache) Get(info *PathInfo) (cacheValue, bool) { }, true } -func (m *manifestCache) PutError(info *PathInfo, err error) { +func (m *manifestCache) PutError(info *PathInfo, err error, sc int) { key := manifestCacheKey(info) if !info.IsDigestManifests { m.tag.SetWithTTL(key, cacheTagValue{ - Error: err, + Error: err, + StatusCode: sc, }, m.duration) } else { m.digest.SetWithTTL(key, cacheDigestValue{ - Error: err, + Error: err, + StatusCode: sc, }, m.duration) } } @@ -112,21 +116,24 @@ type cacheKey struct { } type cacheTagValue struct { - Digest string - Error error + Digest string + Error error + StatusCode int } type cacheDigestValue struct { - MediaType string - Length string - Error error + MediaType string + Length string + Error error + StatusCode int } type cacheValue struct { - Digest string - MediaType string - Length string - Error error + Digest string + MediaType string + Length string + Error error + StatusCode int } func manifestCacheKey(info *PathInfo) cacheKey { diff --git a/gateway/manifest_cache_test.go b/gateway/manifest_cache_test.go index a22de1f..08d99ba 100644 --- a/gateway/manifest_cache_test.go +++ b/gateway/manifest_cache_test.go @@ -21,7 +21,7 @@ func TestManifestCache(t *testing.T) { } err := errors.New("test error") - cache.PutError(info, err) + cache.PutError(info, err, 0) retrievedVal, ok = cache.Get(info) if !ok || retrievedVal.Error == nil { t.Error("Expected an error to be returned") diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 5533c78..170a378 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,9 +1,12 @@ package utils import ( + "encoding/json" "fmt" "net" "net/http" + + "github.com/docker/distribution/registry/api/errcode" ) func ResponseAPIBase(w http.ResponseWriter, r *http.Request) { @@ -30,3 +33,35 @@ func GetIP(str string) string { } return str } + +func ServeError(rw http.ResponseWriter, r *http.Request, err error, sc int) error { + rw.Header().Set("Content-Type", "application/json; charset=utf-8") + + switch errs := err.(type) { + case errcode.Errors: + if len(errs) < 1 { + break + } + + if err, ok := errs[0].(errcode.ErrorCoder); ok { + sc = err.ErrorCode().Descriptor().HTTPStatusCode + } + case errcode.ErrorCoder: + sc = errs.ErrorCode().Descriptor().HTTPStatusCode + err = errcode.Errors{err} // create an envelope. + default: + err = errcode.Errors{err} + } + + if sc == 0 { + if r.Context().Err() != nil { + sc = 499 + } else { + sc = http.StatusInternalServerError + } + } + + rw.WriteHeader(sc) + + return json.NewEncoder(rw).Encode(err) +}