Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions internal/validation/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func RunInstanceLifecycleValidation(t *testing.T, config ProviderConfig) {
require.NoError(t, err, "ValidateInstanceImage should pass")
})

t.Run("ValidateFirewallBlocksPort", func(t *testing.T) {
err := v1.ValidateFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), v1.DefaultFirewallTestPort)
require.NoError(t, err, "ValidateFirewallBlocksPort should pass - non-allowed port should be blocked")
})

t.Run("ValidateDockerFirewallBlocksPort", func(t *testing.T) {
err := v1.ValidateDockerFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), v1.DefaultFirewallTestPort)
require.NoError(t, err, "ValidateDockerFirewallBlocksPort should pass - docker port should be blocked by iptables")
})

if capabilities.IsCapable(v1.CapabilityStopStartInstance) && instance.Stoppable {
t.Run("ValidateStopStartInstance", func(t *testing.T) {
err := v1.ValidateStopStartInstance(ctx, client, instance)
Expand Down Expand Up @@ -235,6 +245,101 @@ func RunNetworkValidation(t *testing.T, config ProviderConfig, opts NetworkValid
})
}

type FirewallValidationOpts struct {
// TestPort is the port to test firewall blocking on (should NOT be in allowed ingress)
TestPort int
// TestDockerFirewall enables docker firewall validation (requires Docker on instance)
TestDockerFirewall bool
}

func RunFirewallValidation(t *testing.T, config ProviderConfig, opts FirewallValidationOpts) {
if testing.Short() {
t.Skip("Skipping validation tests in short mode")
}

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
defer cancel()

client, err := config.Credential.MakeClient(ctx, config.Location)
if err != nil {
t.Fatalf("Failed to create client for %s: %v", config.Credential.GetCloudProviderID(), err)
}

types, err := client.GetInstanceTypes(ctx, v1.GetInstanceTypeArgs{
ArchitectureFilter: &v1.ArchitectureFilter{
IncludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64},
},
})
require.NoError(t, err)
require.NotEmpty(t, types, "Should have instance types")

// Find an available instance type
attrs := v1.CreateInstanceAttrs{}
selectedType := v1.InstanceType{}
for _, typ := range types {
if typ.IsAvailable {
attrs.InstanceType = typ.Type
attrs.Location = typ.Location
attrs.PublicKey = ssh.GetTestPublicKey()
selectedType = typ
break
}
}
require.NotEmpty(t, attrs.InstanceType, "Should find available instance type")

// Create instance for firewall testing
instance, err := v1.ValidateCreateInstance(ctx, client, attrs, selectedType)
require.NoError(t, err, "ValidateCreateInstance should pass")
require.NotNil(t, instance)

defer func() {
if instance != nil {
_ = client.TerminateInstance(ctx, instance.CloudID)
}
}()

// Wait for instance to be running and SSH accessible
t.Run("ValidateSSHAccessible", func(t *testing.T) {
err := v1.ValidateInstanceSSHAccessible(ctx, client, instance, ssh.GetTestPrivateKey())
require.NoError(t, err, "ValidateSSHAccessible should pass")
})

// Refresh instance data
instance, err = client.GetInstance(ctx, instance.CloudID)
require.NoError(t, err)

testPort := opts.TestPort
if testPort == 0 {
testPort = v1.DefaultFirewallTestPort
}

// Test that regular server on 0.0.0.0 is blocked
t.Run("ValidateFirewallBlocksPort", func(t *testing.T) {
err := v1.ValidateFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), testPort)
require.NoError(t, err, "ValidateFirewallBlocksPort should pass - port should be blocked")
})

// Test that Docker container on 0.0.0.0 is blocked (if enabled)
if opts.TestDockerFirewall {
t.Run("ValidateDockerFirewallBlocksPort", func(t *testing.T) {
err := v1.ValidateDockerFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), testPort)
require.NoError(t, err, "ValidateDockerFirewallBlocksPort should pass - docker port should be blocked")
})
}

