//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------


//-------------------------------------------------------------------------------------------------------------------------------------------------------------
// This project includes a pre-compiled version of this file.
// In order to modify it, you need to build the DirectX
// Shader Compiler. Please see the README for instructions.
//-------------------------------------------------------------------------------------------------------------------------------------------------------------


struct FrameData
{
    float anim;
    uint gridWidth;
    uint gridHeight;
};

ConstantBuffer<FrameData> frameData : register(b0, space0);
RWTexture2D<float> outTex : register(u0, space0);

[numthreads(32, 32, 1)]
void MainCS(uint2 index : SV_DispatchThreadID)
{
    uint2 gridSize = uint2(frameData.gridWidth, frameData.gridHeight);
    float anim = frameData.anim;

    const float kAnimationFrequency = 1;
    const float kAnimationSpeed = 4;
    const float kAnimationScaleLow = 0.62;
    const float kAnimationScale = 0.38;

    const float2 kMandelbrotPixelOffset = {-0.2, -0.35};
    const float2 kMandelbrotOrigin = {-1.2, -0.32};
    const float2 kMandelbrotScale = {2.2, 2.0};

    // Map time to zoom value in [kAnimationScaleLow, 1]
    float zoom = kAnimationScaleLow + kAnimationScale * cos(kAnimationFrequency * anim);
    // Speed up zooming
    zoom = pow(zoom, kAnimationSpeed);

    //Scale
    float x0 = zoom * kMandelbrotScale.x * ((float)index.x / gridSize.x + kMandelbrotPixelOffset.x) + kMandelbrotOrigin.x;
    float y0 = zoom * kMandelbrotScale.y * ((float)index.y / gridSize.y + kMandelbrotPixelOffset.y) + kMandelbrotOrigin.y;

    // Implement Mandelbrot set
    float x = 0.0;
    float y = 0.0;
    uint iteration = 0;
    uint max_iteration = 1000;
    float xtmp = 0.0;
    while(x * x + y * y <= 4 && iteration < max_iteration)
    {
        xtmp = x * x - y * y + x0;
        y = 2 * x * y + y0;
        x = xtmp;
        iteration += 1;
    }

    // Convert iteration result to colors
    float color = (0.5 + 0.5 * cos(3.0 + iteration * 0.15));
    outTex[index] = color;
}

struct v2f
{
    float4 position : SV_Position;
    float3 normal : NORMAL;
    half3 color : COLOR;
    float2 texcoord : TEXCOORD;
};

struct VertexData
{
    float4 position;
    float4 normal;
    float4 texcoord;
};

struct InstanceData
{
    float4x4 instanceTransform;
    float4x4 instanceNormalTransform;
    float4 instanceColor;
};

struct CameraData
{
    float4x4 perspectiveTransform;
    float4x4 worldTransform;
    float4x3 worldNormalTransform;
};

ConstantBuffer<CameraData> cameraData : register(b0, space0);
StructuredBuffer<VertexData> vertexData : register(t0, space0);
StructuredBuffer<InstanceData> instanceData : register(t1, space0);
Texture2D<float4> tex : register(t2, space0);
SamplerState samp : register(s0, space1);

v2f MainVS( uint vertexId : SV_VertexID,
            uint instanceId : SV_InstanceID )
{
    v2f o;

    VertexData vd = vertexData[ vertexId ];
    float4 pos = vd.position;
    pos = mul(instanceData[ instanceId ].instanceTransform, pos);
    pos = mul(mul(cameraData.perspectiveTransform, cameraData.worldTransform), pos);
    o.position = pos;

    float3 normal = mul((float3x3)(instanceData[ instanceId ].instanceNormalTransform), vd.normal.xyz);
    normal = mul(cameraData.worldNormalTransform, normal).xyz;
    o.normal = normal;

    o.texcoord = vd.texcoord.xy;

    o.color = half3( instanceData[ instanceId ].instanceColor.rgb );
    return o;
}

half4 MainFS( v2f vin ) : SV_Target
{
    half3 texel = tex.Sample( samp, vin.texcoord ).rgb;

    // assume light coming from (front-top-right)
    float3 l = normalize(float3( 1.0, 1.0, 0.8 ));
    float3 n = normalize( vin.normal );

    half ndotl = half( saturate( dot( n, l ) ) );

    half3 illum = (vin.color * texel * 0.1) + (vin.color * texel * ndotl);
    return half4( illum, 1.0 );
}
