Files
polex/poles/ndshow.py
Alexander Schaefer 17877ef198 Add code.
2019-07-25 10:28:19 +02:00

82 lines
2.4 KiB
Python

#!/usr/bin/env python
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
import torch
def _take(data, index):
slice = data
for i in range(len(index)):
slice = np.take(slice, min(slice.shape[0]-1, index[i]), axis=0)
return slice
def matshow(data, matnames=[], dimnames=[]):
if not isinstance(data, list):
data = [data]
for i in range(len(data)):
if type(data[i]) is torch.Tensor:
data[i] = data[i].numpy()
ndim = max([d.ndim for d in data])
for i in range(len(data)):
while data[i].ndim < ndim:
data[i] = np.expand_dims(data[i], axis=0)
shape = []
for dim in range(ndim - 2):
shape.append(max(d.shape[dim] for d in data))
figure, axes = plt.subplots(
1, len(data), sharex=True, sharey=True, squeeze=False)
for i in range(len(data)):
axes[0,i].imshow(_take(data[i], [0] * (ndim-2)),
vmin=np.amin(data[i]), vmax=np.amax(data[i]),
interpolation=None, origin='lower',
extent=[0.0, data[i].shape[-1], 0.0, data[i].shape[-2]])
for i in range(min(len(data), len(matnames))):
axes[0,i].set_title(matnames[i])
sliders = []
updatefuncs = []
bottom = np.linspace(0.0, 0.1, ndim)[1:-1]
for i in range(len(shape)):
sliderax = plt.axes([0.2, bottom[i], 0.6, 0.02],
facecolor='lightgoldenrodyellow')
if i < len(dimnames):
label = dimnames[i]
else:
label = 'Axis {}'.format(i)
sliders.append(Slider(sliderax, label=label,
valmin=0, valmax=shape[i]-1, valinit=0, valstep=1))
def update(val):
indices = [int(slider.val) for slider in sliders]
for j in range(axes.size):
axes[0,j].images[0].set_array(_take(data[j], indices))
figure.canvas.draw_idle()
updatefuncs.append(update)
sliders[i].on_changed(updatefuncs[i])
plt.show()
if __name__ == '__main__':
matshow(np.random.rand(100, 15, 25, 64, 64))
matshow(np.random.rand(100, 15, 25, 64, 64),
matnames=['Matrix A', 'Matrix B'],
dimnames=['depth', 'height', 'width'])
matshow([np.random.rand(64, 64, 64), np.random.rand(10, 10, 64)])
matshow([np.random.rand(9, 11), np.random.rand(12, 12, 6)],
matnames=['9x11', '12x12x6'])
matshow(np.random.rand(100, 100, 100))