//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------


#include "MeshUtils.hpp"

#include <simd/simd.h>

struct VertexData
{
    simd::float4 position;
    simd::float4 normal;
    simd::float4 texcoord;
};

IndexedMesh mesh_utils::newCubeMesh(float size, MTL::Device* pDevice)
{
    const float s = size * 0.5f;

    VertexData verts[] = {
        //                                                      Texture
        //   Positions                   Normals               Coordinates
        { { -s, -s, +s, 1.f }, { 0.f, 0.f, 1.f, 0.f }, { 0.f, 1.f, 0.f, 0.f } },
        { { +s, -s, +s, 1.f }, { 0.f, 0.f, 1.f, 0.f }, { 1.f, 1.f, 0.f, 0.f } },
        { { +s, +s, +s, 1.f }, { 0.f, 0.f, 1.f, 0.f }, { 1.f, 0.f, 0.f, 0.f } },
        { { -s, +s, +s, 1.f }, { 0.f, 0.f, 1.f, 0.f }, { 0.f, 0.f, 0.f, 0.f } },

        { { +s, -s, +s, 1.f }, { 1.f, 0.f, 0.f, 0.f }, { 0.f, 1.f, 0.f, 0.f } },
        { { +s, -s, -s, 1.f }, { 1.f, 0.f, 0.f, 0.f }, { 1.f, 1.f, 0.f, 0.f } },
        { { +s, +s, -s, 1.f }, { 1.f, 0.f, 0.f, 0.f }, { 1.f, 0.f, 0.f, 0.f } },
        { { +s, +s, +s, 1.f }, { 1.f, 0.f, 0.f, 0.f }, { 0.f, 0.f, 0.f, 0.f } },

        { { +s, -s, -s, 1.f }, { 0.f, 0.f, -1.f, 0.f }, { 0.f, 1.f, 0.f, 0.f } },
        { { -s, -s, -s, 1.f }, { 0.f, 0.f, -1.f, 0.f }, { 1.f, 1.f, 0.f, 0.f } },
        { { -s, +s, -s, 1.f }, { 0.f, 0.f, -1.f, 0.f }, { 1.f, 0.f, 0.f, 0.f } },
        { { +s, +s, -s, 1.f }, { 0.f, 0.f, -1.f, 0.f }, { 0.f, 0.f, 0.f, 0.f } },

        { { -s, -s, -s, 1.f }, { -1.f, 0.f, 0.f, 0.f }, { 0.f, 1.f, 0.f, 0.f } },
        { { -s, -s, +s, 1.f }, { -1.f, 0.f, 0.f, 0.f }, { 1.f, 1.f, 0.f, 0.f } },
        { { -s, +s, +s, 1.f }, { -1.f, 0.f, 0.f, 0.f }, { 1.f, 0.f, 0.f, 0.f } },
        { { -s, +s, -s, 1.f }, { -1.f, 0.f, 0.f, 0.f }, { 0.f, 0.f, 0.f, 0.f } },

        { { -s, +s, +s, 1.f }, { 0.f, 1.f, 0.f, 0.f }, { 0.f, 1.f, 0.f, 0.f } },
        { { +s, +s, +s, 1.f }, { 0.f, 1.f, 0.f, 0.f }, { 1.f, 1.f, 0.f, 0.f } },
        { { +s, +s, -s, 1.f }, { 0.f, 1.f, 0.f, 0.f }, { 1.f, 0.f, 0.f, 0.f } },
        { { -s, +s, -s, 1.f }, { 0.f, 1.f, 0.f, 0.f }, { 0.f, 0.f, 0.f, 0.f } },

        { { -s, -s, -s, 1.f }, { 0.f, -1.f, 0.f, 0.f }, { 0.f, 1.f, 0.f, 0.f } },
        { { +s, -s, -s, 1.f }, { 0.f, -1.f, 0.f, 0.f }, { 1.f, 1.f, 0.f, 0.f } },
        { { +s, -s, +s, 1.f }, { 0.f, -1.f, 0.f, 0.f }, { 1.f, 0.f, 0.f, 0.f } },
        { { -s, -s, +s, 1.f }, { 0.f, -1.f, 0.f, 0.f }, { 0.f, 0.f, 0.f, 0.f } }
    };

    uint16_t indices[] = {
        0, 1, 2, 2, 3, 0, /* front */
        4, 5, 6, 6, 7, 4, /* right */
        8, 9, 10, 10, 11, 8, /* back */
        12, 13, 14, 14, 15, 12, /* left */
        16, 17, 18, 18, 19, 16, /* top */
        20, 21, 22, 22, 23, 20, /* bottom */
    };

    const size_t vertexDataSize = sizeof(verts);
    const size_t indexDataSize  = sizeof(indices);

    MTL::Buffer* pVertexBuffer = pDevice->newBuffer(vertexDataSize, MTL::ResourceStorageModeShared);
    MTL::Buffer* pIndexBuffer  = pDevice->newBuffer(indexDataSize, MTL::ResourceStorageModeShared);

    memcpy(pVertexBuffer->contents(), verts, vertexDataSize);
    memcpy(pIndexBuffer->contents(), indices, indexDataSize);

    IndexedMesh cubeMesh {
        .pVertices  = pVertexBuffer,
        .pIndices   = pIndexBuffer,
        .numIndices = sizeof(indices) / sizeof(indices[0]),
        .indexType  = MTL::IndexTypeUInt16,
        .winding    = MTL::WindingCounterClockwise
    };

    return cubeMesh;
}

