Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 68 additions & 44 deletions graphconstructor/visualization/network_plots_nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,32 @@
from graphconstructor import Graph


def _discrete_palette(cmap_name: str, n_fallback: int = 20):
"""
Return a list of base colors from a colormap. If the colormap is a ListedColormap
(e.g., tab20), we use its native discrete colors. Otherwise we sample `n_fallback`
evenly spaced colors from the continuous colormap.
"""
cmap_obj = colormaps.get_cmap(cmap_name)
# Try to use listed colors when available (e.g., tab10, tab20, Set3, etc.)
base = getattr(cmap_obj, "colors", None)
if base is not None:
return list(base)
# Otherwise, sample a reasonable discrete palette from a continuous map
return [cmap_obj(i / max(n_fallback - 1, 1)) for i in range(n_fallback)]


def plot_graph_by_feature(
G: Graph,
color_attribute: str = None,
attribute_type: str = "categorical",
pos=None,
cmap_name: str = "tab20",
cmap: str = "tab20",
default_color="teal",
with_labels: bool = True):
"""
Color nodes by the selected attribute stored on each node (e.g., node['cf_class']).

Parameters
----------
G : Graph
Expand All @@ -28,54 +43,58 @@ def plot_graph_by_feature(
'categorical' or 'continuous'. This will determine the legend used.
pos : dict or None
Optional positions dict; if None, uses nx.spring_layout.
cmap_name : str
cmap : str
Matplotlib colormap (e.g., 'tab20' for categorical, 'viridis' for continuous).
default_color : str
Color for nodes missing the class attribute.
with_labels : bool
Draw node labels.

Notes:
- For categorical attributes, if the number of unique categories exceeds the number
of discrete colors in the chosen colormap (e.g., >20 for 'tab20'), the palette
is *cycled* so every category still gets a color.
"""
if attribute_type not in {"categorical", "continuous"}:
raise ValueError("attribute_type must be 'categorical' or 'continuous'")

nxG = G.to_networkx()
node_list = list(nxG.nodes())

# Initialize variables for continuous colorbar
norm = None
cmap_continuous = None

if color_attribute:
node_features = [nxG.nodes[n].get(color_attribute, None) for n in node_list]

if attribute_type == "categorical":
# Stable set of unique classes (preserve first-seen order, skip None)
# Preserve first-seen order, skip None
unique_classes = [c for c in dict.fromkeys(node_features) if c is not None]
unique_classes.sort()

# Map classes -> colors

if unique_classes:
cmap = colormaps.get_cmap(cmap_name, len(unique_classes))
class_to_color = {c: cmap(i) for i, c in enumerate(unique_classes)}
base_colors = _discrete_palette(cmap) # e.g., 20 colors for tab20
k = len(base_colors)
# Cycle through palette when categories > base palette size
class_to_color = {c: base_colors[i % k] for i, c in enumerate(unique_classes)}
else:
class_to_color = {}

node_colors = [class_to_color.get(c, default_color) for c in node_features]

elif attribute_type == "continuous":
# Filter out None values to find min/max
valid_values = [v for v in node_features if v is not None]

if valid_values:
# Convert to numeric (in case they aren't already)
try:
valid_values = [float(v) for v in valid_values]
vmin, vmax = min(valid_values), max(valid_values)

# Create normalization and colormap for continuous scale
norm = Normalize(vmin=vmin, vmax=vmax)
cmap_continuous = colormaps.get_cmap(cmap_name)
cmap_continuous = colormaps.get_cmap(cmap)

# Map node values to colors
node_colors = []
for val in node_features:
Expand All @@ -86,10 +105,9 @@ def plot_graph_by_feature(
node_colors.append(default_color)
else:
node_colors.append(default_color)
unique_classes = True # Flag to indicate we have valid data

unique_classes = True # flag: we have valid data
except (ValueError, TypeError):
# Fall back to default color if conversion fails
node_colors = [default_color] * len(node_list)
unique_classes = False
else:
Expand All @@ -98,33 +116,33 @@ def plot_graph_by_feature(
else:
node_colors = default_color
unique_classes = False

# Handle edge weights
if G.weighted:
edge_weights = [d.get("weight", 1.0) for _, _, d in nxG.edges(data=True)]
if edge_weights:
max_weight = max(edge_weights)
edge_widths = [0.5 + 5.0 * (w / max_weight) for w in edge_weights]
edge_colors = [cm.gray(w/max_weight) for w in edge_weights]
edge_colors = [cm.Greys(w / max_weight) for w in edge_weights]
else:
edge_widths = 1.0
edge_colors = "gray"
else:
edge_widths = 1.0
edge_colors = "gray"

# Node sizes based on degree
degrees = dict(nxG.degree())
node_sizes = [200.0 * (1.0 + np.sqrt(degrees.get(n, 0))) for n in nxG.nodes()]

# Layout
if pos is None:
pos = nx.spring_layout(nxG, seed=42)

# Figure size
size = (len(node_list) ** 0.5)
fig, ax = plt.subplots(figsize=(size, size))

# Draw the graph
nx.draw(
nxG,
Expand All @@ -139,24 +157,30 @@ def plot_graph_by_feature(
linewidths=0.5,
font_size=8,
)

# Legend or Colorbar
if unique_classes:
if attribute_type == "categorical":
# Categorical legend with patches
handles = [Patch(facecolor=class_to_color[c], edgecolor="none", label=str(c))
for c in unique_classes]
if attribute_type == "categorical" and color_attribute:
# Recompute unique_classes in this scope if needed
node_features = [nxG.nodes[n].get(color_attribute, None) for n in node_list]
unique_classes = [c for c in dict.fromkeys(node_features) if c is not None]
if unique_classes:
base_colors = _discrete_palette(cmap)
k = len(base_colors)
class_to_color = {c: base_colors[i % k] for i, c in enumerate(unique_classes)}
handles = [
Patch(facecolor=class_to_color[c], edgecolor="none", label=str(c))
for c in unique_classes
]
ax.legend(handles=handles, title=color_attribute, loc="best", frameon=True)

elif attribute_type == "continuous" and norm is not None and cmap_continuous is not None:
# Continuous colorbar
sm = cm.ScalarMappable(cmap=cmap_continuous, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(color_attribute, rotation=270, labelpad=15)


elif attribute_type == "continuous" and norm is not None and cmap_continuous is not None:
sm = cm.ScalarMappable(cmap=cmap_continuous, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(color_attribute, rotation=270, labelpad=15)

ax.set_axis_off()
fig.tight_layout()
plt.show()

return fig, ax
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphconstructor"
version = "0.2.1"
version = "0.2.2"
description = "Python library for graph construction from similarity or distance metrics"
authors = ["Florian Huber <florian.huber@hs-duesseldorf.de>"]

Expand Down
Loading