Note
Go to the end to download the full example code.
Plot Tree from TMap
Plot a tree encoded in a binary TMap
import numpy as np
import matplotlib.pyplot as plt
from triangularmap import ArrayTMap, TMap
# only generate output if run as main to allow imports
is_main = __name__ == "__main__"
Function to generate a random tree:
def get_random_tree(tmap, _offset=0, _n=None):
"""
Generate a random tree encoded in a binary TMap
:param tmap: TMap with all zeros/False
:param _offset: for internal recursion only: offset of the sub-tree
:param _n: for internal recursion only: span length the of the sub-tree
"""
# get span length from tmap if not provided
if _n is None:
_n = tmap.n
if _offset != 0:
raise ValueError("Non-zero offset for initial call")
if _n == 1:
return
start = _offset
end = _offset + _n
split = np.random.randint(start + 1, end)
tmap[start, end] = 1
tmap[start, split] = 1
tmap[split, end] = 1
get_random_tree(tmap, _offset=start, _n=split - start) # left sub-tree
get_random_tree(tmap, _offset=split, _n=end - split) # right sub-tree
Function to extract branches from a tree:
def get_branches(tmap, _start=None, _end=None):
"""
Return a list of branches from the tree encoded in the binary TMap
:param tmap: Binary TMap with encoded tree
:param _start: for internal recursion only: start of the sub-tree
:param _end: for internal recursion only: end of the sub-tree
:return: list of [[(parent_start, parent_end), (child_start, child_end)], ...] branches
"""
if not (_start is None) == (_end is None):
raise ValueError("Either start and end have to be BOTH provided or none of them.")
branch_list = []
if _start is None:
_start = 0
_end = tmap.n
# left child
left_branch = [(_start, _end)]
left_start = _start
for left_end in range(_end - 1, _start, -1):
if tmap[left_start, left_end]:
left_branch.append((left_start, left_end))
break
else:
raise RuntimeError("Could not find left child")
branch_list.append(left_branch)
# right child
right_branch = [(_start, _end)]
right_end = _end
for right_start in range(_start + 1, _end):
if tmap[right_start, right_end]:
right_branch.append((right_start, right_end))
break
else:
raise RuntimeError("Could not find right child")
branch_list.append(right_branch)
# recurse
if left_end - left_start > 1:
branch_list += get_branches(tmap, _start=left_start, _end=left_end)
if right_end - right_start > 1:
branch_list += get_branches(tmap, _start=right_start, _end=right_end)
return branch_list
Function to plot a tree:
def plot_tree(tmap, label_tmap=None, ax=None,
node_kwargs=(), branch_kwargs=(), label_kwargs=(),
x_y_from_start_end=None):
"""
Plot the nodes and branches from a tree encoded in a binary TMap.
:param tmap: Binary TMap with tree
:param ax: axis to plot to or None (default)
:param node_kwargs: key-word arguments passed to the scatter() plot call for nodes
:param branch_kwargs: key-word arguments passed to the plot() call for the branches
:param x_y_from_start_end: function to map (start, end) parts to (x, y) coordinates; default is
``x = (start + end) / 2`` and ``y = end - start``.
"""
if ax is None:
fig, ax = plt.subplots(1, 1)
# function to map (start, end) parts to (x, y) coordinates
if x_y_from_start_end is None:
def x_y_from_start_end(start, end):
return (start + end) / 2, end - start
# plot nodes
x, y = [], []
for start in range(tmap.n):
for end in range(start + 1, tmap.n + 1):
if tmap[start, end]:
x_, y_ = x_y_from_start_end(start, end)
x.append(x_)
y.append(y_)
if label_tmap is not None:
ax.annotate(label_tmap[start, end], (x_, y_), **dict(label_kwargs))
ax.scatter(x=x, y=y, **dict(node_kwargs))
# plot branches
lines = []
for (p_start, p_end), (c_start, c_end) in get_branches(tmap):
if lines:
lines.append((np.nan, np.nan))
lines.append(x_y_from_start_end(p_start, p_end))
lines.append(x_y_from_start_end(c_start, c_end))
lines = np.array(lines)
ax.plot(*lines.T, **dict(branch_kwargs))
Get a random tree and print as TMap:
tmap = ArrayTMap(10, value=0)
get_random_tree(tmap)
if is_main:
print(TMap(["o" if v else "" for v in tmap.arr]).pretty(crosses=True))
╳
╳o╳
╳ ╳ ╳
╳ ╳ ╳ ╳
╳ ╳ ╳ ╳ ╳
╳ ╳ ╳ ╳ ╳ ╳
╳o╳ ╳ ╳ ╳ ╳o╳
╳ ╳ ╳ ╳ ╳ ╳ ╳ ╳
╳ ╳ ╳o╳ ╳ ╳ ╳ ╳o╳
╳o╳ ╳ ╳o╳ ╳o╳ ╳ ╳o╳
╳o╳o╳o╳o╳o╳o╳o╳o╳o╳o╳
Extract branches from this tree and print as TMap:
branches = get_branches(tmap)
tmap = ArrayTMap(tmap.n, value="", dtype=object)
for (parent_start, parent_end), (child_start, child_end) in branches:
tmap[parent_start, parent_end] += "v"
tmap[child_start, child_end] += "^"
if is_main:
print(tmap.pretty(crosses=True))
╳
╱ ╲
╳ vv╳
╱ ╲ ╱ ╲
╳ ╳ ╳
╱ ╲ ╱ ╲ ╱ ╲
╳ ╳ ╳ ╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳ ╳ ╳ ╳ ╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳ ╳ ╳ ╳ ╳ ╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳^vv╳ ╳ ╳ ╳ ╳^vv╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳ ╳ ╳ ╳ ╳ ╳ ╳ ╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳ ╳ ╳^vv╳ ╳ ╳ ╳ ╳^vv╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳^vv╳ ╳ ╳^vv╳ ╳^vv╳ ╳ ╳^vv╳
╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲ ╱ ╲
╳ ^╳ ^╳ ^╳ ^╳ ^╳ ^╳ ^╳ ^╳ ^╳ ^╳
Plot the tree:
if is_main:
plt.figure()
plot_tree(tmap, ax=plt.gca())
plt.show()

Total running time of the script: (0 minutes 0.046 seconds)