Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jan 7, 2025
1 parent 155b326 commit 496e75c
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 98 deletions.
68 changes: 51 additions & 17 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package agent

import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
Expand Down Expand Up @@ -131,7 +132,7 @@ func (c *Agent) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

if r.Method != http.MethodGet && r.Method != http.MethodHead {
errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported)
utils.ServeError(rw, r, errcode.ErrorCodeUnsupported, 0)
return
}

Expand All @@ -157,16 +158,16 @@ func (c *Agent) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if c.authenticator != nil {
t, err = c.authenticator.Authorization(r)
if err != nil {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(err.Error()))
utils.ServeError(rw, r, errcode.ErrorCodeDenied.WithMessage(err.Error()), 0)
return
}
}

if t.Block {
if t.BlockMessage != "" {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage))
utils.ServeError(rw, r, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage), 0)
} else {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0)
}
return
}
Expand All @@ -187,7 +188,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
value, ok := c.blobCache.Get(info.Blobs)
if ok {
if value.Error != nil {
errcode.ServeJSON(rw, value.Error)
utils.ServeError(rw, r, value.Error, 0)
return ctx, true
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, value.Size, start)
Expand All @@ -204,9 +205,10 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
if ctx.Err() != nil {
return nil
}
size, err := c.cacheBlob(info)
size, sc, err := c.cacheBlob(info)
if err != nil {
return err
utils.ServeError(rw, r, err, sc)
return nil
}
c.serveCachedBlob(rw, r, info.Blobs, info, t, size, start)
return nil
Expand All @@ -215,7 +217,7 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t
if err != nil {
c.logger.Warn("error response", "remoteAddr", r.RemoteAddr, "error", err)
c.blobCache.PutError(info.Blobs, err)
errcode.ServeJSON(rw, err)
utils.ServeError(rw, r, err, 0)
return
}
}
Expand All @@ -237,7 +239,7 @@ func sleepDuration(ctx context.Context, size, limit float64, start time.Time) er
return nil
}

func (c *Agent) cacheBlob(info *BlobInfo) (int64, error) {
func (c *Agent) cacheBlob(info *BlobInfo) (int64, int, error) {
ctx := context.Background()
u := &url.URL{
Scheme: "https",
Expand All @@ -247,28 +249,60 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, error) {
forwardReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
c.logger.Warn("failed to new request", "url", u.String(), "error", err)
return 0, err
return 0, 0, err
}

forwardReq.Header.Set("Accept", "*/*")

resp, err := c.httpClient.Do(forwardReq)
if err != nil {
c.logger.Warn("failed to request", "url", u.String(), "error", err)
return 0, errcode.ErrorCodeUnknown
return 0, 0, errcode.ErrorCodeUnknown
}
defer func() {
resp.Body.Close()
}()

switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
return 0, errcode.ErrorCodeDenied
return 0, 0, errcode.ErrorCodeDenied
}

switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
c.logger.Error("upstream denied", "statusCode", resp.StatusCode, "url", u.String())
return 0, 0, errcode.ErrorCodeDenied
}
if resp.StatusCode < http.StatusOK ||
(resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest) {
c.logger.Error("upstream unkown code", "statusCode", resp.StatusCode, "url", u.String())
return 0, 0, errcode.ErrorCodeUnknown
}

if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return 0, errcode.ErrorCodeUnknown.WithMessage(fmt.Sprintf("source response code %d: %s", resp.StatusCode, u.String()))
if resp.StatusCode >= http.StatusBadRequest {
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
c.logger.Error("failed to get body", "statusCode", resp.StatusCode, "url", u.String(), "error", err)
return 0, 0, errcode.ErrorCodeUnknown
}
if !json.Valid(body) {
c.logger.Error("invalid body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body))
return 0, 0, errcode.ErrorCodeDenied
}
var retErrs errcode.Errors
err = retErrs.UnmarshalJSON(body)
if err != nil {
c.logger.Error("failed to unmarshal body", "statusCode", resp.StatusCode, "url", u.String(), "body", string(body))
return 0, 0, errcode.ErrorCodeUnknown
}
return 0, resp.StatusCode, retErrs
}

return c.cache.PutBlob(ctx, info.Blobs, resp.Body)
size, err := c.cache.PutBlob(ctx, info.Blobs, resp.Body)
if err != nil {
return 0, 0, err
}
return size, 0, nil
}

func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64, start time.Time) {
Expand All @@ -290,7 +324,7 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st
if err != nil {
c.logger.Info("failed to get blob", "digest", blob, "error", err)
c.blobCache.Remove(info.Blobs)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0)
return
}
defer data.Close()
Expand All @@ -312,7 +346,7 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st
if err != nil {
c.logger.Info("failed to redirect blob", "digest", blob, "error", err)
c.blobCache.Remove(info.Blobs)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0)
return
}

Expand Down
22 changes: 11 additions & 11 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

if r.Method != http.MethodGet && r.Method != http.MethodHead {
errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported)
utils.ServeError(rw, r, errcode.ErrorCodeUnsupported, 0)
return
}

if oriPath == catalog {
errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported)
utils.ServeError(rw, r, errcode.ErrorCodeUnsupported, 0)
return
}

Expand Down Expand Up @@ -209,17 +209,17 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
if t.Block {
if t.BlockMessage != "" {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage))
utils.ServeError(rw, r, errcode.ErrorCodeDenied.WithMessage(t.BlockMessage), 0)
} else {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0)
}
return
}
}

info, ok := parseOriginPathInfo(oriPath)
if !ok {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0)
return
}

Expand All @@ -238,14 +238,14 @@ func (c *Gateway) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

if info.Host == "" {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0)
return
}

if r.URL.RawQuery != "" {
q := r.URL.Query()
if ns := q.Get("ns"); ns != "" && ns != info.Host {
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0)
return
}
}
Expand Down Expand Up @@ -286,7 +286,7 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf
path, err := info.Path()
if err != nil {
c.logger.Warn("failed to get path", "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0)
return
}
u := &url.URL{
Expand All @@ -297,14 +297,14 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf
forwardReq, err := http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil)
if err != nil {
c.logger.Warn("failed to new request", "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0)
return
}

resp, err := c.httpClient.Do(forwardReq)
if err != nil {
c.logger.Warn("failed to request", "host", info.Host, "image", info.Image, "error", err)
errcode.ServeJSON(rw, errcode.ErrorCodeUnknown)
utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0)
return
}
defer func() {
Expand All @@ -314,7 +314,7 @@ func (c *Gateway) forward(rw http.ResponseWriter, r *http.Request, info *PathInf
switch resp.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
c.logger.Warn("origin direct response 40x", "host", info.Host, "image", info.Image, "response", dumpResponse(resp))
errcode.ServeJSON(rw, errcode.ErrorCodeDenied)
utils.ServeError(rw, r, errcode.ErrorCodeDenied, 0)
return
}

Expand Down
Loading

0 comments on commit 496e75c

Please sign in to comment.