Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.fleets import InstanceGroupPlacement
from dstack._internal.core.models.instances import (
HealthStatus,
InstanceAvailability,
InstanceOfferWithAvailability,
InstanceRuntime,
Expand Down Expand Up @@ -75,6 +76,7 @@
InstanceHealthResponse,
)
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services import events
from dstack._internal.server.services.fleets import (
fleet_model_to_fleet,
get_create_instance_offers,
Expand Down Expand Up @@ -759,8 +761,8 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
)
session.add(health_check_model)

instance.health = health_status
instance.unreachable = not instance_check.reachable
_set_health(session, instance, health_status)
_set_unreachable(session, instance, unreachable=not instance_check.reachable)

if instance_check.reachable:
instance.termination_deadline = None
Expand Down Expand Up @@ -1093,6 +1095,28 @@ async def _terminate(session: AsyncSession, instance: InstanceModel) -> None:
switch_instance_status(session, instance, InstanceStatus.TERMINATED)


def _set_health(session: AsyncSession, instance: InstanceModel, health: HealthStatus) -> None:
if instance.health != health:
events.emit(
session,
f"Instance health changed {instance.health.upper()} -> {health.upper()}",
actor=events.SystemActor(),
targets=[events.Target.from_model(instance)],
)
instance.health = health


def _set_unreachable(session: AsyncSession, instance: InstanceModel, unreachable: bool) -> None:
if instance.unreachable != unreachable:
events.emit(
session,
"Instance became unreachable" if unreachable else "Instance became reachable",
actor=events.SystemActor(),
targets=[events.Target.from_model(instance)],
)
instance.unreachable = unreachable


def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
assert instance.last_termination_retry_at is not None
return instance.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
UserModel,
)
from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus
from dstack._internal.server.services import events, services
from dstack._internal.server.services import files as files_services
from dstack._internal.server.services import logs as logs_services
from dstack._internal.server.services import services
from dstack._internal.server.services.instances import get_instance_ssh_private_keys
from dstack._internal.server.services.jobs import (
find_job,
Expand Down Expand Up @@ -355,7 +355,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
)

