Hierarchical Clustering

Simple brute-force hierarchical clustering that stores all possible hierarchies in a trellis hypergraph.

Creating the Trellis Graph

from __future__ import annotations
from typing import Iterable

import numpy as np


class Node:
    """
    A node in a trellis graph representing a cluster in a hierarchy.
    """

    def __init__(self,
                 mask: np.ndarray,
                 incoming_edges: Iterable[Edge] = (),
                 outgoing_edges: Iterable[Edge] = (),
                 value=None):
        self.mask = mask   # binary mask indicating which data points are covered by this node
        self.mask_str = "".join(["1" if m else "0" for m in self.mask])
        self.incoming_edges = set(incoming_edges)  # set of incoming edges
        self.outgoing_edges = set(outgoing_edges)  # set of outgoing edges
        self.value = value

    def __repr__(self):
        s = ""
        for attr in ["mask_str", "value"]:
            if s:
                s += ", "
            s += f"{attr}: {getattr(self, attr)}"
        return f"{self.__class__.__name__}(" + s + ")"


class Edge:
    """
    A hyper-edge connecting nodes in the trellis graph.
    """

    def __init__(self, incoming_nodes: Iterable[Node], outgoing_node: Node, self_add=True):
        self.incoming_nodes = frozenset(incoming_nodes)  # set of incoming nodes
        self.outgoing_node = outgoing_node  # the single outgoing nodes
        if self_add:
            for n in incoming_nodes:
                n.outgoing_edges.add(self)
            outgoing_node.incoming_edges.add(self)


class HyperGraph:

    def __init__(self, edge_type=Edge, node_type=Node):
        # all nodes in the hyper graph as dict-of-dicts storing nodes per "level" (i.e. number of non-zero mask
        # entries) and using their string mask as ID: {1: {node.mask_str: node, ...}, 2: {...}, ...}
        self.all_nodes = {}
        self.all_edges = []
        self.Edge = edge_type
        self.Node = node_type

    def cluster(self, data):
        # reset all nodes and edges
        self.all_nodes = {1: {}}
        self.all_edges = []

        # create leaf nodes corresponding to data points
        for idx, d in enumerate(data):
            mask = np.zeros(len(data), dtype=bool)
            mask[idx] = True
            node = self.Node(mask=mask, value=d)
            self.all_nodes[1][node.mask_str] = node

        # create nodes level by level
        for level in range(2, len(data) + 1):
            self.all_nodes[level] = {}
            for split in range(1, level):
                for left_node in self.all_nodes[split].values():
                    for right_node in self.all_nodes[level - split].values():
                        if left_node is right_node or np.any(np.logical_and(left_node.mask, right_node.mask)):
                            # don't combine node with itself or other nodes with overlapping masks
                            continue
                        # create new combined node
                        mask = np.logical_or(left_node.mask, right_node.mask)
                        new_node = self.Node(mask=mask, value=left_node.value + right_node.value)
                        try:
                            # try to reuse existing node
                            new_node = self.all_nodes[level][new_node.mask_str]
                        except KeyError:
                            # node does not exist
                            self.all_nodes[level][new_node.mask_str] = new_node
                        # create hyper edge connecting nodes
                        self.all_edges.append(self.Edge(incoming_nodes=[left_node, right_node],
                                                        outgoing_node=new_node,
                                                        self_add=True))


g = HyperGraph()
g.cluster("ABCD")  # 4 points
# g.cluster("ABCDEFGHI")  # 9 points
for level, nodes in g.all_nodes.items():
    print(f"LEVEL {level}")
    for n in nodes.values():
        print(n)
        # for e in n.outgoing_edges:
        #     print(f"    --> {e.outgoing_node}")
