//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------


#include "ShaderPipelineBuilder.hpp"

#include "ShaderCompiler.hpp"
#include <simd/simd.h>

MTL::RenderPipelineState* shader_pipeline::newCubePipeline(const std::string& shaderSearchPath, MTL::Device* pDevice)
{
    IRDescriptorRange1 textureRanges[] = {
        {
            .RangeType = IRDescriptorRangeTypeSRV,
            .NumDescriptors = 1,
            .BaseShaderRegister = 2,
            .RegisterSpace = 0,
            .Flags = IRDescriptorRangeFlagDataVolatile,
            .OffsetInDescriptorsFromTableStart = 0
        }
    };

    IRRootParameter1 params[] = {
        {
            .ParameterType = IRRootParameterTypeCBV,
            .Descriptor = {
                .ShaderRegister = 0,
                .RegisterSpace = 0,
                .Flags = IRRootDescriptorFlagNone,
            },
                .ShaderVisibility = IRShaderVisibilityVertex
        },
        {
            .ParameterType = IRRootParameterTypeSRV,
            .Descriptor = {
                .ShaderRegister = 0,
                .RegisterSpace = 0,
                .Flags = IRRootDescriptorFlagNone,
            },
                .ShaderVisibility = IRShaderVisibilityVertex
        },
        {
            .ParameterType = IRRootParameterTypeSRV,
            .Descriptor = {
                .ShaderRegister = 1,
                .RegisterSpace = 0,
                .Flags = IRRootDescriptorFlagNone,
            },
                .ShaderVisibility = IRShaderVisibilityVertex
        },
        {
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .DescriptorTable = {
                .pDescriptorRanges = textureRanges,
                .NumDescriptorRanges = sizeof( textureRanges ) / sizeof( IRDescriptorRange1 )
            },
                .ShaderVisibility = IRShaderVisibilityPixel
        },
    };

    IRStaticSamplerDescriptor samps[] = {
        {
            .Filter           = IRFilterMaximumMinMagMipLinear,
            .AddressU         = IRTextureAddressModeWrap,
            .AddressV         = IRTextureAddressModeWrap,
            .AddressW         = IRTextureAddressModeWrap,
            .MipLODBias       = 0,
            .MaxLOD           = std::numeric_limits<float>::max(),
            .ShaderRegister   = 0,
            .RegisterSpace    = 1,
            .ShaderVisibility = IRShaderVisibilityPixel
        }
    };

    IRVersionedRootSignatureDescriptor rootSigDesc;
    memset(&rootSigDesc, 0x0, sizeof(IRVersionedRootSignatureDescriptor));
    rootSigDesc.version                    = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.NumParameters     = sizeof(params) / sizeof(IRRootParameter1);
    rootSigDesc.desc_1_1.pParameters       = params;
    rootSigDesc.desc_1_1.pStaticSamplers   = samps;
    rootSigDesc.desc_1_1.NumStaticSamplers = sizeof(samps) / sizeof(IRStaticSamplerDescriptor);

    IRError* pRootSigError    = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pRootSigError);
    assert(pRootSig);

    std::string vertexLibraryPath = shaderSearchPath + "/v.dxil";
    MTL::Library* pVertexLib      = newLibraryFromDXIL(vertexLibraryPath, IRShaderStageVertex, "MainVS", pRootSig, pDevice);
    assert(pVertexLib);

    std::string fragmentLibraryPath = shaderSearchPath + "/f.dxil";
    MTL::Library* pFragmentLib      = newLibraryFromDXIL(fragmentLibraryPath, IRShaderStageFragment, "MainFS", pRootSig, pDevice);
    assert(pFragmentLib);

    MTL::Function* pVertexFn = pVertexLib->newFunction(MTLSTR("MainVS"));
    MTL::Function* pFragFn   = pFragmentLib->newFunction(MTLSTR("MainFS"));

    MTL::RenderPipelineDescriptor* pDesc = MTL::RenderPipelineDescriptor::alloc()->init();
    pDesc->setVertexFunction(pVertexFn);
    pDesc->setFragmentFunction(pFragFn);
    pDesc->colorAttachments()->object(0)->setPixelFormat(MTL::PixelFormat::PixelFormatBGRA8Unorm_sRGB);
    pDesc->setDepthAttachmentPixelFormat(MTL::PixelFormat::PixelFormatDepth16Unorm);

    NS::Error* pError              = nullptr;
    MTL::RenderPipelineState* pPSO = pDevice->newRenderPipelineState(pDesc, &pError);
    if (!pPSO)
    {
        printf("%s\n", pError->localizedDescription()->utf8String());
        __builtin_trap();
    }

    pFragFn->release();
    pVertexFn->release();
    pDesc->release();

    pFragmentLib->release();
    pVertexLib->release();

    IRRootSignatureDestroy(pRootSig);

    return pPSO;
}

