diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py
index 45a5a643e..83162a112 100644
--- a/src/datajoint/dependencies.py
+++ b/src/datajoint/dependencies.py
@@ -137,6 +137,7 @@ def __init__(self, connection=None) -> None:
def clear(self) -> None:
"""Clear the graph and reset loaded state."""
self._loaded = False
+ self._node_alias_count = itertools.count() # reset alias IDs for consistency
super().clear()
def load(self, force: bool = True) -> None:
diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py
index b06686025..48e18fd0d 100644
--- a/src/datajoint/diagram.py
+++ b/src/datajoint/diagram.py
@@ -16,6 +16,7 @@
from .dependencies import topo_sort
from .errors import DataJointError
+from .settings import config
from .table import Table, lookup_class_name
from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier
@@ -90,12 +91,19 @@ class Diagram(nx.DiGraph):
-----
``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``.
Only tables loaded in the connection are displayed.
+
+ Layout direction is controlled via ``dj.config.display.diagram_direction``
+ (default ``"TB"``). Use ``dj.config.override()`` to change temporarily::
+
+ with dj.config.override(display_diagram_direction="LR"):
+ dj.Diagram(schema).draw()
"""
def __init__(self, source, context=None) -> None:
if isinstance(source, Diagram):
# copy constructor
self.nodes_to_show = set(source.nodes_to_show)
+ self._expanded_nodes = set(source._expanded_nodes)
self.context = source.context
super().__init__(source)
return
@@ -137,6 +145,8 @@ def __init__(self, source, context=None) -> None:
# Handle both MySQL backticks and PostgreSQL double quotes
if node.startswith("`%s`" % database) or node.startswith('"%s"' % database):
self.nodes_to_show.add(node)
+ # All nodes start as expanded
+ self._expanded_nodes = set(self.nodes_to_show)
@classmethod
def from_sequence(cls, sequence) -> "Diagram":
@@ -174,6 +184,34 @@ def is_part(part, master):
self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show))
return self
+ def collapse(self) -> "Diagram":
+ """
+ Mark all nodes in this diagram as collapsed.
+
+ Collapsed nodes are shown as a single node per schema. When combined
+ with other diagrams using ``+``, expanded nodes win: if a node is
+ expanded in either operand, it remains expanded in the result.
+
+ Returns
+ -------
+ Diagram
+ A copy of this diagram with all nodes collapsed.
+
+ Examples
+ --------
+ >>> # Show schema1 expanded, schema2 collapsed into single nodes
+ >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse()
+
+ >>> # Collapse all three schemas together
+ >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse()
+
+ >>> # Expand one table from collapsed schema
+ >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable)
+ """
+ result = Diagram(self)
+ result._expanded_nodes = set() # All nodes collapsed
+ return result
+
def __add__(self, arg) -> "Diagram":
"""
Union or downstream expansion.
@@ -188,21 +226,31 @@ def __add__(self, arg) -> "Diagram":
Diagram
Combined or expanded diagram.
"""
- self = Diagram(self) # copy
+ result = Diagram(self) # copy
try:
- self.nodes_to_show.update(arg.nodes_to_show)
+ # Merge nodes and edges from the other diagram
+ result.add_nodes_from(arg.nodes(data=True))
+ result.add_edges_from(arg.edges(data=True))
+ result.nodes_to_show.update(arg.nodes_to_show)
+ # Merge contexts for class name lookups
+ result.context = {**result.context, **arg.context}
+ # Expanded wins: union of expanded nodes from both operands
+ result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes
except AttributeError:
try:
- self.nodes_to_show.add(arg.full_table_name)
+ result.nodes_to_show.add(arg.full_table_name)
+ result._expanded_nodes.add(arg.full_table_name)
except AttributeError:
for i in range(arg):
- new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)
+ new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show)
if not new:
break
# add nodes referenced by aliased nodes
- new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit())))
- self.nodes_to_show.update(new)
- return self
+ new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit())))
+ result.nodes_to_show.update(new)
+ # New nodes from expansion are expanded
+ result._expanded_nodes = result._expanded_nodes | result.nodes_to_show
+ return result
def __sub__(self, arg) -> "Diagram":
"""
@@ -275,7 +323,9 @@ def _make_graph(self) -> nx.DiGraph:
"""
# mark "distinguished" tables, i.e. those that introduce new primary key
# attributes
- for name in self.nodes_to_show:
+ # Filter nodes_to_show to only include nodes that exist in the graph
+ valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
+ for name in valid_nodes:
foreign_attributes = set(
attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"]
)
@@ -283,21 +333,210 @@ def _make_graph(self) -> nx.DiGraph:
"primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"]
)
# include aliased nodes that are sandwiched between two displayed nodes
- gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
- nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show)
+ gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection(
+ nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes)
)
- nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
+ nodes = valid_nodes.union(a for a in gaps if a.isdigit())
# construct subgraph and rename nodes to class names
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph})
# relabel nodes to class names
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
- new_names = [mapping.values()]
+ new_names = list(mapping.values())
if len(new_names) > len(set(new_names)):
raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.")
nx.relabel_nodes(graph, mapping, copy=False)
return graph
+ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]:
+ """
+ Apply collapse logic to the graph.
+
+ Nodes in nodes_to_show but not in _expanded_nodes are collapsed into
+ single schema nodes.
+
+ Parameters
+ ----------
+ graph : nx.DiGraph
+ The graph from _make_graph().
+
+ Returns
+ -------
+ tuple[nx.DiGraph, dict[str, str]]
+ Modified graph and mapping of collapsed schema labels to their table count.
+ """
+ # Filter to valid nodes (those that exist in the underlying graph)
+ valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
+ valid_expanded = self._expanded_nodes.intersection(set(self.nodes()))
+
+ # If all nodes are expanded, no collapse needed
+ if valid_expanded >= valid_nodes:
+ return graph, {}
+
+ # Map full_table_names to class_names
+ full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes}
+ class_to_full = {v: k for k, v in full_to_class.items()}
+
+ # Identify expanded class names
+ expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded}
+
+ # Identify nodes to collapse (class names)
+ nodes_to_collapse = set(graph.nodes()) - expanded_class_names
+
+ if not nodes_to_collapse:
+ return graph, {}
+
+ # Group collapsed nodes by schema
+ collapsed_by_schema = {} # schema_name -> list of class_names
+ for class_name in nodes_to_collapse:
+ full_name = class_to_full.get(class_name)
+ if full_name:
+ parts = full_name.replace('"', "`").split("`")
+ if len(parts) >= 2:
+ schema_name = parts[1]
+ if schema_name not in collapsed_by_schema:
+ collapsed_by_schema[schema_name] = []
+ collapsed_by_schema[schema_name].append(class_name)
+
+ if not collapsed_by_schema:
+ return graph, {}
+
+ # Determine labels for collapsed schemas
+ schema_modules = {}
+ for schema_name, class_names in collapsed_by_schema.items():
+ schema_modules[schema_name] = set()
+ for class_name in class_names:
+ cls = self._resolve_class(class_name)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ schema_modules[schema_name].add(module_name)
+
+ # Collect module names for ALL schemas in the diagram (not just collapsed)
+ all_schema_modules = {} # schema_name -> module_name
+ for node in graph.nodes():
+ full_name = class_to_full.get(node)
+ if full_name:
+ parts = full_name.replace('"', "`").split("`")
+ if len(parts) >= 2:
+ db_schema = parts[1]
+ cls = self._resolve_class(node)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ all_schema_modules[db_schema] = module_name
+
+ # Check which module names are shared by multiple schemas
+ module_to_schemas = {}
+ for db_schema, module_name in all_schema_modules.items():
+ if module_name not in module_to_schemas:
+ module_to_schemas[module_name] = []
+ module_to_schemas[module_name].append(db_schema)
+
+ ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1}
+
+ # Determine labels for collapsed schemas
+ collapsed_labels = {} # schema_name -> label
+ for schema_name, modules in schema_modules.items():
+ if len(modules) == 1:
+ module_name = next(iter(modules))
+ # Use database schema name if module is ambiguous
+ if module_name in ambiguous_modules:
+ label = schema_name
+ else:
+ label = module_name
+ else:
+ label = schema_name
+ collapsed_labels[schema_name] = label
+
+ # Build counts using final labels
+ collapsed_counts = {} # label -> count of tables
+ for schema_name, class_names in collapsed_by_schema.items():
+ label = collapsed_labels[schema_name]
+ collapsed_counts[label] = len(class_names)
+
+ # Create new graph with collapsed nodes
+ new_graph = nx.DiGraph()
+
+ # Map old node names to new names (collapsed nodes -> schema label)
+ node_mapping = {}
+ for node in graph.nodes():
+ full_name = class_to_full.get(node)
+ if full_name:
+ parts = full_name.replace('"', "`").split("`")
+ if len(parts) >= 2 and node in nodes_to_collapse:
+ schema_name = parts[1]
+ node_mapping[node] = collapsed_labels[schema_name]
+ else:
+ node_mapping[node] = node
+ else:
+ # Alias nodes - check if they should be collapsed
+ # An alias node should be collapsed if ALL its neighbors are collapsed
+ neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
+ if neighbors and neighbors <= nodes_to_collapse:
+ # Get schema from first neighbor
+ neighbor = next(iter(neighbors))
+ full_name = class_to_full.get(neighbor)
+ if full_name:
+ parts = full_name.replace('"', "`").split("`")
+ if len(parts) >= 2:
+ schema_name = parts[1]
+ node_mapping[node] = collapsed_labels[schema_name]
+ continue
+ node_mapping[node] = node
+
+ # Build reverse mapping: label -> schema_name
+ label_to_schema = {label: schema for schema, label in collapsed_labels.items()}
+
+ # Add nodes
+ added_collapsed = set()
+ for old_node, new_node in node_mapping.items():
+ if new_node in collapsed_counts:
+ # This is a collapsed schema node
+ if new_node not in added_collapsed:
+ schema_name = label_to_schema.get(new_node, new_node)
+ new_graph.add_node(
+ new_node,
+ node_type=None,
+ collapsed=True,
+ table_count=collapsed_counts[new_node],
+ schema_name=schema_name,
+ )
+ added_collapsed.add(new_node)
+ else:
+ new_graph.add_node(new_node, **graph.nodes[old_node])
+
+ # Add edges (avoiding self-loops and duplicates)
+ for src, dest, data in graph.edges(data=True):
+ new_src = node_mapping[src]
+ new_dest = node_mapping[dest]
+ if new_src != new_dest and not new_graph.has_edge(new_src, new_dest):
+ new_graph.add_edge(new_src, new_dest, **data)
+
+ return new_graph, collapsed_counts
+
+ def _resolve_class(self, name: str):
+ """
+ Safely resolve a table class from a dotted name without eval().
+
+ Parameters
+ ----------
+ name : str
+ Dotted class name like "MyTable" or "Module.MyTable".
+
+ Returns
+ -------
+ type or None
+ The table class if found, otherwise None.
+ """
+ parts = name.split(".")
+ obj = self.context.get(parts[0])
+ for part in parts[1:]:
+ if obj is None:
+ return None
+ obj = getattr(obj, part, None)
+ if obj is not None and isinstance(obj, type) and issubclass(obj, Table):
+ return obj
+ return None
+
@staticmethod
def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None:
"""
@@ -331,8 +570,78 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
)
def make_dot(self):
+ """
+ Generate a pydot graph object.
+
+ Returns
+ -------
+ pydot.Dot
+ The graph object ready for rendering.
+
+ Notes
+ -----
+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
+ Tables are grouped by schema, with the Python module name shown as the
+ group label when available.
+ """
+ direction = config.display.diagram_direction
graph = self._make_graph()
- graph.nodes()
+
+ # Apply collapse logic if needed
+ graph, collapsed_counts = self._apply_collapse(graph)
+
+ # Build schema mapping: class_name -> schema_name
+ # Group by database schema, label with Python module name if 1:1 mapping
+ schema_map = {} # class_name -> schema_name
+ schema_modules = {} # schema_name -> set of module names
+
+ for full_name in self.nodes_to_show:
+ # Extract schema from full table name like `schema`.`table` or "schema"."table"
+ parts = full_name.replace('"', "`").split("`")
+ if len(parts) >= 2:
+ schema_name = parts[1] # schema is between first pair of backticks
+ class_name = lookup_class_name(full_name, self.context) or full_name
+ schema_map[class_name] = schema_name
+
+ # Collect all module names for this schema
+ if schema_name not in schema_modules:
+ schema_modules[schema_name] = set()
+ cls = self._resolve_class(class_name)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ schema_modules[schema_name].add(module_name)
+
+ # Determine cluster labels: use module name if 1:1, else database schema name
+ cluster_labels = {} # schema_name -> label
+ for schema_name, modules in schema_modules.items():
+ if len(modules) == 1:
+ cluster_labels[schema_name] = next(iter(modules))
+ else:
+ cluster_labels[schema_name] = schema_name
+
+ # Disambiguate labels if multiple schemas share the same module name
+ # (e.g., all defined in __main__ in a notebook)
+ label_counts = {}
+ for label in cluster_labels.values():
+ label_counts[label] = label_counts.get(label, 0) + 1
+
+ for schema_name, label in cluster_labels.items():
+ if label_counts[label] > 1:
+ # Multiple schemas share this module name - add schema name
+ cluster_labels[schema_name] = f"{label} ({schema_name})"
+
+ # Assign alias nodes (orange dots) to the same schema as their child table
+ for node, data in graph.nodes(data=True):
+ if data.get("node_type") is _AliasNode:
+ # Find the child (successor) - the table that declares the renamed FK
+ successors = list(graph.successors(node))
+ if successors and successors[0] in schema_map:
+ schema_map[node] = schema_map[successors[0]]
+
+ # Assign collapsed nodes to their schema so they appear in the cluster
+ for node, data in graph.nodes(data=True):
+ if data.get("collapsed") and data.get("schema_name"):
+ schema_map[node] = data["schema_name"]
scale = 1.2 # scaling factor for fonts and boxes
label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -373,8 +682,8 @@ def make_dot(self):
color="#FF000020",
fontcolor="#7F0000A0",
fontsize=round(scale * 10),
- size=0.3 * scale,
- fixed=True,
+ size=0.4 * scale,
+ fixed=False,
),
Imported: dict(
shape="ellipse",
@@ -386,18 +695,33 @@ def make_dot(self):
),
Part: dict(
shape="plaintext",
- color="#0000000",
+ color="#00000000",
fontcolor="black",
fontsize=round(scale * 8),
size=0.1 * scale,
fixed=False,
),
+ "collapsed": dict(
+ shape="box3d",
+ color="#80808060",
+ fontcolor="#404040",
+ fontsize=round(scale * 10),
+ size=0.5 * scale,
+ fixed=False,
+ ),
}
- node_props = {node: label_props[d["node_type"]] for node, d in dict(graph.nodes(data=True)).items()}
+ # Build node_props, handling collapsed nodes specially
+ node_props = {}
+ for node, d in graph.nodes(data=True):
+ if d.get("collapsed"):
+ node_props[node] = label_props["collapsed"]
+ else:
+ node_props[node] = label_props[d["node_type"]]
self._encapsulate_node_names(graph)
self._encapsulate_edge_attributes(graph)
dot = nx.drawing.nx_pydot.to_pydot(graph)
+ dot.set_rankdir(direction)
for node in dot.get_nodes():
node.set_shape("circle")
name = node.get_name().strip('"')
@@ -409,17 +733,36 @@ def make_dot(self):
node.set_fixedsize("shape" if props["fixed"] else False)
node.set_width(props["size"])
node.set_height(props["size"])
- if name.split(".")[0] in self.context:
- cls = eval(name, self.context)
- assert issubclass(cls, Table)
- description = cls().describe(context=self.context).split("\n")
- description = (
- ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0]))
- for q in description
- if not q.startswith("#")
- )
- node.set_tooltip("
".join(description))
- node.set_label("<" + name + ">" if node.get("distinguished") == "True" else name)
+
+ # Handle collapsed nodes specially
+ node_data = graph.nodes.get(f'"{name}"', {})
+ if node_data.get("collapsed"):
+ table_count = node_data.get("table_count", 0)
+ label = f"({table_count} tables)" if table_count != 1 else "(1 table)"
+ node.set_label(label)
+ node.set_tooltip(f"Collapsed schema: {table_count} tables")
+ else:
+ cls = self._resolve_class(name)
+ if cls is not None:
+ description = cls().describe(context=self.context).split("\n")
+ description = (
+ (
+ "-" * 30
+ if q.startswith("---")
+ else (q.replace("->", "→") if "->" in q else q.split(":")[0])
+ )
+ for q in description
+ if not q.startswith("#")
+ )
+ node.set_tooltip("
".join(description))
+ # Strip module prefix from label if it matches the cluster label
+ display_name = name
+ schema_name = schema_map.get(name)
+ if schema_name and "." in name:
+ prefix = name.rsplit(".", 1)[0]
+ if prefix == cluster_labels.get(schema_name):
+ display_name = name.rsplit(".", 1)[1]
+ node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name)
node.set_color(props["color"])
node.set_style("filled")
@@ -431,11 +774,41 @@ def make_dot(self):
if props is None:
raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest))
edge.set_color("#00000040")
- edge.set_style("solid" if props["primary"] else "dashed")
- master_part = graph.nodes[dest]["node_type"] is Part and dest.startswith(src + ".")
+ edge.set_style("solid" if props.get("primary") else "dashed")
+ dest_node_type = graph.nodes[dest].get("node_type")
+ master_part = dest_node_type is Part and dest.startswith(src + ".")
edge.set_weight(3 if master_part else 1)
edge.set_arrowhead("none")
- edge.set_penwidth(0.75 if props["multi"] else 2)
+ edge.set_penwidth(0.75 if props.get("multi") else 2)
+
+ # Group nodes into schema clusters (always on)
+ if schema_map:
+ import pydot
+
+ # Group nodes by schema
+ schemas = {}
+ for node in list(dot.get_nodes()):
+ name = node.get_name().strip('"')
+ schema_name = schema_map.get(name)
+ if schema_name:
+ if schema_name not in schemas:
+ schemas[schema_name] = []
+ schemas[schema_name].append(node)
+
+ # Create clusters for each schema
+ # Use Python module name if 1:1 mapping, otherwise database schema name
+ for schema_name, nodes in schemas.items():
+ label = cluster_labels.get(schema_name, schema_name)
+ cluster = pydot.Cluster(
+ f"cluster_{schema_name}",
+ label=label,
+ style="dashed",
+ color="gray",
+ fontcolor="gray",
+ )
+ for node in nodes:
+ cluster.add_node(node)
+ dot.add_subgraph(cluster)
return dot
@@ -453,6 +826,159 @@ def make_image(self):
else:
raise DataJointError("pyplot was not imported")
+ def make_mermaid(self) -> str:
+ """
+ Generate Mermaid diagram syntax.
+
+ Produces a flowchart in Mermaid syntax that can be rendered in
+ Markdown documentation, GitHub, or https://mermaid.live.
+
+ Returns
+ -------
+ str
+ Mermaid flowchart syntax.
+
+ Notes
+ -----
+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
+ Tables are grouped by schema using Mermaid subgraphs, with the Python
+ module name shown as the group label when available.
+
+ Examples
+ --------
+ >>> print(dj.Diagram(schema).make_mermaid())
+ flowchart TB
+ subgraph my_pipeline
+ Mouse[Mouse]:::manual
+ Session[Session]:::manual
+ Neuron([Neuron]):::computed
+ end
+ Mouse --> Session
+ Session --> Neuron
+ """
+ graph = self._make_graph()
+ direction = config.display.diagram_direction
+
+ # Apply collapse logic if needed
+ graph, collapsed_counts = self._apply_collapse(graph)
+
+ # Build schema mapping for grouping
+ schema_map = {} # class_name -> schema_name
+ schema_modules = {} # schema_name -> set of module names
+
+ for full_name in self.nodes_to_show:
+ parts = full_name.replace('"', "`").split("`")
+ if len(parts) >= 2:
+ schema_name = parts[1]
+ class_name = lookup_class_name(full_name, self.context) or full_name
+ schema_map[class_name] = schema_name
+
+ # Collect all module names for this schema
+ if schema_name not in schema_modules:
+ schema_modules[schema_name] = set()
+ cls = self._resolve_class(class_name)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ schema_modules[schema_name].add(module_name)
+
+ # Determine cluster labels: use module name if 1:1, else database schema name
+ cluster_labels = {}
+ for schema_name, modules in schema_modules.items():
+ if len(modules) == 1:
+ cluster_labels[schema_name] = next(iter(modules))
+ else:
+ cluster_labels[schema_name] = schema_name
+
+ # Assign alias nodes to the same schema as their child table
+ for node, data in graph.nodes(data=True):
+ if data.get("node_type") is _AliasNode:
+ successors = list(graph.successors(node))
+ if successors and successors[0] in schema_map:
+ schema_map[node] = schema_map[successors[0]]
+
+ lines = [f"flowchart {direction}"]
+
+ # Define class styles matching Graphviz colors
+ lines.append(" classDef manual fill:#90EE90,stroke:#006400")
+ lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969")
+ lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
+ lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
+ lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
+ lines.append(" classDef collapsed fill:#808080,stroke:#404040")
+ lines.append("")
+
+ # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
+ shape_map = {
+ Manual: ("[", "]"), # box
+ Lookup: ("[", "]"), # box
+ Computed: ("([", "])"), # stadium/pill
+ Imported: ("([", "])"), # stadium/pill
+ Part: ("[", "]"), # box
+ _AliasNode: ("((", "))"), # circle
+ None: ("((", "))"), # circle
+ }
+
+ tier_class = {
+ Manual: "manual",
+ Lookup: "lookup",
+ Computed: "computed",
+ Imported: "imported",
+ Part: "part",
+ _AliasNode: "",
+ None: "",
+ }
+
+ # Group nodes by schema into subgraphs (including collapsed nodes)
+ schemas = {}
+ for node, data in graph.nodes(data=True):
+ if data.get("collapsed"):
+ # Collapsed nodes use their schema_name attribute
+ schema_name = data.get("schema_name")
+ else:
+ schema_name = schema_map.get(node)
+ if schema_name:
+ if schema_name not in schemas:
+ schemas[schema_name] = []
+ schemas[schema_name].append((node, data))
+
+ # Add nodes grouped by schema subgraphs
+ for schema_name, nodes in schemas.items():
+ label = cluster_labels.get(schema_name, schema_name)
+ lines.append(f" subgraph {label}")
+ for node, data in nodes:
+ safe_id = node.replace(".", "_").replace(" ", "_")
+ if data.get("collapsed"):
+ # Collapsed node - show only table count
+ table_count = data.get("table_count", 0)
+ count_text = f"{table_count} tables" if table_count != 1 else "1 table"
+ lines.append(f' {safe_id}[["({count_text})"]]:::collapsed')
+ else:
+ # Regular node
+ tier = data.get("node_type")
+ left, right = shape_map.get(tier, ("[", "]"))
+ cls = tier_class.get(tier, "")
+ # Strip module prefix from display name if it matches the cluster label
+ display_name = node
+ if "." in node:
+ prefix = node.rsplit(".", 1)[0]
+ if prefix == label:
+ display_name = node.rsplit(".", 1)[1]
+ class_suffix = f":::{cls}" if cls else ""
+ lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}")
+ lines.append(" end")
+
+ lines.append("")
+
+ # Add edges
+ for src, dest, data in graph.edges(data=True):
+ safe_src = src.replace(".", "_").replace(" ", "_")
+ safe_dest = dest.replace(".", "_").replace(" ", "_")
+ # Solid arrow for primary FK, dotted for non-primary
+ style = "-->" if data.get("primary") else "-.->"
+ lines.append(f" {safe_src} {style} {safe_dest}")
+
+ return "\n".join(lines)
+
def _repr_svg_(self):
return self.make_svg()._repr_svg_()
@@ -473,24 +999,38 @@ def save(self, filename: str, format: str | None = None) -> None:
filename : str
Output filename.
format : str, optional
- File format (``'png'`` or ``'svg'``). Inferred from extension if None.
+ File format (``'png'``, ``'svg'``, or ``'mermaid'``).
+ Inferred from extension if None.
Raises
------
DataJointError
If format is unsupported.
+
+ Notes
+ -----
+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
+ Tables are grouped by schema, with the Python module name shown as the
+ group label when available.
"""
if format is None:
if filename.lower().endswith(".png"):
format = "png"
elif filename.lower().endswith(".svg"):
format = "svg"
+ elif filename.lower().endswith((".mmd", ".mermaid")):
+ format = "mermaid"
+ if format is None:
+ raise DataJointError("Could not infer format from filename. Specify format explicitly.")
if format.lower() == "png":
with open(filename, "wb") as f:
f.write(self.make_png().getbuffer().tobytes())
elif format.lower() == "svg":
with open(filename, "w") as f:
f.write(self.make_svg().data)
+ elif format.lower() == "mermaid":
+ with open(filename, "w") as f:
+ f.write(self.make_mermaid())
else:
raise DataJointError("Unsupported file format")
diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py
index ca57a00c6..ddd1b487a 100644
--- a/src/datajoint/settings.py
+++ b/src/datajoint/settings.py
@@ -66,6 +66,7 @@
"database.backend": "DJ_BACKEND",
"database.port": "DJ_PORT",
"loglevel": "DJ_LOG_LEVEL",
+ "display.diagram_direction": "DJ_DIAGRAM_DIRECTION",
}
Role = Enum("Role", "manual lookup imported computed job")
@@ -221,6 +222,11 @@ class DisplaySettings(BaseSettings):
limit: int = 12
width: int = 14
show_tuple_count: bool = True
+ diagram_direction: Literal["TB", "LR"] = Field(
+ default="LR",
+ validation_alias="DJ_DIAGRAM_DIRECTION",
+ description="Default diagram layout direction: 'TB' (top-to-bottom) or 'LR' (left-to-right)",
+ )
class StoresSettings(BaseSettings):
diff --git a/src/datajoint/version.py b/src/datajoint/version.py
index 2ffb3afa8..f19a270de 100644
--- a/src/datajoint/version.py
+++ b/src/datajoint/version.py
@@ -1,4 +1,4 @@
# version bump auto managed by Github Actions:
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
# manually set this version will be eventually overwritten by the above actions
-__version__ = "2.1.0a5"
+__version__ = "2.1.0a7"