mirror of
https://github.com/Mobile-Robotics-W20-Team-9/polex.git
synced 2025-09-07 21:23:13 +00:00
82 lines
2.4 KiB
Python
82 lines
2.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.widgets import Slider
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def _take(data, index):
|
|
slice = data
|
|
for i in range(len(index)):
|
|
slice = np.take(slice, min(slice.shape[0]-1, index[i]), axis=0)
|
|
|
|
return slice
|
|
|
|
|
|
def matshow(data, matnames=[], dimnames=[]):
|
|
if not isinstance(data, list):
|
|
data = [data]
|
|
|
|
for i in range(len(data)):
|
|
if type(data[i]) is torch.Tensor:
|
|
data[i] = data[i].numpy()
|
|
|
|
ndim = max([d.ndim for d in data])
|
|
|
|
for i in range(len(data)):
|
|
while data[i].ndim < ndim:
|
|
data[i] = np.expand_dims(data[i], axis=0)
|
|
|
|
shape = []
|
|
for dim in range(ndim - 2):
|
|
shape.append(max(d.shape[dim] for d in data))
|
|
|
|
figure, axes = plt.subplots(
|
|
1, len(data), sharex=True, sharey=True, squeeze=False)
|
|
|
|
for i in range(len(data)):
|
|
axes[0,i].imshow(_take(data[i], [0] * (ndim-2)),
|
|
vmin=np.amin(data[i]), vmax=np.amax(data[i]),
|
|
interpolation=None, origin='lower',
|
|
extent=[0.0, data[i].shape[-1], 0.0, data[i].shape[-2]])
|
|
|
|
for i in range(min(len(data), len(matnames))):
|
|
axes[0,i].set_title(matnames[i])
|
|
|
|
sliders = []
|
|
updatefuncs = []
|
|
bottom = np.linspace(0.0, 0.1, ndim)[1:-1]
|
|
for i in range(len(shape)):
|
|
sliderax = plt.axes([0.2, bottom[i], 0.6, 0.02],
|
|
facecolor='lightgoldenrodyellow')
|
|
|
|
if i < len(dimnames):
|
|
label = dimnames[i]
|
|
else:
|
|
label = 'Axis {}'.format(i)
|
|
sliders.append(Slider(sliderax, label=label,
|
|
valmin=0, valmax=shape[i]-1, valinit=0, valstep=1))
|
|
|
|
def update(val):
|
|
indices = [int(slider.val) for slider in sliders]
|
|
for j in range(axes.size):
|
|
axes[0,j].images[0].set_array(_take(data[j], indices))
|
|
figure.canvas.draw_idle()
|
|
updatefuncs.append(update)
|
|
|
|
sliders[i].on_changed(updatefuncs[i])
|
|
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
matshow(np.random.rand(100, 15, 25, 64, 64))
|
|
matshow(np.random.rand(100, 15, 25, 64, 64),
|
|
matnames=['Matrix A', 'Matrix B'],
|
|
dimnames=['depth', 'height', 'width'])
|
|
matshow([np.random.rand(64, 64, 64), np.random.rand(10, 10, 64)])
|
|
matshow([np.random.rand(9, 11), np.random.rand(12, 12, 6)],
|
|
matnames=['9x11', '12x12x6'])
|
|
matshow(np.random.rand(100, 100, 100))
|