'''
    Copyright (C) 2021 - 2023 Akaneyu

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
'''

import os
import bpy
import bpy.utils.previews
import blf
import numpy as np
from . ui_renderer import UIRenderer as UIRenderer
from . import utils

class AreaSession:
    def __init__(self):
        self.selection = None
        self.selection_region = None
        self.selecting = False
        self.layer_moving = False
        self.layer_rotating = False
        self.layer_scaling = False
        self.prevent_layer_update_event = False
        self.prev_image = None

class Session:
    def __init__(self):
        self.icons = None
        self.keymaps = []
        self.cached_image_pixels = None
        self.cached_image_hsl = None
        self.cached_layer_location = None
        self.ui_renderer = None
        self.copied_image_pixels = None
        self.copied_image_settings = None
        self.copied_layer_settings = None
        self.areas = {}

def get_session():
    global session

    return session

def get_area_session(context):
    area_session = session.areas.get(context.area, None)
    if not area_session:
        area_session = AreaSession()
        session.areas[context.area] = area_session

    return area_session

def draw_handler():
    global session

    context = bpy.context
    area_session = get_area_session(context)

    info_text = None

    width, height, img_name = 0, 0, ''
    img = context.area.spaces.active.image
    if img:
        width, height = img.size[0], img.size[1]

    # render selection frame
    if img and area_session.selection or area_session.selection_region:
        if area_session.selection:
            selection = area_session.selection
            view_x1 = selection[0][0] / width
            view_y1 = selection[0][1] / height
            view_x2 = selection[1][0] / width
            view_y2 = selection[1][1] / height

            region_pos1 = context.region.view2d.view_to_region(view_x1, view_y1, clip=False)
            region_pos2 = context.region.view2d.view_to_region(view_x2, view_y2, clip=False)
        else:
            region_pos1, region_pos2 = area_session.selection_region

        if not session.ui_renderer:
            session.ui_renderer = UIRenderer()

        region_size = [region_pos2[0] - region_pos1[0],
                region_pos2[1] - region_pos1[1]]

        session.ui_renderer.render_selection_frame(region_pos1, region_size)

    # render layers
    if img:
        if not session.ui_renderer:
            session.ui_renderer = UIRenderer()

        img_props = img.imageeditorplus_properties
        selected_layer_index = img_props.selected_layer_index
        layers = img_props.layers

        for i, layer in reversed(list(enumerate(layers))):
            layer_img = bpy.data.images.get(layer.name, None)
            if layer_img:
                layer_width, layer_height = layer_img.size[0], layer_img.size[1]
                layer_pos = layer.location
                layer_pos1 = [layer_pos[0], layer_pos[1] + layer_height]
                layer_pos2 = [layer_pos[0] + layer_width, layer_pos[1]]

                layer_view_x1 = layer_pos1[0] / width
                layer_view_y1 = 1.0 - layer_pos1[1] / height
                layer_region_pos1 = context.region.view2d.view_to_region(
                        layer_view_x1, layer_view_y1, clip=False)

                layer_view_x2 = layer_pos2[0] / width
                layer_view_y2 = 1.0 - layer_pos2[1] / height
                layer_region_pos2 = context.region.view2d.view_to_region(
                        layer_view_x2, layer_view_y2, clip=False)

                layer_region_size = [layer_region_pos2[0] - layer_region_pos1[0],
                        layer_region_pos2[1] - layer_region_pos1[1]]

                if not layer.hide:
                    session.ui_renderer.render_image(layer_img, layer_region_pos1,
                        layer_region_size, layer.rotation, layer.scale)

                if i == selected_layer_index:
                    session.ui_renderer.render_selection_frame(
                            layer_region_pos1, layer_region_size, layer.rotation,
                            layer.scale)

    # release the selection if the image is changed
    if area_session.selection or area_session.selection_region:
        if img != area_session.prev_image:
            cancel_selection(context)

        elif width != area_session.prev_image_width \
                or height != area_session.prev_image_height:

            crop_selection(context)

        area_session.prev_image = img
        area_session.prev_image_width = width
        area_session.prev_image_height = height

    if area_session.layer_moving \
            or area_session.layer_rotating \
            or area_session.layer_scaling:

        info_text = "LMB: Perform\n" \
            + "RMB: Cancel"

    area_height = context.area.height

    # info text
    if info_text:
        blf.enable(0, blf.WORD_WRAP)
        blf.word_wrap(0, 100)

        blf.position(0, 30, area_height - 70, 0)
        blf.size(0, 14, 72)
        blf.draw(0, info_text)

        blf.disable(0, blf.WORD_WRAP)

def get_curve_node():
    node_group = bpy.data.node_groups.get('imageeditorplus')
    if not node_group:
        node_group = bpy.data.node_groups.new('imageeditorplus', 'ShaderNodeTree')

    nodes = node_group.nodes

    curve_node = next((node for node in nodes if node.bl_idname == 'ShaderNodeRGBCurve'),
            None)
    if not curve_node:
        curve_node = nodes.new('ShaderNodeRGBCurve')

    return curve_node