LEVEL 1
Node(mask_str: 1000, value: A)
Node(mask_str: 0100, value: B)
Node(mask_str: 0010, value: C)
Node(mask_str: 0001, value: D)
LEVEL 2
Node(mask_str: 1100, value: AB)
Node(mask_str: 1010, value: AC)
Node(mask_str: 1001, value: AD)
Node(mask_str: 0110, value: BC)
Node(mask_str: 0101, value: BD)
Node(mask_str: 0011, value: CD)
LEVEL 3
Node(mask_str: 1110, value: ABC)
Node(mask_str: 1101, value: ABD)
Node(mask_str: 1011, value: ACD)
Node(mask_str: 0111, value: BCD)
LEVEL 4
Node(mask_str: 1111, value: ABCD)

Plotting the Trellis Graph

To plot the trellis graph, we first assign locations to the nodes on the first level (the original points)

# arrange nodes in a square in the x-y-plane
node_location_dict = {}
node_labels_dict = {}
n_nodes = len(g.all_nodes[1])
square_edge_len = round(np.sqrt(n_nodes))
for idx, node in enumerate(g.all_nodes[1].values()):
    xyz = np.array([idx % square_edge_len, idx // square_edge_len, 0.])
    xyz *= n_nodes / square_edge_len
    xyz[2] = 1  # set z coordinate
    node_location_dict[node.mask_str] = xyz
    node_labels_dict[node.mask_str] = node.value

Higher-level nodes get the mean x/y-value from all of their children and the number of children as z value

for level, nodes in g.all_nodes.items():
    if level == 1:
        continue
    for idx, node in enumerate(nodes.values()):
        xyz = np.zeros(3)
        n_count = 0
        for e in node.incoming_edges:
            for child in e.incoming_nodes:
                n_count += 1
                xyz += node_location_dict[child.mask_str]
        xyz /= n_count
        xyz[2] = level
        node_location_dict[node.mask_str] = xyz
        node_labels_dict[node.mask_str] = node.value
node_locations = np.array(list(node_location_dict.values()))
node_labels = list(node_labels_dict.values())

Hyper-edges get the mean x/y-value from all connected nodes and a z-value in between their highest incoming node and the single outgoing node

edge_locations = []
connections = []
stop = [np.nan, np.nan, np.nan]  # to interrupt line plot
for e in g.all_edges:
    out_xyz = node_location_dict[e.outgoing_node.mask_str]
    xyz = out_xyz.copy()  # the use of xyz below relies on in-place operations!
    connections += [xyz.copy(), xyz, stop]  # draw connection, then interrupt plot
    n_count = 1
    max_z = -np.inf
    for n in e.incoming_nodes:
        n_count += 1
        node_xyz = node_location_dict[n.mask_str]
        xyz += node_xyz
        connections += [node_xyz, xyz, stop]
        max_z = max(max_z, node_xyz[2])
    xyz /= n_count
    xyz[2] = (max_z + out_xyz[2]) / 2
    edge_locations.append(xyz)
edge_locations = np.array(edge_locations)
connections = np.array(connections)

Nodes, hyper-edges, and connections are plotted as 3D scatter plots

import plotly.graph_objs as go

data = []
# nodes
data.append(go.Scatter3d(x=node_locations[:, 0], y=node_locations[:, 1], z=node_locations[:, 2],
                         text=node_labels, hovertemplate='%{text}<extra></extra>',
                         mode='markers', name="nodes", marker=dict(symbol='circle', size=5)))
# edges
data.append(go.Scatter3d(x=edge_locations[:, 0], y=edge_locations[:, 1], z=edge_locations[:, 2],
                         mode='markers', name="edges", marker=dict(symbol='square', size=1)))
# connections
data.append(go.Scatter3d(x=connections[:, 0], y=connections[:, 1], z=connections[:, 2],
                         mode='lines', hoverinfo='skip', name="connections", opacity=0.1,
                         line=dict(width=2, color=(0, 0, 0))))
# create figure
fig = go.Figure(data=data,
                layout=go.Layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
                                            xaxis_visible=False, yaxis_visible=False, zaxis_visible=False)))
# fig.show()
fig


Total running time of the script: (0 minutes 0.024 seconds)

Gallery generated by Sphinx-Gallery