diff --git a/generate/go_client.tpl b/generate/go_client.tpl index 3f076aeb..cbdd14fb 100644 --- a/generate/go_client.tpl +++ b/generate/go_client.tpl @@ -2,8 +2,6 @@ package client import ( - "sync" - "connectrpc.com/connect" compress "github.com/klauspost/connect-compress/v2" @@ -22,8 +20,6 @@ type ( config *DialConfig interceptors []connect.Interceptor - - sync.Mutex } {{ range $name, $api := . -}} {{ $name | title }} interface { @@ -55,6 +51,11 @@ func New(config *DialConfig) (Client, error) { if config.Token != "" { authInterceptor := &authInterceptor{config: config} c.interceptors = append(c.interceptors, authInterceptor) + + if config.TokenRenewal != nil { + tokenRenewingInterceptor := &tokenRenewingInterceptor{config: config, client: c} + c.interceptors = append(c.interceptors, tokenRenewingInterceptor) + } } if config.Log != nil { loggingInterceptor := &loggingInterceptor{config: config} @@ -62,9 +63,6 @@ func New(config *DialConfig) (Client, error) { } c.interceptors = append(c.interceptors, config.Interceptors...) - // TODO convert to interceptor - go c.startTokenRenewal() - return c, nil } diff --git a/go.mod b/go.mod index 9ae4eb59..20f0009f 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( connectrpc.com/connect v1.19.1 github.com/bufbuild/protocompile v0.14.1 github.com/go-task/slim-sprig/v3 v3.0.0 - github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/go-cmp v0.7.0 github.com/klauspost/connect-compress/v2 v2.1.1 github.com/stretchr/testify v1.11.1 @@ -19,17 +19,16 @@ require ( cel.dev/expr v0.25.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/cel-go v0.26.1 // indirect + github.com/google/cel-go v0.27.0 // indirect github.com/klauspost/compress v1.18.3 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/minio/minlz v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stoewer/go-strcase v1.3.1 // indirect github.com/stretchr/objx v0.5.3 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect golang.org/x/text v0.33.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260122232226-8e98ce8d340d // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fb58cd0c..2b4a53bb 100644 --- a/go.sum +++ b/go.sum @@ -13,15 +13,14 @@ github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7 github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= -github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= -github.com/google/cel-go v0.26.1 h1:iPbVVEdkhTX++hpe3lzSk7D3G3QSYqLGoHOcEio+UXQ= -github.com/google/cel-go v0.26.1/go.mod h1:A9O8OU9rdvrK5MQyrqfIxo1a0u4g3sF8KB6PUIaryMM= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/cel-go v0.27.0 h1:e7ih85+4qVrBuqQWTW4FKSqZYokVuc3HnhH5keboFTo= +github.com/google/cel-go v0.27.0/go.mod h1:tTJ11FWqnhw5KKpnWpvW9CJC3Y9GK4EIS0WXnBbebzw= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= @@ -44,31 +43,24 @@ github.com/rodaine/protogofakeit v0.1.1 h1:ZKouljuRM3A+TArppfBqnH8tGZHOwM/pjvtXe github.com/rodaine/protogofakeit v0.1.1/go.mod h1:pXn/AstBYMaSfc1/RqH3N82pBuxtWgejz1AlYpY1mI0= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/stoewer/go-strcase v1.3.1 h1:iS0MdW+kVTxgMoE1LAZyMiYJFKlOzLooE4MxjirtkAs= -github.com/stoewer/go-strcase v1.3.1/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -google.golang.org/genproto/googleapis/api v0.0.0-20260122232226-8e98ce8d340d h1:tUKoKfdZnSjTf5LW7xpG4c6SZ3Ozisn5eumcoTuMEN4= -google.golang.org/genproto/googleapis/api v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:p3MLuOwURrGBRoEyFHBT3GjUwaCQVKeNqqWxlcISGdw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d h1:xXzuihhT3gL/ntduUZwHECzAn57E8dA6l8SOtYWdD8Q= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/client/client-interceptors.go b/go/client/client-interceptors.go index 0d9c678e..71782fe1 100644 --- a/go/client/client-interceptors.go +++ b/go/client/client-interceptors.go @@ -2,8 +2,14 @@ package client import ( "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" "connectrpc.com/connect" + apiv2models "github.com/metal-stack/api/go/metalstack/api/v2" ) // authinterceptor adds the required auth headers @@ -65,3 +71,79 @@ func (i *loggingInterceptor) WrapStreamingClient(next connect.StreamingClientFun func (i *loggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return next } + +type tokenRenewingInterceptor struct { + config *DialConfig + client *client + + renewing atomic.Bool + + sync.Mutex +} + +func (i *tokenRenewingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { + err := i.renewTokenIfNeeded() + if err != nil { + return nil, err + } + return next(ctx, request) + }) +} + +func (i *tokenRenewingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return next +} + +func (i *tokenRenewingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return next +} + +func (i *tokenRenewingInterceptor) renewTokenIfNeeded() error { + if i.config.expiresAt.IsZero() { + return nil + } + if i.renewing.Load() { + return nil + } + if i.config.Log == nil { + i.config.Log = slog.Default() + } + + replaceBefore := i.config.expiresAt.Sub(i.config.issuedAt) / tokenRenewChecksDuringLifetime + + if time.Until(i.config.expiresAt) > replaceBefore { + return nil + } + + i.renewing.Store(true) + defer i.renewing.Store(false) + + i.config.Log.Info("call token refresh, current token expires soon", "expires", i.config.expiresAt.String()) + + i.Lock() + defer i.Unlock() + + resp, err := i.client.Apiv2().Token().Refresh(context.Background(), &apiv2models.TokenServiceRefreshRequest{}) + if err != nil { + return fmt.Errorf("unable to refresh token %w", err) + } + + i.config.Token = resp.Secret + err = i.config.parse() + if err != nil { + return fmt.Errorf("unable to parse token %w", err) + } + + if i.config.TokenRenewal.PersistTokenFn == nil { + return nil + } + + err = i.config.TokenRenewal.PersistTokenFn(i.config.Token) + if err != nil { + return fmt.Errorf("unable to persist token %w", err) + } + + i.config.Log.Info("token refreshed, new token expires in", "expires", i.config.expiresAt.String()) + return nil +} diff --git a/go/client/client.go b/go/client/client.go index acb80069..5c935922 100755 --- a/go/client/client.go +++ b/go/client/client.go @@ -2,8 +2,6 @@ package client import ( - "sync" - "connectrpc.com/connect" compress "github.com/klauspost/connect-compress/v2" @@ -22,8 +20,6 @@ type ( config *DialConfig interceptors []connect.Interceptor - - sync.Mutex } Adminv2 interface { Filesystem() adminv2connect.FilesystemServiceClient @@ -120,6 +116,11 @@ func New(config *DialConfig) (Client, error) { if config.Token != "" { authInterceptor := &authInterceptor{config: config} c.interceptors = append(c.interceptors, authInterceptor) + + if config.TokenRenewal != nil { + tokenRenewingInterceptor := &tokenRenewingInterceptor{config: config, client: c} + c.interceptors = append(c.interceptors, tokenRenewingInterceptor) + } } if config.Log != nil { loggingInterceptor := &loggingInterceptor{config: config} @@ -127,9 +128,6 @@ func New(config *DialConfig) (Client, error) { } c.interceptors = append(c.interceptors, config.Interceptors...) - // TODO convert to interceptor - go c.startTokenRenewal() - return c, nil } diff --git a/go/client/client_test.go b/go/client/client_test.go index af16a355..686439b7 100644 --- a/go/client/client_test.go +++ b/go/client/client_test.go @@ -41,7 +41,7 @@ func Test_Client(t *testing.T) { server.Close() }() - tokenString, err := generateToken(1 * time.Second) + tokenString, err := generateToken(2 * time.Second) require.NoError(t, err) c, err := client.New(&client.DialConfig{ @@ -50,6 +50,7 @@ func Test_Client(t *testing.T) { Transport: server.Client().Transport, TokenRenewal: &client.TokenRenewal{ PersistTokenFn: func(token string) error { + ts.token = token t.Log("token persisted:", token) return nil }, @@ -64,7 +65,7 @@ func Test_Client(t *testing.T) { require.False(t, ts.wasCalled) require.Equal(t, tokenString, vs.token) - time.Sleep(300 * time.Millisecond) + time.Sleep(1 * time.Second) v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) require.NoError(t, err) require.NotNil(t, v) @@ -79,7 +80,7 @@ func Test_Client(t *testing.T) { require.Equal(t, "1.0", v.Version.Version) require.True(t, ts.wasCalled) - require.NotEqual(t, tokenString, vs.token, "token must have changed") + require.NotEqual(t, tokenString, ts.token, "token must have changed") } func generateToken(duration time.Duration) (string, error) { @@ -121,6 +122,7 @@ func (m *mockVersionService) Get(ctx context.Context, req *apiv2.VersionServiceG type mockTokenService struct { wasCalled bool + token string } // Create implements apiv2connect.TokenServiceHandler. diff --git a/go/client/conn.go b/go/client/conn.go index 38a4f2db..3c5aecb4 100644 --- a/go/client/conn.go +++ b/go/client/conn.go @@ -1,7 +1,6 @@ package client import ( - "context" "errors" "fmt" "log/slog" @@ -10,7 +9,6 @@ import ( "connectrpc.com/connect" "github.com/golang-jwt/jwt/v5" - api "github.com/metal-stack/api/go/metalstack/api/v2" ) const tokenRenewChecksDuringLifetime = 4 @@ -84,70 +82,3 @@ func (dc *DialConfig) parse() error { } return nil } - -func (c *client) startTokenRenewal() { - if c.config.TokenRenewal == nil { - return - } - if c.config.expiresAt.IsZero() { - return - } - if c.config.Log == nil { - c.config.Log = slog.Default() - } - - replaceBefore := c.config.expiresAt.Sub(c.config.issuedAt) / tokenRenewChecksDuringLifetime - - err := c.renewTokenIfNeeded(replaceBefore) - if err != nil { - c.config.Log.Error("unable to renew token", "error", err) - } - - ticker := time.NewTicker(replaceBefore) - defer ticker.Stop() - done := make(chan bool) - for { - select { - case <-done: - return - case <-ticker.C: - err := c.renewTokenIfNeeded(replaceBefore) - if err != nil { - c.config.Log.Error("unable to renew token", "error", err) - } - } - } -} - -func (c *client) renewTokenIfNeeded(replaceBefore time.Duration) error { - if time.Until(c.config.expiresAt) > replaceBefore { - return nil - } - c.config.Log.Info("call token refresh, current token expires soon", "expires", c.config.expiresAt.String()) - - c.Lock() - defer c.Unlock() - - resp, err := c.Apiv2().Token().Refresh(context.Background(), &api.TokenServiceRefreshRequest{}) - if err != nil { - return fmt.Errorf("unable to refresh token %w", err) - } - - c.config.Token = resp.Secret - err = c.config.parse() - if err != nil { - return fmt.Errorf("unable to parse token %w", err) - } - - if c.config.TokenRenewal.PersistTokenFn == nil { - return nil - } - - err = c.config.TokenRenewal.PersistTokenFn(c.config.Token) - if err != nil { - return fmt.Errorf("unable to persist token %w", err) - } - - c.config.Log.Info("token refreshed, new token expires in", "expires", c.config.expiresAt.String()) - return nil -}