def get_curve_mapping():
    return get_curve_node().mapping

def reset_curve_mapping():
    curve_mapping = get_curve_mapping()

    for curve in curve_mapping.curves:
        curve_points = curve.points
        num_curve_points = len(curve_points)

        # remove extra points (> 2)
        if num_curve_points > 2:
            for i in range(num_curve_points - 2):
                curve_points.remove(curve_points[2])

        curve_points[0].location[0] = 0
        curve_points[0].location[1] = 0
        curve_points[0].select = False
        curve_points[1].location[0] = 1.0
        curve_points[1].location[1] = 1.0
        curve_points[1].select = False

    curve_mapping.update()

def get_active_layer(context):
    img = context.area.spaces.active.image
    if not img:
        return None

    img_props = img.imageeditorplus_properties
    layers = img_props.layers
    selected_layer_index = img_props.selected_layer_index

    if selected_layer_index == -1 or selected_layer_index >= len(layers):
        return None

    return layers[selected_layer_index]

def get_target_image(context):
    layer = get_active_layer(context)
    if layer:
        return bpy.data.images.get(layer.name, None)
    else:
        return context.area.spaces.active.image

def cache_image(img, need_hsl=False):
    global session

    pixels = utils.read_pixels_from_image(img)

    session.cached_image_pixels = pixels

    hsl = None
    if need_hsl:
        session.cached_image_hsl = utils.rgb_to_hsl(pixels)

    return pixels, hsl

def get_image_cache():
    global session

    pixels = session.cached_image_pixels
    hsl = session.cached_image_hsl

    if pixels is not None:
        pixels = pixels.copy()
    if hsl is not None:
        hsl = hsl.copy()

    return pixels, hsl

def revert_image_cache(img):
    global session

    pixels = session.cached_image_pixels
    if pixels is None:
        return

    utils.write_pixels_to_image(img, pixels)

def clear_image_cache():
    global session

    session.cached_image_pixels = None
    session.cached_image_hsl = None

def convert_selection(context):
    area_session = get_area_session(context)

    img = context.area.spaces.active.image
    if not img:
        return

    width, height = img.size[0], img.size[1]

    selection_region = area_session.selection_region
    if not selection_region:
        return

    x1, y1 = context.region.view2d.region_to_view(*selection_region[0])
    x2, y2 = context.region.view2d.region_to_view(*selection_region[1])

    x1, x2 = sorted((x1, x2))
    y1, y2 = sorted((y1, y2))

    x1 = round(x1 * width)
    y1 = round(y1 * height)
    x2 = round(x2 * width)
    y2 = round(y2 * height)

    area_session.selection = [[x1, y1], [x2, y2]]

    crop_selection(context)

def crop_selection(context):
    area_session = get_area_session(context)

    img = context.area.spaces.active.image
    if not img:
        return

    width, height = img.size[0], img.size[1]

    if not area_session.selection:
        return

    [x1, y1], [x2, y2] = area_session.selection

    # clamp
    x1 = max(min(x1, width), 0)
    y1 = max(min(y1, height), 0)
    x2 = max(min(x2, width), 0)
    y2 = max(min(y2, height), 0)

    # avoid from zero width or height
    if x2 - x1 <= 0:
        if x2 < width:
            x2 = x2 + 1
        else:
            x1 = x1 - 1

    if y2 - y1 <= 0:
        if y2 < height:
            y2 = y2 + 1
        else:
            y1 = y1 - 1

    area_session.selection = [[x1, y1], [x2, y2]]

def cancel_selection(context):
    global session

    area_session = get_area_session(context)

    area_session.selection = None
    area_session.selection_region = None

def get_selection(context):
    global session

    area_session = get_area_session(context)

    return area_session.selection

def get_target_selection(context):
    global session

    area_session = get_area_session(context)

    selection = area_session.selection
    if not selection:
        return None

    img = context.area.spaces.active.image
    if not img:
        return selection

    img_props = img.imageeditorplus_properties
    layers = img_props.layers
    selected_layer_index = img_props.selected_layer_index

    if selected_layer_index == -1 or selected_layer_index >= len(layers):
        return selection

    return None

def refresh_image(context):
    wm = context.window_manager

    img = context.area.spaces.active.image
    if not img:
        return

    context.area.spaces.active.image = img

    img.update()

    if not hasattr(wm, 'imagelayersnode_api') or wm.imagelayersnode_api.VERSION < (1, 1, 0):
        return

    wm.imagelayersnode_api.update_pasted_layer_nodes(img)

def apply_layer_transform(img, rot, scale):
    global session

    if not session.ui_renderer:
        session.ui_renderer = UIRenderer()

    buff, width, height = session.ui_renderer.render_image_offscreen(img, rot, scale)

    pixels = np.reshape(buff, (height, width, 4)).astype(np.float32) / 255.0

    # gamma correction
    utils.convert_colorspace(pixels, 'Linear',
            'Linear' if img.is_float else img.colorspace_settings.name)

    return pixels, width, height

