'''
    Copyright (C) 2021 - 2024 Akaneyu

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
'''

import numpy as np

def read_pixels_from_image(img):
    width, height = img.size[0], img.size[1]

    pixels = np.empty(len(img.pixels), dtype=np.float32); 
    img.pixels.foreach_get(pixels)
    return np.reshape(pixels, (height, width, 4))

def write_pixels_to_image(img, pixels):
    img.pixels.foreach_set(np.reshape(pixels, -1))

    if img.preview:
        img.preview.reload()

def rgb_to_hsl(rgb):
    red = rgb[:, :, 0]
    green = rgb[:, :, 1]
    blue = rgb[:, :, 2]

    max_chan = np.maximum(np.maximum(red, green), blue)
    min_chan = np.minimum(np.minimum(red, green), blue)
    sum = max_chan + min_chan
    light = sum / 2.0

    diff = max_chan - min_chan

    sat_denom = 1.0 - np.abs(sum - 1.0)
    sat_denom_safe = np.where(sat_denom == 0, 1.0, sat_denom)    # for div by 0

    sat = diff / sat_denom_safe

    diff_safe = np.where(diff == 0, 1.0, diff)   # for div by 0

    hue_0 = np.zeros((rgb.shape[0], rgb.shape[1]))
    hue_0 = np.where(np.equal(max_chan, red), (green - blue) / diff_safe, hue_0)
    hue_0 = np.where(np.equal(max_chan, green), 2.0 + (blue - red) / diff_safe, hue_0)
    hue_0 = np.where(np.equal(max_chan, blue), 4.0 + (red - green) / diff_safe, hue_0)

    hue = (hue_0 / 6.0) % 1.0

    return np.dstack((hue, sat, light))

def hsl_to_rgb(hsl):
    hue = hsl[:, :, 0]
    sat = hsl[:, :, 1]
    light = hsl[:, :, 2]

    c = (1.0 - np.abs(light * 2.0 - 1.0)) * sat

    index_f = hue * 6.0
    index = (hue * 6.0).astype(int)

    x = c * (1.0 - np.abs(index_f % 2.0 - 1.0))

    zeros = np.zeros((index.shape[0], index.shape[1]))
    red_0 = index.choose((c, x, zeros, zeros, x, c))
    green_0 = index.choose((x, c, c, x, zeros, zeros))
    blue_0 = index.choose((zeros, zeros, x, c, c, x))

    m = light - c / 2.0

    return np.dstack((red_0 + m, green_0 + m, blue_0 + m))

def straight_to_premul_alpha(pixels):
    alpha_chan = pixels[:, :, 3:]

    return np.dstack((pixels[:, :, :3] * alpha_chan, alpha_chan))

def premul_to_straight_alpha(pixels):
    alpha_chan = pixels[:, :, 3:]
    new_color_chan = pixels[:, :, :3] \
            / np.where(alpha_chan == 0, 1.0, alpha_chan)    # for div by 0

    return np.dstack((new_color_chan, alpha_chan))

def gaussian_blur_core(pixels, blur_size):

    # straight => premul alpha
    pixels = straight_to_premul_alpha(pixels)

    height, width = pixels.shape[0], pixels.shape[1]

    blur_size_safe = 1.0 if blur_size == 0 else blur_size

    kernel_size = 2 * int(4 * blur_size + 0.5) + 1

    kernel = np.zeros((kernel_size))

    h_kernel_size = kernel_size // 2
    for x in range(-h_kernel_size, h_kernel_size + 1):
        kernel[x + h_kernel_size] = np.exp(-(x ** 2)/(2 * blur_size_safe ** 2))
    kernel = kernel / np.sum(kernel)

    pixels_pad = np.pad(pixels, ((kernel_size // 2, kernel_size // 2),
            (kernel_size // 2, kernel_size // 2), (0, 0)), 'edge')
    height_pad, width_pad = pixels_pad.shape[0], pixels_pad.shape[1]

    gaus_y = np.zeros((height_pad, width, 4))
    for x, v in enumerate(kernel):
        gaus_y += v * pixels_pad[:, x:width + x]

    gaus_x = np.zeros((height, width, 4))
    for y, v in enumerate(kernel):
        gaus_x += v * gaus_y[y:height + y]

    new_pixels = np.clip(gaus_x, 0, 1.0)

    # premul => straight alpha
    return premul_to_straight_alpha(new_pixels)

def convert_colorspace(pixels, src_colorspace, dest_colorspace):
    if src_colorspace == dest_colorspace:
        return

    if src_colorspace == 'Linear' and dest_colorspace == 'sRGB':
        pixels[:, :, 0:3] = pixels[:, :, :3] ** (1.0 / 2.2)
    elif src_colorspace == 'sRGB' and dest_colorspace == 'Linear':
        pixels[:, :, 0:3] = pixels[:, :, :3] ** 2.2

    # unsupported conversion