from openpiv import tools, pyprocess, validation, filters, scaling
import numpy as np
import pandas as pd
from scipy.interpolate import griddata
from matplotlib import pyplot as plt
from matplotlib.colors import hsv_to_rgb
from tqdm import tqdm
from skimage import feature
from skimage.draw import line

def get_noise_threshold(sig2noise, show_hist):
    if show_hist == True:
        plt.hist(sig2noise.flatten(), bins=100)
    hist_array = np.histogram(sig2noise.flatten(), bins=100)
    zero_indices = np.where(hist_array[0] == 0)
    last_zero = zero_indices[0][-1]
    noise_threshold = np.around(hist_array[1][last_zero], decimals=2)*1.01
    return noise_threshold

def reconstruct_rod(r0, flowfield_array, dz, method):
    current_pos = np.asarray(r0)
    pos_3D_vector = current_pos
    displacement_vector = np.array([]).reshape(0,3)

    for i in range(0,flowfield_array.shape[2]):
        x, y, z = current_pos[0], current_pos[1], current_pos[2]
        
        # flowfield array x,y,z,u,v,w
        points = flowfield_array[:, 0:2, i]  # x,y coord grid for interpolating from
        values = flowfield_array[:, 3:5, i] # u,v for interpolating from
                
        displacement_rate = griddata(points, values, (x,y), method=method)
        x,y = (x,y) + displacement_rate*dz
        z = z + dz
        
        current_pos = [x,y,z] # position of next point
        current_disp = [displacement_rate[0],displacement_rate[1],dz]
        
        pos_3D_vector = np.vstack((pos_3D_vector,current_pos))
        # displacement_vector = np.vstack((displacement_vector, current_disp))
        
    # displacement_vector = np.vstack((displacement_vector, np.zeros(np.asarray(current_disp).shape))) # append zeros at end because no displacement vectors
    
    return pos_3D_vector, displacement_vector

def generate_flowfield(frame_a, frame_b, params):
    u0, v0, sig2noise = pyprocess.extended_search_area_piv(
        frame_a.astype(np.uint32),
        frame_b.astype(np.uint32),
        window_size=params["winsize"],
        overlap=params["overlap"],
        dt=params["dz"],
        search_area_size=params["searchsize"],
        sig2noise_method='peak2peak')
    
    x, y = pyprocess.get_coordinates(
        image_size=frame_a.shape,
        search_area_size=params["searchsize"],
        overlap=params["overlap"])
    
    # if params["noise_threshold"] == 0.0:  ## if 0.0 then automatically get threshold
    #     noise_threshold = get_noise_threshold(sig2noise, False)
    #     print("Threshold: " + str(noise_threshold))
    # else:
    #     get_noise_threshold(sig2noise, params["show_hist"])
    #     noise_threshold = params["noise_threshold"]
    
    noise_threshold = params["noise_threshold"]
    
    invalid_mask = validation.sig2noise_val(
        sig2noise,
        threshold = noise_threshold)
    
    u2, v2 = filters.replace_outliers(
        u0, v0,
        invalid_mask,
        method='localmean',
        max_iter=params["max_iter"],
        kernel_size=params["kernel_size"])
    
    x, y, u3, v3 = scaling.uniform(
        x, y, u2, v2,
        scaling_factor = params["scaling_factor"])  # 1 pixels/0.35 um
    x, y, u3, v3 = tools.transform_coordinates(x, y, u3, v3) # 0,0 shall be bottom left, positive rotation rate is counterclockwise
    
    return x, y, u3, v3, invalid_mask


def get_boundaries(flowfield_array):
    xmin, xmax = flowfield_array[:,0,0].min(), flowfield_array[:,0,0].max()
    ymin, ymax = flowfield_array[:,1,0].min(), flowfield_array[:,1,0].max()
    zmin, zmax = flowfield_array[:,2,0].min(), flowfield_array[:,2,0].max()
    return (xmin, xmax), (ymin, ymax), (zmin, zmax)

