diff --git a/internal/validation/suite.go b/internal/validation/suite.go index c4e04e1..31aab6d 100644 --- a/internal/validation/suite.go +++ b/internal/validation/suite.go @@ -119,6 +119,16 @@ func RunInstanceLifecycleValidation(t *testing.T, config ProviderConfig) { require.NoError(t, err, "ValidateInstanceImage should pass") }) + t.Run("ValidateFirewallBlocksPort", func(t *testing.T) { + err := v1.ValidateFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), v1.DefaultFirewallTestPort) + require.NoError(t, err, "ValidateFirewallBlocksPort should pass - non-allowed port should be blocked") + }) + + t.Run("ValidateDockerFirewallBlocksPort", func(t *testing.T) { + err := v1.ValidateDockerFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), v1.DefaultFirewallTestPort) + require.NoError(t, err, "ValidateDockerFirewallBlocksPort should pass - docker port should be blocked by iptables") + }) + if capabilities.IsCapable(v1.CapabilityStopStartInstance) && instance.Stoppable { t.Run("ValidateStopStartInstance", func(t *testing.T) { err := v1.ValidateStopStartInstance(ctx, client, instance) @@ -235,6 +245,101 @@ func RunNetworkValidation(t *testing.T, config ProviderConfig, opts NetworkValid }) } +type FirewallValidationOpts struct { + // TestPort is the port to test firewall blocking on (should NOT be in allowed ingress) + TestPort int + // TestDockerFirewall enables docker firewall validation (requires Docker on instance) + TestDockerFirewall bool +} + +func RunFirewallValidation(t *testing.T, config ProviderConfig, opts FirewallValidationOpts) { + if testing.Short() { + t.Skip("Skipping validation tests in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) + defer cancel() + + client, err := config.Credential.MakeClient(ctx, config.Location) + if err != nil { + t.Fatalf("Failed to create client for %s: %v", config.Credential.GetCloudProviderID(), err) + } + + types, err := client.GetInstanceTypes(ctx, v1.GetInstanceTypeArgs{ + ArchitectureFilter: &v1.ArchitectureFilter{ + IncludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, + }, + }) + require.NoError(t, err) + require.NotEmpty(t, types, "Should have instance types") + + // Find an available instance type + attrs := v1.CreateInstanceAttrs{} + selectedType := v1.InstanceType{} + for _, typ := range types { + if typ.IsAvailable { + attrs.InstanceType = typ.Type + attrs.Location = typ.Location + attrs.PublicKey = ssh.GetTestPublicKey() + selectedType = typ + break + } + } + require.NotEmpty(t, attrs.InstanceType, "Should find available instance type") + + // Create instance for firewall testing + instance, err := v1.ValidateCreateInstance(ctx, client, attrs, selectedType) + require.NoError(t, err, "ValidateCreateInstance should pass") + require.NotNil(t, instance) + + defer func() { + if instance != nil { + _ = client.TerminateInstance(ctx, instance.CloudID) + } + }() + + // Wait for instance to be running and SSH accessible + t.Run("ValidateSSHAccessible", func(t *testing.T) { + err := v1.ValidateInstanceSSHAccessible(ctx, client, instance, ssh.GetTestPrivateKey()) + require.NoError(t, err, "ValidateSSHAccessible should pass") + }) + + // Refresh instance data + instance, err = client.GetInstance(ctx, instance.CloudID) + require.NoError(t, err) + + testPort := opts.TestPort + if testPort == 0 { + testPort = v1.DefaultFirewallTestPort + } + + // Test that regular server on 0.0.0.0 is blocked + t.Run("ValidateFirewallBlocksPort", func(t *testing.T) { + err := v1.ValidateFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), testPort) + require.NoError(t, err, "ValidateFirewallBlocksPort should pass - port should be blocked") + }) + + // Test that Docker container on 0.0.0.0 is blocked (if enabled) + if opts.TestDockerFirewall { + t.Run("ValidateDockerFirewallBlocksPort", func(t *testing.T) { + err := v1.ValidateDockerFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), testPort) + require.NoError(t, err, "ValidateDockerFirewallBlocksPort should pass - docker port should be blocked") + }) + } + + // Test that SSH port is accessible (sanity check) + t.Run("ValidateSSHPortAccessible", func(t *testing.T) { + err := v1.ValidateFirewallAllowsPort(ctx, client, instance, ssh.GetTestPrivateKey(), instance.SSHPort) + require.NoError(t, err, "ValidateFirewallAllowsPort should pass for SSH port") + }) + + // Terminate instance + t.Run("ValidateTerminateInstance", func(t *testing.T) { + err := v1.ValidateTerminateInstance(ctx, client, instance) + require.NoError(t, err, "ValidateTerminateInstance should pass") + }) +} + type KubernetesValidationOpts struct { Name string RefID string diff --git a/v1/instance.go b/v1/instance.go index d56a5e0..0a9b7b2 100644 --- a/v1/instance.go +++ b/v1/instance.go @@ -2,14 +2,10 @@ package v1 import ( "context" - "errors" "fmt" "time" "github.com/alecthomas/units" - "github.com/brevdev/cloud/internal/collections" - "github.com/brevdev/cloud/internal/ssh" - "github.com/google/uuid" ) type CloudInstanceReader interface { @@ -28,112 +24,11 @@ type CloudCreateTerminateInstance interface { CloudInstanceReader } -func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs, selectedType InstanceType) (*Instance, error) { //nolint:gocyclo // ok - t0 := time.Now().Add(-time.Minute) - attrs.RefID = uuid.New().String() - name, err := makeDebuggableName(attrs.Name) - if err != nil { - return nil, err - } - attrs.Name = name - i, err := client.CreateInstance(ctx, attrs) - if err != nil { - return nil, err - } - var validationErr error - t1 := time.Now().Add(1 * time.Minute) - diff := t1.Sub(t0) - if diff > 3*time.Minute { - validationErr = errors.Join(validationErr, fmt.Errorf("create instance took too long: %s", diff)) - } - if i.CreatedAt.Before(t0) { - validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is before t0: %s", i.CreatedAt)) - } - if i.CreatedAt.After(t1) { - validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is after t1: %s", i.CreatedAt)) - } - if i.Name != name { - fmt.Printf("name mismatch: %s != %s, input name does not mean return name will be stable\n", i.Name, name) - } - if i.RefID != attrs.RefID { - validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", i.RefID, attrs.RefID)) - } - if attrs.Location != "" && attrs.Location != i.Location { - validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", attrs.Location, i.Location)) - } - if attrs.SubLocation != "" && attrs.SubLocation != i.SubLocation { - validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", attrs.SubLocation, i.SubLocation)) - } - if attrs.InstanceType != "" && attrs.InstanceType != i.InstanceType { - validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", attrs.InstanceType, i.InstanceType)) - } - if selectedType.ID != "" && selectedType.ID != i.InstanceTypeID { - validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", selectedType.ID, i.InstanceTypeID)) - } - - return i, validationErr -} - -func ValidateListCreatedInstance(ctx context.Context, client CloudCreateTerminateInstance, i *Instance) error { - ins, err := client.ListInstances(ctx, ListInstancesArgs{ - Locations: []string{i.Location}, - }) - if err != nil { - return err - } - var validationErr error - if len(ins) == 0 { - validationErr = errors.Join(validationErr, fmt.Errorf("no instances found")) - } - foundInstance := collections.Find(ins, func(inst Instance) bool { - return inst.CloudID == i.CloudID - }) - if foundInstance == nil { - validationErr = errors.Join(validationErr, fmt.Errorf("instance not found: %s", i.CloudID)) - return validationErr - } - if foundInstance.Location != i.Location { //nolint:gocritic // fine - validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", foundInstance.Location, i.Location)) - } else if foundInstance.RefID == "" { - validationErr = errors.Join(validationErr, fmt.Errorf("refID is empty")) - } else if foundInstance.RefID != i.RefID { - validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", foundInstance.RefID, i.RefID)) - } else if foundInstance.CloudCredRefID == "" { - validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID is empty")) - } else if foundInstance.CloudCredRefID != i.CloudCredRefID { - validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", foundInstance.CloudCredRefID, i.CloudCredRefID)) - } - return validationErr -} - -func ValidateTerminateInstance(ctx context.Context, client CloudCreateTerminateInstance, instance *Instance) error { - err := client.TerminateInstance(ctx, instance.CloudID) - if err != nil { - return err - } - // TODO wait for instance to go into terminating state - return nil -} - type CloudStopStartInstance interface { StopInstance(ctx context.Context, instanceID CloudProviderInstanceID) error StartInstance(ctx context.Context, instanceID CloudProviderInstanceID) error } -func ValidateStopStartInstance(ctx context.Context, client CloudStopStartInstance, instance *Instance) error { - err := client.StopInstance(ctx, instance.CloudID) - if err != nil { - return err - } - // TODO wait for stopped - err = client.StartInstance(ctx, instance.CloudID) - if err != nil { - return err - } - // TODO wait for running - return nil -} - type CloudRebootInstance interface { RebootInstance(ctx context.Context, instanceID CloudProviderInstanceID) error } @@ -152,40 +47,6 @@ type UpdateHandler interface { MergeInstanceTypeForUpdate(currIt InstanceType, newIt InstanceType) InstanceType } -func ValidateMergeInstanceForUpdate(client UpdateHandler, currInst Instance, newInst Instance) error { - mergedInst := client.MergeInstanceForUpdate(currInst, newInst) - - var validationErr error - if currInst.Name != mergedInst.Name { - validationErr = errors.Join(validationErr, fmt.Errorf("name mismatch: %s != %s", currInst.Name, mergedInst.Name)) - } - if currInst.RefID != mergedInst.RefID { - validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", currInst.RefID, mergedInst.RefID)) - } - if currInst.Location != mergedInst.Location { - validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", currInst.Location, newInst.Location)) - } - if currInst.SubLocation != mergedInst.SubLocation { - validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", currInst.SubLocation, mergedInst.SubLocation)) - } - if currInst.InstanceType != "" && currInst.InstanceType != mergedInst.InstanceType { - validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", currInst.InstanceType, mergedInst.InstanceType)) - } - if currInst.InstanceTypeID != "" && currInst.InstanceTypeID != mergedInst.InstanceTypeID { - validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", currInst.InstanceTypeID, mergedInst.InstanceTypeID)) - } - if currInst.CloudCredRefID != mergedInst.CloudCredRefID { - validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", currInst.CloudCredRefID, mergedInst.CloudCredRefID)) - } - if currInst.VolumeType != "" && currInst.VolumeType != mergedInst.VolumeType { - validationErr = errors.Join(validationErr, fmt.Errorf("volumeType mismatch: %s != %s", currInst.VolumeType, mergedInst.VolumeType)) - } - if currInst.Spot != mergedInst.Spot { - validationErr = errors.Join(validationErr, fmt.Errorf("spot mismatch: %v != %v", currInst.Spot, mergedInst.Spot)) - } - return validationErr -} - type Instance struct { Name string RefID string @@ -308,39 +169,3 @@ func makeDebuggableName(name string) (string, error) { } const RunningSSHTimeout = 10 * time.Minute - -func ValidateInstanceSSHAccessible(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error { - var err error - instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout) - if err != nil { - return err - } - sshUser := instance.SSHUser - sshPort := instance.SSHPort - publicIP := instance.PublicIP - // Validate that we have the required SSH connection details - if sshUser == "" { - return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID) - } - if sshPort == 0 { - return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID) - } - if publicIP == "" { - return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) - } - - err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{ - User: sshUser, - HostPort: fmt.Sprintf("%s:%d", publicIP, sshPort), - PrivKey: privateKey, - }, ssh.WaitForSSHOptions{ - Timeout: RunningSSHTimeout, - }) - if err != nil { - return err - } - - fmt.Printf("SSH connection validated successfully for %s@%s:%d\n", sshUser, publicIP, sshPort) - - return nil -} diff --git a/v1/instance_validation.go b/v1/instance_validation.go new file mode 100644 index 0000000..b2f16ec --- /dev/null +++ b/v1/instance_validation.go @@ -0,0 +1,183 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/brevdev/cloud/internal/collections" + "github.com/brevdev/cloud/internal/ssh" + "github.com/google/uuid" +) + +func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs, selectedType InstanceType) (*Instance, error) { //nolint:gocyclo // ok + t0 := time.Now().Add(-time.Minute) + attrs.RefID = uuid.New().String() + name, err := makeDebuggableName(attrs.Name) + if err != nil { + return nil, err + } + attrs.Name = name + i, err := client.CreateInstance(ctx, attrs) + if err != nil { + return nil, err + } + var validationErr error + t1 := time.Now().Add(1 * time.Minute) + diff := t1.Sub(t0) + if diff > 3*time.Minute { + validationErr = errors.Join(validationErr, fmt.Errorf("create instance took too long: %s", diff)) + } + if i.CreatedAt.Before(t0) { + validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is before t0: %s", i.CreatedAt)) + } + if i.CreatedAt.After(t1) { + validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is after t1: %s", i.CreatedAt)) + } + if i.Name != name { + fmt.Printf("name mismatch: %s != %s, input name does not mean return name will be stable\n", i.Name, name) + } + if i.RefID != attrs.RefID { + validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", i.RefID, attrs.RefID)) + } + if attrs.Location != "" && attrs.Location != i.Location { + validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", attrs.Location, i.Location)) + } + if attrs.SubLocation != "" && attrs.SubLocation != i.SubLocation { + validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", attrs.SubLocation, i.SubLocation)) + } + if attrs.InstanceType != "" && attrs.InstanceType != i.InstanceType { + validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", attrs.InstanceType, i.InstanceType)) + } + if selectedType.ID != "" && selectedType.ID != i.InstanceTypeID { + validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", selectedType.ID, i.InstanceTypeID)) + } + + return i, validationErr +} + +func ValidateListCreatedInstance(ctx context.Context, client CloudCreateTerminateInstance, i *Instance) error { + ins, err := client.ListInstances(ctx, ListInstancesArgs{ + Locations: []string{i.Location}, + }) + if err != nil { + return err + } + var validationErr error + if len(ins) == 0 { + validationErr = errors.Join(validationErr, fmt.Errorf("no instances found")) + } + foundInstance := collections.Find(ins, func(inst Instance) bool { + return inst.CloudID == i.CloudID + }) + if foundInstance == nil { + validationErr = errors.Join(validationErr, fmt.Errorf("instance not found: %s", i.CloudID)) + return validationErr + } + if foundInstance.Location != i.Location { //nolint:gocritic // fine + validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", foundInstance.Location, i.Location)) + } else if foundInstance.RefID == "" { + validationErr = errors.Join(validationErr, fmt.Errorf("refID is empty")) + } else if foundInstance.RefID != i.RefID { + validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", foundInstance.RefID, i.RefID)) + } else if foundInstance.CloudCredRefID == "" { + validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID is empty")) + } else if foundInstance.CloudCredRefID != i.CloudCredRefID { + validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", foundInstance.CloudCredRefID, i.CloudCredRefID)) + } + return validationErr +} + +func ValidateTerminateInstance(ctx context.Context, client CloudCreateTerminateInstance, instance *Instance) error { + err := client.TerminateInstance(ctx, instance.CloudID) + if err != nil { + return err + } + // TODO wait for instance to go into terminating state + return nil +} + +func ValidateStopStartInstance(ctx context.Context, client CloudStopStartInstance, instance *Instance) error { + err := client.StopInstance(ctx, instance.CloudID) + if err != nil { + return err + } + // TODO wait for stopped + err = client.StartInstance(ctx, instance.CloudID) + if err != nil { + return err + } + // TODO wait for running + return nil +} + +func ValidateMergeInstanceForUpdate(client UpdateHandler, currInst Instance, newInst Instance) error { + mergedInst := client.MergeInstanceForUpdate(currInst, newInst) + + var validationErr error + if currInst.Name != mergedInst.Name { + validationErr = errors.Join(validationErr, fmt.Errorf("name mismatch: %s != %s", currInst.Name, mergedInst.Name)) + } + if currInst.RefID != mergedInst.RefID { + validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", currInst.RefID, mergedInst.RefID)) + } + if currInst.Location != mergedInst.Location { + validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", currInst.Location, newInst.Location)) + } + if currInst.SubLocation != mergedInst.SubLocation { + validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", currInst.SubLocation, mergedInst.SubLocation)) + } + if currInst.InstanceType != "" && currInst.InstanceType != mergedInst.InstanceType { + validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", currInst.InstanceType, mergedInst.InstanceType)) + } + if currInst.InstanceTypeID != "" && currInst.InstanceTypeID != mergedInst.InstanceTypeID { + validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", currInst.InstanceTypeID, mergedInst.InstanceTypeID)) + } + if currInst.CloudCredRefID != mergedInst.CloudCredRefID { + validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", currInst.CloudCredRefID, mergedInst.CloudCredRefID)) + } + if currInst.VolumeType != "" && currInst.VolumeType != mergedInst.VolumeType { + validationErr = errors.Join(validationErr, fmt.Errorf("volumeType mismatch: %s != %s", currInst.VolumeType, mergedInst.VolumeType)) + } + if currInst.Spot != mergedInst.Spot { + validationErr = errors.Join(validationErr, fmt.Errorf("spot mismatch: %v != %v", currInst.Spot, mergedInst.Spot)) + } + return validationErr +} + +func ValidateInstanceSSHAccessible(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error { + var err error + instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout) + if err != nil { + return err + } + sshUser := instance.SSHUser + sshPort := instance.SSHPort + publicIP := instance.PublicIP + // Validate that we have the required SSH connection details + if sshUser == "" { + return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID) + } + if sshPort == 0 { + return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID) + } + if publicIP == "" { + return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) + } + + err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{ + User: sshUser, + HostPort: fmt.Sprintf("%s:%d", publicIP, sshPort), + PrivKey: privateKey, + }, ssh.WaitForSSHOptions{ + Timeout: RunningSSHTimeout, + }) + if err != nil { + return err + } + + fmt.Printf("SSH connection validated successfully for %s@%s:%d\n", sshUser, publicIP, sshPort) + + return nil +} diff --git a/v1/networking_validation.go b/v1/networking_validation.go new file mode 100644 index 0000000..97c2c55 --- /dev/null +++ b/v1/networking_validation.go @@ -0,0 +1,316 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "time" + + "github.com/brevdev/cloud/internal/ssh" +) + +const ( + // DefaultFirewallTestPort is the port used for testing firewall rules + // This port should NOT be in the allowed ingress rules + DefaultFirewallTestPort = 9999 + + // FirewallTestTimeout is the timeout for testing port accessibility + FirewallTestTimeout = 10 * time.Second + + // PortConnectionTimeout is the timeout for a single connection attempt + PortConnectionTimeout = 5 * time.Second +) + +// ValidateFirewallBlocksPort validates that a port is NOT accessible from outside the instance. +// This is used to verify that firewall rules (UFW, iptables) are working correctly. +func ValidateFirewallBlocksPort(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string, port int) error { + var err error + instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout) + if err != nil { + return fmt.Errorf("failed to wait for instance running: %w", err) + } + + publicIP := instance.PublicIP + if publicIP == "" { + return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) + } + + // First, start a test server on the instance + sshClient, err := ssh.ConnectToHost(ctx, ssh.ConnectionConfig{ + User: instance.SSHUser, + HostPort: fmt.Sprintf("%s:%d", publicIP, instance.SSHPort), + PrivKey: privateKey, + }) + if err != nil { + return fmt.Errorf("failed to SSH into instance: %w", err) + } + defer func() { _ = sshClient.Close() }() + + // Start a simple HTTP server on 0.0.0.0:port in the background + // The server will respond with "OK" to any request + startServerCmd := fmt.Sprintf( + "nohup sh -c 'echo -e \"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK\" | nc -l -p %d' > /dev/null 2>&1 &", + port, + ) + _, _, err = sshClient.RunCommand(ctx, startServerCmd) + if err != nil { + return fmt.Errorf("failed to start test server on instance: %w", err) + } + + // Give the server a moment to start + time.Sleep(500 * time.Millisecond) + + // Now try to connect to the port from outside - this should FAIL + err = checkPortBlocked(ctx, publicIP, port) + if err != nil { + return err + } + + // Clean up: kill any remaining nc processes on that port + killCmd := fmt.Sprintf("pkill -f 'nc -l -p %d' || true", port) + _, _, _ = sshClient.RunCommand(ctx, killCmd) + + return nil +} + +// ValidateDockerFirewallBlocksPort validates that a Docker container listening on 0.0.0.0 +// is NOT accessible from outside the instance due to DOCKER-USER iptables rules. +func ValidateDockerFirewallBlocksPort(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string, port int) error { + var err error + instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout) + if err != nil { + return fmt.Errorf("failed to wait for instance running: %w", err) + } + + publicIP := instance.PublicIP + if publicIP == "" { + return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) + } + + sshClient, err := ssh.ConnectToHost(ctx, ssh.ConnectionConfig{ + User: instance.SSHUser, + HostPort: fmt.Sprintf("%s:%d", publicIP, instance.SSHPort), + PrivKey: privateKey, + }) + if err != nil { + return fmt.Errorf("failed to SSH into instance: %w", err) + } + defer func() { _ = sshClient.Close() }() + + // Check if Docker is available + _, _, err = sshClient.RunCommand(ctx, "docker --version") + if err != nil { + return fmt.Errorf("docker is not available on instance: %w", err) + } + + // Start a Docker container with a simple HTTP server + // Using nginx as it's commonly available + containerName := fmt.Sprintf("firewall-test-%d", port) + startDockerCmd := fmt.Sprintf( + "docker run -d --rm --name %s -p %d:%d nginx:alpine", + containerName, port, 80, + ) + _, stderr, err := sshClient.RunCommand(ctx, startDockerCmd) + if err != nil { + return fmt.Errorf("failed to start docker container: %w, stderr: %s", err, stderr) + } + + // Wait for container to be running and service to be ready + if err := waitForDockerService(ctx, sshClient, containerName, port); err != nil { + _, _, _ = sshClient.RunCommand(ctx, fmt.Sprintf("docker rm -f %s || true", containerName)) + return err + } + + // Debug: show iptables rules for DOCKER-USER chain BEFORE any modification + iptablesOut, _, _ := sshClient.RunCommand(ctx, "sudo iptables -L DOCKER-USER -n -v 2>&1 || echo 'DOCKER-USER chain not found'") + fmt.Printf("Instance %s DOCKER-USER iptables rules BEFORE:\n%s\n", publicIP, iptablesOut) + + // TESTING ONLY: Flush DOCKER-USER rules to simulate vulnerable state (no iptables protection) + // This reproduces the issue where Docker ports are accessible without the iptables fix + fmt.Printf("FLUSHING DOCKER-USER rules to simulate vulnerable state...\n") + _, _, _ = sshClient.RunCommand(ctx, "sudo iptables -F DOCKER-USER") + + // Show rules after flush + iptablesOut, _, _ = sshClient.RunCommand(ctx, "sudo iptables -L DOCKER-USER -n -v 2>&1 || echo 'DOCKER-USER chain not found'") + fmt.Printf("Instance %s DOCKER-USER iptables rules AFTER FLUSH:\n%s\n", publicIP, iptablesOut) + + // Now try to connect to the port from outside + // WITHOUT iptables rules, this should SUCCEED (port accessible = vulnerability exists) + // WITH iptables rules, this should FAIL (port blocked = secure) + fmt.Printf("Testing external connectivity to %s:%d\n", publicIP, port) + err = checkPortBlocked(ctx, publicIP, port) + + // Clean up: stop and remove the container + stopDockerCmd := fmt.Sprintf("docker rm -f %s || true", containerName) + _, _, _ = sshClient.RunCommand(ctx, stopDockerCmd) + + if err != nil { + return err + } + + return nil +} + +// waitForDockerService waits for a Docker container's service to be ready and responding +func waitForDockerService(ctx context.Context, sshClient *ssh.Client, containerName string, port int) error { + for i := 0; i < 30; i++ { // Try for up to 30 seconds + time.Sleep(1 * time.Second) + + // Check container is running + checkContainerCmd := fmt.Sprintf("docker ps --filter name=%s --format '{{.Names}}'", containerName) + stdout, _, err := sshClient.RunCommand(ctx, checkContainerCmd) + if err != nil || stdout == "" { + continue + } + + // Check if the port is listening inside the container (via localhost) + checkPortCmd := fmt.Sprintf("curl -s -o /dev/null -w '%%{http_code}' --connect-timeout 2 http://localhost:%d/ || echo 'failed'", port) + stdout, _, err = sshClient.RunCommand(ctx, checkPortCmd) + if err == nil && stdout != "" && stdout != "failed" && stdout != "000" { + fmt.Printf("Docker container ready after %d seconds, curl returned: %s\n", i+1, stdout) + return nil + } + } + return fmt.Errorf("docker container service did not become ready within 30 seconds") +} + +// checkPortBlocked verifies that a port is NOT accessible from outside +// Returns nil if the port is blocked (expected), returns an error if the port is accessible +func checkPortBlocked(ctx context.Context, host string, port int) error { + addr := fmt.Sprintf("%s:%d", host, port) + + // Try multiple times to be sure - if ANY connection succeeds, the port is accessible + for attempt := 1; attempt <= 3; attempt++ { + attemptCtx, cancel := context.WithTimeout(ctx, PortConnectionTimeout) + + dialer := net.Dialer{Timeout: PortConnectionTimeout} + conn, err := dialer.DialContext(attemptCtx, "tcp", addr) + cancel() + + if err != nil { + fmt.Printf("checkPortBlocked attempt %d: connection to %s failed (expected): %v\n", attempt, addr, err) + continue + } + + // Connection succeeded - port is accessible, which is a problem + _ = conn.Close() + return fmt.Errorf("port %d is accessible from outside but should be blocked by firewall (attempt %d succeeded)", port, attempt) + } + + // All attempts failed to connect - port is blocked as expected + fmt.Printf("checkPortBlocked: confirmed port %d is blocked after 3 attempts\n", port) + return nil +} + +// checkPortAccessible verifies that a port IS accessible from outside +// Returns nil if the port is accessible, returns an error if the port is blocked +func checkPortAccessible(ctx context.Context, host string, port int) error { + ctx, cancel := context.WithTimeout(ctx, PortConnectionTimeout) + defer cancel() + + addr := fmt.Sprintf("%s:%d", host, port) + + dialer := net.Dialer{Timeout: PortConnectionTimeout} + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("port %d is not accessible: %w", port, err) + } + _ = conn.Close() + return nil +} + +// ValidateFirewallAllowsPort validates that a port IS accessible from outside the instance +// when it's in the allowed ingress rules. +func ValidateFirewallAllowsPort(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string, port int) error { + var err error + instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout) + if err != nil { + return fmt.Errorf("failed to wait for instance running: %w", err) + } + + publicIP := instance.PublicIP + if publicIP == "" { + return fmt.Errorf("public IP is not available for instance %s", instance.CloudID) + } + + sshClient, err := ssh.ConnectToHost(ctx, ssh.ConnectionConfig{ + User: instance.SSHUser, + HostPort: fmt.Sprintf("%s:%d", publicIP, instance.SSHPort), + PrivKey: privateKey, + }) + if err != nil { + return fmt.Errorf("failed to SSH into instance: %w", err) + } + defer func() { _ = sshClient.Close() }() + + // Start a simple HTTP server on 0.0.0.0:port + startServerCmd := fmt.Sprintf( + "nohup sh -c 'while true; do echo -e \"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK\" | nc -l -p %d; done' > /dev/null 2>&1 &", + port, + ) + _, _, err = sshClient.RunCommand(ctx, startServerCmd) + if err != nil { + return fmt.Errorf("failed to start test server on instance: %w", err) + } + + time.Sleep(500 * time.Millisecond) + + // Try to connect - this should succeed for allowed ports + err = checkPortAccessible(ctx, publicIP, port) + + // Clean up + killCmd := fmt.Sprintf("pkill -f 'nc -l -p %d' || true", port) + _, _, _ = sshClient.RunCommand(ctx, killCmd) + + if err != nil { + return fmt.Errorf("allowed port %d is not accessible: %w", port, err) + } + + return nil +} + +// ValidateFirewallRules validates that firewall rules are working correctly by: +// 1. Checking that a non-allowed port is blocked +// 2. Checking that an allowed port (SSH) is accessible +func ValidateFirewallRules(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error { + var validationErr error + + // Validate that SSH port is accessible (should be allowed) + err := checkPortAccessible(ctx, instance.PublicIP, instance.SSHPort) + if err != nil { + validationErr = errors.Join(validationErr, fmt.Errorf("SSH port should be accessible: %w", err)) + } + + // Validate that a non-standard port is blocked + err = ValidateFirewallBlocksPort(ctx, client, instance, privateKey, DefaultFirewallTestPort) + if err != nil { + validationErr = errors.Join(validationErr, fmt.Errorf("firewall should block port %d: %w", DefaultFirewallTestPort, err)) + } + + return validationErr +} + +// ValidateHTTPPortBlocked validates that an HTTP port is not accessible via HTTP request +func ValidateHTTPPortBlocked(ctx context.Context, host string, port int) error { + ctx, cancel := context.WithTimeout(ctx, PortConnectionTimeout) + defer cancel() + + url := fmt.Sprintf("http://%s:%d", host, port) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + client := &http.Client{Timeout: PortConnectionTimeout} + resp, err := client.Do(req) + if err != nil { + // Connection failed - expected behavior for blocked port + return nil + } + defer func() { _ = resp.Body.Close() }() + + // If we got a response, the port is accessible - this is a problem + return fmt.Errorf("HTTP port %d is accessible (status: %d) but should be blocked", port, resp.StatusCode) +}