import bpy
import sys
import json
import math
from enum import Enum
from functools import wraps

from operator import truediv
from pathlib import Path
from collections import defaultdict
from mathutils import Vector, Matrix


def traced(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        print(f"{func.__name__}: {args=} {kwargs=} {result=}")
        return result

    return wrapper


def bounding_box(objects, world_transform=True):
    first = True
    xMin = 0
    yMin = 0
    zMin = 0

    xMax = 0
    yMax = 0
    zMax = 0

    for ob in objects:
        bbox_corners = [
            ob.matrix_world @ Vector(corner) if world_transform else Vector(corner)
            for corner in ob.bound_box
        ]

        if first:
            xMin = bbox_corners[0].x
            yMin = bbox_corners[0].y
            zMin = bbox_corners[0].z
            xMax = bbox_corners[0].x
            yMax = bbox_corners[0].y
            zMax = bbox_corners[0].z
            first = False

        for corner in bbox_corners:
            if xMin > corner.x:
                xMin = corner.x

            if yMin > corner.y:
                yMin = corner.y

            if zMin > corner.z:
                zMin = corner.z

            if xMax < corner.x:
                xMax = corner.x

            if yMax < corner.y:
                yMax = corner.y

            if zMax < corner.z:
                zMax = corner.z

    return [Vector((xMin, yMin, zMin)), Vector((xMax, yMax, zMax))]


def bounding_dimension(bbBox):
    return bbBox[1].x - bbBox[0].x, bbBox[1].y - bbBox[0].y, bbBox[1].z - bbBox[0].z


def bounding_center(bbBox):
    return bbBox[0] + 0.5 * (bbBox[1] - bbBox[0])


def save_scene(outPath):
    p = Path(outPath)
    p = p.with_suffix(".blend")
    bpy.ops.wm.save_as_mainfile(filepath=str(p))


def import_gltf(filepath):
    print(f"Loading in scene {filepath}")

    bpy.ops.import_scene.gltf(filepath=filepath)
    selected = bpy.context.selected_objects

    # Assumptions is a root 'main' object
    main = None
    for obj in selected:
        if obj.name != "main":
            continue
        main = obj
        break
    else:
        print("Could not find main object in the imported gltf")
        sys.exit(1)

    assert len(main.children) > 0, "Imported gltf doesnt have any children"

    if len(main.children) > 1:
        ctx = bpy.context.copy()
        ctx["active_object"] = main.children[0]
        ctx["selected_objects"] = main.children  # TODO Maybe filter for mesh type?
        bpy.ops.object.join(ctx)

    child = main.children[0]
    child.parent = None
    child.matrix_world = main.matrix_world @ child.matrix_world
    bpy.data.objects.remove(main)
    return child


def scale_to_unit(obj):
    size = min(bounding_dimension(bounding_box([obj])))
    obj.scale *= Vector.Fill(3, 1.0 / size)
    return obj


def bb_plane(obj, side=(0, 0, -1)):
    bb = [Vector(corner) for corner in obj.bound_box]
    bb_min, bb_max = bounding_box([obj], world_transform=False)
    idx, side = [(i, s) for (i, s) in enumerate(side) if s != 0][0]
    # Assumption: that the untransformed bound_box is axis aligned

    plane_coords = []
    for corner in bb:
        coord = corner[idx]
        if (side < 0 and coord == bb_min[idx]) or (side > 0 and coord == bb_max[idx]):
            plane_coords.append(obj.matrix_world @ corner)

    assert len(plane_coords) == 4, f"Expected to find 4 plane coords {plane_coords=}"
    return plane_coords


def plane_normal(plane):
    AB = plane[1] - plane[0]
    AC = plane[2] - plane[0]
    return AB.cross(AC).normalized()


def plane_center(plane):
    length2 = 0
    center = Vector()
    for c in plane:
        l = (plane[0] - c).length_squared
        if l > length2:
            center = 0.5 * (c - plane[0]) + plane[0]
            length2 = l
    assert center.length_squared > 0, "Could not find a center"
    return center


def align_by_planes(obj, target, side=(0, 0, -1), centered=False):
    bpy.context.view_layer.update()
    obj_plane = bb_plane(obj, side)
    target_plane = bb_plane(target, side)

    o_normal = plane_normal(obj_plane)
    t_normal = plane_normal(target_plane)
    o_center = plane_center(obj_plane)

    rotation_axis = o_normal.cross(t_normal)
    if rotation_axis.length_squared < 0.0001:
        rotation_axis = Vector() # Set to zero because nomalization would make near zero vector l==1
    rotation_axis.normalize()
    rotation_angle = math.acos(min(1.0, o_normal.dot(t_normal)))
    print(f"{math.degrees(rotation_angle)=} {rotation_axis=} {o_normal=} {t_normal=}")

    if rotation_axis.length_squared < 0.0001 and math.degrees(rotation_angle) > 179:
        new_axis = (1.0, 0.0, 0.0) if o_normal.dot(Vector((1.0, 0.0, 0.0))) < 0.999 else (0.0, 1.0, 0.0)
        rotation_axis = Vector(new_axis)
        rot_plane = bb_plane(obj, new_axis)
        o_center = plane_center(rot_plane)
        print(f"{new_axis=} {o_center=} {rot_plane=}")
    
    print(f"{math.degrees(rotation_angle)=} {rotation_axis=} {o_normal=} {t_normal=}")

    obj_mat_org = obj.matrix_world.copy()
    align_mat = (
        Matrix.Translation(o_center)
        @ Matrix.Rotation(rotation_angle, 4, rotation_axis)
        @ Matrix.Translation(-1.0 * o_center)
    )
    obj.matrix_world = align_mat @ obj.matrix_world
    bpy.context.view_layer.update()

    # Reevalute plane
    obj_plane = bb_plane(obj, side)
    o_normal = plane_normal(obj_plane)
    offset = (obj_plane[0] - target_plane[0]).dot(t_normal) * t_normal

    if centered:
        opposite_side = tuple((-1.0 * s for s in side))

        target_opposite = bb_plane(target, opposite_side)
        object_opposite = bb_plane(obj, opposite_side)
        target_axis_center = (
            0.5 * (target_opposite[0] - target_plane[0]).dot(t_normal) * t_normal
            + target_plane[0]
        )
        object_axis_center = (
            0.5 * (object_opposite[0] - obj_plane[0]).dot(o_normal) * o_normal
            + obj_plane[0]
        )
        offset = (object_axis_center - target_axis_center).dot(t_normal) * t_normal

    obj.matrix_world = Matrix.Translation(-1.0 * offset) @ obj.matrix_world
    bpy.context.view_layer.update()

    return obj_mat_org


def do_physics_simulation(obj, colliderType, stopFrame):
    bpy.ops.object.select_all(action="DESELECT")
    obj.select_set(True)
    context.view_layer.objects.active = obj
    bpy.ops.object.origin_set()
    bpy.ops.rigidbody.object_add()
    obj.rigid_body.type = "ACTIVE"
    obj.rigid_body.enabled = True
    obj.rigid_body.collision_shape = colliderType
    context.scene.frame_end = stopFrame
    bpy.ops.ptcache.free_bake_all()
    bpy.ops.ptcache.bake_all()
    context.scene.frame_set(context.scene.frame_end)


if __name__ == "__main__":
    argv = sys.argv
    argv = argv[argv.index("--") + 1 :]  # get all args after "--"

    with open(argv[0], "rt") as f:
        settings = defaultdict(lambda: None, json.load(f))
    inPath = settings["inputPath"]
    renderingUseDefaults = settings["renderingUseDefaults"]
    renderTimeout = settings["renderTimeout"]
    useDenoising = settings["useDenoising"]
    samples = settings["samples"]
    outPath = settings["outPath"]

    if settings["debugBlender"]:
        try:
            import debugpy

            sys.executable = sys.exec_prefix + "\\python.exe"
            debugpy.listen(5678)
            debugpy.wait_for_client()
            breakpoint()
        except ImportError:
            pass

    if "_TARGET_" not in bpy.data.objects:
        print("Could not find the required _TARGET_ object to replace")
        sys.exit(1)

    target = bpy.data.objects["_TARGET_"]
    context = bpy.context

    main = scale_to_unit(import_gltf(inPath))

    if settings["alignDirections"]:
        xSide, ySide, zSide = settings["alignDirections"]
        align_by_planes(
            main, target, side=(1 if xSide == 0 else xSide, 0, 0), centered=(xSide == 0)
        )
        align_by_planes(
            main, target, side=(0, 1 if ySide == 0 else ySide, 0), centered=(ySide == 0)
        )
        align_by_planes(
            main, target, side=(0, 0, 1 if zSide == 0 else zSide), centered=(zSide == 0)
        )

    bpy.ops.object.select_all(action="DESELECT")
    target.select_set(True)
    bpy.ops.object.delete()

    if settings["enablePhysics"]:
        do_physics_simulation(main, settings["colliderType"], settings["stopFrame"])

    dg = bpy.context.evaluated_depsgraph_get()
    dg.update()

    if not renderingUseDefaults:
        bpy.context.scene.render.engine = "CYCLES"
        bpy.context.scene.cycles.samples = int(samples)
        bpy.context.scene.cycles.use_denoising = useDenoising
        bpy.context.scene.cycles.time_limit = int(renderTimeout)

    cams = [c for c in context.scene.objects if c.type == "CAMERA"]

    print(f"Found {len(cams)} cameras")

    if settings["enableAutozoom"]:
        margin = settings["autoZoomMargin"] or 0.05
        context.view_layer.objects.active = main
        main.select_set(True)

    for c in cams:
        context.scene.camera = c

        context.scene.render.resolution_x = settings["width"]
        context.scene.render.resolution_y = settings["height"]
        context.scene.render.filepath = outPath

        if settings["enableAutozoom"]:
            bpy.ops.view3d.camera_to_view_selected()

            # Create a margin around the object by increasing the fov
            c.data.angle += (
                math.atan((1 / (1 - margin)) * math.tan(c.data.angle)) - c.data.angle
            )

        bpy.ops.render.render(write_still=True)

    if settings["saveScene"]:
        save_scene(outPath)