IndexedMesh mesh_utils::newHorizontalQuad(float size, uint32_t divs, MTL::Device* pDevice)
{

    const float s = size * 0.5f;

    simd::float4 floorVertexData[] = {
        { -s, -0.5f, +s - 10, 1.f },
        { +s, -0.5f, +s - 10, 1.f },
        { +s, -0.5f, -s - 10, 1.f },
        { -s, -0.5f, -s - 10, 1.f }
    };

    std::vector<simd::float4> tessellatedFloorVertices;
    for (uint32_t z = 0; z <= divs; ++z)
    {
        for (uint32_t x = 0; x <= divs; ++x)
        {
            float dx = x / (float)divs;
            float dz = z / (float)divs;

            float px = simd::lerp(floorVertexData[0].x, floorVertexData[1].x, dx);
            float pz = simd::lerp(floorVertexData[0].z, floorVertexData[2].z, dz);
            tessellatedFloorVertices.emplace_back(simd::float4 { px, floorVertexData[0].y, pz, 1.f });
        }
    }

    std::vector<uint32_t> tessellatedFloorIndices;
    uint32_t indx = 0;
    for (uint32_t i = 0; i < divs * divs; ++i)
    {
        uint32_t i0 = indx;
        uint32_t i1 = indx + (divs + 1);
        uint32_t i2 = indx + (divs + 1) + 1;

        uint32_t i3 = indx;
        uint32_t i4 = indx + (divs + 1) + 1;
        uint32_t i5 = indx + 1;

        tessellatedFloorIndices.emplace_back(i0);
        tessellatedFloorIndices.emplace_back(i1);
        tessellatedFloorIndices.emplace_back(i2);

        tessellatedFloorIndices.emplace_back(i3);
        tessellatedFloorIndices.emplace_back(i4);
        tessellatedFloorIndices.emplace_back(i5);

        if ((tessellatedFloorIndices.size() % (6 * divs)) != 0)
        {
            ++indx;
        }
        else
        {
            indx += 2;
        }
    }

    IndexedMesh horizontalQuad;

    horizontalQuad.winding    = MTL::WindingClockwise;
    horizontalQuad.indexType  = MTL::IndexTypeUInt32;
    horizontalQuad.numIndices = (uint32_t)tessellatedFloorIndices.size();

    size_t len               = sizeof(simd::float4) * tessellatedFloorVertices.size();
    horizontalQuad.pVertices = pDevice->newBuffer(len, MTL::ResourceStorageModeShared);
    memcpy(horizontalQuad.pVertices->contents(), tessellatedFloorVertices.data(), len);

    len                     = sizeof(uint32_t) * tessellatedFloorIndices.size();
    horizontalQuad.pIndices = pDevice->newBuffer(len, MTL::ResourceStorageModeShared);
    memcpy(horizontalQuad.pIndices->contents(), tessellatedFloorIndices.data(), len);

    return horizontalQuad;
}

IndexedMesh mesh_utils::newScreenQuad(MTL::Device* pDevice)
{
    struct VertexData
    {
        simd::float4 position;
        simd::float2 texcoord;
    };
    
    static_assert(sizeof(VertexData) == 32); // clang pads this to 32 bits, so texcoord is really a float4
    
    std::vector<VertexData> vertexData {
        VertexData {
            .position = simd::float4{-1.0, 1.0, 0.0, 1.0},
            .texcoord = simd::float2{0.0, 0.0}
        },
        VertexData {
            .position = simd::float4{-1.0, -1.0, 0.0, 1.0},
            .texcoord = simd::float2{0.0, 1.0}
        },
        VertexData {
            .position = simd::float4{+1.0, -1.0, 0.0, 1.0},
            .texcoord = simd::float2{1.0, 1.0}
        },
        VertexData {
            .position = simd::float4{+1.0, +1.0, 0.0, 1.0},
            .texcoord = simd::float2{1.0, 0.0}
        }
    };
    
    std::vector<uint16_t> indexData {
        0, 1, 2,
        2, 3, 0
    };
    
    MTL::Buffer* pBuffer = pDevice->newBuffer(vertexData.size() * sizeof(VertexData), MTL::ResourceStorageModeShared);
    memcpy(pBuffer->contents(), vertexData.data(), vertexData.size() * sizeof(VertexData));
    
    MTL::Buffer* pIndexBuffer = pDevice->newBuffer(indexData.size() * sizeof(uint16_t), MTL::ResourceStorageModeShared);
    memcpy(pIndexBuffer->contents(), indexData.data(), indexData.size() * sizeof(uint16_t));
    
    return IndexedMesh {
        .pVertices = pBuffer,
        .pIndices = pIndexBuffer,
        .numIndices = (uint32_t)indexData.size(),
        .indexType = MTL::IndexTypeUInt16,
        .winding = MTL::WindingCounterClockwise
    };
}

