Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for header in rule matching #967

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
406 changes: 189 additions & 217 deletions .schema/version.schema.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion api/decision.go
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ func (h *DecisionHandler) decisions(w http.ResponseWriter, r *http.Request) {
fields["subject"] = sess.Subject
}

rl, err := h.r.RuleMatcher().Match(r.Context(), r.Method, r.URL, rule.ProtocolHTTP)
rl, err := h.r.RuleMatcher().Match(r.Context(), r.Method, r.URL, r.Header, rule.ProtocolHTTP)
if err != nil {
h.r.Logger().WithError(err).
WithFields(fields).
2 changes: 1 addition & 1 deletion api/decision_test.go
Original file line number Diff line number Diff line change
@@ -352,7 +352,7 @@ func (*decisionHandlerRegistryMock) Logger() *logrusx.Logger {
return logrusx.New("", "")
}

func (m *decisionHandlerRegistryMock) Match(ctx context.Context, method string, u *url.URL, _ rule.Protocol) (*rule.Rule, error) {
func (m *decisionHandlerRegistryMock) Match(ctx context.Context, method string, u *url.URL, _ http.Header, _ rule.Protocol) (*rule.Rule, error) {
args := m.Called(ctx, method, u)
return args.Get(0).(*rule.Rule), args.Error(1)
}
4 changes: 2 additions & 2 deletions middleware/grpc_middleware.go
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ func (m *middleware) unaryInterceptor(ctx context.Context, req interface{}, info

log.Debug("matching HTTP request build from gRPC")

r, err := m.RuleMatcher().Match(traceCtx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC)
r, err := m.RuleMatcher().Match(traceCtx, httpReq.Method, httpReq.URL, httpReq.Header, rule.ProtocolGRPC)
if err != nil {
log.WithError(err).Warn("could not find a matching rule")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
@@ -138,7 +138,7 @@ func (m *middleware) streamInterceptor(

log.Debug("matching HTTP request build from gRPC")

r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC)
r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, httpReq.Header, rule.ProtocolGRPC)
if err != nil {
log.WithError(err).Warn("could not find a matching rule")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
2 changes: 1 addition & 1 deletion proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ func (d *Proxy) RoundTrip(r *http.Request) (*http.Response, error) {

func (d *Proxy) Rewrite(r *httputil.ProxyRequest) {
EnrichRequestedURL(r)
rl, err := d.r.RuleMatcher().Match(r.Out.Context(), r.Out.Method, r.Out.URL, rule.ProtocolHTTP)
rl, err := d.r.RuleMatcher().Match(r.Out.Context(), r.Out.Method, r.Out.URL, r.Out.Header, rule.ProtocolHTTP)
if err != nil {
*r.Out = *r.Out.WithContext(context.WithValue(r.Out.Context(), director, err))
return
3 changes: 2 additions & 1 deletion rule/matcher.go
Original file line number Diff line number Diff line change
@@ -5,14 +5,15 @@ package rule

import (
"context"
"net/http"
"net/url"
)

type (
Protocol int

Matcher interface {
Match(ctx context.Context, method string, u *url.URL, protocol Protocol) (*Rule, error)
Match(ctx context.Context, method string, u *url.URL, headers http.Header, protocol Protocol) (*Rule, error)
}
)

88 changes: 61 additions & 27 deletions rule/matcher_test.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ package rule
import (
"context"
"fmt"
"net/http"
"net/url"
"testing"

@@ -49,6 +50,15 @@ var testRules = []Rule{
Mutators: []Handler{{Handler: "id_token"}},
Upstream: Upstream{URL: "http://localhost:3333/", StripPath: "/foo", PreserveHost: false},
},
{
ID: "foo4",
Match: &Match{URL: "https://localhost:343/<baz|bar>", Methods: []string{"PATCH"}, Headers: http.Header{"Content-Type": {"application/some-app.v2+json"}}},
Description: "Patch users rule for version 2",
Authorizer: Handler{Handler: "deny"},
Authenticators: []Handler{{Handler: "oauth2_introspection"}},
Mutators: []Handler{{Handler: "id_token"}},
Upstream: Upstream{URL: "http://localhost:3333/", StripPath: "/foo", PreserveHost: false},
},
{
ID: "grpc1",
Match: &MatchGRPC{Authority: "<baz|bar>.example.com", FullMethod: "grpc.api/Call"},
@@ -88,6 +98,15 @@ var testRulesGlob = []Rule{
Mutators: []Handler{{Handler: "id_token"}},
Upstream: Upstream{URL: "http://localhost:3333/", StripPath: "/foo", PreserveHost: false},
},
{
ID: "foo4",
Match: &Match{URL: "https://localhost:343/<{baz*,bar*}>", Methods: []string{"PATCH"}, Headers: http.Header{"Content-Type": {"application/some-app.v2+json"}}},
Description: "Patch users rule with version 2",
Authorizer: Handler{Handler: "deny"},
Authenticators: []Handler{{Handler: "oauth2_introspection"}},
Mutators: []Handler{{Handler: "id_token"}},
Upstream: Upstream{URL: "http://localhost:3333/", StripPath: "/foo", PreserveHost: false},
},
{
ID: "grpc1",
Match: &MatchGRPC{Authority: "<{baz*,bar*}>.example.com", FullMethod: "grpc.api/Call"},
@@ -97,6 +116,15 @@ var testRulesGlob = []Rule{
Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}},
Upstream: Upstream{URL: "http://bar.example.com/", PreserveHost: false},
},
{
ID: "grpc2",
Match: &MatchGRPC{Authority: "<{baz*,bar*}>.example.com", FullMethod: "grpc.api/CallWithHeader", Headers: http.Header{"Content-Type": {"application/some-app.v2+json"}}},
Description: "gRPC Rule with version 2",
Authorizer: Handler{Handler: "allow", Config: []byte(`{"type":"any"}`)},
Authenticators: []Handler{{Handler: "anonymous", Config: []byte(`{"name":"anonymous1"}`)}},
Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}},
Upstream: Upstream{URL: "http://bar.example.com/", PreserveHost: false},
},
}

func TestMatcher(t *testing.T) {
@@ -105,8 +133,8 @@ func TestMatcher(t *testing.T) {
Repository
}

var testMatcher = func(t *testing.T, matcher Matcher, method string, url string, protocol Protocol, expectErr bool, expect *Rule) {
r, err := matcher.Match(context.Background(), method, mustParseURL(t, url), protocol)
var testMatcher = func(t *testing.T, matcher Matcher, method string, url string, headers http.Header, protocol Protocol, expectErr bool, expect *Rule) {
r, err := matcher.Match(context.Background(), method, mustParseURL(t, url), headers, protocol)
if expectErr {
require.Error(t, err)
} else {
@@ -121,63 +149,67 @@ func TestMatcher(t *testing.T) {
} {
t.Run(fmt.Sprintf("regexp matcher=%s", name), func(t *testing.T) {
t.Run("case=empty", func(t *testing.T) {
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
})

require.NoError(t, matcher.Set(context.Background(), testRules))

t.Run("case=created", func(t *testing.T) {
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, false, &testRules[1])
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolGRPC, true, nil)
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, false, &testRules[0])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolGRPC, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/Call", ProtocolGRPC, false, &testRules[3])
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolHTTP, false, &testRules[1])
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolGRPC, true, nil)
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, false, &testRules[0])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolGRPC, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/Call", http.Header{}, ProtocolGRPC, false, &testRules[4])
})

t.Run("case=cache", func(t *testing.T) {
r, err := matcher.Match(context.Background(), "GET", mustParseURL(t, "https://localhost:34/baz"), ProtocolHTTP)
r, err := matcher.Match(context.Background(), "GET", mustParseURL(t, "https://localhost:34/baz"), http.Header{}, ProtocolHTTP)
require.NoError(t, err)
got, err := matcher.Get(context.Background(), r.ID)
require.NoError(t, err)
assert.NotEmpty(t, got.matchingEngine.Checksum())
})

t.Run("case=nil url", func(t *testing.T) {
_, err := matcher.Match(context.Background(), "GET", nil, ProtocolHTTP)
_, err := matcher.Match(context.Background(), "GET", nil, http.Header{}, ProtocolHTTP)
require.Error(t, err)
})

require.NoError(t, matcher.Set(context.Background(), testRules[1:]))

t.Run("case=updated", func(t *testing.T) {
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, false, &testRules[1])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolHTTP, false, &testRules[1])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "PATCH", "https://localhost:343/bar", http.Header{"Content-Type": {"application/some-app.v1+json"}}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "PATCH", "https://localhost:343/bar", http.Header{"Content-Type": {"application/some-app.v2+json"}}, ProtocolHTTP, false, &testRules[3])
})
})
t.Run(fmt.Sprintf("glob matcher=%s", name), func(t *testing.T) {
require.NoError(t, matcher.SetMatchingStrategy(context.Background(), configuration.Glob))
require.NoError(t, matcher.Set(context.Background(), []Rule{}))
t.Run("case=empty", func(t *testing.T) {
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
})

require.NoError(t, matcher.Set(context.Background(), testRulesGlob))

t.Run("case=created", func(t *testing.T) {
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, false, &testRulesGlob[1])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, false, &testRulesGlob[0])
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/Call", ProtocolGRPC, false, &testRulesGlob[3])
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolHTTP, false, &testRulesGlob[1])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, false, &testRulesGlob[0])
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/Call", http.Header{}, ProtocolGRPC, false, &testRulesGlob[4])
testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/CallWithHeader", http.Header{"Content-Type": []string{"application/some-app.v1+json"}}, ProtocolGRPC, true, nil)
testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/CallWithHeader", http.Header{"Content-Type": []string{"application/some-app.v2+json"}}, ProtocolGRPC, false, &testRulesGlob[5])
})

