'''
    Copyright (C) 2020 - 2024 Akaneyu

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
'''

import sys
import math
import gpu
from gpu_extras.batch import batch_for_shader
from mathutils import Matrix, Vector
import numpy as np

default_vertex_shader = '''
uniform mat4 ModelViewProjectionMatrix;

in vec2 pos;

void main()
{
    gl_Position = ModelViewProjectionMatrix * vec4(pos, 0, 1.0);
}
'''

default_fragment_shader = '''
uniform vec4 color;

out vec4 fragColor;

void main()
{
    fragColor = color;
}
'''

dotted_line_vertex_shader = '''
uniform mat4 ModelViewProjectionMatrix;

in vec2 pos;
in float arcLength;

out float arcLengthOut;

void main()
{
    arcLengthOut = arcLength;

    gl_Position = ModelViewProjectionMatrix * vec4(pos, 0, 1.0);
}
'''

dotted_line_fragment_shader = '''
uniform float scale;
uniform float offset;
uniform vec4 color1;
uniform vec4 color2;

in float arcLengthOut;

out vec4 fragColor;

void main()
{
    if (step(sin((arcLengthOut + offset) * scale), 0.5) == 1) {
        fragColor = color1;
    } else {
        fragColor = color2;
    }
}
'''

image_vertex_shader = '''
uniform mat4 ModelViewProjectionMatrix;

in vec2 pos;
in vec2 texCoord;

out vec2 texCoordOut;

void main()
{
    gl_Position = ModelViewProjectionMatrix * vec4(pos, 0, 1.0);
    texCoordOut = texCoord;
}
'''

image_fragment_shader = '''
uniform sampler2D image;

in vec2 texCoordOut;

out vec4 fragColor;

void main()
{
    fragColor = texture(image, texCoordOut);
}
'''

def make_scale_matrix(scale):
    return Matrix([
        [scale[0], 0, 0, 0],
        [0, scale[1], 0, 0],
        [0, 0, 1.0, 0],
        [0, 0, 0, 1.0]
    ])