def create_layer(base_img, pixels, img_settings, layer_settings):
    base_width, base_height = base_img.size

    target_width, target_height = pixels.shape[1], pixels.shape[0]

    layer_img_prefix = '#layer'
    layer_img_name = base_img.name + layer_img_prefix
    layer_img = bpy.data.images.new(layer_img_name, width=target_width, height=target_height,
        alpha=True, float_buffer=base_img.is_float)
    layer_img.colorspace_settings.name = base_img.colorspace_settings.name

    # gamma correction
    pixels = pixels.copy()
    utils.convert_colorspace(pixels,
            'Linear' if img_settings['is_float']
                    else img_settings['colorspace_name'],
            'Linear' if base_img.is_float else base_img.colorspace_settings.name)

    utils.write_pixels_to_image(layer_img, pixels)

    layer_img.use_fake_user = True
    layer_img.pack()

    img_props = base_img.imageeditorplus_properties
    layers = img_props.layers
    layer = layers.add()

    layer.name = layer_img.name
    layer.location = [int((base_width - target_width) / 2.0),
            int((base_height - target_height) / 2.0)]

    layer_img_postfix = \
            layer_img.name[layer_img.name.rfind(layer_img_prefix) + len(layer_img_prefix):]
    if layer_img_postfix:
        layer.label = 'Pasted Layer ' + layer_img_postfix
    else:
        layer.label = 'Pasted Layer'

    if layer_settings:
        layer.rotation = layer_settings['rotation']
        layer.scale = layer_settings['scale']
        layer.custom_data = layer_settings['custom_data']

    layers.move(len(layers) - 1, 0)
    img_props.selected_layer_index = 0

    rebuild_image_layers_nodes(base_img)

def rebuild_image_layers_nodes(img):
    wm = bpy.context.window_manager

    if not hasattr(wm, 'imagelayersnode_api') or wm.imagelayersnode_api.VERSION < (1, 1, 0):
        return

    wm.imagelayersnode_api.rebuild_image_layers_nodes(img)

def on_layer_placement_changed(self, context):
    area_session = get_area_session(context)
    if area_session.prevent_layer_update_event:
        return

    img = context.area.spaces.active.image
    if not img:
        return

    rebuild_image_layers_nodes(img)

def on_layer_visible_changed(self, context):
    refresh_image(context)

def on_selected_layer_index_changed(self, context):
    if self.selected_layer_index != -1:
        cancel_selection(context)

def load_icons():
    global session

    script_dir = os.path.dirname(os.path.realpath(__file__))
    icons = bpy.utils.previews.new()

    icons_dir = os.path.join(script_dir, "icons")
    for file_name in os.listdir(icons_dir):
        icon_name = os.path.splitext(file_name)[0]
        icons.load(icon_name, os.path.join(icons_dir, file_name), 'IMAGE')

    session.icons = icons

def dispose_icons():
    global session

    bpy.utils.previews.remove(session.icons)

def cleanup_scene():
    node_group = bpy.data.node_groups.get('imageeditorplus')
    if node_group:
        bpy.data.node_groups.remove(node_group)

@bpy.app.handlers.persistent
def save_pre_handler(args):
    cleanup_scene()

    for img in bpy.data.images:
        if img.source != 'VIEWER':
            if img.is_dirty:
                if img.packed_files or not img.filepath:
                    img.pack()
                else:
                    img.save()

class IMAGE_EDITOR_PLUS_WindowPropertyGroup(bpy.types.PropertyGroup):
    foreground_color: bpy.props.FloatVectorProperty(name='Foreground Color', subtype='COLOR_GAMMA',
            min=0, max=1.0, size=3, default=(1.0, 1.0, 1.0))
    background_color: bpy.props.FloatVectorProperty(name='Background Color', subtype='COLOR_GAMMA',
            min=0, max=1.0, size=3, default=(0, 0, 0))

class IMAGE_EDITOR_PLUS_LayerPropertyGroup(bpy.types.PropertyGroup):
    location: bpy.props.IntVectorProperty(size=2, update=on_layer_placement_changed)
    rotation: bpy.props.FloatProperty(subtype='ANGLE', update=on_layer_placement_changed)
    scale: bpy.props.FloatVectorProperty(size=2, default=(1.0, 1.0),
            update=on_layer_placement_changed)
    label: bpy.props.StringProperty()
    hide: bpy.props.BoolProperty(name='Hide', update=on_layer_visible_changed)
    custom_data: bpy.props.StringProperty(default='{}')

class IMAGE_EDITOR_PLUS_ImagePropertyGroup(bpy.types.PropertyGroup):
    layers: bpy.props.CollectionProperty(type=IMAGE_EDITOR_PLUS_LayerPropertyGroup)
    selected_layer_index: bpy.props.IntProperty(update=on_selected_layer_index_changed)

session = Session()