// Test that SSH port is accessible (sanity check)
t.Run("ValidateSSHPortAccessible", func(t *testing.T) {
err := v1.ValidateFirewallAllowsPort(ctx, client, instance, ssh.GetTestPrivateKey(), instance.SSHPort)
require.NoError(t, err, "ValidateFirewallAllowsPort should pass for SSH port")
})

// Terminate instance
t.Run("ValidateTerminateInstance", func(t *testing.T) {
err := v1.ValidateTerminateInstance(ctx, client, instance)
require.NoError(t, err, "ValidateTerminateInstance should pass")
})
}

type KubernetesValidationOpts struct {
Name string
RefID string
Expand Down
175 changes: 0 additions & 175 deletions v1/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@ package v1

import (
"context"
"errors"
"fmt"
"time"

"github.com/alecthomas/units"
"github.com/brevdev/cloud/internal/collections"
"github.com/brevdev/cloud/internal/ssh"
"github.com/google/uuid"
)

type CloudInstanceReader interface {
Expand All @@ -28,112 +24,11 @@ type CloudCreateTerminateInstance interface {
CloudInstanceReader
}

func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs, selectedType InstanceType) (*Instance, error) { //nolint:gocyclo // ok
t0 := time.Now().Add(-time.Minute)
attrs.RefID = uuid.New().String()
name, err := makeDebuggableName(attrs.Name)
if err != nil {
return nil, err
}
attrs.Name = name
i, err := client.CreateInstance(ctx, attrs)
if err != nil {
return nil, err
}
var validationErr error
t1 := time.Now().Add(1 * time.Minute)
diff := t1.Sub(t0)
if diff > 3*time.Minute {
validationErr = errors.Join(validationErr, fmt.Errorf("create instance took too long: %s", diff))
}
if i.CreatedAt.Before(t0) {
validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is before t0: %s", i.CreatedAt))
}
if i.CreatedAt.After(t1) {
validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is after t1: %s", i.CreatedAt))
}
if i.Name != name {
fmt.Printf("name mismatch: %s != %s, input name does not mean return name will be stable\n", i.Name, name)
}
if i.RefID != attrs.RefID {
validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", i.RefID, attrs.RefID))
}
if attrs.Location != "" && attrs.Location != i.Location {
validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", attrs.Location, i.Location))
}
if attrs.SubLocation != "" && attrs.SubLocation != i.SubLocation {
validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", attrs.SubLocation, i.SubLocation))
}
if attrs.InstanceType != "" && attrs.InstanceType != i.InstanceType {
validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", attrs.InstanceType, i.InstanceType))
}
if selectedType.ID != "" && selectedType.ID != i.InstanceTypeID {
validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", selectedType.ID, i.InstanceTypeID))
}

return i, validationErr
}

func ValidateListCreatedInstance(ctx context.Context, client CloudCreateTerminateInstance, i *Instance) error {
ins, err := client.ListInstances(ctx, ListInstancesArgs{
Locations: []string{i.Location},
})
if err != nil {
return err
}
var validationErr error
if len(ins) == 0 {
validationErr = errors.Join(validationErr, fmt.Errorf("no instances found"))
}
foundInstance := collections.Find(ins, func(inst Instance) bool {
return inst.CloudID == i.CloudID
})
if foundInstance == nil {
validationErr = errors.Join(validationErr, fmt.Errorf("instance not found: %s", i.CloudID))
return validationErr
}
if foundInstance.Location != i.Location { //nolint:gocritic // fine
validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", foundInstance.Location, i.Location))
} else if foundInstance.RefID == "" {
validationErr = errors.Join(validationErr, fmt.Errorf("refID is empty"))
} else if foundInstance.RefID != i.RefID {
validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", foundInstance.RefID, i.RefID))
} else if foundInstance.CloudCredRefID == "" {
validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID is empty"))
} else if foundInstance.CloudCredRefID != i.CloudCredRefID {
validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", foundInstance.CloudCredRefID, i.CloudCredRefID))
}
return validationErr
}

func ValidateTerminateInstance(ctx context.Context, client CloudCreateTerminateInstance, instance *Instance) error {
err := client.TerminateInstance(ctx, instance.CloudID)
if err != nil {
return err
}
// TODO wait for instance to go into terminating state
return nil
}

type CloudStopStartInstance interface {
StopInstance(ctx context.Context, instanceID CloudProviderInstanceID) error
StartInstance(ctx context.Context, instanceID CloudProviderInstanceID) error
}

func ValidateStopStartInstance(ctx context.Context, client CloudStopStartInstance, instance *Instance) error {
err := client.StopInstance(ctx, instance.CloudID)
if err != nil {
return err
}
// TODO wait for stopped
err = client.StartInstance(ctx, instance.CloudID)
if err != nil {
return err
}
// TODO wait for running
return nil
}

type CloudRebootInstance interface {
RebootInstance(ctx context.Context, instanceID CloudProviderInstanceID) error
}
Expand All @@ -152,40 +47,6 @@ type UpdateHandler interface {
MergeInstanceTypeForUpdate(currIt InstanceType, newIt InstanceType) InstanceType
}

func ValidateMergeInstanceForUpdate(client UpdateHandler, currInst Instance, newInst Instance) error {
mergedInst := client.MergeInstanceForUpdate(currInst, newInst)

var validationErr error
if currInst.Name != mergedInst.Name {
validationErr = errors.Join(validationErr, fmt.Errorf("name mismatch: %s != %s", currInst.Name, mergedInst.Name))
}
if currInst.RefID != mergedInst.RefID {
validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", currInst.RefID, mergedInst.RefID))
}
if currInst.Location != mergedInst.Location {
validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", currInst.Location, newInst.Location))
}
if currInst.SubLocation != mergedInst.SubLocation {
validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", currInst.SubLocation, mergedInst.SubLocation))
}
if currInst.InstanceType != "" && currInst.InstanceType != mergedInst.InstanceType {
validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", currInst.InstanceType, mergedInst.InstanceType))
}
if currInst.InstanceTypeID != "" && currInst.InstanceTypeID != mergedInst.InstanceTypeID {
validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", currInst.InstanceTypeID, mergedInst.InstanceTypeID))
}
if currInst.CloudCredRefID != mergedInst.CloudCredRefID {
validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", currInst.CloudCredRefID, mergedInst.CloudCredRefID))
}
if currInst.VolumeType != "" && currInst.VolumeType != mergedInst.VolumeType {
validationErr = errors.Join(validationErr, fmt.Errorf("volumeType mismatch: %s != %s", currInst.VolumeType, mergedInst.VolumeType))
}
if currInst.Spot != mergedInst.Spot {
validationErr = errors.Join(validationErr, fmt.Errorf("spot mismatch: %v != %v", currInst.Spot, mergedInst.Spot))
}
return validationErr
}

type Instance struct {
Name string
RefID string
Expand Down Expand Up @@ -308,39 +169,3 @@ func makeDebuggableName(name string) (string, error) {
}

const RunningSSHTimeout = 10 * time.Minute

func ValidateInstanceSSHAccessible(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error {
var err error
instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout)
if err != nil {
return err
}
sshUser := instance.SSHUser
sshPort := instance.SSHPort
publicIP := instance.PublicIP
// Validate that we have the required SSH connection details
if sshUser == "" {
return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID)
}
if sshPort == 0 {
return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID)
}
if publicIP == "" {
return fmt.Errorf("public IP is not available for instance %s", instance.CloudID)
}

err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{
User: sshUser,
HostPort: fmt.Sprintf("%s:%d", publicIP, sshPort),
PrivKey: privateKey,
}, ssh.WaitForSSHOptions{
Timeout: RunningSSHTimeout,
})
if err != nil {
return err
}

fmt.Printf("SSH connection validated successfully for %s@%s:%d\n", sshUser, publicIP, sshPort)

return nil
}
Loading
Loading