//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------

#include "ShaderPipelineBuilder.hpp"

#include "ShaderCompiler.hpp"
#include <simd/simd.h>

MTL::ComputePipelineState* shader_pipeline::newComputeRTPipeline(const std::string& shaderSearchPath, MTL::Device* pDevice )
{
    // Resource layout:

    IRDescriptorRange1 uavRange = {
        .RangeType                         = IRDescriptorRangeTypeUAV,
        .NumDescriptors                    = 1,
        .BaseShaderRegister                = 0,
        .RegisterSpace                     = 1,
        .OffsetInDescriptorsFromTableStart = 0
    };

    IRRootParameter1 params[] = {
        {   /* Acceleration Structure */
            .ParameterType    = IRRootParameterTypeSRV,
            .ShaderVisibility = IRShaderVisibilityAll,
            .Descriptor       = {
                .ShaderRegister = 0,
                .RegisterSpace  = 0,
                .Flags          = IRRootDescriptorFlagDataStatic }
        },
        {
            /* RW Texture */
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityAll,
            .DescriptorTable = {
                .NumDescriptorRanges = 1,
                .pDescriptorRanges = &uavRange
            }
        }
    };

    IRVersionedRootSignatureDescriptor rootSigDesc;
    memset(&rootSigDesc, 0x0, sizeof(IRVersionedRootSignatureDescriptor));
    rootSigDesc.version                = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.NumParameters = sizeof(params) / sizeof(IRRootParameter1);
    rootSigDesc.desc_1_1.pParameters   = params;

    IRError* pRootSigError    = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pRootSigError);
    assert(pRootSig);

    // Make MetalLib:

    std::string computeLibraryPath = shaderSearchPath + "/inline_rt.dxil";
    MTL::Library* pLib             = newLibraryFromDXIL(computeLibraryPath, IRShaderStageCompute, "MainCS", pRootSig, pDevice);
    assert(pLib);

    NS::Error* pError;
    MTL::Function* pFn                     = pLib->newFunction(MTLSTR("MainCS"));
    MTL::ComputePipelineState* pComputePSO = pDevice->newComputePipelineState(pFn, &pError);
    if (!pComputePSO)
    {
        __builtin_printf("%s\n", pError->localizedDescription()->utf8String());
        assert(false);
    }

    pFn->release();
    pLib->release();

    IRRootSignatureDestroy(pRootSig);

    return pComputePSO;
}

