#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from PIL import Image                    #Run 'module load pillow' (on INL's HPC), and make sure to load Python version 3.5
import numpy as np
import math
from pyflann import *                    #Follow the instructions in the 'Solid Texture Synthesis Code User Guide.docx'
                                         #to obtain the FLANN library from GitHub, compile it, and use it on Falcon via an interactive session
from random import random
from matplotlib import pyplot as plt
import matplotlib.image as mpimg		
import timeit
import inputs
import time
import os
import os.path
from os.path import *
from pathlib import Path
import itertools

from scipy import *
from scipy.sparse import *
from scipy.sparse.linalg import spsolve
from numpy.linalg import solve, norm
from numpy.random import random
from numpy import array

from multiprocessing import Process, Pipe, JoinableQueue
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
import concurrent.futures
	
#Note the following format to get the pixel value at the (x,y) position of a plane with the origin defined at the top left corner of that plane
#Conceptually/visually, the x coordinate is horizontal and y coordinate is vertical, as usual
#When using the PIL library: 
#value = image.getpixel((x,y))
#pix = image.load() -----> value = pix[x,y]
#When using matplotlib, the order is reversed: pix[y][x]
#For all other arrays, the indices are also reversed. Ex: B = get_binary(image) -----> value = B[y][x]
	
#*****FUNCTIONS****************************************************************************************************************************************************
	
#This function crops the exemplar image so that it is square and has dimensions compatible with the number of downsamples to be performed
def resize_exemplar(I, num_levels):

	length_x = I.size[0]
	length_y = I.size[1]

	if length_x != length_y:
		if length_x < length_y:
			I = I.crop((0,0,length_x,length_x))
		else:
			I = I.crop((0,0,length_y,length_y))
		
	exemplar_size = I.size[0]                #Side length of the exemplar, which should also equal I.size[1]
	Q = 2**(num_levels-1)                    #Q is the side length of a pixel in the lowest-resolution version of the 
											 #exemplar in terms of pixels in the highest-resolution (original) version
	nf = math.floor(exemplar_size/Q)         #math.floor() returns a float 
	n = int(nf)                              #so it must be converted to an integer
	L = n*Q                                  #L the largest multiple of Q = 2^(num_levels-1) that is less than exemplar_size	

	#Resize the image (of side length exemplar_size) to have a side length of L by cropping I
	I = I.crop((0,0,L,L)) 
	
	return I;	
	
#This function takes a black-and-white image of RGB values and returns a binary matrix of 0s (black) and 1s (white)
def split_images(I, plane, phases):
	Two_Phase = []
	#phases = np.unique(I)
	for y in range(I.size[1]):
		c = []
		for x in range(I.size[0]):
			val = I.getpixel((x,y))
			if val == phases[0] or val == (phases[0],phases[0],phases[0]):
				c.append(0)
			elif val == phases[1] or val == (phases[1],phases[1],phases[1]):
				c.append(1)
			else:
				c.append(1)
			#	rnd = np.random.random()
			#	if rnd < 0.5:
			#		c.append(0)
			#	else:
			#		c.append(1)
		Two_Phase.append(c)
	return Two_Phase
			#for n in range(len(phase_vals)):
			#	if val == phase_vals[n] or val == (phase_vals[n],phase_vals[n],phase_vals[n])
			#		c.append(n)

def get_binary(I):
	
	Bin = []

	for y in range(I.size[1]):
		c = []
		for x in range(I.size[0]):
			val = I.getpixel((x,y))
			#The image information may be different depending on the bit depth (24 bit vs. 32 vit)
			#val = val[0]
			if val == 0 or val == (0,0,0):   
				c.append(0)
			else:
				c.append(1)
		Bin.append(c)
	return Bin	
	
#This function builds Exemplars[] (either binary or trinary)
def build_exemplars(I, plane, phases):#, val1, val2, phase):

	tic = timeit.default_timer()

	Exemplars_this_plane = []

	if not multiphase:
		B1 = get_binary(I)                   	     #Binary matrix representing the cropped original exemplar to be used as the main input to the rest of the program				
	else:
		B1 = split_images(I, plane, phases)
					
	full_exemplar_size = len(B1)		
	print('The shape of B1')
	print(np.shape(B1))									 
	print(full_exemplar_size)													
	#If the exemplar is not periodic, set all boundary pixels to a value of 3 (which will correspond to the color blue)
	if is_exemplar_periodic == False:
		for y1 in range(full_exemplar_size):
			for x1 in range(full_exemplar_size):
				if y1 == 0 or y1 == full_exemplar_size-1 or x1 == 0 or x1 == full_exemplar_size-1:
					B1[y1][x1] = 3
	
	Exemplars_this_plane.append(B1)              #Place B1 into Exemplars_this_plane[] as the first element		
				
	if save_exemplars:
		im1 = Image.new("RGB", (full_exemplar_size, full_exemplar_size))
		pix1 = im1.load()                          
		for y1 in range(full_exemplar_size):
			for x1 in range(full_exemplar_size):
				val = B1[y1][x1]
				if val == 0:
					pix1[x1,y1] = (0,0,0)           #black
				elif val == 1:
					pix1[x1,y1] = (255,255,255)     #white
				elif val == 2:
					pix1[x1,y1] = (255,0,0)         #red
				elif val == 3:
					pix1[x1,y1] = (0,0,255)         #blue
				else:
					pix1[x1,y1] = (0,255,0)         #green
		if multiphase:
			im1.save('exemplar_%s_1_%s_%s.png' % (plane,phases[0],phases[1]))
		else:
			im1.save('exemplar_%s_1.png' % plane)
	
	toc = timeit.default_timer()
	
	if not multiphase:
		print('Finished building exemplar: plane %s of %s, level %s of %s. Time elapsed: %0.1f s.' % (plane+1,3,1,num_levels,(toc-tic)))
	else:
		print('Finished building exemplar: plane %s of %s, level %s of %s, phases %s and %s. Time elapsed: %0.1f s.' % (plane+1,3,1,num_levels,phases[0],phases[1],(toc-tic)))
			
	if downsample_mode == 'binary':									 
												 
		#Create a binary matrix to append to Exemplars[]
		#where Exemplars[k] = binary exemplar at resolution level k (highest resolution is at k = 1)										 
												 
		for k in range(2,num_levels+1):  
			
			exemplar_size = int(full_exemplar_size/2**(k-1))
			
			np.random.seed()
			
			B1_size = int(len(B1))               #which should also equal len(B1[0])
			B2_size = int(B1_size/2)
			B2 = [[0 for i in range(B2_size)] for j in range(B2_size)]
			
			#Step through the exemplar B1 and downsample it by averaging over regions of size B2_size
		   
			for y2 in range(B2_size):            #x2 is the x-position of the pixel in B2 
				for x2 in range(B2_size):        #y2 is the y-position of the pixel in B2
					y1 = 2*y2                    #x1 is the x-position of the pixel in B1 
					x1 = 2*x2                    #y1 is the y-position of the pixel in B1 
					
					avg = (B1[y1][x1]+B1[y1+1][x1]+B1[y1][x1+1]+B1[y1+1][x1+1])/4
					
					if avg < 0.5:
						B2[y2][x2] = 0
					elif avg > 0.5:
						B2[y2][x2] = 1
					else:
						rnd = np.random.random()
						if (rnd > 0.5):
							B2[y2][x2] = 0
						else:
							B2[y2][x2] = 1
			
			#If the exemplar is not periodic, set all boundary pixels to a value of 3 (which will correspond to the color blue)
			if is_exemplar_periodic == False:
				for y2 in range(exemplar_size):
					for x2 in range(exemplar_size):
						if y2 == 0 or y2 == exemplar_size-1 or x2 == 0 or x2 == exemplar_size-1:
							B2[y2][x2] = 3
			
			if save_exemplars:
				im2 = Image.new("RGB", (B2_size, B2_size))  #im2 is the image of the downsampled exemplar
				pix2 = im2.load()                           #pix2 is the matrix of RGB values for im2
				for y2 in range(im2.size[1]):
					for x2 in range(im2.size[0]):
						val = B2[y2][x2]
						if val == 0:
							pix2[x2,y2] = (0,0,0)           #black
						elif val == 1:
							pix2[x2,y2] = (255,255,255)     #white
						elif val == 2:
							pix2[x2,y2] = (255,0,0)         #red
						elif val == 3:
							pix2[x2,y2] = (0,0,255)         #blue
						else:
							pix2[x2,y2] = (0,255,0)         #green

				if not multiphase:
					im2.save('exemplar_%s_%s.png' % (plane,k))  #im2 is the generic name for the downsampled exemplar
															#in this for loop (which iterates num_levels-1 times)
															#so 'exemplar_k.png' specifically identifies the kth downsample
				else:
					im2.save('exemplar_%s_%s_phases_%s_%s.png' % (plane,k,phases[0],phases[1]))
			Exemplars_this_plane.append(B2)                 #Append B2 to Exemplars_this_plane[]
			
			B1 = B2                                         #Set B1 to B2 so that the downsampled exemplar that we just created	

			toc = timeit.default_timer()
			
			if not multiphase:
				print('Finished building exemplar: plane %s of %s, level %s of %s. Time elapsed: %0.1f s.' % (plane+1,3,k,num_levels,(toc-tic)))
			else:
				print('Finished building exemplar: plane %s of %s, level %s of %s, phases %s and %s. Time elapsed: %0.1f s.' % (plane+1,3,k,num_levels,phases[0],phases[1],(toc-tic)))
			
	elif downsample_mode == 'trinary':		
				
		#create a trinary matrix to append to Exemplars_this_plane[]
		#where Exemplars_this_plane[k] = trinary exemplar at resolution level k (highest resolution is at k = 1)
		#Note that k = 1 has already been appended to Exemplars_this_plane[]		
				
		for k in range(2,num_levels+1):   
			
			exemplar_size = full_exemplar_size/2**(k-1)
			
			T1 = Exemplars_this_plane[k-2]       #The trinary matrix constructed previously
			T1_size = int(len(T1))               #which should also equal len(T1[0])
			T2_size = int(T1_size/2)             #T2 is the trinary matrix downsampled from T1      
			T2 = [[0 for i in range(T2_size)] for j in range(T2_size)]
			
			#Step through the exemplar and perform the trinary downsample
		   
			for yT2 in range(T2_size):           #xT2 is the x-position of the pixel in T2
				for xT2 in range(T2_size):       #yT2 is the y-position of the pixel in T2
					yT1 = 2*yT2                  #xT1 is the x-position of the pixel in T1 
					xT1 = 2*xT2                  #yT1 is the y-position of the pixel in T1
					
					if T1[yT1][xT1] == T1[yT1+1][xT1+1] and T1[yT1][xT1] == T1[yT1+1][xT1] and T1[yT1][xT1] == T1[yT1][xT1+1]:
						T2[yT2][xT2] = T1[yT1][xT1]
					else:
						T2[yT2][xT2] = 2
				
			#If the exemplar is not periodic, set all boundary pixels to a value of 3 (which will correspond to the color blue)
			if is_exemplar_periodic == False:
				for y2 in range(exemplar_size):
					for x2 in range(exemplar_size):
						if y2 == 0 or y2 == exemplar_size-1 or x2 == 0 or x2 == exemplar_size-1:
							T2[y2][x2] = 3				
				
			if save_exemplars:
				imT2 = Image.new("RGB", (T2_size,T2_size))  #imT2 is the image of the downsampled exemplar
				pixT2 = imT2.load()                         #pixT2 is the matrix of RGB values for imT2
				for yT2 in range(imT2.size[1]):
					for xT2 in range(imT2.size[0]):
						val = T2[yT2][xT2]
						if val == 0:
							pixT2[xT2,yT2] = (0,0,0)        #black
						elif val == 1:
							pixT2[xT2,yT2] = (255,255,255)  #white
						elif val == 2:
							pixT2[xT2,yT2] = (255,0,0)      #red
						elif val == 3:
							pixT2[xT2,yT2] = (0,0,255)      #blue
						else:
							pixT2[xT2,yT2] = (0,255,0)      #green

				imT2.save('exemplar_%s_%s.png' % (plane,k)) #imT2 is the generic name for the downsampled exemplar
															#in this for loop (which iterates num_levels-1 times)
															#so 'exemplar_k.png' specifically identifies the kth trinary downsample										
			Exemplars_this_plane.append(T2)                 #Append T2 to Exemplars_this_plane[]
			
			toc = timeit.default_timer()
			
			print('Finished building exemplar: plane %s of %s, level %s of %s. Time elapsed: %0.1f s.' % (plane+1,3,k,num_levels,(toc-tic)))
			
	Exemplars.append(Exemplars_this_plane)
	Exemplar_set_single.append(Exemplars_this_plane)
			
