mirror of
https://github.com/Mobile-Robotics-W20-Team-9/polex.git
synced 2025-09-09 21:53:15 +00:00
Add code.
This commit is contained in:
81
poles/ndshow.py
Normal file
81
poles/ndshow.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.widgets import Slider
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def _take(data, index):
|
||||
slice = data
|
||||
for i in range(len(index)):
|
||||
slice = np.take(slice, min(slice.shape[0]-1, index[i]), axis=0)
|
||||
|
||||
return slice
|
||||
|
||||
|
||||
def matshow(data, matnames=[], dimnames=[]):
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
for i in range(len(data)):
|
||||
if type(data[i]) is torch.Tensor:
|
||||
data[i] = data[i].numpy()
|
||||
|
||||
ndim = max([d.ndim for d in data])
|
||||
|
||||
for i in range(len(data)):
|
||||
while data[i].ndim < ndim:
|
||||
data[i] = np.expand_dims(data[i], axis=0)
|
||||
|
||||
shape = []
|
||||
for dim in range(ndim - 2):
|
||||
shape.append(max(d.shape[dim] for d in data))
|
||||
|
||||
figure, axes = plt.subplots(
|
||||
1, len(data), sharex=True, sharey=True, squeeze=False)
|
||||
|
||||
for i in range(len(data)):
|
||||
axes[0,i].imshow(_take(data[i], [0] * (ndim-2)),
|
||||
vmin=np.amin(data[i]), vmax=np.amax(data[i]),
|
||||
interpolation=None, origin='lower',
|
||||
extent=[0.0, data[i].shape[-1], 0.0, data[i].shape[-2]])
|
||||
|
||||
for i in range(min(len(data), len(matnames))):
|
||||
axes[0,i].set_title(matnames[i])
|
||||
|
||||
sliders = []
|
||||
updatefuncs = []
|
||||
bottom = np.linspace(0.0, 0.1, ndim)[1:-1]
|
||||
for i in range(len(shape)):
|
||||
sliderax = plt.axes([0.2, bottom[i], 0.6, 0.02],
|
||||
facecolor='lightgoldenrodyellow')
|
||||
|
||||
if i < len(dimnames):
|
||||
label = dimnames[i]
|
||||
else:
|
||||
label = 'Axis {}'.format(i)
|
||||
sliders.append(Slider(sliderax, label=label,
|
||||
valmin=0, valmax=shape[i]-1, valinit=0, valstep=1))
|
||||
|
||||
def update(val):
|
||||
indices = [int(slider.val) for slider in sliders]
|
||||
for j in range(axes.size):
|
||||
axes[0,j].images[0].set_array(_take(data[j], indices))
|
||||
figure.canvas.draw_idle()
|
||||
updatefuncs.append(update)
|
||||
|
||||
sliders[i].on_changed(updatefuncs[i])
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
matshow(np.random.rand(100, 15, 25, 64, 64))
|
||||
matshow(np.random.rand(100, 15, 25, 64, 64),
|
||||
matnames=['Matrix A', 'Matrix B'],
|
||||
dimnames=['depth', 'height', 'width'])
|
||||
matshow([np.random.rand(64, 64, 64), np.random.rand(10, 10, 64)])
|
||||
matshow([np.random.rand(9, 11), np.random.rand(12, 12, 6)],
|
||||
matnames=['9x11', '12x12x6'])
|
||||
matshow(np.random.rand(100, 100, 100))
|
Reference in New Issue
Block a user