Plot Tree from TMap

Plot a tree encoded in a binary TMap

import numpy as np
import matplotlib.pyplot as plt
from triangularmap import ArrayTMap, TMap

# only generate output if run as main to allow imports
is_main = __name__ == "__main__"

Function to generate a random tree:

def get_random_tree(tmap, _offset=0, _n=None):
    """
    Generate a random tree encoded in a binary TMap

    :param tmap: TMap with all zeros/False
    :param _offset: for internal recursion only: offset of the sub-tree
    :param _n: for internal recursion only: span length the of the sub-tree
    """
    # get span length from tmap if not provided
    if _n is None:
        _n = tmap.n
        if _offset != 0:
            raise ValueError("Non-zero offset for initial call")
    if _n == 1:
        return
    start = _offset
    end = _offset + _n
    split = np.random.randint(start + 1, end)
    tmap[start, end] = 1
    tmap[start, split] = 1
    tmap[split, end] = 1
    get_random_tree(tmap, _offset=start, _n=split - start)  # left sub-tree
    get_random_tree(tmap, _offset=split, _n=end - split)  # right sub-tree

Function to extract branches from a tree:

def get_branches(tmap, _start=None, _end=None):
    """
    Return a list of branches from the tree encoded in the binary TMap

    :param tmap: Binary TMap with encoded tree
    :param _start: for internal recursion only: start of the sub-tree
    :param _end: for internal recursion only: end of the sub-tree
    :return: list of [[(parent_start, parent_end), (child_start, child_end)], ...] branches
    """
    if not (_start is None) == (_end is None):
        raise ValueError("Either start and end have to be BOTH provided or none of them.")
    branch_list = []
    if _start is None:
        _start = 0
        _end = tmap.n
    # left child
    left_branch = [(_start, _end)]
    left_start = _start
    for left_end in range(_end - 1, _start, -1):
        if tmap[left_start, left_end]:
            left_branch.append((left_start, left_end))
            break
    else:
        raise RuntimeError("Could not find left child")
    branch_list.append(left_branch)
    # right child
    right_branch = [(_start, _end)]
    right_end = _end
    for right_start in range(_start + 1, _end):
        if tmap[right_start, right_end]:
            right_branch.append((right_start, right_end))
            break
    else:
        raise RuntimeError("Could not find right child")
    branch_list.append(right_branch)
    # recurse
    if left_end - left_start > 1:
        branch_list += get_branches(tmap, _start=left_start, _end=left_end)
    if right_end - right_start > 1:
        branch_list += get_branches(tmap, _start=right_start, _end=right_end)
    return branch_list

Function to plot a tree:

def plot_tree(tmap, label_tmap=None, ax=None,
              node_kwargs=(), branch_kwargs=(), label_kwargs=(),
              x_y_from_start_end=None):
    """
    Plot the nodes and branches from a tree encoded in a binary TMap.

    :param tmap: Binary TMap with tree
    :param ax: axis to plot to or None (default)
    :param node_kwargs: key-word arguments passed to the scatter() plot call for nodes
    :param branch_kwargs: key-word arguments passed to the plot() call for the branches
    :param x_y_from_start_end: function to map (start, end) parts to (x, y) coordinates; default is
     ``x = (start + end) / 2`` and ``y = end - start``.
    """
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    # function to map (start, end) parts to (x, y) coordinates
    if x_y_from_start_end is None:
        def x_y_from_start_end(start, end):
            return (start + end) / 2, end - start
    # plot nodes
    x, y = [], []
    for start in range(tmap.n):
        for end in range(start + 1, tmap.n + 1):
            if tmap[start, end]:
                x_, y_ = x_y_from_start_end(start, end)
                x.append(x_)
                y.append(y_)
                if label_tmap is not None:
                    ax.annotate(label_tmap[start, end], (x_, y_), **dict(label_kwargs))
    ax.scatter(x=x, y=y, **dict(node_kwargs))
    # plot branches
    lines = []
    for (p_start, p_end), (c_start, c_end) in get_branches(tmap):
        if lines:
            lines.append((np.nan, np.nan))
        lines.append(x_y_from_start_end(p_start, p_end))
        lines.append(x_y_from_start_end(c_start, c_end))
    lines = np.array(lines)
    ax.plot(*lines.T, **dict(branch_kwargs))

Get a random tree and print as TMap:

tmap = ArrayTMap(10, value=0)
get_random_tree(tmap)
if is_main:
    print(TMap(["o" if v else "" for v in tmap.arr]).pretty(crosses=True))
          ╳
         ╳o╳
        ╳ ╳ ╳
       ╳ ╳ ╳ ╳
      ╳ ╳ ╳ ╳ ╳
     ╳ ╳ ╳ ╳ ╳ ╳
    ╳o╳ ╳ ╳ ╳ ╳o╳
   ╳ ╳ ╳ ╳ ╳ ╳ ╳ ╳
  ╳ ╳ ╳o╳ ╳ ╳ ╳ ╳o╳
 ╳o╳ ╳ ╳o╳ ╳o╳ ╳ ╳o╳
╳o╳o╳o╳o╳o╳o╳o╳o╳o╳o╳

Extract branches from this tree and print as TMap:

branches = get_branches(tmap)
tmap = ArrayTMap(tmap.n, value="", dtype=object)
for (parent_start, parent_end), (child_start, child_end) in branches:
    tmap[parent_start, parent_end] += "v"
    tmap[child_start, child_end] += "^"
if is_main:
    print(tmap.pretty(crosses=True))
                    ╳
                   ╱ ╲
                  ╳ vv╳
                 ╱ ╲ ╱ ╲
                ╳   ╳   ╳
               ╱ ╲ ╱ ╲ ╱ ╲
              ╳   ╳   ╳   ╳
             ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
            ╳   ╳   ╳   ╳   ╳
           ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
          ╳   ╳   ╳   ╳   ╳   ╳
         ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
        ╳^vv╳   ╳   ╳   ╳   ╳^vv╳
       ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
      ╳   ╳   ╳   ╳   ╳   ╳   ╳   ╳
     ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
    ╳   ╳   ╳^vv╳   ╳   ╳   ╳   ╳^vv╳
   ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
  ╳^vv╳   ╳   ╳^vv╳   ╳^vv╳   ╳   ╳^vv╳
 ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳  ^╳  ^╳  ^╳  ^╳  ^╳  ^╳  ^╳  ^╳  ^╳  ^╳

Plot the tree:

if is_main:
    plt.figure()
    plot_tree(tmap, ax=plt.gca())
    plt.show()
plot tree from tmap

Total running time of the script: (0 minutes 0.046 seconds)

Gallery generated by Sphinx-Gallery