import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from skimage import io
from tqdm import tqdm
import pathlib
from pathlib import Path
from PIL import Image

## Custom module
from enamel_rod_recon import *

#img_filepath = "PIVframes/Lion1_Mcerv3-etcher30s_(45-46).tif"

image_directory = "PIVframes"
output_directory = "PIVframes/Quiverplots"

#scan through files in PIV frames for tiff
for filename in os.listdir(image_directory):
    if filename.lower().endswith(('.tif')):
        img_filepath = os.path.join(image_directory, filename)
        img_path = Path(img_filepath)
        base_name = img_path.name
        outputname = "quiver"+str(base_name)+'.png'
        output_dir = os.path.join(output_directory, outputname)

        try:

            img = io.imread(img_filepath)  # Z, Y , X
            img.shape

            frame_a = img[0,0:img.shape[2],0:img.shape[2]]
            frame_b = img[1,0:img.shape[2],0:img.shape[2]]

            winsize = 32
            voxel_size = 1 ## 0.35e-6 m
            spacing = 14
            gridsize = (100,100)
           

            params = {
                "noise_threshold": 1.03,  ## usually 1.0 is sufficient, set to 0.0 for auto (to be fixed later)
                "show_hist": True,
                "max_iter": 20,
                "kernel_size": 20,

                "winsize": winsize, # pixels, interrogation window size in frame A
                "dz": voxel_size*spacing, # sec, time interval between the two frames
                "searchsize": winsize,  # pixels, search area size in frame B
                "overlap": winsize//2, # pixels, 50% overlap
                "scaling_factor": 1/voxel_size,
                
                # Reconstruction params
                "interp_method": 'linear',  ## nearest, linear, or cubic (will be slow)
                "gridnums": gridsize,  ## no. of grid points in (x, y) directions
    
             ## Reconstruction params
                "interp_method": 'linear',  ## nearest, linear, or cubic (will be slow)
            }

            x, y, u, v, _  = generate_flowfield(frame_a, frame_b, params)


            plt.figure(figsize=(5,5), num=base_name)
            plt.quiver(x, y, u, v, scale=1.5, color='black')
            ax = plt.gca()
            ax.set_aspect('equal', adjustable='box')
            ax.set_axis_off()

            flowfield_array = np.array([]).reshape(x.reshape(-1,1).shape[0],6,0)  ## initialize empty array
            for slice_num in tqdm(range(1,img.shape[0],spacing)):
                
                frame_a = img[slice_num-1,0:img.shape[2],0:img.shape[2]]
                frame_b = img[slice_num,0:img.shape[2],0:img.shape[2]]
                x, y, u, v, invalid_mask = generate_flowfield(frame_a, frame_b, params)
                z_vect = (slice_num-1)*voxel_size*np.ones(x.reshape(-1,1).shape)
                w_vect = params["dz"]*np.ones(x.reshape(-1,1).shape)
                
                #current_flowfield_flat = np.hstack((x.reshape(-1,1), y.reshape(-1,1), z, u.reshape(-1,1), v.reshape(-1,1), w))
            # flowfield_flat = np.vstack([flowfield_flat, current_flowfield_flat])
                
                current_flowfield = np.hstack((x.reshape(-1,1), y.reshape(-1,1), z_vect, u.reshape(-1,1), v.reshape(-1,1), w_vect))
                flowfield_array = np.dstack([flowfield_array, current_flowfield])
                    
                tools.save("./"+str(slice_num-1).zfill(4)+".csv", x, y, u, v, fmt='%.4e', delimiter=',')

         #Load a background image and plot the quiver plot over it
            #image = plt.imread('Flow_direction.tif')
            #plt.imshow(image, cmap='gray', origin='lower')

        #Save the figure as a PNG (makes a weird .tif.png but works)
            plt.savefig(output_dir)
            plt.show()

            print("All slices completed.")
        except:
            print("Something didn't work right, try the noloop for proper error")