MTL::RenderPipelineState* shader_pipeline::newSimpleGrassPipeline(const std::string& shaderSearchPath, MTL::Device* pDevice)
{
    // Build the root signature:

    IRRootParameter1 params[] = {
        { .ParameterType = IRRootParameterTypeCBV,
            .Descriptor  = {
                 .ShaderRegister = 0,
                 .RegisterSpace  = 0,
                 .Flags          = IRRootDescriptorFlagNone },
            .ShaderVisibility = IRShaderVisibilityVertex }
    };

    IRVersionedRootSignatureDescriptor rootSigDesc;
    rootSigDesc.version                    = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.NumParameters     = sizeof(params) / sizeof(params[0]);
    rootSigDesc.desc_1_1.pParameters       = params;
    rootSigDesc.desc_1_1.NumStaticSamplers = 0;

    IRError* pRootSigError          = nullptr;
    IRRootSignature* pRootSignature = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pRootSigError);
    if (!pRootSignature)
    {
        printf("%s\n", (const char*)IRErrorGetPayload(pRootSigError));
        IRErrorDestroy(pRootSigError);
        assert(pRootSignature);
    }

    // Compile the shader from DXIL to Metal IR.

    std::string vertexLibraryPath = shaderSearchPath + "/grass.v.dxil";
    MTL::Library* pVertexLibrary  = newLibraryFromDXIL(vertexLibraryPath, IRShaderStageVertex, "MainVS", pRootSignature, pDevice);
    assert(pVertexLibrary);

    std::string fragmentLibraryPath = shaderSearchPath + "/grass.f.dxil";
    MTL::Library* pFragmentLibrary  = newLibraryFromDXIL(fragmentLibraryPath, IRShaderStageFragment, "MainFS", pRootSignature, pDevice);
    assert(pFragmentLibrary);

    // Build the pipeline:

    MTL::VertexDescriptor* pVertexDesc = MTL::VertexDescriptor::alloc()->init();
    auto attrib                        = pVertexDesc->attributes()->object(kIRStageInAttributeStartIndex);
    attrib->setFormat(MTL::VertexFormatFloat4);
    attrib->setOffset(0);
    attrib->setBufferIndex(0);

    auto layout = pVertexDesc->layouts()->object(0);
    layout->setStride(sizeof(simd::float4));
    layout->setStepRate(1);
    layout->setStepFunction(MTL::VertexStepFunctionPerVertex);

    //#-code-listing(buildPipelineStates)
    MTL::Function* pVertexFn   = pVertexLibrary->newFunction(MTLSTR("MainVS"));
    MTL::Function* pFragmentFn = pFragmentLibrary->newFunction(MTLSTR("MainFS"));
    assert(pVertexFn);
    assert(pFragmentFn);

    MTL::RenderPipelineDescriptor* pPipelineDesc = MTL::RenderPipelineDescriptor::alloc()->init();
    pPipelineDesc->setVertexDescriptor(pVertexDesc);
    pPipelineDesc->setVertexFunction(pVertexFn);
    pPipelineDesc->setFragmentFunction(pFragmentFn);
    pPipelineDesc->colorAttachments()->object(0)->setPixelFormat(MTL::PixelFormatBGRA8Unorm_sRGB);
    pPipelineDesc->setDepthAttachmentPixelFormat(MTL::PixelFormat::PixelFormatDepth16Unorm);

    NS::Error* pMtlError                     = nullptr;
    MTL::RenderPipelineState* pPipelineState = pDevice->newRenderPipelineState(pPipelineDesc, &pMtlError);
    //#-end-code-listing

    if (!pPipelineState)
    {
        printf("%s\n", pMtlError->localizedDescription()->utf8String());
        assert(pPipelineState);
    }

    // Release resources:

    pPipelineDesc->release();
    pVertexFn->release();
    pFragmentFn->release();
    pVertexDesc->release();
    pFragmentLibrary->release();
    pVertexLibrary->release();

    IRRootSignatureDestroy(pRootSignature);

    return pPipelineState;
}

