From 27c8c29a23a900321cf28483aabc03fc0137096c Mon Sep 17 00:00:00 2001 From: Leland Garofalo Date: Tue, 25 Jul 2023 08:46:54 -0700 Subject: [PATCH] Testing retry and timeout for signing ops (#366) * Testing retry and timeout for signing ops * bugfixes * Adjust use of context for retry/timeout --------- Co-authored-by: Leland Garofalo --- cmd/gokeyless/gokeyless.go | 13 ++++++ server/server.go | 81 +++++++++++++++++++++++++++++++++++--- tests/common_test.go | 11 +++++- 3 files changed, 98 insertions(+), 7 deletions(-) diff --git a/cmd/gokeyless/gokeyless.go b/cmd/gokeyless/gokeyless.go index 95e0df28..74fe0daf 100644 --- a/cmd/gokeyless/gokeyless.go +++ b/cmd/gokeyless/gokeyless.go @@ -64,6 +64,9 @@ type Config struct { TracingEnabled bool `yaml:"tracing_enabled" mapstructure:"tracing_enabled"` TracingAddress string `yaml:"tracing_address" mapstructure:"tracing_address"` TracingSampleRate float64 `yaml:"tracing_sample_rate" mapstructure:"tracing_sample_rate"` // between 0 and 1 + + SignTimeout string `yaml:"sign_timeout" mapstructure:"sign_timeout"` + SignRetryCount int `yaml:"sign_retry_count" mapstructure:"sign_retry_count"` } // PrivateKeyStoreConfig defines a key store. @@ -309,6 +312,16 @@ func runMain() error { } cfg := server.DefaultServeConfig() + if config.SignTimeout != "" { + signTimeoutDuration, err := time.ParseDuration(config.SignTimeout) + if err != nil { + log.Fatalf("failed to parse signTimeout: %s", err) + } + cfg = cfg.WithSignTimeout(signTimeoutDuration) + } + if config.SignRetryCount > 0 { + cfg = cfg.WithSignRetryCount(config.SignRetryCount) + } s, err := server.NewServerFromFile(cfg, config.CertFile, config.KeyFile, config.CACertFile) if err != nil { return fmt.Errorf("cannot start server: %w", err) diff --git a/server/server.go b/server/server.go index 040f7200..7878c1e6 100644 --- a/server/server.go +++ b/server/server.go @@ -51,6 +51,9 @@ type Server struct { listeners map[net.Listener]map[net.Conn]struct{} shutdown bool mtx sync.Mutex + + signTimeout time.Duration + signRetryCount int } // NewServer prepares a TLS server capable of receiving connections from keyless clients. @@ -73,6 +76,8 @@ func NewServer(config *ServeConfig, cert tls.Certificate, keylessCA *x509.CertPo dispatcher: rpc.NewServer(), limitedDispatcher: rpc.NewServer(), listeners: make(map[net.Listener]map[net.Conn]struct{}), + signTimeout: config.signTimeout, + signRetryCount: config.signRetryCount, } return s, nil @@ -448,19 +453,56 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { return makeErrResponse(pkt, protocol.ErrKeyNotFound) } - signSpan, _ := opentracing.StartSpanFromContext(ctx, "execute.Sign") + signSpan, ctx := opentracing.StartSpanFromContext(ctx, "execute.Sign") defer signSpan.Finish() var sig []byte - sig, err = key.Sign(rand.Reader, pkt.Operation.Payload, opts) - if err != nil { - tracing.LogError(span, err) - log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err) - return makeErrResponse(pkt, protocol.ErrCrypto) + // By default, we only try the request once, unless retry count is configured + for attempts := 1 + s.signRetryCount; attempts > 0; attempts-- { + var err error + // If signTimeout is not set, the value will be zero + if s.signTimeout == 0 { + sig, err = key.Sign(rand.Reader, pkt.Operation.Payload, opts) + } else { + ch := make(chan signWithTimeoutWrapper, 1) + ctxTimeout, cancel := context.WithTimeout(ctx, s.signTimeout) + defer cancel() + + go signWithTimeout(ctxTimeout, ch, key, rand.Reader, pkt.Operation.Payload, opts) + select { + case <-ctxTimeout.Done(): + sig = nil + err = ctxTimeout.Err() + case result := <-ch: + sig = result.sig + err = result.error + } + } + if err != nil { + if attempts > 1 { + log.Debugf("Connection %v: failed sign attempt: %s, %d attempt(s) left", connName, err, attempts-1) + continue + } else { + tracing.LogError(span, err) + log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err) + return makeErrResponse(pkt, protocol.ErrCrypto) + } + } + break } return makeRespondResponse(pkt, sig) } +type signWithTimeoutWrapper struct { + sig []byte + error error +} + +func signWithTimeout(ctx context.Context, ch chan signWithTimeoutWrapper, key crypto.Signer, rand io.Reader, digest []byte, opts crypto.SignerOpts) { + sig, err := key.Sign(rand, digest, opts) + ch <- signWithTimeoutWrapper{sig, err} +} + func (s *Server) limitedDo(pkt *protocol.Packet, connName string) response { spanCtx, err := tracing.SpanContextFromBinary(pkt.Operation.JaegerSpan) if err != nil { @@ -697,6 +739,8 @@ type ServeConfig struct { tcpTimeout, unixTimeout time.Duration isLimited func(state tls.ConnectionState) (bool, error) customOpFunc CustomOpFunction + signTimeout time.Duration + signRetryCount int } const ( @@ -718,6 +762,8 @@ func DefaultServeConfig() *ServeConfig { unixTimeout: defaultUnixTimeout, maxConnPendingRequests: 1024, isLimited: func(state tls.ConnectionState) (bool, error) { return false, nil }, + signTimeout: 0, + signRetryCount: 0, } } @@ -757,6 +803,29 @@ func (s *ServeConfig) WithIsLimited(f func(state tls.ConnectionState) (bool, err return s } +// WithSignTimeout specifies the sign operation timeout. This timeout is used to enforce a +// max execution time for a single sign operation +func (s *ServeConfig) WithSignTimeout(timeout time.Duration) *ServeConfig { + s.signTimeout = timeout + return s +} + +// SignTimeout returns the sign operation timeout +func (s *ServeConfig) SignTimeout() time.Duration { + return s.signTimeout +} + +// WithSignRetryCount specifics a number of retries to allow for failed sign operations +func (s *ServeConfig) WithSignRetryCount(signRetryCount int) *ServeConfig { + s.signRetryCount = signRetryCount + return s +} + +// SignRetryCount returns the count of retries allowed for sign operations +func (s *ServeConfig) SignRetryCount() int { + return s.signRetryCount +} + // CustomOpFunction is the signature for custom opcode functions. // // If it returns a non-nil error which implements protocol.Error, the server diff --git a/tests/common_test.go b/tests/common_test.go index 3b47133b..bd6f76cc 100644 --- a/tests/common_test.go +++ b/tests/common_test.go @@ -64,6 +64,9 @@ type IntegrationTestSuite struct { ecdsaKey *client.PrivateKey ed25519Key *client.PrivateKey remote client.Remote + + retryCount int + timeout time.Duration } func fixedCurrentTime() time.Time { @@ -148,6 +151,11 @@ func (s *IntegrationTestSuite) NewRemoteSignerByPubKeyFile(filepath string) (cry func TestSuite(t *testing.T) { s := &IntegrationTestSuite{} suite.Run(t, s) + s2 := &IntegrationTestSuite{ + timeout: time.Second, + retryCount: 3, + } + suite.Run(t, s2) } // SetupTest sets up a compatible server and client for use by tests. @@ -160,7 +168,8 @@ func (s *IntegrationTestSuite) SetupTest() { atomic.StoreUint32(&client.TestDisableConnectionPool, 1) var err error - s.server, err = server.NewServerFromFile(nil, serverCert, serverKey, keylessCA) + config := server.DefaultServeConfig().WithSignTimeout(s.timeout).WithSignRetryCount(s.retryCount) + s.server, err = server.NewServerFromFile(config, serverCert, serverKey, keylessCA) require.NoError(err) s.server.TLSConfig().Time = fixedCurrentTime