Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit ded7187

Browse files
committed
feature/internal/grpc: retry: vendor go-grpc-middleware testing/testpb package
1 parent b2e550c commit ded7187

9 files changed

+1758
-0
lines changed

Diff for: internal/grpc/retry/testpb/BUILD.bazel

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
load("@io_bazel_rules_go//go:def.bzl", "go_library")
2+
load("//dev:go_defs.bzl", "go_test")
3+
4+
go_library(
5+
name = "testpb",
6+
srcs = [
7+
"interceptor_suite.go",
8+
"pingservice.go",
9+
"test.manual_validator.pb.go",
10+
"test.pb.go",
11+
"test_grpc.pb.go",
12+
],
13+
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc/retry/testpb",
14+
visibility = ["//:__subpackages__"],
15+
deps = [
16+
"@com_github_stretchr_testify//require",
17+
"@com_github_stretchr_testify//suite",
18+
"@org_golang_google_grpc//:go_default_library",
19+
"@org_golang_google_grpc//codes",
20+
"@org_golang_google_grpc//credentials",
21+
"@org_golang_google_grpc//credentials/insecure",
22+
"@org_golang_google_grpc//status",
23+
"@org_golang_google_protobuf//reflect/protoreflect",
24+
"@org_golang_google_protobuf//runtime/protoimpl",
25+
],
26+
)
27+
28+
go_test(
29+
name = "testpb_test",
30+
srcs = ["pingservice_test.go"],
31+
embed = [":testpb"],
32+
deps = [
33+
"@com_github_stretchr_testify//require",
34+
"@org_golang_google_grpc//:go_default_library",
35+
"@org_golang_google_grpc//codes",
36+
"@org_golang_google_grpc//credentials/insecure",
37+
"@org_golang_google_grpc//status",
38+
],
39+
)

Diff for: internal/grpc/retry/testpb/interceptor_suite.go