static IRVersionedInputLayoutDescriptor createVertexInputLayoutDescriptor()
{
    IRVersionedInputLayoutDescriptor inputDesc;
    inputDesc.version                       = IRInputLayoutDescriptorVersion_1;
    inputDesc.desc_1_0.numElements          = 1;
    inputDesc.desc_1_0.semanticNames[0]     = "POSITION";
    inputDesc.desc_1_0.inputElementDescs[0] = (IRInputElementDescriptor1) {
        .semanticIndex        = 0,
        .format               = IRFormatR32G32B32A32Float,
        .inputSlot            = 0,
        .alignedByteOffset    = 0,
        .inputSlotClass       = IRInputClassificationPerVertexData,
        .instanceDataStepRate = 0 /* must be 0 for per-vertex data */
    };
    return inputDesc;
}

RenderPipelineStateWithConfig<IRRuntimeGeometryPipelineConfig> shader_pipeline::newGeometryPipeline(const std::string& shaderSearchPath, MTL::Device* pDevice)
{
    // Create the root signature descriptor, matching shader resource interfaces.

    IRVersionedRootSignatureDescriptor rootSigDesc;
    memset(&rootSigDesc, 0x0, sizeof(IRVersionedRootSignatureDescriptor));
    rootSigDesc.version = IRRootSignatureVersion_1_1;

    IRRootParameter1 params[] = {
        { .ParameterType = IRRootParameterTypeCBV,
            .Descriptor  = {
                 .ShaderRegister = 0,
                 .RegisterSpace  = 0,
                 .Flags          = IRRootDescriptorFlagNone,
            },
            .ShaderVisibility = IRShaderVisibilityVertex },
        { .ParameterType = IRRootParameterTypeCBV, .Descriptor = { .ShaderRegister = 1, .RegisterSpace = 0, .Flags = IRRootDescriptorFlagNone } }
    };
    rootSigDesc.desc_1_1.NumParameters = sizeof(params) / sizeof(IRRootParameter1);
    rootSigDesc.desc_1_1.pParameters   = params;

    IRError* pRootSigError    = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pRootSigError);
    assert(pRootSig);

    // Compile the vertex shader.

    std::string vertexLibraryPath        = shaderSearchPath + "/floor.v.dxil";
    auto [pVertexLib, pVertexReflection] = newLibraryWithReflectionFromDXIL(vertexLibraryPath, IRShaderStageVertex, "MainVS", pRootSig, pDevice, true);
    assert(pVertexLib);
    assert(pVertexReflection);

    // Generate the stage-in function.

    IRVersionedInputLayoutDescriptor inputDesc = createVertexInputLayoutDescriptor();
    auto pStageInLib                           = newStageInLibrary(inputDesc, pVertexReflection, pDevice);
    assert(pStageInLib->functionNames()->count() > 0);

    // Compile the geometry shader.

    std::string geometryLibraryPath          = shaderSearchPath + "/floor.g.dxil";
    auto [pGeometryLib, pGeometryReflection] = newLibraryWithReflectionFromDXIL(geometryLibraryPath, IRShaderStageGeometry, "MainGS", pRootSig, pDevice, true);
    assert(pGeometryLib);
    assert(pGeometryReflection);

    // Collect the reflection data.

    IRVersionedVSInfo vsInfo;
    IRVersionedGSInfo gsInfo;

    IRShaderReflectionCopyVertexInfo(pVertexReflection, IRReflectionVersion_1_0, &vsInfo);
    IRShaderReflectionCopyGeometryInfo(pGeometryReflection, IRReflectionVersion_1_0, &gsInfo);

    RenderPipelineStateWithConfig<IRRuntimeGeometryPipelineConfig> geometryPipeline;

    geometryPipeline.pipelineConfig.gsVertexSizeInBytes                    = vsInfo.info_1_0.vertex_output_size_in_bytes;
    geometryPipeline.pipelineConfig.gsMaxInputPrimitivesPerMeshThreadgroup = gsInfo.info_1_0.max_input_primitives_per_mesh_threadgroup;

    IRShaderReflectionReleaseVertexInfo(&vsInfo);
    IRShaderReflectionReleaseGeometryInfo(&gsInfo);

    // Compile the fragment shader.

    std::string fragmentLibraryPath = shaderSearchPath + "/floor.f.dxil";
    MTL::Library* pFragmentLib      = newLibraryFromDXIL(fragmentLibraryPath, IRShaderStageFragment, "MainFS", pRootSig, pDevice);
    assert(pFragmentLib);

    // Build the base pipeline descriptor, defining attachments.

    MTL::MeshRenderPipelineDescriptor* pDesc = MTL::MeshRenderPipelineDescriptor::alloc()->init();
    pDesc->colorAttachments()->object(0)->setPixelFormat(MTL::PixelFormatBGRA8Unorm_sRGB);
    pDesc->setDepthAttachmentPixelFormat(MTL::PixelFormatDepth16Unorm);

    // Prepare the geometry pipeline descriptor.

    IRGeometryEmulationPipelineDescriptor geometryPipelineDesc;
    geometryPipelineDesc.stageInLibrary         = pStageInLib;
    geometryPipelineDesc.vertexLibrary          = pVertexLib;
    geometryPipelineDesc.vertexFunctionName     = "MainVS";
    geometryPipelineDesc.geometryLibrary        = pGeometryLib;
    geometryPipelineDesc.geometryFunctionName   = "MainGS";
    geometryPipelineDesc.fragmentLibrary        = pFragmentLib;
    geometryPipelineDesc.fragmentFunctionName   = "MainFS";
    geometryPipelineDesc.basePipelineDescriptor = pDesc;
    geometryPipelineDesc.pipelineConfig         = geometryPipeline.pipelineConfig;

    NS::Error* pError                     = nullptr;
    geometryPipeline.pRenderPipelineState = IRRuntimeNewGeometryEmulationPipeline(pDevice, &geometryPipelineDesc, &pError);

    if (!geometryPipeline.pRenderPipelineState)
    {
        printf("%s\n", pError->localizedDescription()->utf8String());
        __builtin_trap();
    }

    // Free resources.

    IRRootSignatureDestroy(pRootSig);
    IRShaderReflectionDestroy(pVertexReflection);
    IRShaderReflectionDestroy(pGeometryReflection);

    pDesc->release();
    pFragmentLib->release();
    pGeometryLib->release();
    pVertexLib->release();
    pStageInLib->release();

    return geometryPipeline;
}

