Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions internal/pkg/auth/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@ const (
)

const (
SESSION_EXPIRES_AT_UNIX authFieldKey = "session_expires_at_unix"
ACCESS_TOKEN authFieldKey = "access_token"
REFRESH_TOKEN authFieldKey = "refresh_token"
SERVICE_ACCOUNT_TOKEN authFieldKey = "service_account_token"
SERVICE_ACCOUNT_EMAIL authFieldKey = "service_account_email"
USER_EMAIL authFieldKey = "user_email"
SERVICE_ACCOUNT_KEY authFieldKey = "service_account_key"
PRIVATE_KEY authFieldKey = "private_key"
TOKEN_CUSTOM_ENDPOINT authFieldKey = "token_custom_endpoint"
IDP_TOKEN_ENDPOINT authFieldKey = "idp_token_endpoint" //nolint:gosec // linter false positive
SESSION_EXPIRES_AT_UNIX authFieldKey = "session_expires_at_unix"
ACCESS_TOKEN authFieldKey = "access_token"
REFRESH_TOKEN authFieldKey = "refresh_token"
SERVICE_ACCOUNT_TOKEN authFieldKey = "service_account_token"
SERVICE_ACCOUNT_EMAIL authFieldKey = "service_account_email"
USER_EMAIL authFieldKey = "user_email"
SERVICE_ACCOUNT_KEY authFieldKey = "service_account_key"
PRIVATE_KEY authFieldKey = "private_key"
TOKEN_CUSTOM_ENDPOINT authFieldKey = "token_custom_endpoint"
IDP_TOKEN_ENDPOINT authFieldKey = "idp_token_endpoint" //nolint:gosec // linter false positive
CACHE_ENCRYPTION_KEY authFieldKey = "cache_encryption_key"
CACHE_ENCRYPTION_KEY_AGE authFieldKey = "cache_encryption_key_age"
)

const (
Expand All @@ -59,6 +61,8 @@ var authFieldKeys = []authFieldKey{
TOKEN_CUSTOM_ENDPOINT,
IDP_TOKEN_ENDPOINT,
authFlowType,
CACHE_ENCRYPTION_KEY,
CACHE_ENCRYPTION_KEY_AGE,
}

// All fields that are set when a user logs in
Expand Down
120 changes: 114 additions & 6 deletions internal/pkg/cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,87 @@
package cache

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"
"time"

"github.com/stackitcloud/stackit-cli/internal/pkg/auth"
)

var (
cacheFolderPath string
cacheDirOverwrite string // for testing only
cacheFolderPath string
cacheEncryptionKey []byte

identifierRegex = regexp.MustCompile("^[a-zA-Z0-9-]+$")
ErrorInvalidCacheIdentifier = fmt.Errorf("invalid cache identifier")
)

const (
cacheKeyMaxAge = 90 * 24 * time.Hour
)

func Init() error {
cacheDir, err := os.UserCacheDir()
if err != nil {
return fmt.Errorf("get user cache dir: %w", err)
var cacheDir string
if cacheDirOverwrite == "" {
var err error
cacheDir, err = os.UserCacheDir()
if err != nil {
return fmt.Errorf("get user cache dir: %w", err)
}
} else {
cacheDir = cacheDirOverwrite
}

cacheFolderPath = filepath.Join(cacheDir, "stackit")

// Encryption keys should only be used a limited number of times for aes-gcm.
// Thus, refresh the key periodically. This will invalidate all cached entries.
key, _ := auth.GetAuthField(auth.CACHE_ENCRYPTION_KEY)
age, _ := auth.GetAuthField(auth.CACHE_ENCRYPTION_KEY_AGE)
cacheEncryptionKey = nil
var keyAge time.Time
if age != "" {
ageSeconds, err := strconv.ParseInt(age, 10, 64)
if err == nil {
keyAge = time.Unix(ageSeconds, 0)
}
}
if key != "" && keyAge.Add(cacheKeyMaxAge).After(time.Now()) {
cacheEncryptionKey, _ = base64.StdEncoding.DecodeString(key)
// invalid key length
if len(cacheEncryptionKey) != 32 {
cacheEncryptionKey = nil
}
}
if len(cacheEncryptionKey) == 0 {
cacheEncryptionKey = make([]byte, 32)
_, err := rand.Read(cacheEncryptionKey)
if err != nil {
return fmt.Errorf("cache encryption key: %w", err)
}
key := base64.StdEncoding.EncodeToString(cacheEncryptionKey)
err = auth.SetAuthField(auth.CACHE_ENCRYPTION_KEY, key)
if err != nil {
return fmt.Errorf("save cache encryption key: %w", err)
}
err = auth.SetAuthField(auth.CACHE_ENCRYPTION_KEY_AGE, fmt.Sprint(time.Now().Unix()))
if err != nil {
return fmt.Errorf("save cache encryption key age: %w", err)
}
// cleanup old cache entries as they won't be readable anymore
if err := cleanupCache(); err != nil {
return err
}
}
return nil
}

Expand All @@ -32,7 +93,21 @@ func GetObject(identifier string) ([]byte, error) {
return nil, ErrorInvalidCacheIdentifier
}

return os.ReadFile(filepath.Join(cacheFolderPath, identifier))
data, err := os.ReadFile(filepath.Join(cacheFolderPath, identifier))
if err != nil {
return nil, err
}

block, err := aes.NewCipher(cacheEncryptionKey)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return nil, err
}

return aead.Open(nil, nil, data, nil)
}

func PutObject(identifier string, data []byte) error {
Expand All @@ -48,7 +123,17 @@ func PutObject(identifier string, data []byte) error {
return err
}

return os.WriteFile(filepath.Join(cacheFolderPath, identifier), data, 0o600)
block, err := aes.NewCipher(cacheEncryptionKey)
if err != nil {
return err
}
aead, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return err
}
encrypted := aead.Seal(nil, nil, data, nil)

return os.WriteFile(filepath.Join(cacheFolderPath, identifier), encrypted, 0o600)
}

func DeleteObject(identifier string) error {
Expand All @@ -71,3 +156,26 @@ func validateCacheFolderPath() error {
}
return nil
}

func cleanupCache() error {
if err := validateCacheFolderPath(); err != nil {
return err
}

entries, err := os.ReadDir(cacheFolderPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}

for _, entry := range entries {
name := entry.Name()
err := DeleteObject(name)
if err != nil && !errors.Is(err, ErrorInvalidCacheIdentifier) {
return err
}
}
return nil
}
Loading
Loading