Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions infrahub_sdk/ctl/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def load(
envvar="INFRAHUB_MAX_CONCURRENT_EXECUTION",
),
timeout: int = typer.Option(60, help="Timeout in sec", envvar="INFRAHUB_TIMEOUT"),
allow_upsert: bool = typer.Option(
False, help="Use Upsert mutations instead of Create. Use when objects may already exist."
),
) -> None:
"""Import nodes and their relationships into the database."""
console = Console()
Expand All @@ -45,6 +48,7 @@ def load(
InfrahubSchemaTopologicalSorter(),
continue_on_error=continue_on_error,
console=Console() if not quiet else None,
allow_upsert=allow_upsert,
)
try:
aiorun(importer.import_data(import_directory=directory, branch=branch))
Expand Down
7 changes: 7 additions & 0 deletions infrahub_sdk/graphql/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def convert_to_graphql_as_string(value: Any, convert_enum: bool = False) -> str:
+ " }"
)

# Defensive check: if value looks like a RelatedNode (has _generate_input_data method),
# extract its id to avoid serializing the object repr
if hasattr(value, "_generate_input_data") and hasattr(value, "id"):
node_id = getattr(value, "id", None)
if node_id is not None:
return convert_to_graphql_as_string(value=node_id, convert_enum=convert_enum)

return str(value)


Expand Down
4 changes: 4 additions & 0 deletions infrahub_sdk/node/related_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
setattr(self, prop, None)
self._relationship_metadata = None

elif isinstance(data, RelatedNodeBase):
# Handle when value is already a RelatedNode - extract its identifying data
data = {"id": data.id, "hfid": data.hfid, "__typename": data.typename}

elif isinstance(data, list):
data = {"hfid": data}
elif not isinstance(data, dict):
Expand Down
23 changes: 9 additions & 14 deletions infrahub_sdk/transfer/exporter/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,26 @@ def wrapped_task_output(self, start: str, end: str = "[green]done") -> Generator
if self.console:
self.console.print(f"{end}")

def identify_many_to_many_relationships(
self, node_schema_map: dict[str, MainSchemaTypesAPI]
) -> dict[tuple[str, str], str]:
# Identify many to many relationships by src/dst couples
def identify_many_relationships(self, node_schema_map: dict[str, MainSchemaTypesAPI]) -> dict[tuple[str, str], str]:
# Identify many relationships (both one-way and bidirectional many-to-many)
many_relationship_identifiers: dict[tuple[str, str], str] = {}

for node_schema in node_schema_map.values():
for relationship in node_schema.relationships:
if (
relationship.cardinality != "many"
or not relationship.optional
or not relationship.identifier
or relationship.peer not in node_schema_map
):
continue
for peer_relationship in node_schema_map[relationship.peer].relationships:
if peer_relationship.cardinality != "many" or peer_relationship.peer != node_schema.kind:
continue

forward = many_relationship_identifiers.get((node_schema.kind, relationship.peer))
backward = many_relationship_identifiers.get((relationship.peer, node_schema.kind))
forward = many_relationship_identifiers.get((node_schema.kind, relationship.peer))
backward = many_relationship_identifiers.get((relationship.peer, node_schema.kind))

# Record the relationship only if it's not known in one way or another
if not forward and not backward:
many_relationship_identifiers[node_schema.kind, relationship.peer] = relationship.identifier
# Record the relationship only if it's not known in one way or another
# This avoids duplicating bidirectional many-to-many relationships
if not forward and not backward:
many_relationship_identifiers[node_schema.kind, relationship.peer] = relationship.identifier

return many_relationship_identifiers