MTL::RenderPipelineState* newDebugPipeline(MTL::Device* pDevice)
{
    // Debug the pipeline state (for simple meshes like the floor).
#define MSL(x) #x

    const char* shaderSrc = "#include <metal_stdlib>\n" MSL(
        using namespace metal;                                                                                                              \n struct CameraData {
            \n
                float4x4 perspectiveTransform;
            \n
                float4x4 worldTransform;
            \n
                float4x3 worldNormalTransform;
            \n
        };                                                                                                                                  \n struct v2f {
            \n
                float4 pos [[position]];
            \n
        };                                                                                                                                  \n
            vertex v2f MainVS(constant float4 * vertices [[buffer(0)]], constant CameraData & cameraData [[buffer(16)]], uint vid [[vertex_id]])\n {
                \n return (v2f) { .pos = cameraData.perspectiveTransform * vertices[vid] };
                \n
            }                                                                                                                                   \n
                                                                                                                                            \n
                fragment float4 MainFS(v2f vin [[stage_in]])                                                                                      \n {
                    \n return float4(1.0, 0.0, 1.0, 1.0);
                    \n
                }                                                                                                                                   \n);

    NS::Error* pError  = nil;
    MTL::Library* pLib = pDevice->newLibrary(NS::String::string(shaderSrc, NS::UTF8StringEncoding), nullptr, &pError);
    if (!pLib)
    {
        printf("Error: %s\n", pError->localizedDescription()->utf8String());
        __builtin_trap();
    }

    MTL::Function* pVfn = pLib->newFunction(MTLSTR("MainVS"));
    MTL::Function* pFfn = pLib->newFunction(MTLSTR("MainFS"));

    MTL::RenderPipelineDescriptor* pDesc = MTL::RenderPipelineDescriptor::alloc()->init();
    pDesc->setVertexFunction(pVfn);
    pDesc->setFragmentFunction(pFfn);
    pDesc->colorAttachments()->object(0)->setPixelFormat(MTL::PixelFormatBGRA8Unorm_sRGB);
    pDesc->setDepthAttachmentPixelFormat(MTL::PixelFormatDepth16Unorm);

    MTL::RenderPipelineState* pDebugPSO = pDevice->newRenderPipelineState(pDesc, &pError);
    assert(pDebugPSO);

    pDesc->release();
    pVfn->release();
    pFfn->release();
    pLib->release();

    return pDebugPSO;
#undef MSL
}