if success:
job_model.disconnected_at = None
_reset_disconnected_at(session, job_model)
else:
if job_model.termination_reason:
logger.warning(
Expand All @@ -368,8 +368,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
# job will be terminated and instance will be emptied by process_terminating_jobs
else:
# No job_model.termination_reason set means ssh connection failed
if job_model.disconnected_at is None:
job_model.disconnected_at = common_utils.get_current_datetime()
_set_disconnected_at_now(session, job_model)
if _should_terminate_job_due_to_disconnect(job_model):
# TODO: Replace with JobTerminationReason.INSTANCE_UNREACHABLE for on-demand.
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
Expand Down Expand Up @@ -933,6 +932,28 @@ def _should_terminate_due_to_low_gpu_util(min_util: int, gpus_util: Iterable[Ite
return False


def _set_disconnected_at_now(session: AsyncSession, job_model: JobModel) -> None:
if job_model.disconnected_at is None:
job_model.disconnected_at = common_utils.get_current_datetime()
events.emit(
session,
"Job became unreachable",
actor=events.SystemActor(),
targets=[events.Target.from_model(job_model)],
)


def _reset_disconnected_at(session: AsyncSession, job_model: JobModel) -> None:
if job_model.disconnected_at is not None:
job_model.disconnected_at = None
events.emit(
session,
"Job became reachable",
actor=events.SystemActor(),
targets=[events.Target.from_model(job_model)],
)


def _get_cluster_info(
jobs: List[Job],
replica_num: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def _create_instance_model_for_job(
created_at=common_utils.get_current_datetime(),
started_at=common_utils.get_current_datetime(),
status=InstanceStatus.PROVISIONING,
unreachable=False,
unreachable=True,
job_provisioning_data=job_provisioning_data.json(),
offer=offer.json(),
termination_policy=termination_policy,
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/services/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def create_instance_model(
project=project,
created_at=common_utils.get_current_datetime(),
status=InstanceStatus.PENDING,
unreachable=False,
unreachable=True,
profile=profile.json(),
requirements=requirements.json(),
instance_configuration=instance_config.json(),
Expand Down Expand Up @@ -680,7 +680,7 @@ async def create_ssh_instance_model(
created_at=common_utils.get_current_datetime(),
started_at=common_utils.get_current_datetime(),
status=InstanceStatus.PENDING,
unreachable=False,
unreachable=True,
job_provisioning_data=remote.json(),
remote_connection_info=remote_connection_info.json(),
offer=offer.json(),
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from uuid import UUID

import gpuhunt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.backends.base.compute import (
Expand Down Expand Up @@ -90,6 +91,7 @@
BackendModel,
ComputeGroupModel,
DecryptedString,
EventModel,
FileArchiveModel,
FleetModel,
GatewayComputeModel,
Expand Down Expand Up @@ -1111,6 +1113,11 @@ async def create_secret(
return secret_model


async def list_events(session: AsyncSession) -> list[EventModel]:
res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id))
return list(res.scalars().all())


def get_private_key_string() -> str:
return """
-----BEGIN RSA PRIVATE KEY-----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@

import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.server import settings
from dstack._internal.server.background.tasks.process_events import delete_events
from dstack._internal.server.models import EventModel
from dstack._internal.server.services import events
from dstack._internal.server.testing.common import create_user
from dstack._internal.server.testing.common import create_user, list_events


@pytest.mark.asyncio
Expand All @@ -27,8 +25,7 @@ async def test_deletes_old_events(test_db, session: AsyncSession) -> None:
)
await session.commit()

res = await session.execute(select(EventModel))
all_events = res.scalars().all()
all_events = await list_events(session)
assert len(all_events) == 10

with (
Expand All @@ -37,8 +34,7 @@ async def test_deletes_old_events(test_db, session: AsyncSession) -> None:
):
await delete_events()

res = await session.execute(select(EventModel).order_by(EventModel.recorded_at))
remaining_events = res.scalars().all()
remaining_events = await list_events(session)
assert len(remaining_events) == 5
assert [e.message for e in remaining_events] == [
"Event 5",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
get_job_provisioning_data,
get_placement_group_provisioning_data,
get_remote_connection_info,
list_events,
)
from dstack._internal.utils.common import get_current_datetime

Expand Down Expand Up @@ -324,10 +325,13 @@ async def test_check_shim_process_ureachable_state(
healthcheck.assert_called()

await session.refresh(instance)
events = await list_events(session)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert not instance.unreachable
assert len(events) == 1
assert events[0].message == "Instance became reachable"

@pytest.mark.asyncio
@pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE])
Expand All @@ -351,12 +355,15 @@ async def test_check_shim_switch_to_unreachable_state(
await process_instances()

await session.refresh(instance)
events = await list_events(session)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert instance.unreachable
# Should keep the previous status
assert instance.health == health_status
assert len(events) == 1
assert events[0].message == "Instance became unreachable"

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down Expand Up @@ -384,11 +391,14 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes
await process_instances()

await session.refresh(instance)
events = await list_events(session)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert not instance.unreachable
assert instance.health == HealthStatus.WARNING
assert len(events) == 1
assert events[0].message == "Instance health changed HEALTHY -> WARNING"

res = await session.execute(select(InstanceHealthCheckModel))
health_check = res.scalars().one()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
get_job_runtime_data,
get_run_spec,
get_volume_configuration,
list_events,
)
from dstack._internal.utils.common import get_current_datetime

Expand Down Expand Up @@ -515,9 +516,12 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession):
await process_running_jobs()
assert SSHTunnelMock.call_count == 3
await session.refresh(job)
events = await list_events(session)
assert job is not None
assert job.disconnected_at is not None
assert job.status == JobStatus.PULLING
assert len(events) == 1
assert events[0].message == "Job became unreachable"
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
patch("dstack._internal.server.services.runner.ssh.time.sleep"),
Expand Down
6 changes: 3 additions & 3 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
"job_name": None,
"hostname": None,
"status": "pending",
"unreachable": False,
"unreachable": True,
"health_status": "healthy",
"termination_reason": None,
"termination_reason_message": None,
Expand Down Expand Up @@ -535,7 +535,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
"job_name": None,
"hostname": "1.1.1.1",
"status": "pending",
"unreachable": False,
"unreachable": True,
"health_status": "healthy",
"termination_reason": None,
"termination_reason_message": None,
Expand Down Expand Up @@ -741,7 +741,7 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A
"job_name": None,
"hostname": "10.0.0.101",
"status": "pending",
"unreachable": False,
"unreachable": True,
"health_status": "healthy",
"termination_reason": None,
"termination_reason_message": None,
Expand Down
10 changes: 4 additions & 6 deletions src/tests/_internal/server/services/test_instances.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid

import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

import dstack._internal.server.services.instances as instances_services
Expand All @@ -15,13 +14,14 @@
Resources,
)
from dstack._internal.core.models.profiles import Profile
from dstack._internal.server.models import EventModel, InstanceModel
from dstack._internal.server.models import InstanceModel
from dstack._internal.server.testing.common import (
create_instance,
create_project,
create_user,
get_volume,
get_volume_configuration,
list_events,
)
from dstack._internal.utils.common import get_current_datetime

Expand All @@ -41,8 +41,7 @@ async def test_includes_termination_reason_in_event_messages_only_once(
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATING)
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)

res = await session.execute(select(EventModel))
events = res.scalars().all()
events = await list_events(session)
assert len(events) == 2
assert {e.message for e in events} == {
"Instance status changed PENDING -> TERMINATING. Termination reason: ERROR (Some err)",
Expand All @@ -63,8 +62,7 @@ async def test_includes_termination_reason_in_event_message_when_switching_direc
instance.termination_reason_message = "Some err"
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)

res = await session.execute(select(EventModel))
events = res.scalars().all()
events = await list_events(session)
assert len(events) == 1
assert events[0].message == (
"Instance status changed PENDING -> TERMINATED. Termination reason: ERROR (Some err)"
Expand Down