//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------


#include "ShaderCompiler.hpp"
#include <fstream>

static IRObject* newDXILObject(const std::string& dxilPath)
{
    std::ifstream in(dxilPath);
    if (in)
    {
        in.seekg(0, std::ios::end);
        size_t fileSize = in.tellg();
        in.seekg(0, std::ios::beg);

        uint8_t* dxilBytecode = new uint8_t[fileSize];
        in.read((char*)dxilBytecode, fileSize);

        IRObject* pDXIL = IRObjectCreateFromDXIL(dxilBytecode, fileSize, IRBytecodeOwnershipCopy);
        delete[] dxilBytecode;
        return pDXIL;
    }
    return nullptr;
}

MTL::Library* newStageInLibrary(const IRVersionedInputLayoutDescriptor& inputLayoutDescriptor,
    IRShaderReflection* pVertexStageReflection,
    MTL::Device* pDevice)
{
    MTL::Library* pStageInLib = nullptr;

    IRCompiler* pCompiler              = IRCompilerCreate();
    IRMetalLibBinary* pStageInMetallib = IRMetalLibBinaryCreate();

    if (IRMetalLibSynthesizeStageInFunction(pCompiler, pVertexStageReflection, &inputLayoutDescriptor, pStageInMetallib))
    {
        NS::Error* pError = nullptr;
        pStageInLib       = pDevice->newLibrary(IRMetalLibGetBytecodeData(pStageInMetallib), &pError);
    }

    IRMetalLibBinaryDestroy(pStageInMetallib);
    IRCompilerDestroy(pCompiler);

    return pStageInLib;
}

MTL::Library* newLibraryFromDXIL(const std::string& dxilPath,
    IRShaderStage shaderStage,
    const char* entryPointName,
    const IRRootSignature* pRootSig,
    MTL::Device* pDevice)
{
    IRObject* pDXIL = newDXILObject(dxilPath);
    if (pDXIL)
    {
        //#-code-listing(dxilToMetalIR)
        IRCompiler* pCompiler = IRCompilerCreate();
        IRCompilerSetMinimumDeploymentTarget(pCompiler, IROperatingSystem_macOS, "14.0.0");
        IRCompilerSetEntryPointName(pCompiler, entryPointName);

        IRError* pError = nullptr;
        IRCompilerSetGlobalRootSignature(pCompiler, pRootSig);

        IRObject* pAIR = IRCompilerAllocCompileAndLink(pCompiler, nullptr, pDXIL, &pError);
        //#-end-code-listing

        if (!pAIR)
        {
            __builtin_printf("Error compiling shader to AIR: %s\n", (const char*)IRErrorGetPayload(pError));
            IRErrorDestroy(pError);
            IRCompilerDestroy(pCompiler);
            IRObjectDestroy(pDXIL);
            return nullptr;
        }

        assert(pAIR);

        // Make MetalLib:

        IRMetalLibBinary* pMetalLib = IRMetalLibBinaryCreate();
        if (!IRObjectGetMetalLibBinary(pAIR, shaderStage, pMetalLib))
        {
            __builtin_printf("Error getting metallib binary\n");
            IRObjectDestroy(pAIR);
            IRCompilerDestroy(pCompiler);
            IRObjectDestroy(pDXIL);
            return nullptr;
        }
        size_t metallibSize = IRMetalLibGetBytecodeSize(pMetalLib);

        //#-code-listing(loadLibraries)
        uint8_t* metallibBytecode = new uint8_t[metallibSize];
        IRMetalLibGetBytecode(pMetalLib, metallibBytecode);

        dispatch_data_t metallib = dispatch_data_create(metallibBytecode, metallibSize, dispatch_get_main_queue(), DISPATCH_DATA_DESTRUCTOR_DEFAULT);
        MTL::Library* pLib       = pDevice->newLibrary(metallib, nullptr);
        //#-end-code-listing

        // Cleanup:

        CFRelease(metallib);
        delete[] metallibBytecode;
        IRMetalLibBinaryDestroy(pMetalLib);
        IRObjectDestroy(pAIR);
        IRCompilerDestroy(pCompiler);
        IRObjectDestroy(pDXIL);

        return pLib;
    }
    return nullptr;
}

std::pair<MTL::Library*, IRShaderReflection*> newLibraryWithReflectionFromDXIL(
    const std::string& dxilPath,
    IRShaderStage shaderStage,
    const char* entryPointName,
    const IRRootSignature* pRootSignature,
    MTL::Device* pDevice,
    bool enableGSEmulation)
{
    MTL::Library* pOutMetalLib = nullptr;

    // Load the DXIL file to memory.
    IRObject* pDXIL = newDXILObject(dxilPath);

    // Create the IRConverter compiler to compile DXIL to Metal IR.
    IRError* pCompError   = nil;
    IRCompiler* pCompiler = IRCompilerCreate();

    // Configure the IRConverter compiler to set the minimum deployment target and
    // enable geometry stage emulation if the caller requests it.
    IRCompilerSetMinimumDeploymentTarget(pCompiler, IROperatingSystem_macOS, "14.0.0");
    IRCompilerSetGlobalRootSignature(pCompiler, pRootSignature);
    IRCompilerSetEntryPointName(pCompiler, entryPointName);
    IRCompilerEnableGeometryAndTessellationEmulation(pCompiler, enableGSEmulation);

    // Compile DXIL to a Metal IR object.
    IRObject* pAIR = IRCompilerAllocCompileAndLink(pCompiler, nullptr, pDXIL, &pCompError);

    // Check for compilation errors.
    if (!pAIR)
    {
        printf("%s\n", (const char*)IRErrorGetPayload(pCompError));

        // Free resources.
        IRErrorDestroy(pCompError);
        IRCompilerDestroy(pCompiler);
        IRObjectDestroy(pDXIL);

        // Return a null tuple upon encountering an error.
        return { nullptr, nullptr };
    }

    {
        // Obtain the metallib from the Metal IR.
        IRMetalLibBinary* pMetallibBin = IRMetalLibBinaryCreate();
        if (IRObjectGetMetalLibBinary(pAIR, shaderStage, pMetallibBin))
        {
            dispatch_data_t metallibData = IRMetalLibGetBytecodeData(pMetallibBin);

            NS::Error* pMtlError = nil;
            pOutMetalLib         = pDevice->newLibrary(metallibData, &pMtlError);

            if (!pOutMetalLib)
            {
                printf("%s\n", pMtlError->localizedDescription()->utf8String());
            }

            IRMetalLibBinaryDestroy(pMetallibBin);
        }
    }

    IRShaderReflection* pOutReflection = nullptr;
    {
        pOutReflection = IRShaderReflectionCreate();
        IRObjectGetReflection(pAIR, shaderStage, pOutReflection);
    }

    // Free resources.
    IRObjectDestroy(pAIR);
    IRCompilerDestroy(pCompiler);
    IRObjectDestroy(pDXIL);

    // Return metallib and reflection.
    return { pOutMetalLib, pOutReflection };
}