+233
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
package testpb
5+
6+
import (
7+
"context"
8+
"crypto/rand"
9+
"crypto/rsa"
10+
"crypto/tls"
11+
"crypto/x509"
12+
"crypto/x509/pkix"
13+
"encoding/pem"
14+
"flag"
15+
"math/big"
16+
"net"
17+
"sync"
18+
"time"
19+
20+
"github.com/stretchr/testify/require"
21+
"github.com/stretchr/testify/suite"
22+
"google.golang.org/grpc"
23+
"google.golang.org/grpc/credentials"
24+
"google.golang.org/grpc/credentials/insecure"
25+
)
26+
27+
var (
28+
flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls")
29+
30+
certPEM []byte
31+
keyPEM []byte
32+
)
33+
34+
// InterceptorTestSuite is a testify/Suite that starts a gRPC PingService server and a client.
35+
type InterceptorTestSuite struct {
36+
suite.Suite
37+
38+
TestService TestServiceServer
39+
ServerOpts []grpc.ServerOption
40+
ClientOpts []grpc.DialOption
41+
42+
serverAddr string
43+
ServerListener net.Listener
44+
Server *grpc.Server
45+
clientConn *grpc.ClientConn
46+
Client TestServiceClient
47+
48+
restartServerWithDelayedStart chan time.Duration
49+
serverRunning chan bool
50+
51+
cancels []context.CancelFunc
52+
}
53+
54+
func (s *InterceptorTestSuite) SetupSuite() {
55+
s.restartServerWithDelayedStart = make(chan time.Duration)
56+
s.serverRunning = make(chan bool)
57+
58+
s.serverAddr = "127.0.0.1:0"
59+
var err error
60+
certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"}) // CI:LOCALHOST_OK
61+
require.NoError(s.T(), err, "unable to generate test certificate/key")
62+
63+
go func() {
64+
for {
65+
var err error
66+
s.ServerListener, err = net.Listen("tcp", s.serverAddr)
67+
s.serverAddr = s.ServerListener.Addr().String()
68+
require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
69+
if *flagTls {
70+
cert, err := tls.X509KeyPair(certPEM, keyPEM)
71+
require.NoError(s.T(), err, "unable to load test TLS certificate")
72+
creds := credentials.NewServerTLSFromCert(&cert)
73+
s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds))
74+
}
75+
// This is the point where we hook up the interceptor.
76+
s.Server = grpc.NewServer(s.ServerOpts...)
77+
// Create a service if the instantiator hasn't provided one.
78+
if s.TestService == nil {
79+
s.TestService = &TestPingService{}
80+
}
81+
RegisterTestServiceServer(s.Server, s.TestService)
82+
83+
w := sync.WaitGroup{}
84+
w.Add(1)
85+
go func() {
86+
_ = s.Server.Serve(s.ServerListener)
87+
w.Done()
88+
}()
89+
if s.Client == nil {
90+
s.Client = s.NewClient(s.ClientOpts...)
91+
}
92+
93+
s.serverRunning <- true
94+
95+
d := <-s.restartServerWithDelayedStart
96+
s.Server.Stop()
97+
time.Sleep(d)
98+
w.Wait()
99+
}
100+
}()
101+
102+
select {
103+
case <-s.serverRunning:
104+
case <-time.After(2 * time.Second):
105+
s.T().Fatal("server failed to start before deadline")
106+
}
107+
}
108+
109+
func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool {
110+
s.restartServerWithDelayedStart <- delayedStart
111+
time.Sleep(10 * time.Millisecond)
112+
return s.serverRunning
113+
}
114+
115+
func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) TestServiceClient {
116+
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it.
117+
newDialOpts := append(dialOpts, grpc.WithBlock())
118+
var err error
119+
if *flagTls {
120+
cp := x509.NewCertPool()
121+
if !cp.AppendCertsFromPEM(certPEM) {
122+
s.T().Fatal("failed to append certificate")
123+
}
124+
creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp}) // CI:LOCALHOST_OK
125+
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds))
126+
} else {
127+
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
128+
}
129+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
130+
defer cancel()
131+
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it.
132+
s.clientConn, err = grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...)
133+
require.NoError(s.T(), err, "must not error on client Dial")
134+
return NewTestServiceClient(s.clientConn)
135+
}
136+
137+
func (s *InterceptorTestSuite) ServerAddr() string {
138+
return s.serverAddr
139+
}
140+
141+
type ctxTestNumber struct{}
142+
143+
var (
144+
ctxTestNumberKey = &ctxTestNumber{}
145+
zero = 0
146+
)
147+
148+
func ExtractCtxTestNumber(ctx context.Context) *int {
149+
if v, ok := ctx.Value(ctxTestNumberKey).(*int); ok {
150+
return v
151+
}
152+
return &zero
153+
}
154+
155+
// UnaryServerInterceptor returns a new unary server interceptors that adds query information logging.
156+
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
157+
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
158+
// newCtx := newContext(ctx, log, opts)
159+
newCtx := ctx
160+
resp, err := handler(newCtx, req)
161+
return resp, err
162+
}
163+
}
164+
165+
func (s *InterceptorTestSuite) SimpleCtx() context.Context {
166+
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
167+
ctx = context.WithValue(ctx, ctxTestNumberKey, 1)
168+
s.cancels = append(s.cancels, cancel)
169+
return ctx
170+
}
171+
172+
func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context {
173+
ctx, cancel := context.WithDeadline(context.TODO(), deadline)
174+
s.cancels = append(s.cancels, cancel)
175+
return ctx
176+
}
177+
178+
func (s *InterceptorTestSuite) TearDownSuite() {
179+
time.Sleep(10 * time.Millisecond)
180+
if s.ServerListener != nil {
181+
s.Server.GracefulStop()
182+
s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr())
183+
_ = s.ServerListener.Close()
184+
}
185+
if s.clientConn != nil {
186+
_ = s.clientConn.Close()
187+
}
188+
for _, c := range s.cancels {
189+
c()
190+
}
191+
}
192+
193+
// generateCertAndKey copied from https://github.com/johanbrandhorst/certify/blob/master/issuers/vault/vault_suite_test.go#L255
194+
// with minor modifications.
195+
func generateCertAndKey(san []string) ([]byte, []byte, error) {
196+
priv, err := rsa.GenerateKey(rand.Reader, 2048)
197+
if err != nil {
198+
return nil, nil, err
199+
}
200+
notBefore := time.Now()
201+
notAfter := notBefore.Add(time.Hour)
202+
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
203+
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
204+
if err != nil {
205+
return nil, nil, err
206+
}
207+
template := x509.Certificate{
208+
SerialNumber: serialNumber,
209+
Subject: pkix.Name{
210+
CommonName: "example.com",
211+
},
212+
NotBefore: notBefore,
213+
NotAfter: notAfter,
214+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
215+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
216+
BasicConstraintsValid: true,
217+
DNSNames: san,
218+
}
219+
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
220+
if err != nil {
221+
return nil, nil, err
222+
}
223+
certOut := pem.EncodeToMemory(&pem.Block{
224+
Type: "CERTIFICATE",
225+
Bytes: derBytes,
226+
})
227+
keyOut := pem.EncodeToMemory(&pem.Block{
228+
Type: "RSA PRIVATE KEY",
229+
Bytes: x509.MarshalPKCS1PrivateKey(priv),
230+
})
231+
232+
return certOut, keyOut, nil
233+
}