#This function builds Neighborhoods[] (either binary or trinary)
def build_neighborhoods(plane,phase):

	tic = timeit.default_timer()

	Exemplar_Neighborhoods_plane = []

	for k in range(1, num_levels+1):
	
		kk = num_levels - k
	
		NB_size = NB_sizes[kk]
	
		#This will contain all of the neighborhoods of Exemplars[k], the exemplar at this resolution level
		Neighborhoods_k = [] 

		#Cycle through all neighborhoods of Exemplars[k] (a matrix of 0s, 1s, and 2s), turn them into 1D lists, and append each 
		#one to Neighborhoods_k

		#group = [[] for x in range(num_levels)]
		#for x in range(len(group)):
		#	group[x] = num_levels*x+(k-1)
		#print(group)
		if is_exemplar_periodic:
		
			n_first = 0                             
			
			if multiphase:
				n_last = len(Exemplar_set[plane][phase][k-1])-1
			else:
				n_last = len(Exemplars[plane][k-1])-1
																	 
			offset = int((NB_size-1)/2)                               #number of steps from the central pixel of the neighborhood 
																	  #to the edge of the neighborhood	

			for y in range(n_first, n_last+1):
				for x in range(n_first, n_last+1):
					Neighborhood = []
					for y_step in range(-offset, offset+1):
						for x_step in range(-offset, offset+1):
							x_pos = x + x_step
							if x_pos < 0:
								x_pos = x_pos + n_last + 1
							elif x_pos > n_last:
								x_pos = x_pos - n_last - 1
							y_pos = y + y_step
							if y_pos < 0:
								y_pos = y_pos + n_last + 1
							elif y_pos > n_last:
								y_pos = y_pos - n_last - 1
							if multiphase:
								value = Exemplar_set[plane][phase][k-1][y_pos][x_pos]
							else:
								value = Exemplars[plane][k-1][y_pos][x_pos]
							Neighborhood.append(value)
					Neighborhoods_k.append(Neighborhood)

			Exemplar_Neighborhoods_plane.append(Neighborhoods_k)            #Append Neighborhoods_k to Neighborhoods	

			#print(np.shape(Exemplar_Neighborhoods_plane))
	
		else:
		
			n_first = int((NB_size-1)/2)                              #smallest x- and y-positions of pixels whose neighborhoods do not
																	  #surpass the boundaries of the exemplar
			n_last = int(len(Exemplars[plane][k-1]) - n_first - 1)           #largest x- and y-positions of pixels whose neighborhoods do not
																	  #surpass the boundaries of the exemplar
			offset = n_first;                                         #number of steps from the central pixel of the neighborhood 
																	  #to the edge of the neighborhood

			for y in range(n_first, n_last+1):
				for x in range(n_first, n_last+1):
					Neighborhood = []
					for y_step in range(-offset, offset+1):
						for x_step in range(-offset, offset+1):
							value = Exemplars[plane][k-1][y + y_step][x + x_step]
							Neighborhood.append(value)
					Neighborhoods_k.append(Neighborhood)

			Exemplar_Neighborhoods_plane.append(Neighborhoods_k)            #Append Neighborhoods_k to Neighborhoods
			
	Exemplar_Neighborhoods.append(Exemplar_Neighborhoods_plane)
	
	Exemplar_Neighborhoods_single.append(Exemplar_Neighborhoods_plane)

	print(np.shape(Exemplar_Neighborhoods))
	print(np.shape(Exemplar_Neighborhoods[0][0]))

	toc = timeit.default_timer()

	print('Finished building neighborhoods in %0.1f s' %(toc-tic))
	
#This function returns a matrix of twice the size of the original, upsampled appropriately (e.g., one red pixel becomes 4 red pixels)
def upsample(old_matrix, dim):

	size = len(old_matrix)
	if dim == 2:
		new_matrix = [[0 for i in range(2*size)] for j in range(2*size)]
		for y in range(size):
			for x in range(size):
				old_value = old_matrix[y][x]
				new_matrix[2*y+0][2*x+0] = old_value
				new_matrix[2*y+0][2*x+1] = old_value
				new_matrix[2*y+1][2*x+0] = old_value
				new_matrix[2*y+1][2*x+1] = old_value
	elif dim == 3:
		new_matrix = np.zeros((2*size,2*size,2*size))
		for z in range(size):
			for y in range(size):
				for x in range(size):
					old_value = old_matrix[z][y][x]
					new_matrix[2*z+0][2*y+0][2*x+0] = old_value
					new_matrix[2*z+0][2*y+0][2*x+1] = old_value
					new_matrix[2*z+0][2*y+1][2*x+0] = old_value
					new_matrix[2*z+0][2*y+1][2*x+1] = old_value
					new_matrix[2*z+1][2*y+0][2*x+0] = old_value
					new_matrix[2*z+1][2*y+0][2*x+1] = old_value
					new_matrix[2*z+1][2*y+1][2*x+0] = old_value
					new_matrix[2*z+1][2*y+1][2*x+1] = old_value
	return new_matrix


#This function performs a 2D --> 2D reconstruction of the exemplar, anchored by the pixels given in the matrix anchored_values[] (of size full_recon_size)
def perform_3D_reconstruction(section_number, hist_queue, exemplar_queue, partial_queue, partial_locked_queue, pipes):

	#B1 = get_binary(im1)
	#full_exemplar_size = int(len(B1))
	
    
	start = int(section_number*section_size_large)         #This value refers to the y-location of recon[] where this particular section starts
        
	if section_number == num_sections-1:
		end = recon_size-1                             #This value refers to the y-location of recon[] where this particular section ends
	else:
		end = int((section_number+1)*section_size_large-1)
    
	if section_number == num_sections-1:
		section_size = section_size_small              #The size (specifically, in the y-dimension) of this section
	else:
		section_size = section_size_large

    #recon_this_section[] is a sub-section of recon[] of dimensions section_size*recon_size*recon_size
	recon_this_section = recon[:,start:end+1, :]

	weights_in_recon_this_section_padded = np.zeros((2*offset+recon_size, 2*offset+section_size, 2*offset+recon_size))
        
    #analogous sub-section of is_locked[]
	is_locked_this_section = is_locked[:,start:end+1, :]
        
    #There is a maximum number of list elements that can be passed between processors. A list may need to be chopped into N pieces and sent individually.
    #This will be done by sending (N-1) lists of length large_division, and then 1 of length small_division
	large_division = math.ceil((2*offset+recon_size)*(recon_size+2*offset)*offset/num_passes)
	small_division = (2*offset+recon_size)*(recon_size+2*offset)*offset - large_division*(num_passes-1)
        
	histogram_this_section = np.zeros((3, exemplar_size, exemplar_size))

	exemplar_histogram = np.zeros((3, exemplar_size, exemplar_size))

	#If the user wants the output of the reconstruction to be displayed on the screen
		

	global_error_list = []
	
	#Step through the number of reconstruction iterations
	for m in range(1,num_iterations+1):
		
		tic_iteration = timeit.default_timer()
		
		tic_0 = timeit.default_timer()
		
		recon_this_section_padded = np.pad(recon_this_section, ((offset,offset),(offset,offset),(offset,offset)), mode='wrap')
		
		steps = num_iterations*kk + m
		total_steps = num_levels*num_iterations
		percent_complete = steps / total_steps * 100
		#print('Here 1')
		#Ever steps_between_randomizations iterations, randomize fraction_to_randomize of the unlocked pixels
		if do_randomize:

			total_counts = []
			all_white_counts = []
			all_black_counts = []

			if m % steps_between_randomizations == 0 and m < num_iterations:

				black_count = 0
				white_count = 0
				
				for plane in range(3):
					for y1 in range(exemplar_size):
						for x1 in range(exemplar_size):
							if not multiphase:
								if Exemplars[plane][k-1][y1][x1] == 0:
									black_count += 1
								elif Exemplars[plane][k-1][y1][x1] == 1:
									white_count += 1
	
							else:
								if Exemplar_set[plane][phase][k-1][y1][x1] == 0:
									black_count += 1
								elif Exemplar_set[plane][phase][k-1][y1][x1] == 1:
									white_count += 1

					all_white_counts.append(white_count)
					all_black_counts.append(black_count)
					black_count = 0
					white_count = 0
					total_counts.append(all_white_counts[plane] + all_black_counts[plane])
				#print('counts shapes')
				print(np.shape(total_counts))
				print(np.shape(all_white_counts))
				print(np.shape(all_black_counts))

				for z in range(recon_size):
					for y in range(section_size):
						for x in range(recon_size):	
							if is_locked[z][y][x] == 0:
								change = np.random.random()
								if change < fraction_to_randomize:
									rnd = np.random.random()
									skewed_fraction = all_black_counts[kk]/total_counts[kk]
									if k == 1 or downsample_mode == 'binary':	
										if rnd < skewed_fraction:
											recon_this_section_padded[z][y][x] = 0
										else:
											recon_this_section_padded[z][y][x] = 1
									else:
										if rnd < 0.33:
											recon_this_section_padded[z][y][x] = 0
										elif rnd >= 0.66:
											recon_this_section_padded[z][y][x] = 1
										else:
											recon_this_section_padded[z][y][x] = 2
			#print('--------------------------------------')
			#print(recon[z])	
				tic_a = timeit.default_timer()	
					
		#recon_this_section_padded = np.pad(recon_this_section, offset, mode='wrap')
		#weights_in_recon_this_section = np.zeros((recon_size,section_size,recon_size))
		#weights_in_recon_padded = np.pad(weights_in_recon, offset, mode='wrap')
			#print('Here 2')
			for q in range(offset):	
				#Wrap in the x-direction
				recon_this_section_padded[:,:,(recon_size+offset+q)] =  recon_this_section_padded[:,:,(offset+q)]     
				recon_this_section_padded[:,:,(offset-q-1)] = recon_this_section_padded[:,:,(offset+recon_size-q-1)]
				#Wrap in the z-direction
				recon_this_section_padded[(recon_size+offset+q),:,:] =  recon_this_section_padded[(offset+q),:,:]     
				recon_this_section_padded[(offset-q-1),:,:] = recon_this_section_padded[(offset+recon_size-q-1),:,:] 

		tic_b = timeit.default_timer()
		

######### Exchange information between processes ##################################################################################################################################################
		if structure_type == 'cubic':
			tic0 = timeit.default_timer()
		
			upper = recon_this_section_padded[:,offset:2*offset, :].flatten()
			split_points = [i*large_division for i in range(1,num_passes)]
			upper_split = np.array_split(upper,split_points)
						
			lower = recon_this_section_padded[:,section_size:section_size+offset, :].flatten()
			lower_split = np.array_split(lower,split_points)	
			
			a1, a2 = pipes[section_number]
			b1, b2 = pipes[int((section_number+1)%num_sections)]
		
			temp_upper = []
			temp_lower = []
		
			for i in range(num_passes):
				a2.send(upper_split[i])
				b1.send(lower_split[i])
				temp_upper.append(a2.recv())
				temp_lower.append(b1.recv())
		
			temp_upper = np.array([j for i in temp_upper for j in i])
			temp_lower = np.array([j for i in temp_lower for j in i])
		
			new_upper = temp_upper.reshape((recon_size+2*offset, offset, recon_size+2*offset))
			new_lower = temp_lower.reshape((recon_size+2*offset, offset, recon_size+2*offset))
			
			recon_this_section_padded[:, 0:offset, :] = new_upper
			recon_this_section_padded[:, section_size+offset:section_size+2*offset, :] = new_lower			
			
			tic1 = timeit.default_timer()	



		#print('Here 3')

		for plane in range(3):
			filename = 'neighborhood_%s_section_%s' %(plane, section_number)
			fileObject = open(filename, 'wb')
			if plane == 0:
				Neighborhood1 = np.memmap(filename, shape=(section_size*(recon_size**2),NB_size**2))
			elif plane == 1:
				Neighborhood2 = np.memmap(filename, shape=(section_size*(recon_size**2),NB_size**2))
			elif plane == 2:
				Neighborhood3 = np.memmap(filename, shape=(section_size*(recon_size**2),NB_size**2))

		print(np.shape(Neighborhood1))
		print(np.shape(Neighborhood2))
		print(np.shape(Neighborhood3))


		counter1 = 0
		#Step through the pixels in the reconstruction
		for z in range(recon_size):
			for y in range(section_size):
				for x in range(recon_size):
					xp = x + offset
					yp = y + offset
					zp = z + offset
					#Create a list for the neighborhood centered at this pixel     	
					#Step through the pixels in this neighborhood
                    ######################################################## Code runs to issues at this step on 3 level #####################################
					Neighborhood1[counter1,:] = recon_this_section_padded[zp,(yp-offset):(yp+offset+1), (xp-offset):(xp+offset+1)].ravel() # 
					Neighborhood2[counter1,:] = recon_this_section_padded[(zp-offset):(zp+offset+1), yp, (xp-offset):(xp+offset+1)].ravel()
					Neighborhood3[counter1,:] = recon_this_section_padded[(zp-offset):(zp+offset+1), (yp-offset):(yp+offset+1), xp].ravel()
					counter1 += 1
		#print('----------------------------------')
		#print(recon_padded[z])		
		tic_c = timeit.default_timer()		
		print('Time to create neighborhood lists: %0.1f s' %(tic_c - tic_b))	
		#Perform the nearest neighbor search among the exemplar neighborhoods
		nn_index = []
		dists = []

		#filename = 'neighborhoods'
		#fileObject = open(filename, 'wb')
		#Neighborhoods = np.memmap(filename, shape=(3,section_size*(recon_size**2),NB_size**2))

		print(np.shape(Neighborhood1))
		#for plane in range(3):
		#	if plane == 0:
		#		Neighborhoods[0] = Neighborhood1
		#	elif plane == 1:
		#		Neighborhoods[1] = Neighborhood2
		#	elif plane == 2:
		#		Neighborhoods[2] = Neighborhood3
		Neighborhoods = np.stack((Neighborhood1, Neighborhood2, Neighborhood3))
		#print('Here 4')

		#filename = 'testset'
		#fileObject = open(filename, 'wb')
		#testset = np.memmap(filename, shape=(section_size*(recon_size**2),NB_size**2))

		print(np.shape(Neighborhoods))
		for plane in range(3):
			#new = Neighborhoods[plane]
			#print(np.shape(Neighborhoods[plane]))
			#Neighborhood = Neighborhoods[plane].reshape((recon_size**3,NB_size**2))
			#print(np.shape(Neighborhood))
			testset = Neighborhoods[plane].astype(np.int32)
			print(np.shape(testset))
			flann_this_plane = flann[plane]
			nn_index_temp, dists_temp = flann_this_plane.nn_index(testset, num_neighbors = neighbor_choices, checks=params[plane]["checks"])
			nn_index.append(nn_index_temp)
			dists.append(dists_temp)
					
		global_error = 0
		for j in range(recon_size*section_size*recon_size):
			for plane in range(3):
				global_error += dists[plane][j]/(recon_size**3*NB_size**2)
		global_error_list.append(global_error)
		
		histogram_this_section.fill(0)	
		#Update the exemplar histogram and weights histogram for this level and iteration
		#print(np.shape(histogram))
		for j in range(recon_size*section_size*recon_size):
			for plane in range(3):
				if neighbor_choices == 1:
					index_in_exemplar = nn_index[plane][j]
				else:
					nearest_index = np.array([abs(j-nn_index[plane][j][i]) for i in range(neighbor_choices)]).argmin()
					index_in_exemplar = nn_index[plane][j][nearest_index]
				if is_exemplar_periodic:
					x = index_in_exemplar % exemplar_size
					y = int((index_in_exemplar - x)/exemplar_size)
				else:
					a = int((NB_size-1)/2)
					b = exemplar_size - a - 1
					x = a + index_in_exemplar % (b-a+1)
					y = a + int((index_in_exemplar-x+a)/(b-a+1))
				histogram_this_section[plane][y][x] += 1
		#print('Here 5')
		for plane in range(3):
			hist_queue.put(histogram_this_section[plane])
		#hist_queue.join()
		for plane in range(3):
			#print('Here')
			exemplar_histogram[plane] = exemplar_queue.get()
			exemplar_queue.task_done()
		

		#keep_running = True
		#signal_queue.put(keep_running)
		#pipes_histogram[section_number][1].send(histogram_this_section)

		#pipes_histogram[section_number][1].close()

		for y in range(exemplar_size):
			for x in range(exemplar_size):
				for plane in range(3):
				#if use_histogram_reweighting:
					val = exemplar_histogram[plane][y][x]
					weights_in_exemplar[plane][y][x] /= (1.0 + max(0, val-(m-1)*recon_size**3*1.0/exemplar_size**2))
					
		tic_d = timeit.default_timer()	
		#print('Time to perform nearest neighborhood search and update weighted exemplar histogram: %0.1f s' %(tic_d - tic_c))
		#Zero all entries in recon_padded because now it will be used to hold votes	  
		recon_this_section_padded.fill(0)
		print(np.shape(recon_this_section))
		print(np.shape(weights_in_exemplar))	
		tic_d2 = timeit.default_timer()
		weights_in_recon_this_section_padded.fill(0.0)
		#Update the votes in this iteration of the reconstruction				
		for plane in range(3):
			for z in range(recon_size):			
				for y in range(section_size):
					for x in range(recon_size):
						index_in_recon = z*recon_size*section_size+y*recon_size+x
						if neighbor_choices == 1:
							index_in_dataset = nn_index[plane][index_in_recon]
						else:
							nearest_index = np.array([abs(index_in_recon-nn_index[plane][index_in_recon][i]) for i in range(neighbor_choices)]).argmin()
							index_in_dataset = nn_index[plane][index_in_recon][nearest_index]
						
						xp = x + offset
						yp = y + offset
						zp = z + offset
						if is_exemplar_periodic:
							x_in_exemplar = index_in_dataset % exemplar_size
							y_in_exemplar = int((index_in_dataset - x_in_exemplar)/exemplar_size)	
						else:
							a = int((NB_size-1)/2)
							b = exemplar_size - a - 1
							x_in_exemplar = a + index_in_dataset % (b-a+1)
							y_in_exemplar = a + int((index_in_dataset-x_in_exemplar+a)/(b-a+1))		
						nn = dataset[plane][index_in_dataset].reshape((NB_size,NB_size))
						weighted_nn = nn * weights_in_exemplar[plane][y_in_exemplar][x_in_exemplar]
						
						if plane == 0:
							recon_this_section_padded[zp,(yp-offset):(yp+offset+1),(xp-offset):(xp+offset+1)] += weighted_nn
							weights_in_recon_this_section_padded[zp,(yp-offset):(yp+offset+1),(xp-offset):(xp+offset+1)] += weights_in_exemplar[plane][y_in_exemplar][x_in_exemplar]

						if plane == 1:
							recon_this_section_padded[(zp-offset):(zp+offset+1), yp, (xp-offset):(xp+offset+1)] += weighted_nn
							weights_in_recon_this_section_padded[(zp-offset):(zp+offset+1), yp, (xp-offset):(xp+offset+1)] += weights_in_exemplar[plane][y_in_exemplar][x_in_exemplar]

						if plane == 2:
							recon_this_section_padded[(zp-offset):(zp+offset+1),(yp-offset):(yp+offset+1),xp] += weighted_nn
							weights_in_recon_this_section_padded[(zp-offset):(zp+offset+1),(yp-offset):(yp+offset+1), xp] += weights_in_exemplar[plane][y_in_exemplar][x_in_exemplar]
		

		tic_d3 = timeit.default_timer()
		print('Time to update votes: %0.1f s' %(tic_d3 - tic_d2))
		
		for q in range(offset):		
			#Wrap in x-direction
		
			recon_this_section_padded[:,:,(offset+q)] += recon_this_section_padded[:,:,(recon_size+offset+q)]
			recon_this_section_padded[:,:,(recon_size+offset+q)].fill(0)
			recon_this_section_padded[:,:,(offset+recon_size-q-1)] += recon_this_section_padded[:,:,(offset-q-1)]
			recon_this_section_padded[:,:,(offset-q-1)].fill(0)
		
			weights_in_recon_this_section_padded[:,:,(offset+q)] += weights_in_recon_this_section_padded[:,:,(recon_size+offset+q)] 
			weights_in_recon_this_section_padded[:,:,(recon_size+offset+q)].fill(0)
			weights_in_recon_this_section_padded[:,:,(offset+recon_size-q-1)] += weights_in_recon_this_section_padded[:,:,(offset-q-1)]
			weights_in_recon_this_section_padded[:,:,(offset-q-1)].fill(0)
			
			#Wrap in z-direction
			
			recon_this_section_padded[(offset+q),:,:] += recon_this_section_padded[(recon_size+offset+q),:,:]
			recon_this_section_padded[(recon_size+offset+q),:,:].fill(0)
			recon_this_section_padded[(offset+recon_size-q-1),:,:] += recon_this_section_padded[(offset-q-1),:,:]
			recon_this_section_padded[(offset-q-1),:,:].fill(0)
		
			weights_in_recon_this_section_padded[(offset+q),:,:] += weights_in_recon_this_section_padded[(recon_size+offset+q),:,:]
			weights_in_recon_this_section_padded[(recon_size+offset+q),:,:].fill(0)
			weights_in_recon_this_section_padded[(offset+recon_size-q-1),:,:] += weights_in_recon_this_section_padded[(offset-q-1),:,:]
			weights_in_recon_this_section_padded[(offset-q-1),:,:].fill(0)
		
		tic_e = timeit.default_timer()		

######### Exchange weighted vote sums between processes ###########################################################################################################################################
		if structure_type == 'cubic':
			#print('Got here')
			upper = recon_this_section_padded[:, 0:offset, :].flatten()
			recon_this_section_padded[:, 0:offset, :].fill(0)
			split_points = [i*large_division for i in range(1,num_passes)]
			upper_split = np.array_split(upper,split_points)
					
			lower = recon_this_section_padded[:, offset+section_size:section_size+2*offset, :].flatten()
			recon_this_section_padded[:, offset+section_size:section_size+2*offset, :].fill(0)
			lower_split = np.array_split(lower,split_points)
		
			temp_upper = []
			temp_lower = []
		
			for i in range(num_passes):
				a2.send(upper_split[i])
				b1.send(lower_split[i])
				temp_upper.append(a2.recv())
				temp_lower.append(b1.recv())
		
			temp_upper = np.array([j for i in temp_upper for j in i])
			temp_lower = np.array([j for i in temp_lower for j in i])
		
			new_upper = temp_upper.reshape((recon_size+2*offset, offset, recon_size+2*offset))
			new_lower = temp_lower.reshape((recon_size+2*offset, offset, recon_size+2*offset))
			
			recon_this_section_padded[:, offset:2*offset, :] += new_upper
			recon_this_section_padded[:, section_size:section_size+offset, :] += new_lower		
		
		######### Exchange summed weights between processes ###############################################################################################################################################
			#print('Got farther')	
			upper = weights_in_recon_this_section_padded[:, 0:offset, :].flatten()
			weights_in_recon_this_section_padded[:, 0:offset, :].fill(0)
			split_points = [i*large_division for i in range(1,num_passes)]
			upper_split = np.array_split(upper,split_points)
						
			lower = weights_in_recon_this_section_padded[:, offset+section_size:section_size+2*offset, :].flatten()
			weights_in_recon_this_section_padded[:, offset+section_size:section_size+2*offset, :].fill(0)
			lower_split = np.array_split(lower,split_points)
		
			temp_upper = []
			temp_lower = []
		
			for i in range(num_passes):
				a2.send(upper_split[i])
				b1.send(lower_split[i])
				temp_upper.append(a2.recv())
				temp_lower.append(b1.recv())
		
			temp_upper = np.array([j for i in temp_upper for j in i])
			temp_lower = np.array([j for i in temp_lower for j in i])
		
			new_upper = temp_upper.reshape((recon_size+2*offset, offset, recon_size+2*offset))
			new_lower = temp_lower.reshape((recon_size+2*offset, offset, recon_size+2*offset))
			
			weights_in_recon_this_section_padded[:, offset:2*offset, :] += new_upper
			weights_in_recon_this_section_padded[:, section_size:section_size+offset, :] += new_lower									
				
			tic9 = timeit.default_timer()		
				
		######### Tally votes #############################################################################################################################################################################
		#keep_running = True
		#signal_queue.put(keep_running)
		#print('Time to tally votes')
		recon_this_section_padded[offset:(offset+recon_size), offset:(offset+section_size), offset:(offset+recon_size)] /= weights_in_recon_this_section_padded[offset:(offset+recon_size), offset:(offset+section_size), offset:(offset+recon_size)]
		#print('--------------------------')
		#print(recon_padded[z])
		#print('------------------------------')
		#print(recon[z])
		for z in range(recon_size):
			for y in range(section_size):
				for x in range(recon_size):
					xp = x + offset
					yp = y + offset
					zp = z + offset
					avg_vote = recon_this_section_padded[zp,yp,xp]
					#print(avg_vote)
				#shade = math.floor(avg_vote*255)
				#pix_iteration[x,y] = (shade,shade,shade)
					if is_locked[z][y][x] == 0:
						if k == 1 or downsample_mode == 'binary':
							if avg_vote < 0.5:
								recon_this_section[z][y][x] = 0
							#pix_iteration[x,y] = (0,0,0) 
							elif avg_vote > 0.5:
								recon_this_section[z][y][x] = 1
							#pix_iteration[x,y] = (255,255,255) 
							else:
								rnd = np.random.random()
								if rnd < 0.5:
									recon_this_section[z][y][x] = 0
								#pix_iteration[x,y] = (0,0,0) 
								else:
									recon_this_section[z][y][x] = 1
								#pix_iteration[x,y] = (255,255,255) 								
						else:
							if avg_vote < 0.66:
								recon_this_section[z][y][x] = 0
							#pix_iteration[x,y] = (0,0,0) 
							elif avg_vote >= 1.33:
								recon_this_section[z][y][x] = 2
							#pix_iteration[x,y] = (255,0,0) 
							else:
								recon_this_section[z][y][x] = 1
							#pix_iteration[x,y] = (255,255,255)
		#print('------------------------------')
		#print(recon[z])

		if output_option == 1:
		#if show_all and output_option == 1:
			partial_recon = np.zeros((recon_size,recon_size,recon_size))
			partial_recon[:,start:end+1,:] += recon_this_section
			partial_queue.put(partial_recon)
			partial_queue.join()
		#if output_option == 1:
		#	error_queue.put(global_error)
		#	error_queue.join()

		#print('Got here too')
			#partial_queue.task_done()
			#pipes_partial_recon[section_number][1].send(partial_recon_section)
		#partial_recon = np.zeros((recon_size,recon_size,recon_size))
		#partial_recon[:,start:end+1,:] += recon_this_section
		

		#filename = 'partial_recon_section_%s' %section_number
		#fileObject = open(filename, 'wb')
		#partial_recon_section = np.memmap(filename, shape=(recon_size,recon_size,recon_size))
		#partial_recon_section[:,:,:] = partial_recon[:,:,:]

		#recon_array.append(partial_recon_section)
		#print(np.shape(recon_array))
		#filename = 'recon_array'
		#fileObject = open(filename, 'wb')
		#partial_recon_section = np.memmap(filename, shape=(recon_size,recon_size,recon_size))
		#partial_recon_section[:,:,:] = partial_recon[:,:,:]
		
		

		#pipes_partial_recon[section_number][1].close()

		

		#tic_f = timeit.default_timer()
		#print('Time to prepare the plots: %0.1f s' %(tic_f - tic_d3))
		'''print('0...a = %f' % (tic_a - tic_0))
		print('a...b = %f' % (tic_b - tic_a))
		print('b...c = %f' % (tic_c - tic_b))
		print('c...d = %f' % (tic_d - tic_c))
		print('d2...d3 = %f' % (tic_d3 - tic_d2))
		print('d3...e = %f' % (tic_e - tic_d3))
		print('e...f = %f' % (tic_f - tic_e))'''
		
		#compile_reconstruction_and_build_plots()

		#print('-------------------------------------')
		#midplane = int(math.floor(recon_size/2))
		#print(recon[midplane][y][x])
		#print('-------------------------------------')
		#print(recon[y][midplane][x])
		#print('-------------------------------------')
		#print(recon[y][x][midplane])

		
		toc_iteration = timeit.default_timer()
			
		print('Time for iteration number %s: %f' % (m, toc_iteration-tic_iteration))
		#times.append(toc_iteration-tic_iteration)
	#Update is_locked[] by changing each pixel which is not red (has value 2 in recon[]) to 1

	partial_recon = np.zeros((recon_size,recon_size,recon_size))
	partial_recon[:,start:end+1,:] += recon_this_section
	partial_queue.put(partial_recon)
	partial_queue.join()

	if use_lock:
		for z in range(recon_size):
			for y in range(section_size):
				for x in range(recon_size):
					if recon_this_section[z][y][x] != 2:
						is_locked_this_section[z][y][x] = 1
                                

	partial_is_locked = np.zeros((recon_size,recon_size,recon_size))
	partial_is_locked[:,start:end+1,:] += is_locked_this_section
	partial_locked_queue.put(partial_is_locked)
	partial_locked_queue.join()
	
	#print('The average time for each iteration was: %f' %np.mean(times, dtype=np.float64))
	#print('The maximum time was: %f' %max(times))
	#print('The minimum time was: %f' %min(times))
	#time.sleep(15)

#******************************************************************************************************************************************************************
	
Exemplars = []                           #Holds the exemplar at each resolution level (binary or trinary)
Exemplar_set = []
Exemplar_set_single = []
Exemplar_Neighborhoods = []              #Holds all neighborhoods of the exemplar (binary or trinary) at each resolution level
Exemplar_Neighborhoods_single = []
Exemplar_Neighborhoods_set = []
global_error_list = []
all_recons = []
image1 = inputs.image1
image2 = inputs.image2
image3 = inputs.image3
if type(image1) is not str or type(image2) is not str or type(image3) is not str:
	print('At least one of your images is an invalid entry')
	sys.exit()
if Path(image1).is_file() == False or Path(image2).is_file() == False or Path(image3).is_file() == False:
	print('At least one of your images does not exist in this directory')
	sys.exit()

structure_type = inputs.structure_type
if structure_type != 'cubic' and structure_type != 'plate' or type(structure_type) is not str:
	print('Your structure type is invalid. You must enter either "cubic" or "plate"')
	sys.exit()

multiphase = inputs.multiphase
if multiphase is not True and multiphase is not False:
	print('You must indicate whether the exemplar is to be reconstructed with multiple phases')
	sys.exit()

is_exemplar_periodic = inputs.is_exemplar_periodic              #Select True if the exemplar should be treated as having periodic boundary conditions
if is_exemplar_periodic is not True and is_exemplar_periodic is not False:
	print('You must choose whether the exemplar is periodic, using True or False')
	sys.exit()
		
downsample_mode = inputs.downsample_mode				 #The options are 'trinary' and 'binary'
if downsample_mode is not 'binary' and downsample_mode is not 'trinary':
	print('That is not a valid downsampling method')
	sys.exit()

save_exemplars = inputs.save_exemplars					 #Select True to save the exemplars at each resolution level as PNG images
if save_exemplars is not True and save_exemplars is not False:
	print('You must choose whether or not to save the exemplars and their downsampled versions, using True or False. I recommend True')
	sys.exit()

num_levels = inputs.num_levels #Number of downsamplings to perform (i.e. number of resolution levels)Level %s of %s \nIteration %s of %s \nProgress: %0.1f %%' % ((num_levels-k+1),num_levels,m,num_iterations,percent_complete))
print('num levels')
print(num_levels) 
if type(num_levels) is not int:
	print('Please use an integer for the number of resolution levels')
	sys.exit()

if num_levels > 8:
	print('You might want to use fewer levels. The quality of the reconstruction will not appreciably change with that many, and will take more time than it is worth')
	sys.exit()

NB_sizes = inputs.NB_sizes #Side length of a square neighborhood, from lowest resolution to highest resolution
print('NB_sizes')
print(NB_sizes)
for nb in range(len(NB_sizes)):
	if type(NB_sizes[nb]) is not int:
		print('Neighborhoods can only be of integer size')
		sys.exit()

Num_sections = inputs.Num_sections
print('Num sections')
print(Num_sections)
for ns in range(len(Num_sections)):
	if type(Num_sections[ns]) is not int:
		print('You can only use an integer number of sections')
		sys.exit()

neighbor_choices = inputs.neighbor_choices
print('neighbor choices')
print(neighbor_choices)
if type(neighbor_choices) is not int:
	print('You cannot choose a non-integer number of neighborhood choices')
	sys.exit()

full_recon_size = inputs.full_recon_size #Full reconstruction size, but will need to downsize it for each level. This must be a power of 2.
print('recon size')
print(full_recon_size)
if full_recon_size%2 != 0 or type(full_recon_size) is not int:
	print('The full reconstruction size must be an even integer')
	sys.exit()

num_iterations = inputs.num_iterations #Number of iterations of the reconstruction for each resolution level
print('iterations')
print(num_iterations)
if type(num_iterations) is not int:
	print('You can only have an integer number of iterations')
	sys.exit()

use_lock = inputs.use_lock                         #Set to True when non-red pixels should be locked at the end of each level
if use_lock is not True and use_lock is not False:
	print('You must choose whether to use the locking scheme, using True or False. I recommend False')
	sys.exit()

do_randomize = inputs.do_randomize
print('randomize?')
print(do_randomize)
if do_randomize is not True and do_randomize is not False:
	print('You must choose whether to include the randomization condition, using True or False. I recommend True')
	sys.exit()

steps_between_randomizations = inputs.steps_between_randomizations       #Every this many iterations...
print('steps between randoms')
print(steps_between_randomizations)
if type(steps_between_randomizations) is not int:
	print('You cannot randomize at non-integer steps')
	sys.exit()

fraction_to_randomize = inputs.fraction_to_randomize             #This fraction of the unlocked pixels are randomized
print('fracton to randomize')
print(fraction_to_randomize)
if fraction_to_randomize > 1.0:
	print('You cannot randomize more pixels than are in the reconstruction. Please choose a decimal between 0 and 1.0, inclusive')
	sys.exit()

steps_between_plotting = inputs.steps_between_plotting               #Only display/save the reconstruction on the screen/to file this many iterations, depending on the following option...
if type(steps_between_plotting) is not int:
	print('You cannot plot at non-integer steps')
	sys.exit()

output_option = inputs.output_option                        #1 = show output on screen, 2 = print output to PNG images, otherwise = no output
if output_option != 1 and output_option != 2 and type(output_option) is not int:
	print('You have not chosen a valid output option- you must pick 1 or 2. I recommend 1')
	sys.exit()

use_histogram_reweighting = inputs.use_histogram_reweighting         #Select True to use exemplar histogram reweighting to avoid lopsided sampling of the exemplar neighborhoods
if use_histogram_reweighting is not True and use_histogram_reweighting is not False:
	print('You must choose whether to use the histogram reweighting scheme, using True or False. I highly recommend True')
	sys.exit()

flann_precision = inputs.flann_precision  #1 = nearest neighbors are exact, not approximate. This will slow down the reconstruction.              
print('flann precision')
print(flann_precision)
if flann_precision > 1.0:
	print('You cannot use a precision greater than 1.0, since 1.0 means FLANN will search for exact neighborhood matches. Choose a decimal between 0 and 1.0, inclusive')
	sys.exit()
elif type(flann_precision) is not float:
	print('Please enter a decimal between 0 and 1.0, inclusive')
	sys.exit()

path = inputs.full_path
if type(path) is not str:
	print('The path you entered is not valid')
	sys.exit()
if Path(path).is_dir() == False:
	print('The path you entered does not exist')
	sys.exit()

dir_name = inputs.new_folder_name
if type(dir_name) is not str:
	print('The folder you entered is not valid')
	sys.exit()            	
	
#*****START THE RECONSTRUCTION*************************************************************************************************************************************
	
img1 = Image.open(inputs.image1).convert('L')     #Original, high-resolution black-and-white image of the two-phase microstructure
img2 = Image.open(inputs.image2).convert('L')
img3 = Image.open(inputs.image3).convert('L')


if structure_type == 'cubic':
	im1 = resize_exemplar(img1, num_levels)   #Crop the exemplar so that it is square and has dimensions compatible with the downsamplings	
	im2 = resize_exemplar(img2, num_levels)
	im3 = resize_exemplar(img3, num_levels)

	equal_size = min([im1.size[0], im1.size[1], im2.size[0], im2.size[1], im3.size[0], im3.size[1]])
	im1 = im1.crop((0,0,equal_size,equal_size))
	im2 = im2.crop((0,0,equal_size,equal_size))
	im3 = im3.crop((0,0,equal_size,equal_size))

elif structure_type == 'plate':
	size1 = img1.size[1]
	size2 = img2.size[1]
	size3 = img3.size[1]
	print(size2)
	print(size3)
	diff1 = size1 % size2
	diff2 = size1 % size3

	if diff1 != 0 or diff2 != 0:
		if diff1 < diff2:
			im3 = im3.crop((0,0,img3.size[0],img3.size[1]-(diff2-diff1)))
		elif diff2 > diff1:
			im2 = im2.crop((0,0,img2.size[0],img2.size[1]-(diff1-diff2)))

		diff_new = im1.size[1] % im2.size[1]
		if diff_new != 0:
			im1 = resize_exemplar(img1, num_levels)


	im2 = im2.crop((0,0,im1.size[0],im2.size[1]))
	im3 = im3.crop((0,0,im1.size[0],im3.size[1]))

	scale2 = math.ceil(im1.size[1]/im2.size[1])
	print(scale2)
	scale3 = math.ceil(im1.size[1]/im3.size[1])
	print(np.shape(im2))
	im2 = np.tile(im2,(scale2,1))
	im2 = im2[0:im1.size[1],:]
	im3 = np.tile(im3,(scale3,1))
	print(np.shape(im2))
	im3 = im3[0:im1.size[1],:]
	print(np.shape(im3))

#Num_sections = [min(scale2,scale3),min(scale2,scale3),min(scale2,scale3)]
#print(Num_sections)
	img2 = Image.fromarray(im2, 'L')
	img3 = Image.fromarray(im3, 'L')

	im1 = resize_exemplar(img1, num_levels)   #Crop the exemplar so that it is square and has dimensions compatible with the downsamplings	
	im2 = resize_exemplar(img2, num_levels)
	im3 = resize_exemplar(img3, num_levels)

img1.save('exemplar_1.png')               #Save the cropped original exemplar as 'exemplar_1.png'
img2.save('exemplar_2.png')
img3.save('exemplar_3.png')

if multiphase:
	num_phases = max(len(np.unique(im1)),len(np.unique(im2)),len(np.unique(im3)))
	#img1.load()
	#im1_np = list(img1.getdata())
	#img2.load()
	#im2_np = list(img2.getdata())
	#img3.load()
	#im3_np = list(img3.getdata())
	two_phases_im1 = list(itertools.combinations(np.unique(im1),2))
	two_phases_im2 = list(itertools.combinations(np.unique(im2),2))
	two_phases_im3 = list(itertools.combinations(np.unique(im3),2))
	all_phases = [two_phases_im1,two_phases_im2,two_phases_im3]

	for plane in range(3):
		Exemplar_set_single = []
		if plane == 0:
			two_phases = list(itertools.combinations(np.unique(im1),2))
			print(two_phases)
			for phase in range(len(two_phases)):
				print(phase)
				val1 = two_phases[phase][0]
				val2 = two_phases[phase][1]
				build_exemplars(im1,plane,two_phases[phase])
			#Exemplar_set_single.append(Exemplars)
		#Exemplar_set.append(Exemplar_set_single)

		elif plane == 1:
			two_phases = list(itertools.combinations(np.unique(im2),2))
			for phase in range(len(two_phases)):
				print(phase)
				val1 = two_phases[phase][0]
				val2 = two_phases[phase][1]
				build_exemplars(im2,plane,two_phases[phase])
				print('Exemplars shape')
				print(np.shape(Exemplars))
			#Exemplar_set_single.append(Exemplars)
		#Exemplar_set.append(Exemplar_set_single)
		#print(np.shape(Exemplar_set))

		elif plane == 2:
			two_phases = list(itertools.combinations(np.unique(im3),2))
			for phase in range(len(two_phases)):
				print(phase)
				val1 = two_phases[phase][0]
				val2 = two_phases[phase][1]
				build_exemplars(im3,plane,two_phases[phase])
			#Exemplar_set_single.append(Exemplars)
		Exemplar_set.append(Exemplar_set_single)
		print(np.shape(Exemplar_set))	

	group = [[] for x in range(num_phases)]
	for x in range(len(group)):
		group[x] = x

	for plane in range(3):

		Exemplar_Neighborhoods_single = []

		if plane == 0:
			for x in range(len(group)):
				build_neighborhoods(0,group[x])		#Build Neighborhoods[]
			print(group)
		
		elif plane == 1:
			for x in range(len(group)):
				build_neighborhoods(1,group[x])
			print(group)

		elif plane == 2:
			for x in range(len(group)):
				build_neighborhoods(2,group[x])
			print(group)

		Exemplar_Neighborhoods_set.append(Exemplar_Neighborhoods_single)
		print('Exemplar_Neighborhoods_set shape')
		print(np.shape(Exemplar_Neighborhoods_set))

else:
	binary1 = list(itertools.combinations(np.unique(im1),2))
	build_exemplars(im1,0,binary1)                    #Build Exemplars[]
	binary2 = list(itertools.combinations(np.unique(im2),2))
	build_exemplars(im2,1,binary2)
	binary3 = list(itertools.combinations(np.unique(im3),2))
	build_exemplars(im3,2,binary3)

	num_phases = max(len(binary1),len(binary2),len(binary3))

	build_neighborhoods(0,0)	                 #Build Neighborhoods[]
	build_neighborhoods(1,1)
	build_neighborhoods(2,2)

sizes = [im1.size[0], im1.size[1], im2.size[0], im2.size[1], im3.size[0], im3.size[1]]
full_exemplar_size = min(sizes)

print('Exemplars shape')
print(np.shape(Exemplars))
print(im1.size)
print(im2.size)
print(im3.size)
print(binary1)
print(binary2)
print(binary3)
print(num_phases)

for phase in range(num_phases):
	for kk in range(num_levels):
		k = num_levels - kk
		recon_size = int(full_recon_size/2**(k-1))
		exemplar_size = int(full_exemplar_size/2**(k-1))
		NB_size = NB_sizes[kk]
		group = [[] for x in range(num_levels)]
		for x in range(len(group)):
			group[x] = num_phases*kk+x
		
	
		if structure_type == 'cubic':
			num_sections = Num_sections[kk]
		elif structure_type == 'plate':
			num_sections = scale2
		
		#num_sections = Num_sections[k-1]
		
		#number of steps from the central pixel of the neighborhood to the edge of the neighborhood
		offset = int((NB_size-1)/2)
	       
		if is_exemplar_periodic == False:
			n_first = 1
			n_last = recon_size-2
		else:
			n_first = 0
			n_last = recon_size-1

		#A value of 1 in is_locked[] means that a particular pixel will not be changed again for the rest of the reconstruction
	#histogram[] holds a histogram for the exemplar at this resolution tallying the number of times each pixel's neighborhood
	#is selected as a nearest neighbor during the FLANN searches through the reconstruction's neighborhoods
		if k == num_levels:
			recon = np.zeros((recon_size,recon_size,recon_size))
			is_locked = np.zeros((recon_size,recon_size,recon_size))
		else:
			recon = upsample(recon,3)
			is_locked = upsample(is_locked,3)
	#       histogram = np.zeros((3,exemplar_size,exemplar_size))
	#       histogram.fill(0)
	#weights_in_exemplar.fill(1.0)
		weights_in_exemplar = np.zeros((3,exemplar_size,exemplar_size))
		weights_in_exemplar.fill(1.0)


	    #Randomize the microstructure, but do not change the values of pixels that are locked
		if use_lock:
			for z in range(n_first, n_last+1):
				for y in range(n_first, n_last+1):
					for x in range(n_first, n_last+1):
						if is_locked[z][y][x] == 0:
							rnd = np.random.random()
							if k == 1 or downsample_mode == 'binary':
								if rnd < 0.5:
									recon[z][y][x] = 0
								else:
									recon[z][y][x] = 1
							else:
								if rnd < 0.33:
									recon[z][y][x] = 0
								elif rnd >= 0.66:
									recon[z][y][x] = 2
								else:
									recon[z][y][x] = 1

		elif use_lock == False and kk == 0:
			for z in range(n_first, n_last+1):
				for y in range(n_first, n_last+1):
					for x in range(n_first, n_last+1):
						rnd = np.random.random()
						if downsample_mode == 'binary':
							if rnd < 0.5:
								recon[z][y][x] = 0
							else:
								recon[z][y][x] = 1
						else:
							if rnd < 0.33:
								recon[z][y][x] = 0
							elif rnd >= 0.66:
								recon[z][y][x] = 2
							else:
								recon[z][y][x] = 1

	    #Build the nearest neighbor index
		flann = []
		dataset = []
		params = []
		tic000 = timeit.default_timer()

		for plane in range(3):
			flann_this_plane = FLANN()
			flann.append(flann_this_plane)
			if multiphase:
				dataset.append(np.matrix(Exemplar_Neighborhoods_set[phase][plane][k-1]).astype(np.int32))
			else:
				dataset.append(np.matrix(Exemplar_Neighborhoods[plane][k-1]).astype(np.int32))
	    
	    #A value of 1 for the target_precision means that the nearest neighbor searches are exact.
	    #The reconstruction will take longer but will be more accurate.
	    
	    #For some reason I have to remove the algorithm="autotuned" parameter to prevent this from crashing
			params.append(flann[plane].build_index(dataset[plane], target_precision=flann_precision))
		tic111 = timeit.default_timer()
		
		print('*******************************************************************************')
		print('Time to build the 3 FLANN indices: %0.1f s' % (tic111-tic000))
		print('*******************************************************************************')
		
	    #There is an upper limit to the size of the pipe buffer. Therefore, whatever is to be passed between
	    #processes must be broken into smaller pieces. The maximum list size appears to be around 33,400 elements.
		num_passes = math.ceil(recon_size*(recon_size+2*offset)*offset/33400)
		
	    #The reconstruction will be divided (vertically) into num_sections sections. The last section will have length
	    #section_size_small, while all num_section-1 others will have length section_size_large
		section_size_large = int(math.ceil(recon_size/num_sections))
		section_size_small = int(recon_size - section_size_large*(num_sections-1))
		
		if section_size_large < offset:
			max_num_sections = int(math.floor(2*recon_size/(NB_size-1)))
			print('Please decrease num_sections so that section_size_large > offset. In your case, the maximum allowable num_sections is: %d' % max_num_sections)
			sys.exit()

	    #Construct the pipes (connection objects) between processes. Each section of the reconstruction will be owned by a different process.

		if __name__ == '__main__':

			hist_queue = JoinableQueue()
			exemplar_queue = JoinableQueue()
			partial_queue = JoinableQueue()
			partial_locked_queue = JoinableQueue()
			pipes = []
			for i in range(num_sections):
				pipes.append(Pipe())

			#pool = mp.Pool(num_sections)
			#args = np.arange(num_sections)
			#pool.map(perform_3D_reconstruction, args, chunksize=1)
			#Do the reconstruction
			processes = [Process(target=perform_3D_reconstruction, args=(i, hist_queue, exemplar_queue, partial_queue, partial_locked_queue, pipes)) for i in range(num_sections)]

			for p in processes:
				p.start()

			three_exemplar_histograms = np.zeros((3,exemplar_size,exemplar_size))
			exemplar_histogram_full = np.zeros((3,exemplar_size,exemplar_size))
			recon = np.zeros((recon_size,recon_size,recon_size))
			is_locked = np.zeros((recon_size,recon_size,recon_size))

			if output_option == 1 and kk == 0:
				fig = plt.figure()
				ax1 = fig.add_subplot(3,3,1)
				ax11 = fig.add_subplot(3,3,2)
				ax111 = fig.add_subplot(3,3,3)
				ax2 = fig.add_subplot(3,3,4)
				ax22 = fig.add_subplot(3,3,5)
				ax222 = fig.add_subplot(3,3,6)
				ax3 = fig.add_subplot(3,3,7)
				ax33 = fig.add_subplot(3,3,8)
				ax333 = fig.add_subplot(3,3,9)


			for m in range(1,num_iterations+1):

				for plane in range(3):
					for i in range(num_sections):
						temp_hist = hist_queue.get()
						#hist_queue.join()
						exemplar_histogram_full[plane] += temp_hist
					three_exemplar_histograms[plane] = exemplar_histogram_full[plane]
				for plane in range(3):
					for i in range(num_sections):
						hist_queue.task_done()
					for i in range(num_sections):
						#print('Here 2')
						exemplar_queue.put(three_exemplar_histograms[plane])
						exemplar_queue.join()

				recon.fill(0)
				print(p)
				print('Checkpoint 1 %i' %kk)
				for i in range(num_sections):
					#if pipes_partial_recon[i][0].poll():
					#partial_recon_recv = pipes_partial_recon[i][0].recv()
					partial_recon_temp = partial_queue.get()
					recon += partial_recon_temp
				for i in range(num_sections):
					partial_queue.task_done()
				#global_error_this_iteration = 0
				#for i in range(num_sections):
				#	global_error_this_iteration += error_queue.get()
				#global_error_list.append(global_error_this_iteration)
				#ax2.set_ylim([0,global_error_list[0]])
				#for i in range(num_sections):
				#	error_queue.task_done



				if output_option == 1 and m % steps_between_plotting == 0:
	   
					steps = num_iterations*kk + m
					total_steps = num_levels*num_iterations
					percent_complete = steps / total_steps * 100
		        
	#--------------------------------RECONSTRUCTION PLOTS--------------------------------------------------
		                
					pix1 = [[[0.0, 0.0, 0.0] for i in range(recon_size)] for j in range(recon_size)]
					for y in range(recon_size):
						for x in range(recon_size):
							midplane = int(math.floor(recon_size/2))
							val = recon[midplane][y][x]
							if val == 0:
								pix1[y][x] = [0.0, 0.0, 0.0]
							elif val == 1:
								pix1[y][x] = [1.0, 1.0, 1.0]
							elif val == 2:
								pix1[y][x] = [1.0, 0.0, 0.0]
							elif val == 3:
								pix1[y][x] = [0.0, 0.0, 1.0]
							else:
								pix1[y][x] = [0.0, 1.0, 0.0]
					ax1.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax1.set_title('XY Reconstruction')
	      				#ax1.set_title('Level %s of %s \nIteration %s of %s \nProgress: %0.1f %%' % ((num_levels-k+1),num_levels,m,num_iterations,percent_complete))
					fig.suptitle('Level %s of %s. Iteration %s of %s. Progress: %0.1f %%' % ((num_levels-k+1),num_levels,m,num_iterations,percent_complete))
					ax1.imshow(pix1, interpolation='nearest')
		                        
		                        
					pix2 = [[[0.0, 0.0, 0.0] for i in range(recon_size)] for j in range(recon_size)]
					for y in range(recon_size):
						for x in range(recon_size):
							midplane = int(math.floor(recon_size/2))
							val = recon[y][midplane][x]
							if val == 0:
								pix2[y][x] = [0.0, 0.0, 0.0]
							elif val == 1:
								pix2[y][x] = [1.0, 1.0, 1.0]
							elif val == 2:
								pix2[y][x] = [1.0, 0.0, 0.0]
							elif val == 3:
								pix2[y][x] = [0.0, 0.0, 1.0]
							else:
								pix2[y][x] = [0.0, 1.0, 0.0]
					ax2.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax2.set_title('XZ Reconstruction')
					ax2.imshow(pix2, interpolation='nearest')
		            
		                        
					pix3 = [[[0.0, 0.0, 0.0] for i in range(recon_size)] for j in range(recon_size)]
					for y in range(recon_size):
						for x in range(recon_size):
							midplane = int(math.floor(recon_size/2))
							val = recon[y][x][midplane]
							if val == 0:
								pix3[y][x] = [0.0, 0.0, 0.0]
							elif val == 1:
								pix3[y][x] = [1.0, 1.0, 1.0]
							elif val == 2:
								pix3[y][x] = [1.0, 0.0, 0.0]
							elif val == 3:
								pix3[y][x] = [0.0, 0.0, 1.0]
							else:
								pix3[y][x] = [0.0, 1.0, 0.0]
					ax3.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax3.set_title('YZ Reconstruction')
					ax3.imshow(pix3, interpolation='nearest')
		                        
		#--------------------------------HISTOGRAM PLOTS----------------------------------------------
					histogram_max = 1
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							val = three_exemplar_histograms[0][y][x]
							if val > histogram_max:
								histogram_max = val

					pix11 = [[[0.0, 0.0, 0.0] for i in range(exemplar_size)] for j in range(exemplar_size)]
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							val = three_exemplar_histograms[0][y][x]
							val = val/histogram_max
							num_colors = 5
							val = math.floor(val*num_colors)+1
							val = 1-val/num_colors
							pix11[y][x] = [0, 0, val]
							if three_exemplar_histograms[0][y][x] == 0:
								pix11[y][x] = [1, 1, 1]
					ax11.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax11.set_title('XY Histogram')
					ax11.imshow(pix11, interpolation='nearest')


					histogram_max = 1
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							val = three_exemplar_histograms[1][y][x]
							if val > histogram_max:
								histogram_max = val

					pix22 = [[[0.0, 0.0, 0.0] for i in range(exemplar_size)] for j in range(exemplar_size)]
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							val = three_exemplar_histograms[1][y][x]
							val = val/histogram_max
							num_colors = 5
							val = math.floor(val*num_colors)+1
							val = 1-val/num_colors
							pix22[y][x] = [0, 0, val]
							if three_exemplar_histograms[1][y][x] == 0:
								pix22[y][x] = [1, 1, 1]
					ax22.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax22.set_title('XZ Histogram')
					ax22.imshow(pix22, interpolation='nearest')
	    
	    
					histogram_max = 1
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							val = three_exemplar_histograms[2][y][x]
							if val > histogram_max:
								histogram_max = val
		        
					pix33 = [[[0.0, 0.0, 0.0] for i in range(exemplar_size)] for j in range(exemplar_size)]
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							val = three_exemplar_histograms[2][y][x]
							val = val/histogram_max
							num_colors = 5
							val = math.floor(val*num_colors)+1
							val = 1-val/num_colors
							pix33[y][x] = [0, 0, val]
							if three_exemplar_histograms[2][y][x] == 0:
								pix33[y][x] = [1, 1, 1]
					ax33.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax33.set_title('YZ Histogram')
					ax33.imshow(pix33, interpolation='nearest')
		                
		#-----------------------------------------------EXEMPLAR PLOTS---------------------------------------
		                        
					pix111 = [[[0.0, 0.0, 0.0] for i in range(exemplar_size)] for j in range(exemplar_size)]
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							if multiphase:
								val = Exemplar_set[phase][phase][k-1][y][x]
							else:
								val = Exemplars[0][k-1][y][x]
							if val == 0:
								pix111[y][x] = [0.0, 0.0, 0.0]
							elif val == 1:
								pix111[y][x] = [1.0, 1.0, 1.0]
							elif val == 2:
								pix111[y][x] = [1.0, 0.0, 0.0]
							elif val == 3:
								pix111[y][x] = [0.0, 0.0, 1.0]
							else:
								pix111[y][x] = [0.0, 1.0, 0.0]
					ax111.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax111.set_title('XY Exemplar')
					ax111.imshow(pix111, interpolation='nearest')
		                
		                
					pix222 = [[[0.0, 0.0, 0.0] for i in range(exemplar_size)] for j in range(exemplar_size)]
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							if multiphase:
								val = Exemplar_set[phase][phase][k-1][y][x]
							else:
								val = Exemplars[1][k-1][y][x]
							if val == 0:
								pix222[y][x] = [0.0, 0.0, 0.0]
							elif val == 1:
								pix222[y][x] = [1.0, 1.0, 1.0]
							elif val == 2:
								pix222[y][x] = [1.0, 0.0, 0.0]
							elif val == 3:
								pix222[y][x] = [0.0, 0.0, 1.0]
							else:
								pix222[y][x] = [0.0, 1.0, 0.0]
					ax222.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax222.set_title('XZ Exemplar')
					ax222.imshow(pix222, interpolation='nearest')
		                
		                
					pix333 = [[[0.0, 0.0, 0.0] for i in range(exemplar_size)] for j in range(exemplar_size)]
					for y in range(exemplar_size):
						for x in range(exemplar_size):
							if multiphase:
								val = Exemplar_set[phase][phase][k-1][y][x]
							else:
								val = Exemplars[2][k-1][y][x]
							if val == 0:
								pix333[y][x] = [0.0, 0.0, 0.0]
							elif val == 1:
								pix333[y][x] = [1.0, 1.0, 1.0]
							elif val == 2:
								pix333[y][x] = [1.0, 0.0, 0.0]
							elif val == 3:
								pix333[y][x] = [0.0, 0.0, 1.0]
							else:
								pix333[y][x] = [0.0, 1.0, 0.0]
					ax333.tick_params(axis='both', which='both', bottom='off', top='off', labelbottom='off', right='off', left='off', labelleft='off')
					ax333.set_title('YZ Exemplar')
					ax333.imshow(pix333, interpolation='nearest')
		                
					#plt.plot(global_error_list)
		                
					if m < num_iterations or kk < num_levels-1:
						plt.show(block=False)
						plt.pause(0.001)
						plt.cla()
					else:
						ax1.set_title('Reconstruction Completed. Progress: 100 %')
						plt.show(block=False)
						plt.pause(0.001)
						plt.cla()
						time.sleep(10)
						plt.close('all')
					
		    	                                        
	    #If the user wants the output of the reconstruction to be printed to PNG images
	    #Do so every steps_between_plotting iterations
				elif output_option == 2 and m % steps_between_plotting == 0:
		                                                
					PNG_image = Image.new("RGB", (recon_size, recon_size))
					pix_PNG_image = PNG_image.load()
					for y in range(recon_size):
						for x in range(recon_size):
							val = recon[y][x]
							if val == 0:
								pix_PNG_image[x,y] = (0,0,0)
							elif val == 1:
								pix_PNG_image[x,y] = (255,255,255)
							elif val == 2:
								pix_PNG_image[x,y] = (255,0,0)
							elif val == 3:
								pix_PNG_image[x,y] = (0,0,255)
							else:
								pix_PNG_image[x,y] = (0,255,0)
					PNG_image.save('level_%s_beginning_of_iteration_%s.png' % ((num_levels-k+1),n))

			recon.fill(0)
			for i in range(num_sections):
				#if pipes_partial_recon[i][0].poll():
				#partial_recon_recv = pipes_partial_recon[i][0].recv()
				partial_recon_recv = partial_queue.get()
				recon += partial_recon_recv
			for i in range(num_sections):
				partial_queue.task_done()

			for i in range(num_sections):
				partial_is_locked = partial_locked_queue.get()
				is_locked += partial_is_locked
			for i in range(num_sections):
				partial_locked_queue.task_done()

		#for i in range(num_sections):
			#if pipes_is_locked[i][0].poll():
		#	partial_locked = pipes_is_locked[i][0].recv()
		#	is_locked += partial_locked


		
			for p in processes:
				p.join()

		if not multiphase:
			if structure_type == 'cubic':
				dirpath = inputs.full_path
				dirname = inputs.new_folder_name
				path_and_name = dirpath+dirname
				if not os.path.exists(path_and_name):
					os.makedirs(path_and_name)
				slash = '/'
	
				for z in range(recon_size):
					recon_plane = recon[z]
					img = Image.fromarray((recon_plane*255).astype(np.uint8))
					filename = 'Level_%d_plane_%d.png' %(kk,z)
					path_and_filename = path_and_name+slash+filename
					img.save(path_and_filename)

			elif structure_type == 'plate':
				dirname_base = 'plate_Recon_'
				filename_base1 = 'Recon_'
				filename_base2 = '_Level_'
				filename_base3 = '_plane_'
				extension = '.png'
				slash = '/'
				section_size_large = int(math.ceil(recon_size/num_sections))
				section_size_small = int(recon_size - section_size_large*(num_sections-1))

				for section_number in range(num_sections):
					recon_num = str(section_number)
					newdir = dirname_base+recon_num
					print(newdir)
				#directory = os.path.dirname(newdir)
				#directory = path+directory
					if not os.path.exists(newdir):
						os.makedirs(newdir)

					start = int(section_number*section_size_large)
					if section_number == num_sections-1:
						end = recon_size-1
					else:
						end = int((section_number+1)*section_size_large-1)

					if section_number == num_sections-1:
						section_size = section_size_small
					else:
						section_size = section_size_large


					filename_base1_new = filename_base1+str(section_number)
					recon_section = recon[:,start:end+1,:]

					for z in range(recon_size):
						recon_section_plane = recon_section[z]
						img = Image.fromarray((recon_section_plane*255).astype(np.uint8))
						path_and_name = newdir+slash+filename_base1+str(section_number)+filename_base2+str(kk)+filename_base3+str(z)+extension
						img.save(path_and_name)
		
		else:
			if structure_type == 'cubic':
				dirname_base = 'Recon_'
				filename_base1 = 'Phases_'
				filename_base2 = '_and_'
				filename_base3 = '_Level_'
				filename_base4 = '_plane_'
				extension = '.png'
				slash = '/'
			
				#for phase in range(num_phases):
				recon_num = str(phase)
				phase1 = str(two_phases_im1[phase][0])
				phase2 = str(two_phases_im1[phase][1])
				newdir = dirname_base+recon_num
				print(newdir)
			#directory = os.path.dirname(newdir)
			#directory = path+directory
				if not os.path.exists(newdir):
					os.makedirs(newdir)

				for z in range(recon_size):
					for y in range(recon_size):
						for x in range(recon_size):
							if recon[z][y][x] == 0:
								recon[z][y][x] = two_phases_im1[phase][0]
							else:
								recon[z][y][x] = two_phases_im1[phase][1]

				for z in range(recon_size):
					recon_plane = recon[z]
					img = Image.fromarray((recon_plane).astype(np.uint8))
					path_and_name = newdir+slash+filename_base1+phase1+filename_base2+phase2+filename_base3+str(kk)+filename_base4+str(z)+extension
					img.save(path_and_name)

			elif structure_type == 'plate':
				dirname_base = 'Plate_'
				filename_base1 = 'Phases_'
				filename_base2 = '_and_'
				filename_base3 = '_Level_'
				filename_base4 = '_plane_'
				extension = '.png'
				slash = '/'
				section_size_large = int(math.ceil(recon_size/num_sections))
				section_size_small = int(recon_size - section_size_large*(num_sections-1))
				phase1 = str(two_phases_im1[phase][0])
				phase2 = str(two_phases_im1[phase][1])

				for z in range(recon_size):
					for y in range(recon_size):
						for x in range(recon_size):
							if recon[z][y][x] == 0:
								recon[z][y][x] = two_phases_im1[phase][0]
							else:
								recon[z][y][x] = two_phases_im1[phase][1]

				for section_number in range(num_sections):
					recon_num = str(section_number)
					newdir = dirname_base+recon_num
					print(newdir)
				#directory = os.path.dirname(newdir)
				#directory = path+directory
					if not os.path.exists(newdir):
						os.makedirs(newdir)

					start = int(section_number*section_size_large)
					if section_number == num_sections-1:
						end = recon_size-1
					else:
						end = int((section_number+1)*section_size_large-1)

					if section_number == num_sections-1:
						section_size = section_size_small
					else:
						section_size = section_size_large


					filename_base1_new = filename_base1+phase1+filename_base2+phase2 #str(section_number)
					recon_section = recon[:,start:end+1,:]

					for z in range(recon_size):
						recon_section_plane = recon_section[z]
						img = Image.fromarray((recon_section_plane).astype(np.uint8))
						path_and_name = newdir+slash+filename_base1_new+filename_base3+str(kk)+filename_base4+str(z)+extension
						img.save(path_and_name)

		if multiphase and kk == num_levels-1:
			
		#filename = 'full-recon'
		#fileObject = open(filename, 'wb')
		#recon = np.memmap(filename, shape=(recon_size,recon_size,recon_size))
	#return recon

			all_recons.append(recon)

if multiphase:
	print(np.shape(all_recons))

	empties = [None]*num_phases
	for i in range(len(empties)):
		empties[i] = []

	filename = 'full_recon'
	fileObject = open(filename,'wb')
	full_recon = np.memmap(filename,shape=(recon_size,recon_size,recon_size))
	#full_recon = np.squeeze([[[[empties] for i in range(recon_size)] for j in range(recon_size)] for k in range(recon_size)])
	full_recon = [[[[empties] for i in range(recon_size)] for j in range(recon_size)] for k in range(recon_size)]

	full_recon = np.squeeze(full_recon)

	print(np.shape(full_recon))

	#full_recon = full_recon[:][:][:][:]

	for phase in range(num_phases):
		#for plane in range(3):
		for z in range(recon_size):
			for y in range(recon_size):
				for x in range(recon_size):
					#if plane == 0:
					if all_recons[phase][z][y][x] == 0:
						all_recons[phase][z][y][x] = two_phases_im1[phase][0]
					else:
						all_recons[phase][z][y][x] = two_phases_im1[phase][1]

				#elif plane == 1:
					#if all_recons[phase][z][y][x] == 0:
					#	all_recons[phase][z][y][x] = two_phases_im2[phase][0]
					#else:
					#	all_recons[phase][z][y][x] = two_phases_im2[phase][1]

				#elif plane == 2:
					#if all_recons[phase][z][y][x] == 0:
					#	all_recons[phase][z][y][x] = two_phases_im3[phase][0]
					#else:
					#	all_recons[phase][z][y][x] = two_phases_im3[phase][1]



	for phase in range(num_phases):
		for z in range(recon_size):
			for y in range(recon_size):
				for x in range(recon_size):
					#for plane in range(3):
					#	if plane == 0:
					
					full_recon[z][y][x][phase] = all_recons[phase][z][y][x]

print('done')
