Note
Go to the end to download the full example code.
Hierarchical Clustering
Simple brute-force hierarchical clustering that stores all possible hierarchies in a trellis hypergraph.
Creating the Trellis Graph
from __future__ import annotations
from typing import Iterable
import numpy as np
class Node:
"""
A node in a trellis graph representing a cluster in a hierarchy.
"""
def __init__(self,
mask: np.ndarray,
incoming_edges: Iterable[Edge] = (),
outgoing_edges: Iterable[Edge] = (),
value=None):
self.mask = mask # binary mask indicating which data points are covered by this node
self.mask_str = "".join(["1" if m else "0" for m in self.mask])
self.incoming_edges = set(incoming_edges) # set of incoming edges
self.outgoing_edges = set(outgoing_edges) # set of outgoing edges
self.value = value
def __repr__(self):
s = ""
for attr in ["mask_str", "value"]:
if s:
s += ", "
s += f"{attr}: {getattr(self, attr)}"
return f"{self.__class__.__name__}(" + s + ")"
class Edge:
"""
A hyper-edge connecting nodes in the trellis graph.
"""
def __init__(self, incoming_nodes: Iterable[Node], outgoing_node: Node, self_add=True):
self.incoming_nodes = frozenset(incoming_nodes) # set of incoming nodes
self.outgoing_node = outgoing_node # the single outgoing nodes
if self_add:
for n in incoming_nodes:
n.outgoing_edges.add(self)
outgoing_node.incoming_edges.add(self)
class HyperGraph:
def __init__(self, edge_type=Edge, node_type=Node):
# all nodes in the hyper graph as dict-of-dicts storing nodes per "level" (i.e. number of non-zero mask
# entries) and using their string mask as ID: {1: {node.mask_str: node, ...}, 2: {...}, ...}
self.all_nodes = {}
self.all_edges = []
self.Edge = edge_type
self.Node = node_type
def cluster(self, data):
# reset all nodes and edges
self.all_nodes = {1: {}}
self.all_edges = []
# create leaf nodes corresponding to data points
for idx, d in enumerate(data):
mask = np.zeros(len(data), dtype=bool)
mask[idx] = True
node = self.Node(mask=mask, value=d)
self.all_nodes[1][node.mask_str] = node
# create nodes level by level
for level in range(2, len(data) + 1):
self.all_nodes[level] = {}
for split in range(1, level):
for left_node in self.all_nodes[split].values():
for right_node in self.all_nodes[level - split].values():
if left_node is right_node or np.any(np.logical_and(left_node.mask, right_node.mask)):
# don't combine node with itself or other nodes with overlapping masks
continue
# create new combined node
mask = np.logical_or(left_node.mask, right_node.mask)
new_node = self.Node(mask=mask, value=left_node.value + right_node.value)
try:
# try to reuse existing node
new_node = self.all_nodes[level][new_node.mask_str]
except KeyError:
# node does not exist
self.all_nodes[level][new_node.mask_str] = new_node
# create hyper edge connecting nodes
self.all_edges.append(self.Edge(incoming_nodes=[left_node, right_node],
outgoing_node=new_node,
self_add=True))
g = HyperGraph()
g.cluster("ABCD") # 4 points
# g.cluster("ABCDEFGHI") # 9 points
for level, nodes in g.all_nodes.items():
print(f"LEVEL {level}")
for n in nodes.values():
print(n)
# for e in n.outgoing_edges:
# print(f" --> {e.outgoing_node}")
LEVEL 1
Node(mask_str: 1000, value: A)
Node(mask_str: 0100, value: B)
Node(mask_str: 0010, value: C)
Node(mask_str: 0001, value: D)
LEVEL 2
Node(mask_str: 1100, value: AB)
Node(mask_str: 1010, value: AC)
Node(mask_str: 1001, value: AD)
Node(mask_str: 0110, value: BC)
Node(mask_str: 0101, value: BD)
Node(mask_str: 0011, value: CD)
LEVEL 3
Node(mask_str: 1110, value: ABC)
Node(mask_str: 1101, value: ABD)
Node(mask_str: 1011, value: ACD)
Node(mask_str: 0111, value: BCD)
LEVEL 4
Node(mask_str: 1111, value: ABCD)
Plotting the Trellis Graph
To plot the trellis graph, we first assign locations to the nodes on the first level (the original points)
# arrange nodes in a square in the x-y-plane
node_location_dict = {}
node_labels_dict = {}
n_nodes = len(g.all_nodes[1])
square_edge_len = round(np.sqrt(n_nodes))
for idx, node in enumerate(g.all_nodes[1].values()):
xyz = np.array([idx % square_edge_len, idx // square_edge_len, 0.])
xyz *= n_nodes / square_edge_len
xyz[2] = 1 # set z coordinate
node_location_dict[node.mask_str] = xyz
node_labels_dict[node.mask_str] = node.value
Higher-level nodes get the mean x/y-value from all of their children and the number of children as z value
for level, nodes in g.all_nodes.items():
if level == 1:
continue
for idx, node in enumerate(nodes.values()):
xyz = np.zeros(3)
n_count = 0
for e in node.incoming_edges:
for child in e.incoming_nodes:
n_count += 1
xyz += node_location_dict[child.mask_str]
xyz /= n_count
xyz[2] = level
node_location_dict[node.mask_str] = xyz
node_labels_dict[node.mask_str] = node.value
node_locations = np.array(list(node_location_dict.values()))
node_labels = list(node_labels_dict.values())
Hyper-edges get the mean x/y-value from all connected nodes and a z-value in between their highest incoming node and the single outgoing node
edge_locations = []
connections = []
stop = [np.nan, np.nan, np.nan] # to interrupt line plot
for e in g.all_edges:
out_xyz = node_location_dict[e.outgoing_node.mask_str]
xyz = out_xyz.copy() # the use of xyz below relies on in-place operations!
connections += [xyz.copy(), xyz, stop] # draw connection, then interrupt plot
n_count = 1
max_z = -np.inf
for n in e.incoming_nodes:
n_count += 1
node_xyz = node_location_dict[n.mask_str]
xyz += node_xyz
connections += [node_xyz, xyz, stop]
max_z = max(max_z, node_xyz[2])
xyz /= n_count
xyz[2] = (max_z + out_xyz[2]) / 2
edge_locations.append(xyz)
edge_locations = np.array(edge_locations)
connections = np.array(connections)
Nodes, hyper-edges, and connections are plotted as 3D scatter plots
import plotly.graph_objs as go
data = []
# nodes
data.append(go.Scatter3d(x=node_locations[:, 0], y=node_locations[:, 1], z=node_locations[:, 2],
text=node_labels, hovertemplate='%{text}<extra></extra>',
mode='markers', name="nodes", marker=dict(symbol='circle', size=5)))
# edges
data.append(go.Scatter3d(x=edge_locations[:, 0], y=edge_locations[:, 1], z=edge_locations[:, 2],
mode='markers', name="edges", marker=dict(symbol='square', size=1)))
# connections
data.append(go.Scatter3d(x=connections[:, 0], y=connections[:, 1], z=connections[:, 2],
mode='lines', hoverinfo='skip', name="connections", opacity=0.1,
line=dict(width=2, color=(0, 0, 0))))
# create figure
fig = go.Figure(data=data,
layout=go.Layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
xaxis_visible=False, yaxis_visible=False, zaxis_visible=False)))
# fig.show()
fig
Total running time of the script: (0 minutes 0.024 seconds)