import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from skimage import io
from tqdm import tqdm
import pathlib
from pathlib import Path
from PIL import Image

## Custom module
from enamel_rod_recon import *

img_filepath = "PIVframes/Substack (460-461).tif"

img_path = Path(img_filepath)
base_name = img_path.name


image_directory = "PIVframes"
output_directory = "PIVframes/Quiverplots"
filename = "quiver"+str(base_name)+'.png'

output_dir = os.path.join(output_directory, filename)

img = io.imread(img_filepath)  # Z, Y , X
img.shape

frame_a = img[0,0:img.shape[2],0:img.shape[2]]
frame_b = img[1,0:img.shape[2],0:img.shape[2]]

winsize = 32
downscale = 2
voxel_size = 0.3e-6*downscale # meters
pixel_scale = 1  # set to 1, operate everything in pixels before converting to um
spacing = 1

params = {
    "noise_threshold": 1.0,  ## usually 1.0 is sufficient, set to 0.0 for auto (to be fixed later)
    "show_hist": False,
    "max_iter": 40,
    "kernel_size": winsize//2,

    "winsize": winsize, # pixels, interrogation window size in frame A
    "dz": pixel_scale*spacing, # sec, time interval between the two frames
    "searchsize": winsize,  # pixels, search area size in frame B
    "overlap": int(0.75*winsize), # pixels, 50% overlap
    "scaling_factor": 1/pixel_scale,
    
             ## Reconstruction params
    "interp_method": 'linear',  ## nearest, linear, or cubic (will be slow)
}

x, y, u, v, _  = generate_flowfield(frame_a, frame_b, params)


plt.figure(figsize=(5,5), num=base_name)
plt.quiver(x, y, u, v, scale=1.5, color='black')
ax = plt.gca()
ax.set_aspect('equal', adjustable='box')
#ax.set_axis_off()

flowfield_array = np.array([]).reshape(x.reshape(-1,1).shape[0],6,0)  ## initialize empty array
for slice_num in tqdm(range(1,img.shape[0],spacing)):
    
    frame_a = img[slice_num-1,0:img.shape[2],0:img.shape[2]]
    frame_b = img[slice_num,0:img.shape[2],0:img.shape[2]]
    x, y, u, v, invalid_mask = generate_flowfield(frame_a, frame_b, params)
    z_vect = (slice_num-1)*voxel_size*np.ones(x.reshape(-1,1).shape)
    w_vect = params["dz"]*np.ones(x.reshape(-1,1).shape)
    
    #current_flowfield_flat = np.hstack((x.reshape(-1,1), y.reshape(-1,1), z, u.reshape(-1,1), v.reshape(-1,1), w))
# flowfield_flat = np.vstack([flowfield_flat, current_flowfield_flat])
    
    current_flowfield = np.hstack((x.reshape(-1,1), y.reshape(-1,1), z_vect, u.reshape(-1,1), v.reshape(-1,1), w_vect))
    flowfield_array = np.dstack([flowfield_array, current_flowfield])
        
    tools.save("./"+str(slice_num-1).zfill(4)+".csv", x, y, u, v, fmt='%.4e', delimiter=',')

#Load a background image and plot the quiver plot over it
#image = plt.imread('Flow_direction.tif')
#plt.imshow(image, cmap='gray', origin='lower')

#Save the figure as a PNG (makes a weird .tif.png but works)
plt.savefig(output_dir)
plt.show()

print("All slices completed.")
