diff --git a/infrahub_sdk/store.py b/infrahub_sdk/store.py index 6420495b..e8406b10 100644 --- a/infrahub_sdk/store.py +++ b/infrahub_sdk/store.py @@ -1,8 +1,11 @@ from __future__ import annotations +import inspect import warnings from typing import TYPE_CHECKING, Literal, overload +from infrahub_sdk.protocols_base import CoreNodeBase + from .exceptions import NodeInvalidError, NodeNotFoundError from .node.parsers import parse_human_friendly_id @@ -16,8 +19,15 @@ def get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str | None = Non if isinstance(schema, str): return schema - if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr] - return schema.__name__ # type: ignore[union-attr] + if schema is None: + return None + + if issubclass(schema, CoreNodeBase): + if inspect.iscoroutinefunction(schema.save): + return schema.__name__ + if schema.__name__[-4:] == "Sync": + return schema.__name__[:-4] + return schema.__name__ return None diff --git a/tests/unit/sdk/test_store.py b/tests/unit/sdk/test_store.py index 83644aae..eaeab57c 100644 --- a/tests/unit/sdk/test_store.py +++ b/tests/unit/sdk/test_store.py @@ -6,7 +6,8 @@ from infrahub_sdk.exceptions import NodeInvalidError, NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync -from infrahub_sdk.store import NodeStore, NodeStoreSync +from infrahub_sdk.protocols import BuiltinIPAddressSync, BuiltinIPPrefix +from infrahub_sdk.store import NodeStore, NodeStoreSync, get_schema_name if TYPE_CHECKING: from infrahub_sdk.schema import NodeSchemaAPI @@ -157,3 +158,8 @@ def test_node_store_get_with_hfid( store.get(kind="BuiltinLocation", key="anotherkey") with pytest.raises(NodeNotFoundError): store.get(key="anotherkey") + + +def test_store_get_schema_name() -> None: + assert get_schema_name(schema=BuiltinIPPrefix) == BuiltinIPPrefix.__name__ + assert get_schema_name(schema=BuiltinIPAddressSync) == BuiltinIPAddressSync.__name__[:-4]