RenderPipelineStateWithConfig<IRRuntimeTessellationPipelineConfig> shader_pipeline::newTessellationPipeline(const std::string& shaderSearchPath, MTL::Device* pDevice)
{
    std::string vertexShaderPath = shaderSearchPath + "/tsgs.v.dxil";
    std::string hullShaderPath   = shaderSearchPath + "/tsgs.h.dxil";
    std::string domainShaderPath = shaderSearchPath + "/tsgs.d.dxil";
    std::string geomShaderPath   = shaderSearchPath + "/tsgs.g.dxil";
    std::string fragShaderPath   = shaderSearchPath + "/tsgs.f.dxil";

    // Build the root signature, describing how to bind global resources to the pipeline.

    IRRootSignature* pRootSignature = nullptr;
    {
        IRRootParameter1 params[] = {
            {
                .ParameterType = IRRootParameterTypeCBV,
                .ShaderVisibility = IRShaderVisibilityAll,
                .Descriptor = {
                    .ShaderRegister = 0,
                    .RegisterSpace = 0,
                    .Flags = IRRootDescriptorFlagNone
                }
            },
            {
                .ParameterType = IRRootParameterTypeCBV,
                .ShaderVisibility = IRShaderVisibilityAll,
                .Descriptor = {
                    .ShaderRegister = 1,
                    .RegisterSpace = 0,
                    .Flags = IRRootDescriptorFlagNone
                }
            },
            {
                .ParameterType = IRRootParameterTypeCBV,
                .ShaderVisibility = IRShaderVisibilityAll,
                .Descriptor = {
                    .ShaderRegister = 0,
                    .RegisterSpace = 2,
                    .Flags = IRRootDescriptorFlagNone
                }
            }
        };

        IRError* pError = nullptr;
        IRVersionedRootSignatureDescriptor rootSigDesc;
        rootSigDesc.version                    = IRRootSignatureVersion_1_1;
        rootSigDesc.desc_1_1.pParameters       = params;
        rootSigDesc.desc_1_1.NumParameters     = sizeof(params) / sizeof(IRRootParameter1);
        rootSigDesc.desc_1_1.NumStaticSamplers = 0;
        pRootSignature                         = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pError);

        if (!pRootSignature)
        {
            printf("%s\n", (const char*)IRErrorGetPayload(pError));
            IRErrorDestroy(pError);
        }
        assert(pRootSignature);
    }

    // Compile all shaders from DXIL to metallibs.

    auto [pVertexLib, pVertexReflection] = newLibraryWithReflectionFromDXIL(vertexShaderPath,
        IRShaderStageVertex,
        "MainVS",
        pRootSignature,
        pDevice,
        true);
    assert(pVertexLib);
    assert(pVertexReflection);

    auto [pHullLib, pHullReflection] = newLibraryWithReflectionFromDXIL(hullShaderPath,
        IRShaderStageHull,
        "MainHS",
        pRootSignature,
        pDevice,
        false);

    assert(pHullLib);
    assert(pHullReflection);

    auto [pDomainLib, pDomainReflection] = newLibraryWithReflectionFromDXIL(domainShaderPath,
        IRShaderStageDomain,
        "MainDS",
        pRootSignature,
        pDevice,
        false);
    assert(pDomainLib);
    assert(pDomainReflection);

    auto [pGeometryLib, pGeometryReflection] = newLibraryWithReflectionFromDXIL(geomShaderPath,
        IRShaderStageGeometry,
        "MainGS",
        pRootSignature,
        pDevice,
        false);

    assert(pGeometryLib);
    assert(pGeometryReflection);

    auto [pFragLib, pFragReflection] = newLibraryWithReflectionFromDXIL(fragShaderPath,
        IRShaderStageFragment,
        "MainFS",
        pRootSignature,
        pDevice,
        false);
    assert(pFragLib);
    assert(pFragReflection);

    // Synthesize a stage-in function for the object shader.

    IRVersionedInputLayoutDescriptor inputDesc = createVertexInputLayoutDescriptor();
    auto pStageInLib                           = newStageInLibrary(inputDesc, pVertexReflection, pDevice);
    assert(pStageInLib);

    // Validate that the pipeline stage inputs and outputs are compatible.
    {

        IRVersionedGSInfo gsInfo;
        IRVersionedHSInfo hsInfo;
        IRVersionedDSInfo dsInfo;

        IRShaderReflectionCopyGeometryInfo(pGeometryReflection, IRReflectionVersion_1_0, &gsInfo);
        IRShaderReflectionCopyHullInfo(pHullReflection, IRReflectionVersion_1_0, &hsInfo);
        IRShaderReflectionCopyDomainInfo(pDomainReflection, IRReflectionVersion_1_0, &dsInfo);

        assert(IRRuntimeValidateTessellationPipeline(
            (IRRuntimeTessellatorOutputPrimitive)hsInfo.info_1_0.tessellator_output_primitive,
            (IRRuntimePrimitiveType)gsInfo.info_1_0.input_primitive,
            hsInfo.info_1_0.output_control_point_size,
            dsInfo.info_1_0.input_control_point_size,
            hsInfo.info_1_0.patch_constants_size,
            dsInfo.info_1_0.patch_constants_size,
            hsInfo.info_1_0.input_control_point_count,
            dsInfo.info_1_0.input_control_point_count));

        IRShaderReflectionReleaseDomainInfo(&dsInfo);
        IRShaderReflectionReleaseHullInfo(&hsInfo);
        IRShaderReflectionReleaseGeometryInfo(&gsInfo);
    }

    // Start building the render pipeline with its config:
    RenderPipelineStateWithConfig<IRRuntimeTessellationPipelineConfig> tessellationPipeline;

    // Store the tessellation configuration for draw-time.
    {
        IRVersionedVSInfo vsInfo;
        IRVersionedGSInfo gsInfo;
        IRVersionedHSInfo hsInfo;

        IRShaderReflectionCopyVertexInfo(pVertexReflection, IRReflectionVersion_1_0, &vsInfo);
        IRShaderReflectionCopyGeometryInfo(pGeometryReflection, IRReflectionVersion_1_0, &gsInfo);
        IRShaderReflectionCopyHullInfo(pHullReflection, IRReflectionVersion_1_0, &hsInfo);

        tessellationPipeline.pipelineConfig.outputPrimitiveType                    = (IRRuntimeTessellatorOutputPrimitive)hsInfo.info_1_0.tessellator_output_primitive;
        tessellationPipeline.pipelineConfig.vsOutputSizeInBytes                    = vsInfo.info_1_0.vertex_output_size_in_bytes;
        tessellationPipeline.pipelineConfig.gsMaxInputPrimitivesPerMeshThreadgroup = gsInfo.info_1_0.max_input_primitives_per_mesh_threadgroup;
        tessellationPipeline.pipelineConfig.hsMaxPatchesPerObjectThreadgroup       = hsInfo.info_1_0.max_patches_per_object_threadgroup;
        tessellationPipeline.pipelineConfig.hsInputControlPointCount               = hsInfo.info_1_0.input_control_point_count;
        tessellationPipeline.pipelineConfig.hsMaxObjectThreadsPerThreadgroup       = hsInfo.info_1_0.max_object_threads_per_patch;
        tessellationPipeline.pipelineConfig.hsMaxTessellationFactor                = hsInfo.info_1_0.max_tessellation_factor;
        tessellationPipeline.pipelineConfig.gsInstanceCount                        = gsInfo.info_1_0.instance_count;

        IRShaderReflectionReleaseHullInfo(&hsInfo);
        IRShaderReflectionReleaseGeometryInfo(&gsInfo);
        IRShaderReflectionReleaseVertexInfo(&vsInfo);
    }

    // Build the mesh pipeline.
    MTL::MeshRenderPipelineDescriptor* pPipelineDesc = MTL::MeshRenderPipelineDescriptor::alloc()->init();

    // Set the color attachment format:
    pPipelineDesc->colorAttachments()->object(0)->setPixelFormat(MTL::PixelFormatBGRA8Unorm_sRGB);
    pPipelineDesc->setDepthAttachmentPixelFormat(MTL::PixelFormatDepth16Unorm);

    IRGeometryTessellationEmulationPipelineDescriptor tessellationDesc;
    tessellationDesc.basePipelineDescriptor = pPipelineDesc;
    tessellationDesc.pipelineConfig         = tessellationPipeline.pipelineConfig;
    tessellationDesc.stageInLibrary         = pStageInLib;
    tessellationDesc.vertexLibrary          = pVertexLib;
    tessellationDesc.vertexFunctionName     = "MainVS";
    tessellationDesc.hullLibrary            = pHullLib;
    tessellationDesc.hullFunctionName       = "MainHS";
    tessellationDesc.domainLibrary          = pDomainLib;
    tessellationDesc.domainFunctionName     = "MainDS";
    tessellationDesc.geometryLibrary        = pGeometryLib;
    tessellationDesc.geometryFunctionName   = "MainGS";
    tessellationDesc.fragmentLibrary        = pFragLib;
    tessellationDesc.fragmentFunctionName   = "MainFS";

    NS::Error* pError                         = nullptr;
    tessellationPipeline.pRenderPipelineState = IRRuntimeNewGeometryTessellationEmulationPipeline(pDevice, &tessellationDesc, &pError);
    if (!tessellationPipeline.pRenderPipelineState)
    {
        printf("%s\n", pError->localizedDescription()->utf8String());
        __builtin_trap();
    }

    // Release resources.

    pPipelineDesc->release();
    pStageInLib->release();
    pFragLib->release();
    pGeometryLib->release();
    pDomainLib->release();
    pHullLib->release();
    pVertexLib->release();

    IRShaderReflectionDestroy(pFragReflection);
    IRShaderReflectionDestroy(pGeometryReflection);
    IRShaderReflectionDestroy(pDomainReflection);
    IRShaderReflectionDestroy(pHullReflection);
    IRShaderReflectionDestroy(pVertexReflection);

    IRRootSignatureDestroy(pRootSignature);

    return tessellationPipeline;
}

