From f98cdbf2def71d05509387e27d42a9f1a769c448 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 12:20:06 -0600 Subject: [PATCH 01/17] feat(diagram): add direction, Mermaid output, and schema grouping Bug fixes: - Fix isdigit() missing parentheses in _make_graph - Fix nested list creation in _make_graph - Remove dead code in make_dot - Fix invalid color code for Part tier - Replace eval() with safe _resolve_class() method New features: - Add direction parameter ("TB", "LR", "BT", "RL") for layout control - Add make_mermaid() method for web-friendly diagram output - Add group_by_schema parameter to cluster nodes by database schema - Update save() to support .mmd/.mermaid file extensions Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 230 ++++++++++++++++++++++++++++++++++---- src/datajoint/settings.py | 6 + 2 files changed, 216 insertions(+), 20 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index b06686025..def151207 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -16,6 +16,7 @@ from .dependencies import topo_sort from .errors import DataJointError +from .settings import config from .table import Table, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier @@ -90,6 +91,12 @@ class Diagram(nx.DiGraph): ----- ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``. Only tables loaded in the connection are displayed. + + Layout direction is controlled via ``dj.config.display.diagram_direction`` + (default ``"TB"``). Use ``dj.config.override()`` to change temporarily:: + + with dj.config.override(display_diagram_direction="LR"): + dj.Diagram(schema).draw() """ def __init__(self, source, context=None) -> None: @@ -286,18 +293,42 @@ def _make_graph(self) -> nx.DiGraph: gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection( nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show) ) - nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit) + nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit()) # construct subgraph and rename nodes to class names graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph}) # relabel nodes to class names mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()} - new_names = [mapping.values()] + new_names = list(mapping.values()) if len(new_names) > len(set(new_names)): raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.") nx.relabel_nodes(graph, mapping, copy=False) return graph + def _resolve_class(self, name: str): + """ + Safely resolve a table class from a dotted name without eval(). + + Parameters + ---------- + name : str + Dotted class name like "MyTable" or "Module.MyTable". + + Returns + ------- + type or None + The table class if found, otherwise None. + """ + parts = name.split(".") + obj = self.context.get(parts[0]) + for part in parts[1:]: + if obj is None: + return None + obj = getattr(obj, part, None) + if obj is not None and isinstance(obj, type) and issubclass(obj, Table): + return obj + return None + @staticmethod def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None: """ @@ -330,9 +361,39 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None: copy=False, ) - def make_dot(self): + def make_dot(self, group_by_schema: bool = False): + """ + Generate a pydot graph object. + + Parameters + ---------- + group_by_schema : bool, optional + If True, group nodes into clusters by their database schema. + Default False. + + Returns + ------- + pydot.Dot + The graph object ready for rendering. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + """ + direction = config.display.diagram_direction graph = self._make_graph() - graph.nodes() + + # Build schema mapping if grouping is requested + schema_map = {} + if group_by_schema: + for full_name in self.nodes_to_show: + # Extract schema from full table name like `schema`.`table` or "schema"."table" + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2: + schema_name = parts[1] # schema is between first pair of backticks + # Find the class name for this full_name + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name scale = 1.2 # scaling factor for fonts and boxes label_props = { # http://matplotlib.org/examples/color/named_colors.html @@ -386,7 +447,7 @@ def make_dot(self): ), Part: dict( shape="plaintext", - color="#0000000", + color="#00000000", fontcolor="black", fontsize=round(scale * 8), size=0.1 * scale, @@ -398,6 +459,7 @@ def make_dot(self): self._encapsulate_node_names(graph) self._encapsulate_edge_attributes(graph) dot = nx.drawing.nx_pydot.to_pydot(graph) + dot.set_rankdir(direction) for node in dot.get_nodes(): node.set_shape("circle") name = node.get_name().strip('"') @@ -409,9 +471,8 @@ def make_dot(self): node.set_fixedsize("shape" if props["fixed"] else False) node.set_width(props["size"]) node.set_height(props["size"]) - if name.split(".")[0] in self.context: - cls = eval(name, self.context) - assert issubclass(cls, Table) + cls = self._resolve_class(name) + if cls is not None: description = cls().describe(context=self.context).split("\n") description = ( ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) @@ -437,34 +498,148 @@ def make_dot(self): edge.set_arrowhead("none") edge.set_penwidth(0.75 if props["multi"] else 2) + # Group nodes into schema clusters if requested + if group_by_schema and schema_map: + import pydot + + # Group nodes by schema + schemas = {} + for node in list(dot.get_nodes()): + name = node.get_name().strip('"') + schema_name = schema_map.get(name) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append(node) + + # Create clusters for each schema + for schema_name, nodes in schemas.items(): + cluster = pydot.Cluster( + f"cluster_{schema_name}", + label=schema_name, + style="dashed", + color="gray", + fontcolor="gray", + ) + for node in nodes: + cluster.add_node(node) + dot.add_subgraph(cluster) + return dot - def make_svg(self): + def make_svg(self, group_by_schema: bool = False): from IPython.display import SVG - return SVG(self.make_dot().create_svg()) + return SVG(self.make_dot(group_by_schema=group_by_schema).create_svg()) - def make_png(self): - return io.BytesIO(self.make_dot().create_png()) + def make_png(self, group_by_schema: bool = False): + return io.BytesIO(self.make_dot(group_by_schema=group_by_schema).create_png()) - def make_image(self): + def make_image(self, group_by_schema: bool = False): if plot_active: - return plt.imread(self.make_png()) + return plt.imread(self.make_png(group_by_schema=group_by_schema)) else: raise DataJointError("pyplot was not imported") + def make_mermaid(self) -> str: + """ + Generate Mermaid diagram syntax. + + Produces a flowchart in Mermaid syntax that can be rendered in + Markdown documentation, GitHub, or https://mermaid.live. + + Returns + ------- + str + Mermaid flowchart syntax. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + + Examples + -------- + >>> print(dj.Diagram(schema).make_mermaid()) + flowchart TB + Mouse[Mouse]:::manual + Session[Session]:::manual + Neuron([Neuron]):::computed + Mouse --> Session + Session --> Neuron + """ + graph = self._make_graph() + direction = config.display.diagram_direction + + lines = [f"flowchart {direction}"] + + # Define class styles matching Graphviz colors + lines.append(" classDef manual fill:#90EE90,stroke:#006400") + lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969") + lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000") + lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B") + lines.append(" classDef part fill:#FFFFFF,stroke:#000000") + lines.append("") + + # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box + shape_map = { + Manual: ("[", "]"), # box + Lookup: ("[", "]"), # box + Computed: ("([", "])"), # stadium/pill + Imported: ("([", "])"), # stadium/pill + Part: ("[", "]"), # box + _AliasNode: ("((", "))"), # circle + None: ("((", "))"), # circle + } + + tier_class = { + Manual: "manual", + Lookup: "lookup", + Computed: "computed", + Imported: "imported", + Part: "part", + _AliasNode: "", + None: "", + } + + # Add nodes + for node, data in graph.nodes(data=True): + tier = data.get("node_type") + left, right = shape_map.get(tier, ("[", "]")) + cls = tier_class.get(tier, "") + # Mermaid node IDs can't have dots, replace with underscores + safe_id = node.replace(".", "_").replace(" ", "_") + class_suffix = f":::{cls}" if cls else "" + lines.append(f" {safe_id}{left}{node}{right}{class_suffix}") + + lines.append("") + + # Add edges + for src, dest, data in graph.edges(data=True): + safe_src = src.replace(".", "_").replace(" ", "_") + safe_dest = dest.replace(".", "_").replace(" ", "_") + # Solid arrow for primary FK, dotted for non-primary + style = "-->" if data.get("primary") else "-.->" + lines.append(f" {safe_src} {style} {safe_dest}") + + return "\n".join(lines) + def _repr_svg_(self): return self.make_svg()._repr_svg_() - def draw(self): + def draw(self, group_by_schema: bool = False): if plot_active: - plt.imshow(self.make_image()) + plt.imshow(self.make_image(group_by_schema=group_by_schema)) plt.gca().axis("off") plt.show() else: raise DataJointError("pyplot was not imported") - def save(self, filename: str, format: str | None = None) -> None: + def save( + self, + filename: str, + format: str | None = None, + group_by_schema: bool = False, + ) -> None: """ Save diagram to file. @@ -473,24 +648,39 @@ def save(self, filename: str, format: str | None = None) -> None: filename : str Output filename. format : str, optional - File format (``'png'`` or ``'svg'``). Inferred from extension if None. + File format (``'png'``, ``'svg'``, or ``'mermaid'``). + Inferred from extension if None. + group_by_schema : bool, optional + If True, group nodes into clusters by their database schema. + Default False. Only applies to png and svg formats. Raises ------ DataJointError If format is unsupported. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. """ if format is None: if filename.lower().endswith(".png"): format = "png" elif filename.lower().endswith(".svg"): format = "svg" + elif filename.lower().endswith((".mmd", ".mermaid")): + format = "mermaid" + if format is None: + raise DataJointError("Could not infer format from filename. Specify format explicitly.") if format.lower() == "png": with open(filename, "wb") as f: - f.write(self.make_png().getbuffer().tobytes()) + f.write(self.make_png(group_by_schema=group_by_schema).getbuffer().tobytes()) elif format.lower() == "svg": with open(filename, "w") as f: - f.write(self.make_svg().data) + f.write(self.make_svg(group_by_schema=group_by_schema).data) + elif format.lower() == "mermaid": + with open(filename, "w") as f: + f.write(self.make_mermaid()) else: raise DataJointError("Unsupported file format") diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index ca57a00c6..445aaf54e 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -66,6 +66,7 @@ "database.backend": "DJ_BACKEND", "database.port": "DJ_PORT", "loglevel": "DJ_LOG_LEVEL", + "display.diagram_direction": "DJ_DIAGRAM_DIRECTION", } Role = Enum("Role", "manual lookup imported computed job") @@ -221,6 +222,11 @@ class DisplaySettings(BaseSettings): limit: int = 12 width: int = 14 show_tuple_count: bool = True + diagram_direction: Literal["TB", "LR"] = Field( + default="TB", + validation_alias="DJ_DIAGRAM_DIRECTION", + description="Default diagram layout direction: 'TB' (top-to-bottom) or 'LR' (left-to-right)", + ) class StoresSettings(BaseSettings): From 0dd5a69cc6ae5700d39306edda2613596dd38c37 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 13:47:01 -0600 Subject: [PATCH 02/17] feat: always group diagram nodes by schema with module labels - Remove group_by_schema parameter (always enabled) - Show Python module name as cluster label when available - Assign alias nodes (orange dots) to child table's schema - Add schema grouping (subgraphs) to Mermaid output Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 149 +++++++++++++++++++++++++-------------- 1 file changed, 98 insertions(+), 51 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index def151207..148bbdcfd 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -361,16 +361,10 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None: copy=False, ) - def make_dot(self, group_by_schema: bool = False): + def make_dot(self): """ Generate a pydot graph object. - Parameters - ---------- - group_by_schema : bool, optional - If True, group nodes into clusters by their database schema. - Default False. - Returns ------- pydot.Dot @@ -379,21 +373,39 @@ def make_dot(self, group_by_schema: bool = False): Notes ----- Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. """ direction = config.display.diagram_direction graph = self._make_graph() - # Build schema mapping if grouping is requested - schema_map = {} - if group_by_schema: - for full_name in self.nodes_to_show: - # Extract schema from full table name like `schema`.`table` or "schema"."table" - parts = full_name.replace('"', '`').split('`') - if len(parts) >= 2: - schema_name = parts[1] # schema is between first pair of backticks - # Find the class name for this full_name - class_name = lookup_class_name(full_name, self.context) or full_name - schema_map[class_name] = schema_name + # Build schema mapping: class_name -> (schema_name, module_name) + # Group by database schema, but label with Python module name when available + schema_map = {} # class_name -> schema_name + module_map = {} # schema_name -> module_name (for cluster labels) + + for full_name in self.nodes_to_show: + # Extract schema from full table name like `schema`.`table` or "schema"."table" + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2: + schema_name = parts[1] # schema is between first pair of backticks + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Try to get Python module name for the cluster label + if schema_name not in module_map: + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + # Use the last part of the module path (e.g., "my_pipeline" from "package.my_pipeline") + module_map[schema_name] = cls.__module__.split(".")[-1] + + # Assign alias nodes (orange dots) to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + # Find the child (successor) - the table that declares the renamed FK + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] scale = 1.2 # scaling factor for fonts and boxes label_props = { # http://matplotlib.org/examples/color/named_colors.html @@ -498,8 +510,8 @@ def make_dot(self, group_by_schema: bool = False): edge.set_arrowhead("none") edge.set_penwidth(0.75 if props["multi"] else 2) - # Group nodes into schema clusters if requested - if group_by_schema and schema_map: + # Group nodes into schema clusters (always on) + if schema_map: import pydot # Group nodes by schema @@ -513,10 +525,12 @@ def make_dot(self, group_by_schema: bool = False): schemas[schema_name].append(node) # Create clusters for each schema + # Use Python module name as label when available, otherwise database schema name for schema_name, nodes in schemas.items(): + label = module_map.get(schema_name, schema_name) cluster = pydot.Cluster( f"cluster_{schema_name}", - label=schema_name, + label=label, style="dashed", color="gray", fontcolor="gray", @@ -527,17 +541,17 @@ def make_dot(self, group_by_schema: bool = False): return dot - def make_svg(self, group_by_schema: bool = False): + def make_svg(self): from IPython.display import SVG - return SVG(self.make_dot(group_by_schema=group_by_schema).create_svg()) + return SVG(self.make_dot().create_svg()) - def make_png(self, group_by_schema: bool = False): - return io.BytesIO(self.make_dot(group_by_schema=group_by_schema).create_png()) + def make_png(self): + return io.BytesIO(self.make_dot().create_png()) - def make_image(self, group_by_schema: bool = False): + def make_image(self): if plot_active: - return plt.imread(self.make_png(group_by_schema=group_by_schema)) + return plt.imread(self.make_png()) else: raise DataJointError("pyplot was not imported") @@ -556,20 +570,47 @@ def make_mermaid(self) -> str: Notes ----- Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema using Mermaid subgraphs, with the Python + module name shown as the group label when available. Examples -------- >>> print(dj.Diagram(schema).make_mermaid()) flowchart TB - Mouse[Mouse]:::manual - Session[Session]:::manual - Neuron([Neuron]):::computed + subgraph my_pipeline + Mouse[Mouse]:::manual + Session[Session]:::manual + Neuron([Neuron]):::computed + end Mouse --> Session Session --> Neuron """ graph = self._make_graph() direction = config.display.diagram_direction + # Build schema mapping for grouping + schema_map = {} # class_name -> schema_name + module_map = {} # schema_name -> module_name (for subgraph labels) + + for full_name in self.nodes_to_show: + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2: + schema_name = parts[1] + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + if schema_name not in module_map: + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_map[schema_name] = cls.__module__.split(".")[-1] + + # Assign alias nodes to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + lines = [f"flowchart {direction}"] # Define class styles matching Graphviz colors @@ -601,15 +642,27 @@ def make_mermaid(self) -> str: None: "", } - # Add nodes + # Group nodes by schema into subgraphs + schemas = {} for node, data in graph.nodes(data=True): - tier = data.get("node_type") - left, right = shape_map.get(tier, ("[", "]")) - cls = tier_class.get(tier, "") - # Mermaid node IDs can't have dots, replace with underscores - safe_id = node.replace(".", "_").replace(" ", "_") - class_suffix = f":::{cls}" if cls else "" - lines.append(f" {safe_id}{left}{node}{right}{class_suffix}") + schema_name = schema_map.get(node) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append((node, data)) + + # Add nodes grouped by schema subgraphs + for schema_name, nodes in schemas.items(): + label = module_map.get(schema_name, schema_name) + lines.append(f" subgraph {label}") + for node, data in nodes: + tier = data.get("node_type") + left, right = shape_map.get(tier, ("[", "]")) + cls = tier_class.get(tier, "") + safe_id = node.replace(".", "_").replace(" ", "_") + class_suffix = f":::{cls}" if cls else "" + lines.append(f" {safe_id}{left}{node}{right}{class_suffix}") + lines.append(" end") lines.append("") @@ -626,20 +679,15 @@ def make_mermaid(self) -> str: def _repr_svg_(self): return self.make_svg()._repr_svg_() - def draw(self, group_by_schema: bool = False): + def draw(self): if plot_active: - plt.imshow(self.make_image(group_by_schema=group_by_schema)) + plt.imshow(self.make_image()) plt.gca().axis("off") plt.show() else: raise DataJointError("pyplot was not imported") - def save( - self, - filename: str, - format: str | None = None, - group_by_schema: bool = False, - ) -> None: + def save(self, filename: str, format: str | None = None) -> None: """ Save diagram to file. @@ -650,9 +698,6 @@ def save( format : str, optional File format (``'png'``, ``'svg'``, or ``'mermaid'``). Inferred from extension if None. - group_by_schema : bool, optional - If True, group nodes into clusters by their database schema. - Default False. Only applies to png and svg formats. Raises ------ @@ -662,6 +707,8 @@ def save( Notes ----- Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. """ if format is None: if filename.lower().endswith(".png"): @@ -674,10 +721,10 @@ def save( raise DataJointError("Could not infer format from filename. Specify format explicitly.") if format.lower() == "png": with open(filename, "wb") as f: - f.write(self.make_png(group_by_schema=group_by_schema).getbuffer().tobytes()) + f.write(self.make_png().getbuffer().tobytes()) elif format.lower() == "svg": with open(filename, "w") as f: - f.write(self.make_svg(group_by_schema=group_by_schema).data) + f.write(self.make_svg().data) elif format.lower() == "mermaid": with open(filename, "w") as f: f.write(self.make_mermaid()) From 903e6b2cb0b7910323b1cc9bc3af64d299ef86a8 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 13:51:43 -0600 Subject: [PATCH 03/17] chore: bump version to 2.1.0a6 --- src/datajoint/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datajoint/version.py b/src/datajoint/version.py index 2ffb3afa8..535dd4134 100644 --- a/src/datajoint/version.py +++ b/src/datajoint/version.py @@ -1,4 +1,4 @@ # version bump auto managed by Github Actions: # label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit) # manually set this version will be eventually overwritten by the above actions -__version__ = "2.1.0a5" +__version__ = "2.1.0a6" From d41b75f1117f51cca0600f95ef5f966aa7832f35 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 13:58:35 -0600 Subject: [PATCH 04/17] feat: improve schema grouping labels with fallback logic - Collect all module names per schema, not just the first - Use Python module name as label if 1:1 mapping with schema - Fall back to database schema name if multiple modules - Strip module prefix from class names when it matches cluster label Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 71 +++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 148bbdcfd..72f79f3fd 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -379,10 +379,10 @@ def make_dot(self): direction = config.display.diagram_direction graph = self._make_graph() - # Build schema mapping: class_name -> (schema_name, module_name) - # Group by database schema, but label with Python module name when available + # Build schema mapping: class_name -> schema_name + # Group by database schema, label with Python module name if 1:1 mapping schema_map = {} # class_name -> schema_name - module_map = {} # schema_name -> module_name (for cluster labels) + schema_modules = {} # schema_name -> set of module names for full_name in self.nodes_to_show: # Extract schema from full table name like `schema`.`table` or "schema"."table" @@ -392,12 +392,21 @@ def make_dot(self): class_name = lookup_class_name(full_name, self.context) or full_name schema_map[class_name] = schema_name - # Try to get Python module name for the cluster label - if schema_name not in module_map: - cls = self._resolve_class(class_name) - if cls is not None and hasattr(cls, "__module__"): - # Use the last part of the module path (e.g., "my_pipeline" from "package.my_pipeline") - module_map[schema_name] = cls.__module__.split(".")[-1] + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name # Assign alias nodes (orange dots) to the same schema as their child table for node, data in graph.nodes(data=True): @@ -492,7 +501,14 @@ def make_dot(self): if not q.startswith("#") ) node.set_tooltip(" ".join(description)) - node.set_label("<" + name + ">" if node.get("distinguished") == "True" else name) + # Strip module prefix from label if it matches the cluster label + display_name = name + schema_name = schema_map.get(name) + if schema_name and "." in name: + prefix = name.rsplit(".", 1)[0] + if prefix == cluster_labels.get(schema_name): + display_name = name.rsplit(".", 1)[1] + node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) node.set_color(props["color"]) node.set_style("filled") @@ -525,9 +541,9 @@ def make_dot(self): schemas[schema_name].append(node) # Create clusters for each schema - # Use Python module name as label when available, otherwise database schema name + # Use Python module name if 1:1 mapping, otherwise database schema name for schema_name, nodes in schemas.items(): - label = module_map.get(schema_name, schema_name) + label = cluster_labels.get(schema_name, schema_name) cluster = pydot.Cluster( f"cluster_{schema_name}", label=label, @@ -590,7 +606,7 @@ def make_mermaid(self) -> str: # Build schema mapping for grouping schema_map = {} # class_name -> schema_name - module_map = {} # schema_name -> module_name (for subgraph labels) + schema_modules = {} # schema_name -> set of module names for full_name in self.nodes_to_show: parts = full_name.replace('"', '`').split('`') @@ -599,10 +615,21 @@ def make_mermaid(self) -> str: class_name = lookup_class_name(full_name, self.context) or full_name schema_map[class_name] = schema_name - if schema_name not in module_map: - cls = self._resolve_class(class_name) - if cls is not None and hasattr(cls, "__module__"): - module_map[schema_name] = cls.__module__.split(".")[-1] + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name # Assign alias nodes to the same schema as their child table for node, data in graph.nodes(data=True): @@ -653,15 +680,21 @@ def make_mermaid(self) -> str: # Add nodes grouped by schema subgraphs for schema_name, nodes in schemas.items(): - label = module_map.get(schema_name, schema_name) + label = cluster_labels.get(schema_name, schema_name) lines.append(f" subgraph {label}") for node, data in nodes: tier = data.get("node_type") left, right = shape_map.get(tier, ("[", "]")) cls = tier_class.get(tier, "") safe_id = node.replace(".", "_").replace(" ", "_") + # Strip module prefix from display name if it matches the cluster label + display_name = node + if "." in node: + prefix = node.rsplit(".", 1)[0] + if prefix == label: + display_name = node.rsplit(".", 1)[1] class_suffix = f":::{cls}" if cls else "" - lines.append(f" {safe_id}{left}{node}{right}{class_suffix}") + lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") lines.append(" end") lines.append("") From 80489fc48a01f1557e763d7b626c6fcd631c6d6f Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 14:07:05 -0600 Subject: [PATCH 05/17] feat: add collapse() method for high-level pipeline views - Add collapse() method to mark diagrams for collapsing when combined - Collapsed schemas appear as single nodes showing table count - "Expanded wins" - nodes in non-collapsed diagrams stay expanded - Works with both Graphviz and Mermaid output - Use box3d shape for collapsed nodes in Graphviz Example: dj.Diagram(schema1) + dj.Diagram(schema2).collapse() Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 279 ++++++++++++++++++++++++++++++++++----- 1 file changed, 245 insertions(+), 34 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 72f79f3fd..7f08bd44e 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -103,6 +103,8 @@ def __init__(self, source, context=None) -> None: if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) + self._explicit_nodes = set(source._explicit_nodes) + self._is_collapsed = source._is_collapsed self.context = source.context super().__init__(source) return @@ -130,6 +132,8 @@ def __init__(self, source, context=None) -> None: # Enumerate nodes from all the items in the list self.nodes_to_show = set() + self._explicit_nodes = set() # nodes that should never be collapsed + self._is_collapsed = False # whether this diagram's nodes should be collapsed when combined try: self.nodes_to_show.add(source.full_table_name) except AttributeError: @@ -181,6 +185,31 @@ def is_part(part, master): self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show)) return self + def collapse(self) -> "Diagram": + """ + Mark this diagram for collapsing when combined with other diagrams. + + When a collapsed diagram is added to a non-collapsed diagram, its nodes + are shown as a single collapsed node per schema, unless they also appear + in the non-collapsed diagram (expanded wins). + + Returns + ------- + Diagram + A copy of this diagram marked for collapsing. + + Examples + -------- + >>> # Show schema1 expanded, schema2 collapsed into single nodes + >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse() + + >>> # Explicitly expand one table from schema2 + >>> dj.Diagram(schema1) + dj.Diagram(TableFromSchema2) + dj.Diagram(schema2).collapse() + """ + result = Diagram(self) + result._is_collapsed = True + return result + def __add__(self, arg) -> "Diagram": """ Union or downstream expansion. @@ -195,21 +224,36 @@ def __add__(self, arg) -> "Diagram": Diagram Combined or expanded diagram. """ - self = Diagram(self) # copy + result = Diagram(self) # copy try: - self.nodes_to_show.update(arg.nodes_to_show) + result.nodes_to_show.update(arg.nodes_to_show) + # Handle collapse: nodes from non-collapsed diagrams are explicit (expanded) + if not self._is_collapsed: + result._explicit_nodes.update(self.nodes_to_show) + else: + result._explicit_nodes.update(self._explicit_nodes) + if not arg._is_collapsed: + result._explicit_nodes.update(arg.nodes_to_show) + else: + result._explicit_nodes.update(arg._explicit_nodes) + # Result is not collapsed (it's a combination) + result._is_collapsed = False except AttributeError: try: - self.nodes_to_show.add(arg.full_table_name) + result.nodes_to_show.add(arg.full_table_name) + result._explicit_nodes.add(arg.full_table_name) except AttributeError: for i in range(arg): - new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show) + new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show) if not new: break # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit()))) - self.nodes_to_show.update(new) - return self + new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit()))) + result.nodes_to_show.update(new) + # Expanded nodes from + N expansion are explicit + if not self._is_collapsed: + result._explicit_nodes = result.nodes_to_show.copy() + return result def __sub__(self, arg) -> "Diagram": """ @@ -305,6 +349,131 @@ def _make_graph(self) -> nx.DiGraph: nx.relabel_nodes(graph, mapping, copy=False) return graph + def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]: + """ + Apply collapse logic to the graph. + + Nodes in nodes_to_show but not in _explicit_nodes are collapsed into + single schema nodes. + + Parameters + ---------- + graph : nx.DiGraph + The graph from _make_graph(). + + Returns + ------- + tuple[nx.DiGraph, dict[str, str]] + Modified graph and mapping of collapsed schema labels to their table count. + """ + if not self._explicit_nodes or self._explicit_nodes == self.nodes_to_show: + # No collapse needed + return graph, {} + + # Map full_table_names to class_names + full_to_class = { + node: lookup_class_name(node, self.context) or node + for node in self.nodes_to_show + } + class_to_full = {v: k for k, v in full_to_class.items()} + + # Identify explicit class names (should be expanded) + explicit_class_names = { + full_to_class.get(node, node) for node in self._explicit_nodes + } + + # Identify nodes to collapse (class names) + nodes_to_collapse = set(graph.nodes()) - explicit_class_names + + if not nodes_to_collapse: + return graph, {} + + # Group collapsed nodes by schema + collapsed_by_schema = {} # schema_name -> list of class_names + for class_name in nodes_to_collapse: + full_name = class_to_full.get(class_name) + if full_name: + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2: + schema_name = parts[1] + if schema_name not in collapsed_by_schema: + collapsed_by_schema[schema_name] = [] + collapsed_by_schema[schema_name].append(class_name) + + if not collapsed_by_schema: + return graph, {} + + # Determine labels for collapsed schemas + schema_modules = {} + for schema_name, class_names in collapsed_by_schema.items(): + schema_modules[schema_name] = set() + for class_name in class_names: + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + collapsed_labels = {} # schema_name -> label + collapsed_counts = {} # label -> count of tables + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + label = next(iter(modules)) + else: + label = schema_name + collapsed_labels[schema_name] = label + collapsed_counts[label] = len(collapsed_by_schema[schema_name]) + + # Create new graph with collapsed nodes + new_graph = nx.DiGraph() + + # Map old node names to new names (collapsed nodes -> schema label) + node_mapping = {} + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2 and node in nodes_to_collapse: + schema_name = parts[1] + node_mapping[node] = collapsed_labels[schema_name] + else: + node_mapping[node] = node + else: + # Alias nodes - check if they should be collapsed + # An alias node should be collapsed if ALL its neighbors are collapsed + neighbors = set(graph.predecessors(node)) | set(graph.successors(node)) + if neighbors and neighbors <= nodes_to_collapse: + # Get schema from first neighbor + neighbor = next(iter(neighbors)) + full_name = class_to_full.get(neighbor) + if full_name: + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2: + schema_name = parts[1] + node_mapping[node] = collapsed_labels[schema_name] + continue + node_mapping[node] = node + + # Add nodes + added_collapsed = set() + for old_node, new_node in node_mapping.items(): + if new_node in collapsed_counts: + # This is a collapsed schema node + if new_node not in added_collapsed: + new_graph.add_node(new_node, node_type=None, collapsed=True, + table_count=collapsed_counts[new_node]) + added_collapsed.add(new_node) + else: + new_graph.add_node(new_node, **graph.nodes[old_node]) + + # Add edges (avoiding self-loops and duplicates) + for src, dest, data in graph.edges(data=True): + new_src = node_mapping[src] + new_dest = node_mapping[dest] + if new_src != new_dest and not new_graph.has_edge(new_src, new_dest): + new_graph.add_edge(new_src, new_dest, **data) + + return new_graph, collapsed_counts + def _resolve_class(self, name: str): """ Safely resolve a table class from a dotted name without eval(). @@ -379,6 +548,9 @@ def make_dot(self): direction = config.display.diagram_direction graph = self._make_graph() + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + # Build schema mapping: class_name -> schema_name # Group by database schema, label with Python module name if 1:1 mapping schema_map = {} # class_name -> schema_name @@ -474,8 +646,22 @@ def make_dot(self): size=0.1 * scale, fixed=False, ), + "collapsed": dict( + shape="box3d", + color="#80808060", + fontcolor="#404040", + fontsize=round(scale * 10), + size=0.5 * scale, + fixed=False, + ), } - node_props = {node: label_props[d["node_type"]] for node, d in dict(graph.nodes(data=True)).items()} + # Build node_props, handling collapsed nodes specially + node_props = {} + for node, d in graph.nodes(data=True): + if d.get("collapsed"): + node_props[node] = label_props["collapsed"] + else: + node_props[node] = label_props[d["node_type"]] self._encapsulate_node_names(graph) self._encapsulate_edge_attributes(graph) @@ -492,23 +678,32 @@ def make_dot(self): node.set_fixedsize("shape" if props["fixed"] else False) node.set_width(props["size"]) node.set_height(props["size"]) - cls = self._resolve_class(name) - if cls is not None: - description = cls().describe(context=self.context).split("\n") - description = ( - ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) - for q in description - if not q.startswith("#") - ) - node.set_tooltip(" ".join(description)) - # Strip module prefix from label if it matches the cluster label - display_name = name - schema_name = schema_map.get(name) - if schema_name and "." in name: - prefix = name.rsplit(".", 1)[0] - if prefix == cluster_labels.get(schema_name): - display_name = name.rsplit(".", 1)[1] - node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) + + # Handle collapsed nodes specially + node_data = graph.nodes.get(f'"{name}"', {}) + if node_data.get("collapsed"): + table_count = node_data.get("table_count", 0) + label = f"{name}\\n({table_count} tables)" if table_count != 1 else f"{name}\\n(1 table)" + node.set_label(label) + node.set_tooltip(f"Collapsed schema: {table_count} tables") + else: + cls = self._resolve_class(name) + if cls is not None: + description = cls().describe(context=self.context).split("\n") + description = ( + ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) + for q in description + if not q.startswith("#") + ) + node.set_tooltip(" ".join(description)) + # Strip module prefix from label if it matches the cluster label + display_name = name + schema_name = schema_map.get(name) + if schema_name and "." in name: + prefix = name.rsplit(".", 1)[0] + if prefix == cluster_labels.get(schema_name): + display_name = name.rsplit(".", 1)[1] + node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) node.set_color(props["color"]) node.set_style("filled") @@ -520,11 +715,12 @@ def make_dot(self): if props is None: raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest)) edge.set_color("#00000040") - edge.set_style("solid" if props["primary"] else "dashed") - master_part = graph.nodes[dest]["node_type"] is Part and dest.startswith(src + ".") + edge.set_style("solid" if props.get("primary") else "dashed") + dest_node_type = graph.nodes[dest].get("node_type") + master_part = dest_node_type is Part and dest.startswith(src + ".") edge.set_weight(3 if master_part else 1) edge.set_arrowhead("none") - edge.set_penwidth(0.75 if props["multi"] else 2) + edge.set_penwidth(0.75 if props.get("multi") else 2) # Group nodes into schema clusters (always on) if schema_map: @@ -604,6 +800,9 @@ def make_mermaid(self) -> str: graph = self._make_graph() direction = config.display.diagram_direction + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + # Build schema mapping for grouping schema_map = {} # class_name -> schema_name schema_modules = {} # schema_name -> set of module names @@ -646,6 +845,7 @@ def make_mermaid(self) -> str: lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000") lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B") lines.append(" classDef part fill:#FFFFFF,stroke:#000000") + lines.append(" classDef collapsed fill:#808080,stroke:#404040") lines.append("") # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box @@ -669,14 +869,25 @@ def make_mermaid(self) -> str: None: "", } - # Group nodes by schema into subgraphs + # Group nodes by schema into subgraphs (only non-collapsed nodes) schemas = {} + collapsed_nodes = [] for node, data in graph.nodes(data=True): - schema_name = schema_map.get(node) - if schema_name: - if schema_name not in schemas: - schemas[schema_name] = [] - schemas[schema_name].append((node, data)) + if data.get("collapsed"): + collapsed_nodes.append((node, data)) + else: + schema_name = schema_map.get(node) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append((node, data)) + + # Add collapsed nodes (not in subgraphs) + for node, data in collapsed_nodes: + safe_id = node.replace(".", "_").replace(" ", "_") + table_count = data.get("table_count", 0) + count_text = f"{table_count} tables" if table_count != 1 else "1 table" + lines.append(f" {safe_id}[[\"{node}
({count_text})\"]]:::collapsed") # Add nodes grouped by schema subgraphs for schema_name, nodes in schemas.items(): From 3292a068f9339877f1aa649df1e73e338c51b864 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 15:05:22 -0600 Subject: [PATCH 06/17] fix: properly merge diagrams from different schemas When combining diagrams from different schemas using +, the underlying networkx graphs and contexts are now properly merged. This fixes issues where cross-schema references would fail to render. Changes: - __add__: Merge nodes, edges, and contexts from both diagrams - _make_graph: Filter nodes_to_show to only include valid nodes - _apply_collapse: Use validated node sets to prevent KeyError Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 7f08bd44e..d44fd970d 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -226,7 +226,12 @@ def __add__(self, arg) -> "Diagram": """ result = Diagram(self) # copy try: + # Merge nodes and edges from the other diagram + result.add_nodes_from(arg.nodes(data=True)) + result.add_edges_from(arg.edges(data=True)) result.nodes_to_show.update(arg.nodes_to_show) + # Merge contexts for class name lookups + result.context = {**result.context, **arg.context} # Handle collapse: nodes from non-collapsed diagrams are explicit (expanded) if not self._is_collapsed: result._explicit_nodes.update(self.nodes_to_show) @@ -326,7 +331,9 @@ def _make_graph(self) -> nx.DiGraph: """ # mark "distinguished" tables, i.e. those that introduce new primary key # attributes - for name in self.nodes_to_show: + # Filter nodes_to_show to only include nodes that exist in the graph + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + for name in valid_nodes: foreign_attributes = set( attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"] ) @@ -334,10 +341,10 @@ def _make_graph(self) -> nx.DiGraph: "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"] ) # include aliased nodes that are sandwiched between two displayed nodes - gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection( - nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show) + gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection( + nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes) ) - nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit()) + nodes = valid_nodes.union(a for a in gaps if a.isdigit()) # construct subgraph and rename nodes to class names graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph}) @@ -366,20 +373,24 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] tuple[nx.DiGraph, dict[str, str]] Modified graph and mapping of collapsed schema labels to their table count. """ - if not self._explicit_nodes or self._explicit_nodes == self.nodes_to_show: + # Filter to valid nodes (those that exist in the underlying graph) + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + valid_explicit = self._explicit_nodes.intersection(set(self.nodes())) + + if not valid_explicit or valid_explicit == valid_nodes: # No collapse needed return graph, {} # Map full_table_names to class_names full_to_class = { node: lookup_class_name(node, self.context) or node - for node in self.nodes_to_show + for node in valid_nodes } class_to_full = {v: k for k, v in full_to_class.items()} # Identify explicit class names (should be expanded) explicit_class_names = { - full_to_class.get(node, node) for node in self._explicit_nodes + full_to_class.get(node, node) for node in valid_explicit } # Identify nodes to collapse (class names) From c3c4c0f0ec6fa94a526039b8f647fe38d1367cea Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 15:41:50 -0600 Subject: [PATCH 07/17] fix: diagram improvements for collapse and display - Disambiguate cluster labels when multiple schemas share same module name (e.g., all defined in __main__) - adds schema name to label - Fix Computed node shape to use same size as Imported (ellipse, not small circle) - Merge nodes, edges, and contexts when combining diagrams from different schemas Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index d44fd970d..06d7270d5 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -591,6 +591,17 @@ def make_dot(self): else: cluster_labels[schema_name] = schema_name + # Disambiguate labels if multiple schemas share the same module name + # (e.g., all defined in __main__ in a notebook) + label_counts = {} + for label in cluster_labels.values(): + label_counts[label] = label_counts.get(label, 0) + 1 + + for schema_name, label in cluster_labels.items(): + if label_counts[label] > 1: + # Multiple schemas share this module name - add schema name + cluster_labels[schema_name] = f"{label} ({schema_name})" + # Assign alias nodes (orange dots) to the same schema as their child table for node, data in graph.nodes(data=True): if data.get("node_type") is _AliasNode: @@ -638,8 +649,8 @@ def make_dot(self): color="#FF000020", fontcolor="#7F0000A0", fontsize=round(scale * 10), - size=0.3 * scale, - fixed=True, + size=0.4 * scale, + fixed=False, ), Imported: dict( shape="ellipse", From ba1a237e9e960e84efe9f7e4b3a5438a66e703ba Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 16:11:43 -0600 Subject: [PATCH 08/17] fix: collapse chaining for multiple collapsed diagrams Fixed bug where A.collapse() + B.collapse() + C.collapse() only collapsed the last diagram. The issue was: 1. _apply_collapse returned early when _explicit_nodes was empty 2. Combined diagrams lost track of which nodes came from collapsed sources Changes: - Remove early return when _explicit_nodes is empty - Track explicit nodes properly through chained + operations - Fresh non-collapsed diagrams add all nodes to explicit - Combined diagrams only add their existing explicit nodes Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 06d7270d5..d42dd80ce 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -232,17 +232,26 @@ def __add__(self, arg) -> "Diagram": result.nodes_to_show.update(arg.nodes_to_show) # Merge contexts for class name lookups result.context = {**result.context, **arg.context} - # Handle collapse: nodes from non-collapsed diagrams are explicit (expanded) - if not self._is_collapsed: + # Handle collapse: track which nodes should be explicit (expanded) + # - Always preserve existing _explicit_nodes from both sides + # - For a fresh (non-combined) non-collapsed diagram, add all its nodes to explicit + # - A fresh diagram has empty _explicit_nodes and _is_collapsed=False + # This ensures "expanded wins" and chained collapsed diagrams stay collapsed + result._explicit_nodes = set() + # Add self's explicit nodes + result._explicit_nodes.update(self._explicit_nodes) + # If self is a fresh non-collapsed diagram (not combined, not marked collapsed), + # treat all its nodes as explicit + if not self._is_collapsed and not self._explicit_nodes: result._explicit_nodes.update(self.nodes_to_show) - else: - result._explicit_nodes.update(self._explicit_nodes) - if not arg._is_collapsed: + # Add arg's explicit nodes + result._explicit_nodes.update(arg._explicit_nodes) + # If arg is a fresh non-collapsed diagram, treat all its nodes as explicit + if not arg._is_collapsed and not arg._explicit_nodes: result._explicit_nodes.update(arg.nodes_to_show) - else: - result._explicit_nodes.update(arg._explicit_nodes) - # Result is not collapsed (it's a combination) - result._is_collapsed = False + # Result is "collapsed" if BOTH operands were collapsed (no explicit nodes added) + # This allows chained collapsed diagrams to stay collapsed: A.collapse() + B.collapse() + C.collapse() + result._is_collapsed = self._is_collapsed and arg._is_collapsed except AttributeError: try: result.nodes_to_show.add(arg.full_table_name) @@ -377,8 +386,8 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) valid_explicit = self._explicit_nodes.intersection(set(self.nodes())) - if not valid_explicit or valid_explicit == valid_nodes: - # No collapse needed + if valid_explicit == valid_nodes: + # All nodes are explicit (expanded) - no collapse needed return graph, {} # Map full_table_names to class_names From 26264d4d41b994a01331c61e4da2ee9383961851 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 16:16:58 -0600 Subject: [PATCH 09/17] fix: reset alias node counter on dependencies clear Fixed bug where combining diagrams created duplicate alias nodes (orange dots for renamed FKs). The issue was that _node_alias_count wasn't reset when clear() was called, so each load() created new IDs. Now Person + Marriage shows 2 alias nodes instead of 4. Co-Authored-By: Claude Opus 4.5 --- src/datajoint/dependencies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 45a5a643e..83162a112 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -137,6 +137,7 @@ def __init__(self, connection=None) -> None: def clear(self) -> None: """Clear the graph and reset loaded state.""" self._loaded = False + self._node_alias_count = itertools.count() # reset alias IDs for consistency super().clear() def load(self, force: bool = True) -> None: From 09cf50d84f8c75945c361b7b2ea307b9920fcdb8 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 16:22:20 -0600 Subject: [PATCH 10/17] fix: don't collapse fresh diagrams that were never combined A fresh dj.Diagram(schema) was incorrectly collapsing because _explicit_nodes was empty. Now we check both _explicit_nodes and _is_collapsed to determine if collapse should be applied: - Fresh diagram (_explicit_nodes empty, _is_collapsed=False): no collapse - Combined collapsed (_explicit_nodes empty, _is_collapsed=True): collapse all - Mixed combination: collapse only non-explicit nodes Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index d42dd80ce..932cf9b85 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -386,6 +386,15 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) valid_explicit = self._explicit_nodes.intersection(set(self.nodes())) + # Determine if collapse should be applied: + # - If _explicit_nodes is empty AND _is_collapsed is False, this is a fresh + # diagram that was never combined with collapsed diagrams → no collapse + # - If _explicit_nodes is empty AND _is_collapsed is True, this is the result + # of combining only collapsed diagrams → collapse all nodes + # - If _explicit_nodes equals valid_nodes, all nodes are explicit → no collapse + if not valid_explicit and not self._is_collapsed: + # Fresh diagram, never combined with collapsed diagrams + return graph, {} if valid_explicit == valid_nodes: # All nodes are explicit (expanded) - no collapse needed return graph, {} From 77ebfb5743bb022dd4383e1935794fbac45bb6fe Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 16:58:44 -0600 Subject: [PATCH 11/17] fix: use database schema name for collapsed nodes when module is ambiguous When multiple schemas share the same Python module name (e.g., __main__ in notebooks), collapsed nodes now use the database schema name instead. This makes it clear which schema is collapsed when tables from different schemas are mixed in the same diagram. Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 932cf9b85..74f8639a1 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -442,15 +442,47 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] module_name = cls.__module__.split(".")[-1] schema_modules[schema_name].add(module_name) + # Collect module names for ALL schemas in the diagram (not just collapsed) + all_schema_modules = {} # schema_name -> module_name + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + parts = full_name.replace('"', '`').split('`') + if len(parts) >= 2: + db_schema = parts[1] + cls = self._resolve_class(node) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + all_schema_modules[db_schema] = module_name + + # Check which module names are shared by multiple schemas + module_to_schemas = {} + for db_schema, module_name in all_schema_modules.items(): + if module_name not in module_to_schemas: + module_to_schemas[module_name] = [] + module_to_schemas[module_name].append(db_schema) + + ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1} + + # Determine labels for collapsed schemas collapsed_labels = {} # schema_name -> label - collapsed_counts = {} # label -> count of tables for schema_name, modules in schema_modules.items(): if len(modules) == 1: - label = next(iter(modules)) + module_name = next(iter(modules)) + # Use database schema name if module is ambiguous + if module_name in ambiguous_modules: + label = schema_name + else: + label = module_name else: label = schema_name collapsed_labels[schema_name] = label - collapsed_counts[label] = len(collapsed_by_schema[schema_name]) + + # Build counts using final labels + collapsed_counts = {} # label -> count of tables + for schema_name, class_names in collapsed_by_schema.items(): + label = collapsed_labels[schema_name] + collapsed_counts[label] = len(class_names) # Create new graph with collapsed nodes new_graph = nx.DiGraph() From de00340f6ea58cdac30f2220fa60aef15d297582 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 17:19:11 -0600 Subject: [PATCH 12/17] fix: place collapsed nodes inside schema clusters for proper layout Collapsed nodes now include schema_name attribute and are added to schema_map so they appear inside the cluster with other tables from the same schema. This fixes the visual layout so collapsed middle layers appear between top and bottom tables, maintaining DAG flow. Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 74f8639a1..b1728c861 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -514,14 +514,19 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] continue node_mapping[node] = node + # Build reverse mapping: label -> schema_name + label_to_schema = {label: schema for schema, label in collapsed_labels.items()} + # Add nodes added_collapsed = set() for old_node, new_node in node_mapping.items(): if new_node in collapsed_counts: # This is a collapsed schema node if new_node not in added_collapsed: + schema_name = label_to_schema.get(new_node, new_node) new_graph.add_node(new_node, node_type=None, collapsed=True, - table_count=collapsed_counts[new_node]) + table_count=collapsed_counts[new_node], + schema_name=schema_name) added_collapsed.add(new_node) else: new_graph.add_node(new_node, **graph.nodes[old_node]) @@ -660,6 +665,11 @@ def make_dot(self): if successors and successors[0] in schema_map: schema_map[node] = schema_map[successors[0]] + # Assign collapsed nodes to their schema so they appear in the cluster + for node, data in graph.nodes(data=True): + if data.get("collapsed") and data.get("schema_name"): + schema_map[node] = data["schema_name"] + scale = 1.2 # scaling factor for fonts and boxes label_props = { # http://matplotlib.org/examples/color/named_colors.html None: dict( From 4d6b7acad2d94533471b70ab1cb5ec7d1bb33932 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 17:24:34 -0600 Subject: [PATCH 13/17] feat: change default diagram direction from TB to LR Left-to-right layout is more natural for pipeline visualization, matching the typical data flow representation in documentation. Co-Authored-By: Claude Opus 4.5 --- src/datajoint/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index 445aaf54e..ddd1b487a 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -223,7 +223,7 @@ class DisplaySettings(BaseSettings): width: int = 14 show_tuple_count: bool = True diagram_direction: Literal["TB", "LR"] = Field( - default="TB", + default="LR", validation_alias="DJ_DIAGRAM_DIRECTION", description="Default diagram layout direction: 'TB' (top-to-bottom) or 'LR' (left-to-right)", ) From 9e87106258ad05f7f14c094ff93febdb1972231b Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 17:39:34 -0600 Subject: [PATCH 14/17] fix: collapsed nodes show only table count, not redundant name Since collapsed nodes are now inside clusters that display the schema/module name, the node label only needs to show "(N tables)". Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 52 ++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index b1728c861..326b154d1 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -765,7 +765,7 @@ def make_dot(self): node_data = graph.nodes.get(f'"{name}"', {}) if node_data.get("collapsed"): table_count = node_data.get("table_count", 0) - label = f"{name}\\n({table_count} tables)" if table_count != 1 else f"{name}\\n(1 table)" + label = f"({table_count} tables)" if table_count != 1 else "(1 table)" node.set_label(label) node.set_tooltip(f"Collapsed schema: {table_count} tables") else: @@ -951,43 +951,43 @@ def make_mermaid(self) -> str: None: "", } - # Group nodes by schema into subgraphs (only non-collapsed nodes) + # Group nodes by schema into subgraphs (including collapsed nodes) schemas = {} - collapsed_nodes = [] for node, data in graph.nodes(data=True): if data.get("collapsed"): - collapsed_nodes.append((node, data)) + # Collapsed nodes use their schema_name attribute + schema_name = data.get("schema_name") else: schema_name = schema_map.get(node) - if schema_name: - if schema_name not in schemas: - schemas[schema_name] = [] - schemas[schema_name].append((node, data)) - - # Add collapsed nodes (not in subgraphs) - for node, data in collapsed_nodes: - safe_id = node.replace(".", "_").replace(" ", "_") - table_count = data.get("table_count", 0) - count_text = f"{table_count} tables" if table_count != 1 else "1 table" - lines.append(f" {safe_id}[[\"{node}
({count_text})\"]]:::collapsed") + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append((node, data)) # Add nodes grouped by schema subgraphs for schema_name, nodes in schemas.items(): label = cluster_labels.get(schema_name, schema_name) lines.append(f" subgraph {label}") for node, data in nodes: - tier = data.get("node_type") - left, right = shape_map.get(tier, ("[", "]")) - cls = tier_class.get(tier, "") safe_id = node.replace(".", "_").replace(" ", "_") - # Strip module prefix from display name if it matches the cluster label - display_name = node - if "." in node: - prefix = node.rsplit(".", 1)[0] - if prefix == label: - display_name = node.rsplit(".", 1)[1] - class_suffix = f":::{cls}" if cls else "" - lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") + if data.get("collapsed"): + # Collapsed node - show only table count + table_count = data.get("table_count", 0) + count_text = f"{table_count} tables" if table_count != 1 else "1 table" + lines.append(f" {safe_id}[[\"({count_text})\"]]:::collapsed") + else: + # Regular node + tier = data.get("node_type") + left, right = shape_map.get(tier, ("[", "]")) + cls = tier_class.get(tier, "") + # Strip module prefix from display name if it matches the cluster label + display_name = node + if "." in node: + prefix = node.rsplit(".", 1)[0] + if prefix == label: + display_name = node.rsplit(".", 1)[1] + class_suffix = f":::{cls}" if cls else "" + lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") lines.append(" end") lines.append("") From d5bdf51ebafbd3f1feed1dc5aaf1b4c0cb19f104 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 17:54:15 -0600 Subject: [PATCH 15/17] refactor: simplify collapse logic to use single _expanded_nodes set Replace complex _explicit_nodes + _is_collapsed with simpler design: - Fresh diagrams: all nodes expanded - collapse(): clears _expanded_nodes - + operator: union of _expanded_nodes (expanded wins) Bump version to 2.1.0a7 Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 82 ++++++++++++++-------------------------- src/datajoint/version.py | 2 +- 2 files changed, 29 insertions(+), 55 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 326b154d1..eb59e728e 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -103,8 +103,7 @@ def __init__(self, source, context=None) -> None: if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) - self._explicit_nodes = set(source._explicit_nodes) - self._is_collapsed = source._is_collapsed + self._expanded_nodes = set(source._expanded_nodes) self.context = source.context super().__init__(source) return @@ -132,8 +131,6 @@ def __init__(self, source, context=None) -> None: # Enumerate nodes from all the items in the list self.nodes_to_show = set() - self._explicit_nodes = set() # nodes that should never be collapsed - self._is_collapsed = False # whether this diagram's nodes should be collapsed when combined try: self.nodes_to_show.add(source.full_table_name) except AttributeError: @@ -148,6 +145,8 @@ def __init__(self, source, context=None) -> None: # Handle both MySQL backticks and PostgreSQL double quotes if node.startswith("`%s`" % database) or node.startswith('"%s"' % database): self.nodes_to_show.add(node) + # All nodes start as expanded + self._expanded_nodes = set(self.nodes_to_show) @classmethod def from_sequence(cls, sequence) -> "Diagram": @@ -187,27 +186,30 @@ def is_part(part, master): def collapse(self) -> "Diagram": """ - Mark this diagram for collapsing when combined with other diagrams. + Mark all nodes in this diagram as collapsed. - When a collapsed diagram is added to a non-collapsed diagram, its nodes - are shown as a single collapsed node per schema, unless they also appear - in the non-collapsed diagram (expanded wins). + Collapsed nodes are shown as a single node per schema. When combined + with other diagrams using ``+``, expanded nodes win: if a node is + expanded in either operand, it remains expanded in the result. Returns ------- Diagram - A copy of this diagram marked for collapsing. + A copy of this diagram with all nodes collapsed. Examples -------- >>> # Show schema1 expanded, schema2 collapsed into single nodes >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse() - >>> # Explicitly expand one table from schema2 - >>> dj.Diagram(schema1) + dj.Diagram(TableFromSchema2) + dj.Diagram(schema2).collapse() + >>> # Collapse all three schemas together + >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse() + + >>> # Expand one table from collapsed schema + >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable) """ result = Diagram(self) - result._is_collapsed = True + result._expanded_nodes = set() # All nodes collapsed return result def __add__(self, arg) -> "Diagram": @@ -232,30 +234,12 @@ def __add__(self, arg) -> "Diagram": result.nodes_to_show.update(arg.nodes_to_show) # Merge contexts for class name lookups result.context = {**result.context, **arg.context} - # Handle collapse: track which nodes should be explicit (expanded) - # - Always preserve existing _explicit_nodes from both sides - # - For a fresh (non-combined) non-collapsed diagram, add all its nodes to explicit - # - A fresh diagram has empty _explicit_nodes and _is_collapsed=False - # This ensures "expanded wins" and chained collapsed diagrams stay collapsed - result._explicit_nodes = set() - # Add self's explicit nodes - result._explicit_nodes.update(self._explicit_nodes) - # If self is a fresh non-collapsed diagram (not combined, not marked collapsed), - # treat all its nodes as explicit - if not self._is_collapsed and not self._explicit_nodes: - result._explicit_nodes.update(self.nodes_to_show) - # Add arg's explicit nodes - result._explicit_nodes.update(arg._explicit_nodes) - # If arg is a fresh non-collapsed diagram, treat all its nodes as explicit - if not arg._is_collapsed and not arg._explicit_nodes: - result._explicit_nodes.update(arg.nodes_to_show) - # Result is "collapsed" if BOTH operands were collapsed (no explicit nodes added) - # This allows chained collapsed diagrams to stay collapsed: A.collapse() + B.collapse() + C.collapse() - result._is_collapsed = self._is_collapsed and arg._is_collapsed + # Expanded wins: union of expanded nodes from both operands + result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes except AttributeError: try: result.nodes_to_show.add(arg.full_table_name) - result._explicit_nodes.add(arg.full_table_name) + result._expanded_nodes.add(arg.full_table_name) except AttributeError: for i in range(arg): new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show) @@ -264,9 +248,8 @@ def __add__(self, arg) -> "Diagram": # add nodes referenced by aliased nodes new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit()))) result.nodes_to_show.update(new) - # Expanded nodes from + N expansion are explicit - if not self._is_collapsed: - result._explicit_nodes = result.nodes_to_show.copy() + # New nodes from expansion are expanded + result._expanded_nodes = result._expanded_nodes | result.nodes_to_show return result def __sub__(self, arg) -> "Diagram": @@ -369,7 +352,7 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] """ Apply collapse logic to the graph. - Nodes in nodes_to_show but not in _explicit_nodes are collapsed into + Nodes in nodes_to_show but not in _expanded_nodes are collapsed into single schema nodes. Parameters @@ -384,19 +367,10 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] """ # Filter to valid nodes (those that exist in the underlying graph) valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) - valid_explicit = self._explicit_nodes.intersection(set(self.nodes())) - - # Determine if collapse should be applied: - # - If _explicit_nodes is empty AND _is_collapsed is False, this is a fresh - # diagram that was never combined with collapsed diagrams → no collapse - # - If _explicit_nodes is empty AND _is_collapsed is True, this is the result - # of combining only collapsed diagrams → collapse all nodes - # - If _explicit_nodes equals valid_nodes, all nodes are explicit → no collapse - if not valid_explicit and not self._is_collapsed: - # Fresh diagram, never combined with collapsed diagrams - return graph, {} - if valid_explicit == valid_nodes: - # All nodes are explicit (expanded) - no collapse needed + valid_expanded = self._expanded_nodes.intersection(set(self.nodes())) + + # If all nodes are expanded, no collapse needed + if valid_expanded >= valid_nodes: return graph, {} # Map full_table_names to class_names @@ -406,13 +380,13 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] } class_to_full = {v: k for k, v in full_to_class.items()} - # Identify explicit class names (should be expanded) - explicit_class_names = { - full_to_class.get(node, node) for node in valid_explicit + # Identify expanded class names + expanded_class_names = { + full_to_class.get(node, node) for node in valid_expanded } # Identify nodes to collapse (class names) - nodes_to_collapse = set(graph.nodes()) - explicit_class_names + nodes_to_collapse = set(graph.nodes()) - expanded_class_names if not nodes_to_collapse: return graph, {} diff --git a/src/datajoint/version.py b/src/datajoint/version.py index 535dd4134..f19a270de 100644 --- a/src/datajoint/version.py +++ b/src/datajoint/version.py @@ -1,4 +1,4 @@ # version bump auto managed by Github Actions: # label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit) # manually set this version will be eventually overwritten by the above actions -__version__ = "2.1.0a6" +__version__ = "2.1.0a7" From e59eeb30ee395daa11431c0643af08e31c3c8592 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 21:49:05 -0600 Subject: [PATCH 16/17] fix: break long line in diagram.py to pass lint Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index eb59e728e..59a971765 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -747,7 +747,10 @@ def make_dot(self): if cls is not None: description = cls().describe(context=self.context).split("\n") description = ( - ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) + ( + "-" * 30 if q.startswith("---") + else (q.replace("->", "→") if "->" in q else q.split(":")[0]) + ) for q in description if not q.startswith("#") ) From 810ceee0ad82fcb6acb99cf2ae987a9751a3e888 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 23 Jan 2026 21:51:07 -0600 Subject: [PATCH 17/17] style: apply ruff format to diagram.py Co-Authored-By: Claude Opus 4.5 --- src/datajoint/diagram.py | 50 ++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 59a971765..48e18fd0d 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -374,16 +374,11 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] return graph, {} # Map full_table_names to class_names - full_to_class = { - node: lookup_class_name(node, self.context) or node - for node in valid_nodes - } + full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes} class_to_full = {v: k for k, v in full_to_class.items()} # Identify expanded class names - expanded_class_names = { - full_to_class.get(node, node) for node in valid_expanded - } + expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded} # Identify nodes to collapse (class names) nodes_to_collapse = set(graph.nodes()) - expanded_class_names @@ -396,7 +391,7 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] for class_name in nodes_to_collapse: full_name = class_to_full.get(class_name) if full_name: - parts = full_name.replace('"', '`').split('`') + parts = full_name.replace('"', "`").split("`") if len(parts) >= 2: schema_name = parts[1] if schema_name not in collapsed_by_schema: @@ -421,7 +416,7 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] for node in graph.nodes(): full_name = class_to_full.get(node) if full_name: - parts = full_name.replace('"', '`').split('`') + parts = full_name.replace('"', "`").split("`") if len(parts) >= 2: db_schema = parts[1] cls = self._resolve_class(node) @@ -466,7 +461,7 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] for node in graph.nodes(): full_name = class_to_full.get(node) if full_name: - parts = full_name.replace('"', '`').split('`') + parts = full_name.replace('"', "`").split("`") if len(parts) >= 2 and node in nodes_to_collapse: schema_name = parts[1] node_mapping[node] = collapsed_labels[schema_name] @@ -481,7 +476,7 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] neighbor = next(iter(neighbors)) full_name = class_to_full.get(neighbor) if full_name: - parts = full_name.replace('"', '`').split('`') + parts = full_name.replace('"', "`").split("`") if len(parts) >= 2: schema_name = parts[1] node_mapping[node] = collapsed_labels[schema_name] @@ -498,9 +493,13 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str] # This is a collapsed schema node if new_node not in added_collapsed: schema_name = label_to_schema.get(new_node, new_node) - new_graph.add_node(new_node, node_type=None, collapsed=True, - table_count=collapsed_counts[new_node], - schema_name=schema_name) + new_graph.add_node( + new_node, + node_type=None, + collapsed=True, + table_count=collapsed_counts[new_node], + schema_name=schema_name, + ) added_collapsed.add(new_node) else: new_graph.add_node(new_node, **graph.nodes[old_node]) @@ -598,7 +597,7 @@ def make_dot(self): for full_name in self.nodes_to_show: # Extract schema from full table name like `schema`.`table` or "schema"."table" - parts = full_name.replace('"', '`').split('`') + parts = full_name.replace('"', "`").split("`") if len(parts) >= 2: schema_name = parts[1] # schema is between first pair of backticks class_name = lookup_class_name(full_name, self.context) or full_name @@ -748,7 +747,8 @@ def make_dot(self): description = cls().describe(context=self.context).split("\n") description = ( ( - "-" * 30 if q.startswith("---") + "-" * 30 + if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0]) ) for q in description @@ -867,7 +867,7 @@ def make_mermaid(self) -> str: schema_modules = {} # schema_name -> set of module names for full_name in self.nodes_to_show: - parts = full_name.replace('"', '`').split('`') + parts = full_name.replace('"', "`").split("`") if len(parts) >= 2: schema_name = parts[1] class_name = lookup_class_name(full_name, self.context) or full_name @@ -909,13 +909,13 @@ def make_mermaid(self) -> str: # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box shape_map = { - Manual: ("[", "]"), # box - Lookup: ("[", "]"), # box - Computed: ("([", "])"), # stadium/pill - Imported: ("([", "])"), # stadium/pill - Part: ("[", "]"), # box - _AliasNode: ("((", "))"), # circle - None: ("((", "))"), # circle + Manual: ("[", "]"), # box + Lookup: ("[", "]"), # box + Computed: ("([", "])"), # stadium/pill + Imported: ("([", "])"), # stadium/pill + Part: ("[", "]"), # box + _AliasNode: ("((", "))"), # circle + None: ("((", "))"), # circle } tier_class = { @@ -951,7 +951,7 @@ def make_mermaid(self) -> str: # Collapsed node - show only table count table_count = data.get("table_count", 0) count_text = f"{table_count} tables" if table_count != 1 else "1 table" - lines.append(f" {safe_id}[[\"({count_text})\"]]:::collapsed") + lines.append(f' {safe_id}[["({count_text})"]]:::collapsed') else: # Regular node tier = data.get("node_type")