class UIRenderer:
    def __init__(self):
        self.default_shader = gpu.types.GPUShader(default_vertex_shader,
                default_fragment_shader)
        self.default_shader_u_color = self.default_shader.uniform_from_name("color")

        self.dotted_line_shader = gpu.types.GPUShader(dotted_line_vertex_shader,
                dotted_line_fragment_shader)
        self.dotted_line_shader_u_color1 = self.dotted_line_shader.uniform_from_name("color1")
        self.dotted_line_shader_u_color2 = self.dotted_line_shader.uniform_from_name("color2")

        #self.image_shader = gpu.shader.from_builtin('2D_IMAGE')
        self.image_shader = gpu.types.GPUShader(image_vertex_shader,
                image_fragment_shader)

    def render_selection_frame(self, pos, size, rot=0, scale=(1.0, 1.0)):
        width, height = size[0], size[1]

        prev_blend = gpu.state.blend_get()
        gpu.state.blend_set('ALPHA')

        gpu.state.line_width_set(2.0)

        with gpu.matrix.push_pop():
            verts = [[0, 0], [0, height], [width, height], [width, 0], [0, 0]]
            # T <= R <= S <= centering
            mat = Matrix.Translation([pos[0] + width / 2.0, pos[1] + height / 2.0, 0]) \
                    @ Matrix.Rotation(rot, 4, 'Z') \
                    @ make_scale_matrix(scale) \
                    @ Matrix.Translation([-width / 2.0, -height / 2.0, 0])

            for i, vert in enumerate(verts):
                verts[i] = (mat @ Vector(vert + [0, 1]))[:2]

            verts = np.array(verts, 'f')

            arc_lengths = [0]
            for a, b in zip(verts[:-1], verts[1:]):
                arc_lengths.append(arc_lengths[-1] + np.linalg.norm(a - b))
        
            batch = batch_for_shader(self.dotted_line_shader, 'LINE_STRIP',
                {"pos": verts, "arcLength": arc_lengths})

            self.dotted_line_shader.bind()

            self.dotted_line_shader.uniform_float("scale", 0.6)
            self.dotted_line_shader.uniform_float("offset", 0)
            self.dotted_line_shader.uniform_vector_float(self.dotted_line_shader_u_color1,
                    np.array([1.0, 1.0, 1.0, 0.5], 'f'), 4)
            self.dotted_line_shader.uniform_vector_float(self.dotted_line_shader_u_color2,
                    np.array([0.0, 0.0, 0.0, 0.5], 'f'), 4)

            batch.draw(self.dotted_line_shader)

        gpu.state.blend_set(prev_blend)

    def render_image_sub(self, img, pos, size, rot, scale):
        width, height = size[0], size[1]

        texture = gpu.texture.from_image(img)

        with gpu.matrix.push_pop():
            gpu.matrix.translate([pos[0] + width / 2.0, pos[1] + height / 2.0])
            gpu.matrix.multiply_matrix(
                    Matrix.Rotation(rot, 4, 'Z'))
            gpu.matrix.scale(scale)
            gpu.matrix.translate([-width / 2.0, -height / 2.0])

            batch = batch_for_shader(self.image_shader, 'TRI_FAN',
                {
                    "pos": [
                        (0, 0),
                        (width, 0),
                        size,
                        (0, height)
                    ],
                    "texCoord": [(0, 0), (1, 0), (1, 1), (0, 1)]
                })

            self.image_shader.bind()

            self.image_shader.uniform_sampler('image', texture)

            batch.draw(self.image_shader)

    def render_image(self, img, pos, size, rot=0, scale=(1.0, 1.0)):
        prev_blend = gpu.state.blend_get()
        gpu.state.blend_set('ALPHA')

        self.render_image_sub(img, pos, size, rot, scale)

        gpu.state.blend_set(prev_blend)

    def render_image_offscreen(self, img, rot=0, scale=(1.0, 1.0)):
        width, height = img.size[0], img.size[1]

        box = [[0, 0], [width, 0], [0, height], [width, height]]
        mat = Matrix.Rotation(rot, 4, 'Z') \
                @ make_scale_matrix(scale) \
                @ Matrix.Translation([-width / 2.0, -height / 2.0, 0])
        min_x, min_y = sys.float_info.max, sys.float_info.max
        max_x, max_y = -sys.float_info.max, -sys.float_info.max

        # calculate bounding box
        for pos in box:
            pos = mat @ Vector(pos + [0, 1])
            min_x = min(min_x, pos[0])
            min_y = min(min_y, pos[1])
            max_x = max(max_x, pos[0])
            max_y = max(max_y, pos[1])

        ofs_width = math.ceil(max_x - min_x)
        ofs_height = math.ceil(max_y - min_y)

        ofs = gpu.types.GPUOffScreen(ofs_width, ofs_height)
        with ofs.bind():
            fb = gpu.state.active_framebuffer_get()
            fb.clear(color=(0.0, 0.0, 0.0, 0.0))

            with gpu.matrix.push_pop():
                gpu.matrix.load_projection_matrix(Matrix.Identity(4))

                gpu.matrix.load_identity()
                gpu.matrix.scale([1.0 / (ofs_width / 2.0), 1.0 / (ofs_height / 2.0)])
                gpu.matrix.translate([-width / 2.0, -height / 2.0])

                self.render_image_sub(img, [0, 0], [width, height], rot, scale)

            buff = fb.read_color(0, 0, ofs_width, ofs_height, 4, 0, 'UBYTE')

        ofs.free()

        return buff, ofs_width, ofs_height

    def render_info_box(self, pos1, pos2):
        prev_blend = gpu.state.blend_get()
        gpu.state.blend_set('ALPHA')

        verts = [
            pos1,
            (pos2[0], pos1[1]),
            (pos1[0], pos2[1]),
            pos2
        ]

        indices = [
            (0, 1, 2),
            (2, 1, 3)
        ]

        batch = batch_for_shader(self.default_shader, 'TRIS',
            {"pos": verts}, indices=indices)

        self.default_shader.bind()

        self.default_shader.uniform_vector_float(self.default_shader_u_color,
                np.array([0, 0, 0, 0.7], 'f'), 4)

        batch.draw(self.default_shader)

        gpu.state.blend_set(prev_blend)