MTL::ComputePipelineState* shader_pipeline::newComputePipeline(const std::string& shaderSearchPath, MTL::Device* pDevice)
{
    // Resource layout:

    IRDescriptorRange1 uavRange = {
        .RangeType                         = IRDescriptorRangeTypeUAV,
        .NumDescriptors                    = 1,
        .BaseShaderRegister                = 0,
        .RegisterSpace                     = 0,
        .OffsetInDescriptorsFromTableStart = 0
    };

    IRRootParameter1 params[] = {
        {
            .ParameterType    = IRRootParameterTypeCBV,
            .ShaderVisibility = IRShaderVisibilityAll,
            .Descriptor       = {
                .ShaderRegister = 0,
                .RegisterSpace  = 0,
                .Flags          = IRRootDescriptorFlagDataStatic }
        },
        {
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityAll,
            .DescriptorTable = {
                .NumDescriptorRanges = 1,
                .pDescriptorRanges = &uavRange
            }
        }
    };

    IRVersionedRootSignatureDescriptor rootSigDesc;
    memset(&rootSigDesc, 0x0, sizeof(IRVersionedRootSignatureDescriptor));
    rootSigDesc.version                = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.NumParameters = sizeof(params) / sizeof(IRRootParameter1);
    rootSigDesc.desc_1_1.pParameters   = params;

    IRError* pRootSigError    = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pRootSigError);
    assert(pRootSig);

    // Make MetalLib:

    std::string computeLibraryPath = shaderSearchPath + "/k.dxil";
    MTL::Library* pLib             = newLibraryFromDXIL(computeLibraryPath, IRShaderStageCompute, "MainCS", pRootSig, pDevice);
    assert(pLib);

    NS::Error* pError;
    MTL::Function* pFn                     = pLib->newFunction(MTLSTR("MainCS"));
    MTL::ComputePipelineState* pComputePSO = pDevice->newComputePipelineState(pFn, &pError);
    if (!pComputePSO)
    {
        __builtin_printf("%s\n", pError->localizedDescription()->utf8String());
        assert(false);
    }

    pFn->release();
    pLib->release();

    IRRootSignatureDestroy(pRootSig);

    return pComputePSO;
}