Expand All @@ -69,7 +64,7 @@ async def retrieve_many_to_many_relationships(
page_number = 1
page_size = 50

many_relationship_identifiers = list(self.identify_many_to_many_relationships(node_schema_map).values())
many_relationship_identifiers = list(self.identify_many_relationships(node_schema_map).values())
many_relationships: list[dict[str, Any]] = []

if not many_relationship_identifiers:
Expand Down
4 changes: 3 additions & 1 deletion infrahub_sdk/transfer/importer/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def __init__(
topological_sorter: InfrahubSchemaTopologicalSorter,
continue_on_error: bool = False,
console: Console | None = None,
allow_upsert: bool = False,
) -> None:
self.client = client
self.topological_sorter = topological_sorter
self.continue_on_error = continue_on_error
self.console = console
self.allow_upsert = allow_upsert
self.all_nodes: dict[str, InfrahubNode] = {}
self.schemas_by_kind: Mapping[str, NodeSchema] = {}
# Map relationship schema by attribute of a node kind e.g. {"MyNodeKind": {"MyRelationship": RelationshipSchema}}
Expand Down Expand Up @@ -88,7 +90,7 @@ async def import_data(self, import_directory: Path, branch: str) -> None:
if not schema_import_nodes:
continue
for node in schema_import_nodes:
save_batch.add(task=node.create, node=node, allow_upsert=True)
save_batch.add(task=node.create, node=node, allow_upsert=self.allow_upsert)

await self.execute_batches([save_batch], "Creating and/or updating nodes")

Expand Down
8 changes: 6 additions & 2 deletions tests/integration/test_export_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ async def test_step05_import_initial_dataset_with_existing_data(
nodes = await client.all(kind=kind)
counters[kind] = len(nodes)

importer = LineDelimitedJSONImporter(client=client, topological_sorter=InfrahubSchemaTopologicalSorter())
importer = LineDelimitedJSONImporter(
client=client, topological_sorter=InfrahubSchemaTopologicalSorter(), allow_upsert=True
)
await importer.import_data(import_directory=temporary_directory, branch="main")

for kind in (TESTING_PERSON, TESTING_CAR, TESTING_MANUFACTURER):
Expand Down Expand Up @@ -311,7 +313,9 @@ async def test_step03_import_initial_dataset_with_existing_data(
await node.tags.fetch()
relationship_count_before += len(node.tags.peers)

importer = LineDelimitedJSONImporter(client=client, topological_sorter=InfrahubSchemaTopologicalSorter())
importer = LineDelimitedJSONImporter(
client=client, topological_sorter=InfrahubSchemaTopologicalSorter(), allow_upsert=True
)
await importer.import_data(import_directory=temporary_directory, branch="main")

for kind in (TESTING_CAR, TESTING_MANUFACTURER):
Expand Down
30 changes: 29 additions & 1 deletion tests/unit/sdk/graphql/test_renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from infrahub_sdk.graphql.renderers import render_input_block, render_query_block
from infrahub_sdk.graphql.renderers import convert_to_graphql_as_string, render_input_block, render_query_block


def test_render_query_block(query_data_no_filter: dict[str, Any]) -> None:
Expand Down Expand Up @@ -145,3 +145,31 @@ def test_render_input_block(input_data_01: dict[str, Any]) -> None:
" }",
]
assert lines == expected_lines


class RelatedNodeLikeObject:
"""Mock object that looks like a RelatedNode (has _generate_input_data and id)."""

def __init__(self, node_id: str) -> None:
self._id = node_id

@property
def id(self) -> str:
return self._id

def _generate_input_data(self) -> dict[str, Any]:
return {"id": self._id}


def test_convert_to_graphql_as_string_handles_related_node_like_object() -> None:
"""Test that convert_to_graphql_as_string handles objects with _generate_input_data and id."""
# This tests the defensive check added to handle cases where a RelatedNode-like
# object somehow gets passed to convert_to_graphql_as_string without being
# converted to a dict first
related_node_like = RelatedNodeLikeObject("test-uuid-789")

result = convert_to_graphql_as_string(related_node_like)

# Should extract the id and convert it properly, not produce the object repr
assert result == '"test-uuid-789"'
assert "RelatedNodeLikeObject" not in result
84 changes: 84 additions & 0 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3271,3 +3271,87 @@ def test_relationship_manager_generate_query_data_without_include_metadata() ->
assert "count" in data
assert "edges" in data
assert "node" in data["edges"]


class TestRelatedNodeAsData:
"""Test that RelatedNodeBase correctly handles another RelatedNode as input data."""

def test_related_node_extracts_id_from_another_related_node(self, location_schema: NodeSchemaAPI) -> None:
"""When passing a RelatedNode as data, the id should be extracted correctly."""
# First, create a RelatedNode with a string ID
original_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data={"id": "original-uuid-123", "__typename": "BuiltinTag"},
)
assert original_related_node.id == "original-uuid-123"

# Now create another RelatedNode passing the first one as data
# This simulates what happens when doing: node.parent = another_related_node
new_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data=original_related_node,
)

# The new RelatedNode should have extracted the ID from the original
assert new_related_node.id == "original-uuid-123"
assert isinstance(new_related_node.id, str)

def test_related_node_generate_input_data_returns_string_id(self, location_schema: NodeSchemaAPI) -> None:
"""_generate_input_data should return string id, not a RelatedNode object."""
# Create a RelatedNode with a string ID
original_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data={"id": "original-uuid-456", "__typename": "BuiltinTag"},
)

# Create another RelatedNode passing the first one as data
new_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data=original_related_node,
)

# _generate_input_data should return a dict with string id
input_data = new_related_node._generate_input_data()
assert input_data == {"id": "original-uuid-456"}
assert isinstance(input_data["id"], str)

def test_related_node_extracts_hfid_from_another_related_node(self, location_schema: NodeSchemaAPI) -> None:
"""When passing a RelatedNode with hfid as data, the hfid should be extracted correctly."""
# Create a RelatedNode with an hfid
original_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data={"hfid": ["Namespace", "Name"], "__typename": "BuiltinTag"},
)
assert original_related_node.hfid == ["Namespace", "Name"]

# Create another RelatedNode passing the first one as data
new_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data=original_related_node,
)

# The new RelatedNode should have extracted the hfid from the original
assert new_related_node.hfid == ["Namespace", "Name"]

def test_related_node_extracts_typename_from_another_related_node(self, location_schema: NodeSchemaAPI) -> None:
"""When passing a RelatedNode as data, the typename should be extracted correctly."""
original_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data={"id": "test-id", "__typename": "BuiltinTag"},
)
assert original_related_node.typename == "BuiltinTag"

new_related_node = RelatedNodeBase(
branch="main",
schema=location_schema.relationships[0],
data=original_related_node,
)

assert new_related_node.typename == "BuiltinTag"
Loading