diff --git a/internal/pkg/auth/storage.go b/internal/pkg/auth/storage.go index 5e857f6a7..686a0f677 100644 --- a/internal/pkg/auth/storage.go +++ b/internal/pkg/auth/storage.go @@ -27,16 +27,18 @@ const ( ) const ( - SESSION_EXPIRES_AT_UNIX authFieldKey = "session_expires_at_unix" - ACCESS_TOKEN authFieldKey = "access_token" - REFRESH_TOKEN authFieldKey = "refresh_token" - SERVICE_ACCOUNT_TOKEN authFieldKey = "service_account_token" - SERVICE_ACCOUNT_EMAIL authFieldKey = "service_account_email" - USER_EMAIL authFieldKey = "user_email" - SERVICE_ACCOUNT_KEY authFieldKey = "service_account_key" - PRIVATE_KEY authFieldKey = "private_key" - TOKEN_CUSTOM_ENDPOINT authFieldKey = "token_custom_endpoint" - IDP_TOKEN_ENDPOINT authFieldKey = "idp_token_endpoint" //nolint:gosec // linter false positive + SESSION_EXPIRES_AT_UNIX authFieldKey = "session_expires_at_unix" + ACCESS_TOKEN authFieldKey = "access_token" + REFRESH_TOKEN authFieldKey = "refresh_token" + SERVICE_ACCOUNT_TOKEN authFieldKey = "service_account_token" + SERVICE_ACCOUNT_EMAIL authFieldKey = "service_account_email" + USER_EMAIL authFieldKey = "user_email" + SERVICE_ACCOUNT_KEY authFieldKey = "service_account_key" + PRIVATE_KEY authFieldKey = "private_key" + TOKEN_CUSTOM_ENDPOINT authFieldKey = "token_custom_endpoint" + IDP_TOKEN_ENDPOINT authFieldKey = "idp_token_endpoint" //nolint:gosec // linter false positive + CACHE_ENCRYPTION_KEY authFieldKey = "cache_encryption_key" + CACHE_ENCRYPTION_KEY_AGE authFieldKey = "cache_encryption_key_age" ) const ( @@ -59,6 +61,8 @@ var authFieldKeys = []authFieldKey{ TOKEN_CUSTOM_ENDPOINT, IDP_TOKEN_ENDPOINT, authFlowType, + CACHE_ENCRYPTION_KEY, + CACHE_ENCRYPTION_KEY_AGE, } // All fields that are set when a user logs in diff --git a/internal/pkg/cache/cache.go b/internal/pkg/cache/cache.go index cf019ecb2..beaf87d12 100644 --- a/internal/pkg/cache/cache.go +++ b/internal/pkg/cache/cache.go @@ -1,26 +1,87 @@ package cache import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" "errors" "fmt" "os" "path/filepath" "regexp" + "strconv" + "time" + + "github.com/stackitcloud/stackit-cli/internal/pkg/auth" ) var ( - cacheFolderPath string + cacheDirOverwrite string // for testing only + cacheFolderPath string + cacheEncryptionKey []byte identifierRegex = regexp.MustCompile("^[a-zA-Z0-9-]+$") ErrorInvalidCacheIdentifier = fmt.Errorf("invalid cache identifier") ) +const ( + cacheKeyMaxAge = 90 * 24 * time.Hour +) + func Init() error { - cacheDir, err := os.UserCacheDir() - if err != nil { - return fmt.Errorf("get user cache dir: %w", err) + var cacheDir string + if cacheDirOverwrite == "" { + var err error + cacheDir, err = os.UserCacheDir() + if err != nil { + return fmt.Errorf("get user cache dir: %w", err) + } + } else { + cacheDir = cacheDirOverwrite } + cacheFolderPath = filepath.Join(cacheDir, "stackit") + + // Encryption keys should only be used a limited number of times for aes-gcm. + // Thus, refresh the key periodically. This will invalidate all cached entries. + key, _ := auth.GetAuthField(auth.CACHE_ENCRYPTION_KEY) + age, _ := auth.GetAuthField(auth.CACHE_ENCRYPTION_KEY_AGE) + cacheEncryptionKey = nil + var keyAge time.Time + if age != "" { + ageSeconds, err := strconv.ParseInt(age, 10, 64) + if err == nil { + keyAge = time.Unix(ageSeconds, 0) + } + } + if key != "" && keyAge.Add(cacheKeyMaxAge).After(time.Now()) { + cacheEncryptionKey, _ = base64.StdEncoding.DecodeString(key) + // invalid key length + if len(cacheEncryptionKey) != 32 { + cacheEncryptionKey = nil + } + } + if len(cacheEncryptionKey) == 0 { + cacheEncryptionKey = make([]byte, 32) + _, err := rand.Read(cacheEncryptionKey) + if err != nil { + return fmt.Errorf("cache encryption key: %w", err) + } + key := base64.StdEncoding.EncodeToString(cacheEncryptionKey) + err = auth.SetAuthField(auth.CACHE_ENCRYPTION_KEY, key) + if err != nil { + return fmt.Errorf("save cache encryption key: %w", err) + } + err = auth.SetAuthField(auth.CACHE_ENCRYPTION_KEY_AGE, fmt.Sprint(time.Now().Unix())) + if err != nil { + return fmt.Errorf("save cache encryption key age: %w", err) + } + // cleanup old cache entries as they won't be readable anymore + if err := cleanupCache(); err != nil { + return err + } + } return nil } @@ -32,7 +93,21 @@ func GetObject(identifier string) ([]byte, error) { return nil, ErrorInvalidCacheIdentifier } - return os.ReadFile(filepath.Join(cacheFolderPath, identifier)) + data, err := os.ReadFile(filepath.Join(cacheFolderPath, identifier)) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(cacheEncryptionKey) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCMWithRandomNonce(block) + if err != nil { + return nil, err + } + + return aead.Open(nil, nil, data, nil) } func PutObject(identifier string, data []byte) error { @@ -48,7 +123,17 @@ func PutObject(identifier string, data []byte) error { return err } - return os.WriteFile(filepath.Join(cacheFolderPath, identifier), data, 0o600) + block, err := aes.NewCipher(cacheEncryptionKey) + if err != nil { + return err + } + aead, err := cipher.NewGCMWithRandomNonce(block) + if err != nil { + return err + } + encrypted := aead.Seal(nil, nil, data, nil) + + return os.WriteFile(filepath.Join(cacheFolderPath, identifier), encrypted, 0o600) } func DeleteObject(identifier string) error { @@ -71,3 +156,26 @@ func validateCacheFolderPath() error { } return nil } + +func cleanupCache() error { + if err := validateCacheFolderPath(); err != nil { + return err + } + + entries, err := os.ReadDir(cacheFolderPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + for _, entry := range entries { + name := entry.Name() + err := DeleteObject(name) + if err != nil && !errors.Is(err, ErrorInvalidCacheIdentifier) { + return err + } + } + return nil +} diff --git a/internal/pkg/cache/cache_test.go b/internal/pkg/cache/cache_test.go index cc68c6590..4ef45891b 100644 --- a/internal/pkg/cache/cache_test.go +++ b/internal/pkg/cache/cache_test.go @@ -6,10 +6,20 @@ import ( "path/filepath" "testing" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" + "github.com/stackitcloud/stackit-cli/internal/pkg/auth" ) -func TestGetObject(t *testing.T) { +func overwriteCacheDir(t *testing.T) func() { + cacheDirOverwrite = t.TempDir() + return func() { + cacheDirOverwrite = "" + } +} + +func TestGetObjectErrors(t *testing.T) { + defer overwriteCacheDir(t)() if err := Init(); err != nil { t.Fatalf("cache init failed: %s", err) } @@ -17,25 +27,16 @@ func TestGetObject(t *testing.T) { tests := []struct { description string identifier string - expectFile bool expectedErr error }{ - { - description: "identifier exists", - identifier: "test-cache-get-exists", - expectFile: true, - expectedErr: nil, - }, { description: "identifier does not exist", identifier: "test-cache-get-not-exists", - expectFile: false, expectedErr: os.ErrNotExist, }, { description: "identifier is invalid", identifier: "in../../valid", - expectFile: false, expectedErr: ErrorInvalidCacheIdentifier, }, } @@ -44,17 +45,6 @@ func TestGetObject(t *testing.T) { t.Run(tt.description, func(t *testing.T) { id := tt.identifier + "-" + uuid.NewString() - // setup - if tt.expectFile { - err := os.MkdirAll(cacheFolderPath, 0o750) - if err != nil { - t.Fatalf("create cache folder: %s", err.Error()) - } - path := filepath.Join(cacheFolderPath, id) - if err := os.WriteFile(path, []byte("dummy"), 0o600); err != nil { - t.Fatalf("setup: WriteFile (%s) failed", path) - } - } // test file, err := GetObject(id) @@ -62,19 +52,14 @@ func TestGetObject(t *testing.T) { t.Fatalf("returned error (%q) does not match %q", err.Error(), tt.expectedErr.Error()) } - if tt.expectFile { - if len(file) < 1 { - t.Fatalf("expected a file but byte array is empty (len %d)", len(file)) - } - } else { - if len(file) > 0 { - t.Fatalf("didn't expect a file, but byte array is not empty (len %d)", len(file)) - } + if len(file) > 0 { + t.Fatalf("didn't expect a file, but byte array is not empty (len %d)", len(file)) } }) } } func TestPutObject(t *testing.T) { + defer overwriteCacheDir(t)() if err := Init(); err != nil { t.Fatalf("cache init failed: %s", err) } @@ -128,6 +113,10 @@ func TestPutObject(t *testing.T) { // setup if tt.existingFile { + err := os.MkdirAll(cacheFolderPath, 0o750) + if err != nil { + t.Fatalf("create cache folder: %s", err.Error()) + } if err := os.WriteFile(path, []byte("dummy"), 0o600); err != nil { t.Fatalf("setup: WriteFile (%s) failed", path) } @@ -149,6 +138,7 @@ func TestPutObject(t *testing.T) { } func TestDeleteObject(t *testing.T) { + defer overwriteCacheDir(t)() if err := Init(); err != nil { t.Fatalf("cache init failed: %s", err) } @@ -186,8 +176,11 @@ func TestDeleteObject(t *testing.T) { // setup if tt.existingFile { + if err := os.MkdirAll(cacheFolderPath, 0o700); err != nil { + t.Fatalf("setup: MkdirAll (%s) failed: %v", path, err) + } if err := os.WriteFile(path, []byte("dummy"), 0o600); err != nil { - t.Fatalf("setup: WriteFile (%s) failed", path) + t.Fatalf("setup: WriteFile (%s) failed: %v", path, err) } } // test @@ -205,3 +198,90 @@ func TestDeleteObject(t *testing.T) { }) } } + +func clearKeys(t *testing.T) { + t.Helper() + err := auth.DeleteAuthField(auth.CACHE_ENCRYPTION_KEY) + if err != nil { + t.Fatalf("delete cache encryption key: %v", err) + } + err = auth.DeleteAuthField(auth.CACHE_ENCRYPTION_KEY_AGE) + if err != nil { + t.Fatalf("delete cache encryption key age: %v", err) + } +} + +func TestWriteAndRead(t *testing.T) { + for _, tt := range []struct { + name string + clearKeys bool + }{ + { + name: "normal", + }, + { + name: "fresh keys", + clearKeys: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + defer overwriteCacheDir(t)() + if tt.clearKeys { + clearKeys(t) + } + if err := Init(); err != nil { + t.Fatalf("cache init failed: %s", err) + } + + id := "test-cycle-" + uuid.NewString() + data := []byte("test-data") + err := PutObject(id, data) + if err != nil { + t.Fatalf("putobject failed: %v", err) + } + + readData, err := GetObject(id) + if err != nil { + t.Fatalf("getobject failed: %v", err) + } + + diff := cmp.Diff(data, readData) + if diff != "" { + t.Fatalf("unexpected data diff: %v", diff) + } + }) + } +} + +func TestCacheCleanup(t *testing.T) { + defer overwriteCacheDir(t)() + if err := Init(); err != nil { + t.Fatalf("cache init failed: %s", err) + } + + id := "test-cycle-" + uuid.NewString() + data := []byte("test-data") + err := PutObject(id, data) + if err != nil { + t.Fatalf("putobject failed: %v", err) + } + + clearKeys(t) + + // initialize again to trigger cache cleanup + if err := Init(); err != nil { + t.Fatalf("cache init failed: %s", err) + } + + _, err = GetObject(id) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("getobject failed with unexpected error: %v", err) + } +} + +func TestInit(t *testing.T) { + // test that init without cache directory overwrite works + if err := Init(); err != nil { + t.Fatalf("cache init failed: %s", err) + } +}