Diff for: internal/grpc/retry/testpb/pingservice.go

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
/*
5+
Package `grpc_testing` provides helper functions for testing validators in this package.
6+
*/
7+
8+
package testpb
9+
10+
import (
11+
"context"
12+
"io"
13+
14+
"google.golang.org/grpc/codes"
15+
"google.golang.org/grpc/status"
16+
)
17+
18+
const (
19+
// ListResponseCount is the expected number of responses to PingList
20+
ListResponseCount = 100
21+
)
22+
23+
var TestServiceFullName = _TestService_serviceDesc.ServiceName
24+
25+
// Interface implementation assert.
26+
var _ TestServiceServer = &TestPingService{}
27+
28+
type TestPingService struct {
29+
UnimplementedTestServiceServer
30+
}
31+
32+
func (s *TestPingService) PingEmpty(_ context.Context, _ *PingEmptyRequest) (*PingEmptyResponse, error) {
33+
return &PingEmptyResponse{}, nil
34+
}
35+
36+
func (s *TestPingService) Ping(ctx context.Context, ping *PingRequest) (*PingResponse, error) {
37+
// Modify the ctx value to verify the logger sees the value updated from the initial value
38+
n := ExtractCtxTestNumber(ctx)
39+
if n != nil {
40+
*n = 42
41+
}
42+
// Send user trailers and headers.
43+
return &PingResponse{Value: ping.Value, Counter: 0}, nil
44+
}
45+
46+
func (s *TestPingService) PingError(_ context.Context, ping *PingErrorRequest) (*PingErrorResponse, error) {
47+
code := codes.Code(ping.ErrorCodeReturned)
48+
return nil, status.Error(code, "Userspace error")
49+
}
50+
51+
func (s *TestPingService) PingList(ping *PingListRequest, stream TestService_PingListServer) error {
52+
if ping.ErrorCodeReturned != 0 {
53+
return status.Error(codes.Code(ping.ErrorCodeReturned), "foobar")
54+
}
55+
56+
// Send user trailers and headers.
57+
for i := 0; i < ListResponseCount; i++ {
58+
if err := stream.Send(&PingListResponse{Value: ping.Value, Counter: int32(i)}); err != nil {
59+
return err
60+
}
61+
}
62+
return nil
63+
}
64+
65+
func (s *TestPingService) PingStream(stream TestService_PingStreamServer) error {
66+
count := 0
67+
for {
68+
ping, err := stream.Recv()
69+
if err == io.EOF {
70+
break
71+
}
72+
if err != nil {
73+
return err
74+
}
75+
if err := stream.Send(&PingStreamResponse{Value: ping.Value, Counter: int32(count)}); err != nil {
76+
return err
77+
}
78+
79+
count += 1
80+
}
81+
return nil
82+
}

0 commit comments

Comments
 (0)