Skip to content

Commit

Permalink
Enrich Dex logs with real IP and request ID (#3661)
Browse files Browse the repository at this point in the history
Signed-off-by: m.nabokikh <[email protected]>
Signed-off-by: Maksim Nabokikh <[email protected]>
Co-authored-by: Márk Sági-Kazár <[email protected]>
  • Loading branch information
nabokihms and sagikazarmark authored Aug 1, 2024
1 parent 6ceb265 commit 2256607
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 187 deletions.
42 changes: 33 additions & 9 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"net/netip"
"os"
"slices"
"strings"
Expand Down Expand Up @@ -182,15 +183,38 @@ type OAuth2 struct {

// Web is the config format for the HTTP server.
type Web struct {
HTTP string `json:"http"`
HTTPS string `json:"https"`
Headers Headers `json:"headers"`
TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"`
TLSMinVersion string `json:"tlsMinVersion"`
TLSMaxVersion string `json:"tlsMaxVersion"`
AllowedOrigins []string `json:"allowedOrigins"`
AllowedHeaders []string `json:"allowedHeaders"`
HTTP string `json:"http"`
HTTPS string `json:"https"`
Headers Headers `json:"headers"`
TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"`
TLSMinVersion string `json:"tlsMinVersion"`
TLSMaxVersion string `json:"tlsMaxVersion"`
AllowedOrigins []string `json:"allowedOrigins"`
AllowedHeaders []string `json:"allowedHeaders"`
ClientRemoteIP ClientRemoteIP `json:"clientRemoteIP"`
}

type ClientRemoteIP struct {
Header string `json:"header"`
TrustedProxies []string `json:"trustedProxies"`
}

func (cr *ClientRemoteIP) ParseTrustedProxies() ([]netip.Prefix, error) {
if cr == nil {
return nil, nil
}

trusted := make([]netip.Prefix, 0, len(cr.TrustedProxies))
for _, cidr := range cr.TrustedProxies {
ipNet, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("failed to parse CIDR %q: %v", cidr, err)
}
trusted = append(trusted, ipNet)
}

return trusted, nil
}

type Headers struct {
Expand Down
67 changes: 67 additions & 0 deletions cmd/dex/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package main

import (
"context"
"fmt"
"log/slog"
"os"
"strings"

"github.com/dexidp/dex/server"
)

var logFormats = []string{"json", "text"}

func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var handler slog.Handler
switch strings.ToLower(format) {
case "", "text":
handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json":
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}

return slog.New(newRequestContextHandler(handler)), nil
}

var _ slog.Handler = requestContextHandler{}

type requestContextHandler struct {
handler slog.Handler
}

func newRequestContextHandler(handler slog.Handler) slog.Handler {
return requestContextHandler{
handler: handler,
}
}

func (h requestContextHandler) Enabled(ctx context.Context, level slog.Level) bool {
return h.handler.Enabled(ctx, level)
}

func (h requestContextHandler) Handle(ctx context.Context, record slog.Record) error {
if v, ok := ctx.Value(server.RequestKeyRemoteIP).(string); ok {
record.AddAttrs(slog.String(string(server.RequestKeyRemoteIP), v))
}

if v, ok := ctx.Value(server.RequestKeyRequestID).(string); ok {
record.AddAttrs(slog.String(string(server.RequestKeyRequestID), v))
}

return h.handler.Handle(ctx, record)
}

func (h requestContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return requestContextHandler{h.handler.WithAttrs(attrs)}
}

func (h requestContextHandler) WithGroup(name string) slog.Handler {
return h.handler.WithGroup(name)
}
27 changes: 7 additions & 20 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ func runServe(options serveOptions) error {
}

serverConfig.RefreshTokenPolicy = refreshTokenPolicy

serverConfig.RealIPHeader = c.Web.ClientRemoteIP.Header
serverConfig.TrustedRealIPCIDRs, err = c.Web.ClientRemoteIP.ParseTrustedProxies()
if err != nil {
return fmt.Errorf("failed to parse client remote IP settings: %v", err)
}

serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("failed to initialize server: %v", err)
Expand Down Expand Up @@ -528,26 +535,6 @@ func runServe(options serveOptions) error {
return nil
}

var logFormats = []string{"json", "text"}

func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var handler slog.Handler
switch strings.ToLower(format) {
case "", "text":
handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json":
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}

return slog.New(handler), nil
}

func applyConfigOverrides(options serveOptions, config *Config) {
if options.webHTTPAddr != "" {
config.Web.HTTP = options.webHTTPAddr
Expand Down
5 changes: 4 additions & 1 deletion examples/config-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ web:
# X-XSS-Protection: "1; mode=block"
# Content-Security-Policy: "default-src 'self'"
# Strict-Transport-Security: "max-age=31536000; includeSubDomains"

# clientRemoteIP:
# header: X-Forwarded-For
# trustedProxies:
# - 10.0.0.0/8

# Configuration for dex appearance
# frontend:
Expand Down
38 changes: 19 additions & 19 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
invalidAttempt = false
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
s.logger.Error("server template error", "err", err)
s.logger.ErrorContext(r.Context(), "server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
Expand All @@ -64,7 +64,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Error("could not parse Device Request body", "err", err)
s.logger.ErrorContext(r.Context(), "could not parse Device Request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}
Expand All @@ -85,7 +85,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
return
}

s.logger.Info("received device request", "client_id", clientID, "scoped", scopes)
s.logger.InfoContext(r.Context(), "received device request", "client_id", clientID, "scoped", scopes)

// Make device code
deviceCode := storage.NewDeviceCode()
Expand All @@ -107,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Error("failed to store device request", "err", err)
s.logger.ErrorContext(r.Context(), "failed to store device request", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
Expand All @@ -126,14 +126,14 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Error("failed to store device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to store device token", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}

u, err := url.Parse(s.issuerURL.String())
if err != nil {
s.logger.Error("could not parse issuer URL", "err", err)
s.logger.ErrorContext(r.Context(), "could not parse issuer URL", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Error("failed to get device code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
return
Expand Down Expand Up @@ -241,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
s.logger.Error("failed to update device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
Expand All @@ -258,7 +258,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
case providedCodeVerifier != "" && codeChallengeFromStorage != "":
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod)
if err != nil {
s.logger.Error("failed to calculate code challenge", "err", err)
s.logger.ErrorContext(r.Context(), "failed to calculate code challenge", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get auth code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired auth code.")
Expand All @@ -315,7 +315,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired user code.")
Expand All @@ -325,7 +325,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
client, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Error("failed to get client", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get client", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
Expand All @@ -339,7 +339,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.logger.Error("could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.logger.ErrorContext(r.Context(), "could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return
}
Expand All @@ -349,7 +349,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device token", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired device code.")
Expand All @@ -362,7 +362,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
respStr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
s.logger.Error("failed to marshal device token response", "err", err)
s.logger.ErrorContext(r.Context(), "failed to marshal device token response", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return old, err
}
Expand All @@ -374,13 +374,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Error("failed to update device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
}

if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
s.logger.Error("Server template error", "err", err)
s.logger.ErrorContext(r.Context(), "Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}

Expand Down Expand Up @@ -412,10 +412,10 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device request", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err)
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
s.logger.Error("Server template error", "err", err)
s.logger.ErrorContext(r.Context(), "Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
return
Expand Down
Loading

0 comments on commit 2256607

Please sign in to comment.