From ca2c623002f911e8a9f4dbb8d2b6ec35de43d6e7 Mon Sep 17 00:00:00 2001 From: Alex Gittings Date: Thu, 22 Jan 2026 15:09:08 +0000 Subject: [PATCH 1/2] Add allow_upsert option to LineDelimitedJSONImporter and load function; enhance handling of RelatedNode inputs --- infrahub_sdk/ctl/importer.py | 4 ++ infrahub_sdk/graphql/renderers.py | 7 +++ infrahub_sdk/node/related_node.py | 4 ++ infrahub_sdk/transfer/importer/json.py | 4 +- tests/integration/test_export_import.py | 8 ++- tests/unit/sdk/graphql/test_renderer.py | 30 ++++++++- tests/unit/sdk/test_node.py | 84 +++++++++++++++++++++++++ 7 files changed, 137 insertions(+), 4 deletions(-) diff --git a/infrahub_sdk/ctl/importer.py b/infrahub_sdk/ctl/importer.py index 420c6d75..1838ee03 100644 --- a/infrahub_sdk/ctl/importer.py +++ b/infrahub_sdk/ctl/importer.py @@ -32,6 +32,9 @@ def load( envvar="INFRAHUB_MAX_CONCURRENT_EXECUTION", ), timeout: int = typer.Option(60, help="Timeout in sec", envvar="INFRAHUB_TIMEOUT"), + allow_upsert: bool = typer.Option( + False, help="Use Upsert mutations instead of Create. Use when objects may already exist." + ), ) -> None: """Import nodes and their relationships into the database.""" console = Console() @@ -45,6 +48,7 @@ def load( InfrahubSchemaTopologicalSorter(), continue_on_error=continue_on_error, console=Console() if not quiet else None, + allow_upsert=allow_upsert, ) try: aiorun(importer.import_data(import_directory=directory, branch=branch)) diff --git a/infrahub_sdk/graphql/renderers.py b/infrahub_sdk/graphql/renderers.py index 91b77526..ecb7cc3c 100644 --- a/infrahub_sdk/graphql/renderers.py +++ b/infrahub_sdk/graphql/renderers.py @@ -85,6 +85,13 @@ def convert_to_graphql_as_string(value: Any, convert_enum: bool = False) -> str: + " }" ) + # Defensive check: if value looks like a RelatedNode (has _generate_input_data method), + # extract its id to avoid serializing the object repr + if hasattr(value, "_generate_input_data") and hasattr(value, "id"): + node_id = getattr(value, "id", None) + if node_id is not None: + return convert_to_graphql_as_string(value=node_id, convert_enum=convert_enum) + return str(value) diff --git a/infrahub_sdk/node/related_node.py b/infrahub_sdk/node/related_node.py index 5b46a8f7..977a34e4 100644 --- a/infrahub_sdk/node/related_node.py +++ b/infrahub_sdk/node/related_node.py @@ -49,6 +49,10 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict, setattr(self, prop, None) self._relationship_metadata = None + elif isinstance(data, RelatedNodeBase): + # Handle when value is already a RelatedNode - extract its identifying data + data = {"id": data.id, "hfid": data.hfid, "__typename": data.typename} + elif isinstance(data, list): data = {"hfid": data} elif not isinstance(data, dict): diff --git a/infrahub_sdk/transfer/importer/json.py b/infrahub_sdk/transfer/importer/json.py index 9c0b7ab9..15f93380 100644 --- a/infrahub_sdk/transfer/importer/json.py +++ b/infrahub_sdk/transfer/importer/json.py @@ -31,11 +31,13 @@ def __init__( topological_sorter: InfrahubSchemaTopologicalSorter, continue_on_error: bool = False, console: Console | None = None, + allow_upsert: bool = False, ) -> None: self.client = client self.topological_sorter = topological_sorter self.continue_on_error = continue_on_error self.console = console + self.allow_upsert = allow_upsert self.all_nodes: dict[str, InfrahubNode] = {} self.schemas_by_kind: Mapping[str, NodeSchema] = {} # Map relationship schema by attribute of a node kind e.g. {"MyNodeKind": {"MyRelationship": RelationshipSchema}} @@ -88,7 +90,7 @@ async def import_data(self, import_directory: Path, branch: str) -> None: if not schema_import_nodes: continue for node in schema_import_nodes: - save_batch.add(task=node.create, node=node, allow_upsert=True) + save_batch.add(task=node.create, node=node, allow_upsert=self.allow_upsert) await self.execute_batches([save_batch], "Creating and/or updating nodes") diff --git a/tests/integration/test_export_import.py b/tests/integration/test_export_import.py index a138728f..794c003a 100644 --- a/tests/integration/test_export_import.py +++ b/tests/integration/test_export_import.py @@ -172,7 +172,9 @@ async def test_step05_import_initial_dataset_with_existing_data( nodes = await client.all(kind=kind) counters[kind] = len(nodes) - importer = LineDelimitedJSONImporter(client=client, topological_sorter=InfrahubSchemaTopologicalSorter()) + importer = LineDelimitedJSONImporter( + client=client, topological_sorter=InfrahubSchemaTopologicalSorter(), allow_upsert=True + ) await importer.import_data(import_directory=temporary_directory, branch="main") for kind in (TESTING_PERSON, TESTING_CAR, TESTING_MANUFACTURER): @@ -311,7 +313,9 @@ async def test_step03_import_initial_dataset_with_existing_data( await node.tags.fetch() relationship_count_before += len(node.tags.peers) - importer = LineDelimitedJSONImporter(client=client, topological_sorter=InfrahubSchemaTopologicalSorter()) + importer = LineDelimitedJSONImporter( + client=client, topological_sorter=InfrahubSchemaTopologicalSorter(), allow_upsert=True + ) await importer.import_data(import_directory=temporary_directory, branch="main") for kind in (TESTING_CAR, TESTING_MANUFACTURER): diff --git a/tests/unit/sdk/graphql/test_renderer.py b/tests/unit/sdk/graphql/test_renderer.py index 642d688f..98d9c2fc 100644 --- a/tests/unit/sdk/graphql/test_renderer.py +++ b/tests/unit/sdk/graphql/test_renderer.py @@ -1,6 +1,6 @@ from typing import Any -from infrahub_sdk.graphql.renderers import render_input_block, render_query_block +from infrahub_sdk.graphql.renderers import convert_to_graphql_as_string, render_input_block, render_query_block def test_render_query_block(query_data_no_filter: dict[str, Any]) -> None: @@ -145,3 +145,31 @@ def test_render_input_block(input_data_01: dict[str, Any]) -> None: " }", ] assert lines == expected_lines + + +class RelatedNodeLikeObject: + """Mock object that looks like a RelatedNode (has _generate_input_data and id).""" + + def __init__(self, node_id: str) -> None: + self._id = node_id + + @property + def id(self) -> str: + return self._id + + def _generate_input_data(self) -> dict[str, Any]: + return {"id": self._id} + + +def test_convert_to_graphql_as_string_handles_related_node_like_object() -> None: + """Test that convert_to_graphql_as_string handles objects with _generate_input_data and id.""" + # This tests the defensive check added to handle cases where a RelatedNode-like + # object somehow gets passed to convert_to_graphql_as_string without being + # converted to a dict first + related_node_like = RelatedNodeLikeObject("test-uuid-789") + + result = convert_to_graphql_as_string(related_node_like) + + # Should extract the id and convert it properly, not produce the object repr + assert result == '"test-uuid-789"' + assert "RelatedNodeLikeObject" not in result diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py index 74434a92..9abea03f 100644 --- a/tests/unit/sdk/test_node.py +++ b/tests/unit/sdk/test_node.py @@ -3271,3 +3271,87 @@ def test_relationship_manager_generate_query_data_without_include_metadata() -> assert "count" in data assert "edges" in data assert "node" in data["edges"] + + +class TestRelatedNodeAsData: + """Test that RelatedNodeBase correctly handles another RelatedNode as input data.""" + + def test_related_node_extracts_id_from_another_related_node(self, location_schema: NodeSchemaAPI) -> None: + """When passing a RelatedNode as data, the id should be extracted correctly.""" + # First, create a RelatedNode with a string ID + original_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data={"id": "original-uuid-123", "__typename": "BuiltinTag"}, + ) + assert original_related_node.id == "original-uuid-123" + + # Now create another RelatedNode passing the first one as data + # This simulates what happens when doing: node.parent = another_related_node + new_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data=original_related_node, + ) + + # The new RelatedNode should have extracted the ID from the original + assert new_related_node.id == "original-uuid-123" + assert isinstance(new_related_node.id, str) + + def test_related_node_generate_input_data_returns_string_id(self, location_schema: NodeSchemaAPI) -> None: + """_generate_input_data should return string id, not a RelatedNode object.""" + # Create a RelatedNode with a string ID + original_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data={"id": "original-uuid-456", "__typename": "BuiltinTag"}, + ) + + # Create another RelatedNode passing the first one as data + new_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data=original_related_node, + ) + + # _generate_input_data should return a dict with string id + input_data = new_related_node._generate_input_data() + assert input_data == {"id": "original-uuid-456"} + assert isinstance(input_data["id"], str) + + def test_related_node_extracts_hfid_from_another_related_node(self, location_schema: NodeSchemaAPI) -> None: + """When passing a RelatedNode with hfid as data, the hfid should be extracted correctly.""" + # Create a RelatedNode with an hfid + original_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data={"hfid": ["Namespace", "Name"], "__typename": "BuiltinTag"}, + ) + assert original_related_node.hfid == ["Namespace", "Name"] + + # Create another RelatedNode passing the first one as data + new_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data=original_related_node, + ) + + # The new RelatedNode should have extracted the hfid from the original + assert new_related_node.hfid == ["Namespace", "Name"] + + def test_related_node_extracts_typename_from_another_related_node(self, location_schema: NodeSchemaAPI) -> None: + """When passing a RelatedNode as data, the typename should be extracted correctly.""" + original_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data={"id": "test-id", "__typename": "BuiltinTag"}, + ) + assert original_related_node.typename == "BuiltinTag" + + new_related_node = RelatedNodeBase( + branch="main", + schema=location_schema.relationships[0], + data=original_related_node, + ) + + assert new_related_node.typename == "BuiltinTag" From ce6441e99b66702802e8057c1684737c74f668d9 Mon Sep 17 00:00:00 2001 From: Alex Gittings Date: Fri, 23 Jan 2026 10:06:40 +0000 Subject: [PATCH 2/2] Support optional fields --- infrahub_sdk/transfer/exporter/json.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/infrahub_sdk/transfer/exporter/json.py b/infrahub_sdk/transfer/exporter/json.py index 077ee8fa..5b1d26ad 100644 --- a/infrahub_sdk/transfer/exporter/json.py +++ b/infrahub_sdk/transfer/exporter/json.py @@ -34,31 +34,26 @@ def wrapped_task_output(self, start: str, end: str = "[green]done") -> Generator if self.console: self.console.print(f"{end}") - def identify_many_to_many_relationships( - self, node_schema_map: dict[str, MainSchemaTypesAPI] - ) -> dict[tuple[str, str], str]: - # Identify many to many relationships by src/dst couples + def identify_many_relationships(self, node_schema_map: dict[str, MainSchemaTypesAPI]) -> dict[tuple[str, str], str]: + # Identify many relationships (both one-way and bidirectional many-to-many) many_relationship_identifiers: dict[tuple[str, str], str] = {} for node_schema in node_schema_map.values(): for relationship in node_schema.relationships: if ( relationship.cardinality != "many" - or not relationship.optional or not relationship.identifier or relationship.peer not in node_schema_map ): continue - for peer_relationship in node_schema_map[relationship.peer].relationships: - if peer_relationship.cardinality != "many" or peer_relationship.peer != node_schema.kind: - continue - forward = many_relationship_identifiers.get((node_schema.kind, relationship.peer)) - backward = many_relationship_identifiers.get((relationship.peer, node_schema.kind)) + forward = many_relationship_identifiers.get((node_schema.kind, relationship.peer)) + backward = many_relationship_identifiers.get((relationship.peer, node_schema.kind)) - # Record the relationship only if it's not known in one way or another - if not forward and not backward: - many_relationship_identifiers[node_schema.kind, relationship.peer] = relationship.identifier + # Record the relationship only if it's not known in one way or another + # This avoids duplicating bidirectional many-to-many relationships + if not forward and not backward: + many_relationship_identifiers[node_schema.kind, relationship.peer] = relationship.identifier return many_relationship_identifiers @@ -69,7 +64,7 @@ async def retrieve_many_to_many_relationships( page_number = 1 page_size = 50 - many_relationship_identifiers = list(self.identify_many_to_many_relationships(node_schema_map).values()) + many_relationship_identifiers = list(self.identify_many_relationships(node_schema_map).values()) many_relationships: list[dict[str, Any]] = [] if not many_relationship_identifiers: