import bpy, nodeitems_utils, mathutils
import math, random, time, colorsys
from . import Node, colorize_square, TextureSetGetter, Image, DATA_TextureSet, SOCKET_TextureSet, SOCKET_Map, MapGetter, DATA_GETTER_NODE_Image, DATA_Image, \
	DATA_Map, SOCKET_Image
from .. import my_globals
from ..utils import définirPixel
from typing import Tuple, List


class Get_Image(bpy.types.Node, Node, DATA_GETTER_NODE_Image):
	bl_idname = 'sc_node_2oeepxhcdeyiputfhnig'
	bl_label = 'Get image'
	img: bpy.props.PointerProperty(name='Image', type=bpy.types.Image, update=Node.prop_updated)
	at_least_one_input_socket_required = False

	def sc_init(self, context):
		self.create_output(SOCKET_Image, is_new_data_output=False)

	def are_all_inputs_correct(self):
		return self.img != None

	def sc_draw_buttons(self, context, layout):
		layout.prop(self, 'img', text='')

	@Node.get_data_first
	def get_images(self, *args, **kwargs):
		if not self.img:
			self.print('No image specified')
			raise ValueError
		else:
			image = DATA_Image(self.img)
			self.print(f'Image "{self.img.name}" found')
		return [image]


class MapToBlenderImageNode(bpy.types.Node, Node, DATA_GETTER_NODE_Image):
	bl_idname = 'sc_node_ffxli2xzx5km3oysmh7b'
	bl_label = 'Map to Blender image'

	image_name: bpy.props.StringProperty(
		name='Image name',
		description='Name of the ouput image to draw into. If the image exists, it will be modified, otherwise it will be created',
		default='SceneCity map')

	image_resolution: bpy.props.IntProperty(
		name='Image resolution',
		description='Resolution of the ouput image to draw into',
		default=512,
		min=4)

	normalize: bpy.props.BoolProperty(
		name='Normalize',
		description='Draw values in the 0 to 1 color range, so that something is always visible no matter the real underlying map values',
		default=True)

	def sc_init(self, context):
		self.width = 300
		self.create_input(SOCKET_Map, is_required=False, label='Map -> RGB')
		self.create_input(SOCKET_Map, is_required=False, label='Map -> Red channel')
		self.create_input(SOCKET_Map, is_required=False, label='Map -> Green channel')
		self.create_input(SOCKET_Map, is_required=False, label='Map -> Blue channel')
		self.create_input(SOCKET_Map, is_required=False, label='Map -> Alpha channel')
		self.create_output(SOCKET_Image, is_new_data_output=False)

	class SC_OT_Map_to_blender_image(bpy.types.Operator):
		bl_idname = 'sc.map_to_blender_image'
		bl_description = 'Draw map into image'
		bl_label = 'Create / update image'
		source_node_path: bpy.props.StringProperty()

		def execute(self, context):
			source_node: MapToBlenderImageNode = eval(self.source_node_path)
			source_node.get_images()
			return {'FINISHED'}

	def sc_draw_buttons(self, context, layout):
		layout.prop(self, 'image_name', icon='FILE_IMAGE')
		layout.prop(self, 'image_resolution')
		layout.prop(self, 'normalize')
		self.create_operator(layout, MapToBlenderImageNode.SC_OT_Map_to_blender_image)

	# def create_images(self):
	# 	input_socket_rgb: SOCKET_Map = self.inputs[0]
	# 	input_socket_r: SOCKET_Map = self.inputs[1]
	# 	input_socket_g: SOCKET_Map = self.inputs[2]
	# 	input_socket_b: SOCKET_Map = self.inputs[3]
	# 	input_socket_a: SOCKET_Map = self.inputs[4]
	# 	# get maps for each channel
	# 	rgb_map = input_socket_rgb.get_input_map()
	# 	red_map = input_socket_r.get_input_map()
	# 	green_map = input_socket_g.get_input_map()
	# 	blue_map = input_socket_b.get_input_map()
	# 	alpha_map = input_socket_a.get_input_map()
	#
	# 	startTime = time.time()  # start counting time here, don't count previous work needed
	#
	# 	# get or create new image
	# 	try:
	# 		image = bpy.data.images[self.image_name]
	# 		image.scale(self.image_resolution, self.image_resolution)
	# 	except:
	# 		image = bpy.data.images.new(
	# 			self.image_name,
	# 			self.image_resolution,
	# 			self.image_resolution,
	# 			alpha=True,
	# 			float_buffer=True)
	# 		image.name = self.image_name  # force name
	#
	# 	# store values, and get min and max first
	# 	# computed_values_rgb = None
	# 	# computed_values_r = None
	# 	# computed_values_g = None
	# 	# computed_values_b = None
	# 	# computed_values_a = None
	# 	if rgb_map:
	# 		rgb_max_value = -math.inf
	# 		rgb_min_value = math.inf
	# 		computed_values_rgb = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
	# 	if red_map:
	# 		r_max_value = -math.inf
	# 		r_min_value = math.inf
	# 		computed_values_r = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
	# 	if green_map:
	# 		g_max_value = -math.inf
	# 		g_min_value = math.inf
	# 		computed_values_g = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
	# 	if blue_map:
	# 		b_max_value = -math.inf
	# 		b_min_value = math.inf
	# 		computed_values_b = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
	# 	if alpha_map:
	# 		a_max_value = -math.inf
	# 		a_min_value = math.inf
	# 		computed_values_a = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
	# 	last_progress_display_time = -math.inf
	# 	for i in range(self.image_resolution):
	# 		xPercent = i / self.image_resolution
	# 		if time.time() >= last_progress_display_time + .2:
	# 			self.afficher_barre_progression(i / self.image_resolution / 2, time.time() - startTime)
	# 			last_progress_display_time = time.time()
	# 		for j in range(self.image_resolution):
	# 			yPercent = j / self.image_resolution
	# 			if rgb_map:
	# 				map_value = rgb_map.get_value(xPercent, yPercent)
	# 				# print(map_value, rgb_max_value)
	# 				if map_value > rgb_max_value: rgb_max_value = map_value
	# 				elif map_value < rgb_min_value: rgb_min_value = map_value
	# 				computed_values_rgb[i][j] = map_value
	# 			if red_map:
	# 				map_value = red_map.get_value(xPercent, yPercent)
	# 				if map_value > r_max_value: r_max_value = map_value
	# 				elif map_value < r_min_value: r_min_value = map_value
	# 				computed_values_r[i][j] = map_value
	# 			if green_map:
	# 				map_value = green_map.get_value(xPercent, yPercent)
	# 				if map_value > g_max_value: g_max_value = map_value
	# 				elif map_value < g_min_value: g_min_value = map_value
	# 				computed_values_g[i][j] = map_value
	# 			if blue_map:
	# 				map_value = blue_map.get_value(xPercent, yPercent)
	# 				if map_value > b_max_value: b_max_value = map_value
	# 				elif map_value < b_min_value: b_min_value = map_value
	# 				computed_values_b[i][j] = map_value
	# 			if alpha_map:
	# 				map_value = alpha_map.get_value(xPercent, yPercent)
	# 				if map_value > a_max_value: a_max_value = map_value
	# 				elif map_value < a_min_value: a_min_value = map_value
	# 				computed_values_a[i][j] = map_value
	# 	if rgb_map: rgb_écart_valeurs_max = rgb_max_value - rgb_min_value
	# 	if red_map: r_écart_valeurs_max = r_max_value - r_min_value
	# 	if green_map: g_écart_valeurs_max = g_max_value - g_min_value
	# 	if blue_map: b_écart_valeurs_max = b_max_value - b_min_value
	# 	if alpha_map: a_écart_valeurs_max = a_max_value - a_min_value
	#
	# 	# draw
	# 	total_pixels = self.image_resolution ** 2 * 4
	# 	pixels = [0 for _ in range(total_pixels)]
	# 	for i in range(self.image_resolution):
	# 		if time.time() >= last_progress_display_time + .2:
	# 			self.afficher_barre_progression(0.5 + i / self.image_resolution / 2, time.time() - startTime)
	# 			last_progress_display_time = time.time()
	# 		for j in range(self.image_resolution):
	# 			if rgb_map:
	# 				if self.normalize and rgb_écart_valeurs_max != 0:
	# 					r = g = b = (computed_values_rgb[i][j] - rgb_min_value) / rgb_écart_valeurs_max
	# 				else:
	# 					r = g = b = computed_values_rgb[i][j]
	# 			else:
	# 				r = g = b = 0
	# 			a = 1
	# 			if red_map:
	# 				if self.normalize and r_écart_valeurs_max != 0:
	# 					r = (computed_values_r[i][j] - r_min_value) / r_écart_valeurs_max
	# 				else:
	# 					r = computed_values_r[i][j]
	# 			if green_map:
	# 				if self.normalize and g_écart_valeurs_max != 0:
	# 					g = (computed_values_g[i][j] - g_min_value) / g_écart_valeurs_max
	# 				else:
	# 					g = computed_values_g[i][j]
	# 			if blue_map:
	# 				if self.normalize and b_écart_valeurs_max != 0:
	# 					b = (computed_values_b[i][j] - b_min_value) / b_écart_valeurs_max
	# 				else:
	# 					b = computed_values_b[i][j]
	# 			if alpha_map:
	# 				if self.normalize and a_écart_valeurs_max != 0:
	# 					a = (computed_values_a[i][j] - a_min_value) / a_écart_valeurs_max
	# 				else:
	# 					a = computed_values_a[i][j]
	# 			définirPixel(pixels, self.image_resolution, i, j, (r, g, b, a))
	#
	# 	image.pixels = pixels
	# 	self.afficher_barre_progression(1, time.time() - startTime)
	# 	image.update()
	# 	return [DATA_Image(image)]

	def _get_images_necessary_data(self, *args, **kwargs):
		input_socket_rgb: SOCKET_Map = self.inputs[0]
		input_socket_r: SOCKET_Map = self.inputs[1]
		input_socket_g: SOCKET_Map = self.inputs[2]
		input_socket_b: SOCKET_Map = self.inputs[3]
		input_socket_a: SOCKET_Map = self.inputs[4]
		return (input_socket_rgb.get_input_map(), input_socket_r.get_input_map(), input_socket_g.get_input_map(), input_socket_b.get_input_map(),
				input_socket_a.get_input_map())
		# pass

	@Node.get_data_first
	def get_images(self, input_maps: Tuple[DATA_Map], *args, **kwargs):
		# get maps for each channel
		rgb_map = input_maps[0]
		red_map = input_maps[1]
		green_map = input_maps[2]
		blue_map = input_maps[3]
		alpha_map = input_maps[4]

		startTime = time.time()  # start counting time here, don't count previous work needed

		# get or create new image
		try:
			image = bpy.data.images[self.image_name]
			image.scale(self.image_resolution, self.image_resolution)
		except:
			image = bpy.data.images.new(
				self.image_name,
				self.image_resolution,
				self.image_resolution,
				alpha=True,
				float_buffer=True)
			image.name = self.image_name  # force name

		# store values, and get min and max first
		# computed_values_rgb = None
		# computed_values_r = None
		# computed_values_g = None
		# computed_values_b = None
		# computed_values_a = None
		if rgb_map:
			rgb_max_value = -math.inf
			rgb_min_value = math.inf
			computed_values_rgb = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
		if red_map:
			r_max_value = -math.inf
			r_min_value = math.inf
			computed_values_r = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
		if green_map:
			g_max_value = -math.inf
			g_min_value = math.inf
			computed_values_g = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
		if blue_map:
			b_max_value = -math.inf
			b_min_value = math.inf
			computed_values_b = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
		if alpha_map:
			a_max_value = -math.inf
			a_min_value = math.inf
			computed_values_a = [[0 for __ in range(self.image_resolution)] for _ in range(self.image_resolution)]
		last_progress_display_time = -math.inf
		for i in range(self.image_resolution):
			xPercent = i / self.image_resolution
			if time.time() >= last_progress_display_time + .2:
				self.afficher_barre_progression(i / self.image_resolution / 2, time.time() - startTime)
				last_progress_display_time = time.time()
			for j in range(self.image_resolution):
				yPercent = j / self.image_resolution
				if rgb_map:
					map_value = rgb_map.get_value(xPercent, yPercent)
					# print(map_value, rgb_max_value)
					if map_value > rgb_max_value: rgb_max_value = map_value
					elif map_value < rgb_min_value: rgb_min_value = map_value
					computed_values_rgb[i][j] = map_value
				if red_map:
					map_value = red_map.get_value(xPercent, yPercent)
					if map_value > r_max_value: r_max_value = map_value
					elif map_value < r_min_value: r_min_value = map_value
					computed_values_r[i][j] = map_value
				if green_map:
					map_value = green_map.get_value(xPercent, yPercent)
					if map_value > g_max_value: g_max_value = map_value
					elif map_value < g_min_value: g_min_value = map_value
					computed_values_g[i][j] = map_value
				if blue_map:
					map_value = blue_map.get_value(xPercent, yPercent)
					if map_value > b_max_value: b_max_value = map_value
					elif map_value < b_min_value: b_min_value = map_value
					computed_values_b[i][j] = map_value
				if alpha_map:
					map_value = alpha_map.get_value(xPercent, yPercent)
					if map_value > a_max_value: a_max_value = map_value
					elif map_value < a_min_value: a_min_value = map_value
					computed_values_a[i][j] = map_value
		if rgb_map: rgb_écart_valeurs_max = rgb_max_value - rgb_min_value
		if red_map: r_écart_valeurs_max = r_max_value - r_min_value
		if green_map: g_écart_valeurs_max = g_max_value - g_min_value
		if blue_map: b_écart_valeurs_max = b_max_value - b_min_value
		if alpha_map: a_écart_valeurs_max = a_max_value - a_min_value

		# draw
		total_pixels = self.image_resolution ** 2 * 4
		pixels = [0 for _ in range(total_pixels)]
		for i in range(self.image_resolution):
			if time.time() >= last_progress_display_time + .2:
				self.afficher_barre_progression(0.5 + i / self.image_resolution / 2, time.time() - startTime)
				last_progress_display_time = time.time()
			for j in range(self.image_resolution):
				if rgb_map:
					if self.normalize and rgb_écart_valeurs_max != 0:
						r = g = b = (computed_values_rgb[i][j] - rgb_min_value) / rgb_écart_valeurs_max
					else:
						r = g = b = computed_values_rgb[i][j]
				else:
					r = g = b = 0
				a = 1
				if red_map:
					if self.normalize and r_écart_valeurs_max != 0:
						r = (computed_values_r[i][j] - r_min_value) / r_écart_valeurs_max
					else:
						r = computed_values_r[i][j]
				if green_map:
					if self.normalize and g_écart_valeurs_max != 0:
						g = (computed_values_g[i][j] - g_min_value) / g_écart_valeurs_max
					else:
						g = computed_values_g[i][j]
				if blue_map:
					if self.normalize and b_écart_valeurs_max != 0:
						b = (computed_values_b[i][j] - b_min_value) / b_écart_valeurs_max
					else:
						b = computed_values_b[i][j]
				if alpha_map:
					if self.normalize and a_écart_valeurs_max != 0:
						a = (computed_values_a[i][j] - a_min_value) / a_écart_valeurs_max
					else:
						a = computed_values_a[i][j]
				définirPixel(pixels, self.image_resolution, i, j, (r, g, b, a))

		image.pixels = pixels
		self.afficher_barre_progression(1, time.time() - startTime)
		image.update()
		return [DATA_Image(image)]
		# return self.create_images()