def grid_reconstruct(grid_bounds, flowfield_array, params):
    gridnums = params["gridnums"]
    xrange, yrange = grid_bounds[0], grid_bounds[1]
    z0 = grid_bounds[2][0]
    
    ## Generate grid points
    x0_pts = np.linspace(xrange[0],xrange[1],gridnums[0]+2)  ## can't use first and last points because on edge
    y0_pts = np.linspace(yrange[0],yrange[1],gridnums[1]+2)

    rod_pathlines_array = []
    displacement_array = []
    
    # now cycle through each rod at (x0, y0)
    with tqdm(total=gridnums[0]*gridnums[1]) as pbar:
        for x0 in x0_pts[1:-1]:
            for y0 in y0_pts[1:-1]:
                pos_3D_vector, _ = reconstruct_rod((x0, y0, z0), flowfield_array, params["dz"], params["interp_method"])
                rod_pathlines_array.append(pos_3D_vector)
#                 displacement_array.append(displacement_vector)
                # displacement_array.append(np.zeros(displacement_vector.shape)) 
                pbar.update(1)

    print("All points completed.")
    return rod_pathlines_array, displacement_array

######################## For textural analysis ########################

def dominant_direction(img):
    axx, axy, ayy = feature.structure_tensor(img, order ='xy', mode="reflect")
    dom_ori = np.arctan2(2*axy.mean(), (ayy.mean() - axx.mean())) / 2
    return dom_ori

def orientation_analysis(img, sigma):
    """OrientationJ's output for
    * orientation
    * coherence
    * energy
    """
    eps = 1e-20

    axx, axy, ayy = feature.structure_tensor(
        img.astype(numpy.float32), sigma=sigma, mode="reflect"
    )
    l1, l2 = feature.structure_tensor_eigvals(axx, axy, ayy)
    ori = numpy.arctan2(2 * axy, (ayy - axx)) / 2

    coh = ((l2 - l1) / (l2 + l1 + eps)) ** 2
    ene = numpy.sqrt(axx + ayy)
    ene /= ene.max()

    return ori, coh, ene

def generate_meas_lines(band_img, n_px_space):
    th = dominant_direction(band_img)
    print("Dominant direction is {:.2f} degrees.".format(th*180/np.pi))
    
    x_max, y_max = band_img.shape[1]-1, band_img.shape[0]-1
    x_start, y_start = 0, 0
    y_end = y_max
    x_end = int(x_start + (y_end-y_start)*np.tan(th))
    
    pixel_data_array = []

    for x_count in range(x_start, int(x_max-x_end), n_px_space):
        rr, cc = line(y_start, x_start+x_count, y_end, x_end+x_count)
        length = np.sqrt(rr**2 + cc**2)
        pixel_data_array.append(np.hstack((length.reshape(-1,1),band_img[rr, cc].reshape(-1,1))))

    return np.asarray(pixel_data_array)


def calculate_bandwidths(pixel_data_array, min_rod_diam):
    width_data_array = []

    for line_index in range(0,pixel_data_array.shape[0]):
        length = pixel_data_array[line_index,:,0]
        gv_shift = (pixel_data_array[line_index,:,1] - np.max(pixel_data_array[line_index,:,1])//2) > 0
        changes = gv_shift[0:-1]^gv_shift[1:]
        change_index = np.where(changes==True)[0]
        widths = np.diff(length[change_index])

        width_data_array.append(widths)
    
    width_data_array = np.asarray([i for sub in width_data_array for i in sub]) ## unwrap
    width_data_array = width_data_array[width_data_array > min_rod_diam]

    return width_data_array


######################## For color coding ########################
def vector_to_rgb(angle, absolute, max_abs):
    """Get the rgb value for the given `angle` and the `absolute` value

    Parameters
    ----------
    angle : float
        The angle in radians
    absolute : float
        The absolute value of the gradient
    
    Returns
    -------
    array_like
        The rgb value as a tuple with values [0..1]
    """

    # normalize angle
    angle = -angle + np.pi/2  #shift this to match imagej
    angle = angle%(2*np.pi)
    if angle < 0:
        angle += 2*np.pi

    return hsv_to_rgb((angle/(2*np.pi), absolute/max_abs, absolute/max_abs))

def generate_color_wheel():
    ax = plt.subplot(236, projection='polar')

    n = 200
    t = np.linspace(0, 2*np.pi, n)
    r = np.linspace(0, 1.0, n)
    rg, tg = np.meshgrid(r, t)

    c = np.array(list(map(vector_to_rgb, tg.T.flatten(), (rg.T.flatten()), np.ones_like(rg.T.flatten()))))

    cv = c.reshape((n, n, 3))

    m = ax.pcolormesh(t, r, cv[:,:,1], color=c, shading='auto')
    m.set_array(None)
    ax.set_yticklabels([])

    plt.show()
    
    return