Note
Go to the end to download the full example code.
Optimisation Landscape
…
…
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import torch
import plotly.colors as pc
import colorsys
from examples.util import show
def subdue_color(rgb_color_str, sat_scale=0.7, light_scale=1):
rgb_color_str = rgb_color_str.strip().lower().replace('rgb(', '').replace(')', '')
r, g, b = [int(x.strip()) / 255.0 for x in rgb_color_str.split(',')]
h, l, s = colorsys.rgb_to_hls(r, g, b)
s = min(s * sat_scale, 1.0)
l = min(l * light_scale, 1.0)
r_new, g_new, b_new = colorsys.hls_to_rgb(h, l, s)
return f'rgb({int(r_new*255)}, {int(g_new*255)}, {int(b_new*255)})'
jet = pc.sequential.Jet
subdued_jet = [(i / (len(jet) - 1), subdue_color(c)) for i, c in enumerate(jet)]
# plot_2D = True
plot_2D = False
def func(xy):
x, y = xy[0], xy[1]
# x, y = y, x
return (
(y / (2 * torch.pi)) * (
- 3 * torch.sin(0.5 * x + torch.pi / 2)
- torch.sin(x)
- 2 * torch.sin(1.5 * x)
)
+
(1 - y / (2 * torch.pi)) * (
- 1
- 2 * torch.sin(0.5 * x + torch.pi / 2)
)
+
2 * (
- torch.sin(0.5 * y)
- 0.1 * torch.sin(1.5 * y)
#
# - 0.2 * torch.sin(1.5 * y)
# - 0.05 * torch.sin(2.5 * y)
)
)
def optimise(func, init_xy, lr=0.1, max_iter=1000, tol=1e-3, min_step=1e-2):
# Initialize parameter tensor
xy = torch.tensor(init_xy, dtype=torch.float32, requires_grad=True)
# Use PyTorch SGD optimizer
optimizer = torch.optim.SGD([xy], lr=lr)
path = [xy.detach().clone()]
for _ in range(max_iter):
optimizer.zero_grad()
loss = func(xy)
loss.backward()
prev_xy = xy.detach().clone()
optimizer.step()
path.append(xy.detach().clone())
# Convergence check
if torch.norm(xy - prev_xy).item() < tol:
break
path = torch.stack(path)
cleaned = [path[0]]
for p in path[1:]:
if torch.norm(p - cleaned[-1]) >= min_step:
cleaned.append(p)
return torch.stack(cleaned)
# get arrow at end of path
def path_arrow(path_xy, path_z, color='black', **kwargs):
start = np.array([path_xy[-2, 0], path_xy[-2, 1], path_z[-2]])
end = np.array([path_xy[-1, 0], path_xy[-1, 1], path_z[-1]])
direction = end - start
colorscale = [[0, color], [1, color]]
kwargs = dict(**dict(sizemode="absolute",
sizeref=0.5,
anchor="tip",
showscale=False,
colorscale=colorscale),
**kwargs)
return go.Cone(
x=[end[0]], y=[end[1]], z=[end[2]],
u=[direction[0]], v=[direction[1]], w=[direction[2]],
**kwargs
)
# x/y/z limits
x_lim = 0, 2 * torch.pi
y_lim = 0, 2 * torch.pi
z_lim = -10, 3
# grid for plotting surface
x_fine = torch.linspace(*x_lim, 100)
y_fine = torch.linspace(*y_lim, 100)
if plot_2D:
x_fine_grid = x_fine
y_fine_grid = torch.zeros_like(y_fine)
else:
x_fine_grid, y_fine_grid = torch.meshgrid(x_fine, y_fine)
z_fine_grid = func((x_fine_grid, y_fine_grid)).detach().numpy()
if plot_2D:
plt.plot(x_fine_grid, z_fine_grid)
plt.show()
exit()
# Run optimization
init_point = [5.2, 5.5]
path = optimise(func, init_point, lr=0.1, min_step=1e-2)
path_z = func(path.T) + 0.1
# detach
path = path.detach().numpy()
path_z = path_z.detach().numpy()
# Surface plot
surface = go.Surface(z=z_fine_grid, x=x_fine_grid, y=y_fine_grid,
colorscale=subdued_jet,
showscale=False, opacity=1,
# hidesurface=True,
contours=dict(
z=dict(
show=True,
usecolormap=True,
project=dict(z=True),
highlight=False,
start=np.min(z_fine_grid),
end=np.max(z_fine_grid),
size=(np.max(z_fine_grid) - np.min(z_fine_grid)) / 30
)
)
)
plot_path = [
go.Scatter3d(x=path[:, 0], y=path[:, 1], z=path_z,
mode='lines',
line=dict(color='black', width=5)),
path_arrow(path, path_z, color='black')
]
path_z = np.ones_like(path_z) * z_lim[0]
plot_path_proj = [
go.Scatter3d(x=path[:, 0], y=path[:, 1], z=path_z,
mode='lines',
line=dict(color='gray', width=5)),
path_arrow(path, path_z, color='gray')
]
# Create grid
x_coarse = torch.linspace(*x_lim, 20)
y_coarse = torch.linspace(*y_lim, 20)
x_grid_x, x_grid_y = torch.meshgrid(x_coarse, y_fine)
x_grid_z = func((x_grid_x, x_grid_y)).detach().numpy()
y_grid_x, y_grid_y = torch.meshgrid(x_fine, y_coarse)
y_grid_z = func((y_grid_x, y_grid_y)).detach().numpy()
line_kwargs = dict(
mode='lines',
line=dict(color='black', width=5),
showlegend=False,
)
x_lines = []
for i in range(len(x_coarse)):
x_lines.append(go.Scatter3d(
x=[x_coarse[i]] * len(y_fine),
y=y_fine,
z=x_grid_z[i, :],
**line_kwargs
))
y_lines = []
for j in range(len(y_coarse)):
y_lines.append(go.Scatter3d(
x=x_fine,
y=[y_coarse[j]] * len(x_fine),
z=y_grid_z[:, j],
**line_kwargs
))
# Combine and plot
fig = go.Figure(data=
[surface]
+ plot_path
+ plot_path_proj
# + x_lines
# + y_lines
)
fig.update_layout(scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False, range=z_lim),
aspectratio=dict(x=1, y=1, z=0.7),
camera=dict(eye=dict(x=-1.5, y=-1.5, z=1.2))
))
show(fig, True)
Total running time of the script: (0 minutes 0.391 seconds)