class SC_OT_TextureSet2BlImagesNodeCreateImages(bpy.types.Operator):
	bl_idname = 'node.texture_set_2_blender_images_node_create_image'
	bl_description = 'Draw into the image with the given name'
	bl_label = 'Create / update images'
	source_node_path: bpy.props.StringProperty()

	def execute(self, context):
		my_globals.todel_derniere_operation_par_operator_non_standard = True

		source_node = eval(self.source_node_path)  # type: TextureSetToBlenderImagesNode
		texture_set_source_node = source_node.inputs['Texture set'].links[0].from_node  # type: TextureSetGetter
		texture_set = texture_set_source_node.get_texture_set()  # type: DATA_TextureSet

		startTime = time.time()
		for texture_name, image in texture_set.textures.items():
			blender_image_name = source_node.blender_images_name_prefix + ' - ' + texture_name
			try:
				blender_image = bpy.data.images[blender_image_name]
				blender_image.scale(image.resolution[0], image.resolution[1])
			except:
				blender_image = bpy.data.images.new(
					blender_image_name,
					image.resolution[0],
					image.resolution[1],
					alpha=True,
					float_buffer=True)
				blender_image.name = blender_image_name

			pixel_data = [0 for i in range(image.resolution[0] * image.resolution[1] * 4)]
			imageWidth = image.resolution[0]
			for i, pixel_column in enumerate(image.pixels):
				for j, pixel in enumerate(pixel_column):
					indicePixel = j * imageWidth * 4 + i * 4
					pixel_data[indicePixel + 0] = pixel.r
					pixel_data[indicePixel + 1] = pixel.g
					pixel_data[indicePixel + 2] = pixel.b
					pixel_data[indicePixel + 3] = pixel.a
			blender_image.pixels = pixel_data
		source_node.last_operation_time = time.time() - startTime

		return {'FINISHED'}


class TextureSetToBlenderImagesNode(bpy.types.Node, Node):
	bl_idname = 'sc_node_v2afgdv6fre205cylj7q'
	bl_label = 'Texture set to Blender images'

	# last_operation_time: bpy.props.FloatProperty(
	# 	name='',
	# 	description='',
	# 	default=0)

	blender_images_name_prefix: bpy.props.StringProperty(
		name='Images prefix',
		description='Name prefix for the different Blender images to create (or to update if they already exist)',
		default='A texture', )

	def sc_init(self, context):
		# self.inputs.new('TextureSetSocket', 'Texture set')
		self.create_input(SOCKET_TextureSet, is_required=True)

	def sc_draw_buttons(self, context, layout):
		# if len(self.inputs['Texture set'].links) <= 0:
		# 	layout.label(text="Source texture set is needed", icon="ERROR")
		#
		# self.ui_display_doc_and_last_job_done2(layout)

		layout.prop(self, 'blender_images_name_prefix', icon='IMAGE_DATA')
		op = layout.operator(SC_OT_TextureSet2BlImagesNodeCreateImages.bl_idname)
		op.source_node_path = 'bpy.data.node_groups["' + self.id_data.name + '"].' + self.path_from_id()