MTL::RenderPipelineState* shader_pipeline::newPresentPipeline( const std::string& shaderSearchPath, MTL::Device* pDevice )
{
    std::vector<IRDescriptorRange1> srvRanges {
        IRDescriptorRange1 {
            .RangeType = IRDescriptorRangeTypeSRV,
            .NumDescriptors = 1,
            .BaseShaderRegister = 0,
            .RegisterSpace = 0,
            .Flags = IRDescriptorRangeFlagNone,
            .OffsetInDescriptorsFromTableStart = 0
        }
    };
    
    std::vector<IRDescriptorRange1> smpRanges = {
        IRDescriptorRange1 {
            .RangeType = IRDescriptorRangeTypeSampler,
            .NumDescriptors = 1,
            .BaseShaderRegister = 0,
            .RegisterSpace = 1,
            .Flags = IRDescriptorRangeFlagNone,
            .OffsetInDescriptorsFromTableStart = 0
        }
    };
    
    std::vector<IRRootParameter1> rootParams = {
        {
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityPixel,
            .DescriptorTable = {
                .pDescriptorRanges = srvRanges.data(),
                .NumDescriptorRanges = (uint32_t)srvRanges.size()
            }
        },
        {
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityPixel,
            .DescriptorTable = {
                .pDescriptorRanges = smpRanges.data(),
                .NumDescriptorRanges = (uint32_t)smpRanges.size()
            }
        }
    };
    
    IRVersionedRootSignatureDescriptor rootSigDesc;
    rootSigDesc.version = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.pParameters = rootParams.data();
    rootSigDesc.desc_1_1.NumParameters = (uint32_t)rootParams.size();
    rootSigDesc.desc_1_1.NumStaticSamplers = 0;
    
    IRError* pIrError = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pIrError);
    assert(pRootSig);
    
    NS::SharedPtr<MTL::Library> pVtxLib = NS::TransferPtr(newLibraryFromDXIL(shaderSearchPath + "/present_vs.dxil",
                                                                             IRShaderStageVertex,
                                                                             "MainVS",
                                                                             pRootSig,
                                                                             pDevice));
    assert(pVtxLib);
    
    NS::SharedPtr<MTL::Library> pFragLib = NS::TransferPtr(newLibraryFromDXIL(shaderSearchPath + "/present_fs.dxil",
                                                                              IRShaderStageFragment,
                                                                              "MainFS",
                                                                              pRootSig,
                                                                              pDevice));
    assert(pFragLib);
    
    NS::SharedPtr<MTL::Function> pVFn = NS::TransferPtr(pVtxLib->newFunction(MTLSTR("MainVS")));
    NS::SharedPtr<MTL::Function> pFFn = NS::TransferPtr(pFragLib->newFunction(MTLSTR("MainFS")));
    assert(pVFn);
    assert(pFFn);
    
    NS::SharedPtr<MTL::VertexDescriptor> pVtxDesc = NS::TransferPtr(MTL::VertexDescriptor::alloc()->init());
    auto pAttrib0 = pVtxDesc->attributes()->object(kIRStageInAttributeStartIndex + 0);
    auto pAttrib1 = pVtxDesc->attributes()->object(kIRStageInAttributeStartIndex + 1);
    auto pLayout  = pVtxDesc->layouts()->object(kIRVertexBufferBindPoint);
    
    pAttrib0->setFormat(MTL::VertexFormatFloat4);
    pAttrib0->setOffset(0);
    pAttrib0->setBufferIndex(kIRVertexBufferBindPoint);
    
    pAttrib1->setFormat(MTL::VertexFormatFloat2);
    pAttrib1->setOffset(sizeof(simd::float4));
    pAttrib1->setBufferIndex(kIRVertexBufferBindPoint);
    
    pLayout->setStride(sizeof(simd::float4) + sizeof(simd::float4)); // clang pads the VertexData struct to 32 bytes, so that's the stride (16+16)
    pLayout->setStepRate(1);
    pLayout->setStepFunction(MTL::VertexStepFunctionPerVertex);
    
    NS::SharedPtr<MTL::RenderPipelineDescriptor> pPsoDesc = NS::TransferPtr(MTL::RenderPipelineDescriptor::alloc()->init());
    pPsoDesc->setVertexDescriptor(pVtxDesc.get());
    pPsoDesc->setVertexFunction(pVFn.get());
    pPsoDesc->setFragmentFunction(pFFn.get());
    
    auto pColorDesc = pPsoDesc->colorAttachments()->object(0);
    pColorDesc->setPixelFormat(MTL::PixelFormatBGRA8Unorm_sRGB);
    pColorDesc->setBlendingEnabled(true);
    pColorDesc->setSourceRGBBlendFactor(MTL::BlendFactorSourceAlpha);
    pColorDesc->setSourceAlphaBlendFactor(MTL::BlendFactorSourceAlpha);
    pColorDesc->setDestinationRGBBlendFactor(MTL::BlendFactorOneMinusSourceAlpha);
    pColorDesc->setDestinationAlphaBlendFactor(MTL::BlendFactorOneMinusSourceAlpha);
    
    NS::Error* pMtlError = nullptr;
    MTL::RenderPipelineState* pPso = pDevice->newRenderPipelineState(pPsoDesc.get(), &pMtlError);
    assert(pPso);
    
    IRRootSignatureDestroy(pRootSig);
    
    return pPso;
}