t.Run("case=cache", func(t *testing.T) {
r, err := matcher.Match(context.Background(), "GET", mustParseURL(t, "https://localhost:34/baz"), ProtocolHTTP)
r, err := matcher.Match(context.Background(), "GET", mustParseURL(t, "https://localhost:34/baz"), http.Header{}, ProtocolHTTP)
require.NoError(t, err)
got, err := matcher.Get(context.Background(), r.ID)
require.NoError(t, err)
@@ -187,9 +219,11 @@ func TestMatcher(t *testing.T) {
require.NoError(t, matcher.Set(context.Background(), testRulesGlob[1:]))

t.Run("case=updated", func(t *testing.T) {
testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, false, &testRulesGlob[1])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil)
testMatcher(t, matcher, "GET", "https://localhost:34/baz", http.Header{}, ProtocolHTTP, false, &testRulesGlob[1])
testMatcher(t, matcher, "POST", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", http.Header{}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "PATCH", "https://localhost:343/bar", http.Header{"Content-Type": []string{"application/some-app.v1+json"}}, ProtocolHTTP, true, nil)
testMatcher(t, matcher, "PATCH", "https://localhost:343/bar", http.Header{"Content-Type": []string{"application/some-app.v2+json"}}, ProtocolHTTP, false, &testRulesGlob[3])
})
})
}
6 changes: 3 additions & 3 deletions rule/repository_memory.go
Original file line number Diff line number Diff line change
@@ -110,7 +110,7 @@ func (m *RepositoryMemory) Set(ctx context.Context, rules []Rule) error {
return nil
}

func (m *RepositoryMemory) Match(ctx context.Context, method string, u *url.URL, protocol Protocol) (*Rule, error) {
func (m *RepositoryMemory) Match(ctx context.Context, method string, u *url.URL, headers http.Header, protocol Protocol) (*Rule, error) {
if u == nil {
return nil, errors.WithStack(errors.New("nil URL provided"))
}
@@ -121,15 +121,15 @@ func (m *RepositoryMemory) Match(ctx context.Context, method string, u *url.URL,
var rules []*Rule
for k := range m.rules {
r := &m.rules[k]
if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil {
if matched, err := r.IsMatching(m.matchingStrategy, method, u, headers, protocol); err != nil {
return nil, errors.WithStack(err)
} else if matched {
rules = append(rules, r)
}
}
for k := range m.invalidRules {
r := &m.invalidRules[k]
if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil {
if matched, err := r.IsMatching(m.matchingStrategy, method, u, headers, protocol); err != nil {
return nil, errors.WithStack(err)
} else if matched {
rules = append(rules, r)
Loading