#include "rt_common.hlsl"

struct LocalData
{
    float radiusScale;
};

ConstantBuffer<LocalData> localData : register(b0, space16);

struct SphereAttributes
{
    float t;
    float3 normal;
};

struct Sphere
{
    float3 center;
    float radius;
};

[shader("intersection")]
void SphereIntersection()
{
    float3 rayOrigin = WorldRayOrigin();
    float3 rayDirection = WorldRayDirection();
    
    Sphere s;
    s.center = mul(ObjectToWorld3x4(), float4(0,0,0,1)).xyz;
    s.radius = 1 * localData.radiusScale;
    
    // analytical solution for ray-sphere intersection
    
    float a = dot(rayDirection, rayDirection);
    float b = 2 * dot(rayDirection, rayOrigin - s.center);
    float c = dot(rayOrigin - s.center, rayOrigin - s.center) - (s.radius * s.radius);
    
    float det = (b*b - 4*a*c);
    if ( det < 0 )
    {
        // No hit
        return;
    }
    
    float t1 = (-b + sqrt(det)) / (2.0 * a);
    float t2 = (-b - sqrt(det)) / (2.0 * a);
    
    // Pick the intesection closest to the origin that
    // is not behind the origin of the ray.
    
    float t = 65535;
    
    if ( t1 < t && t1 > 0.0 )
    {
        t = t1;
    }
    
    if ( t2 < t && t2 > 0.0 )
    {
        t = t2;
    }
    
    if ( t <= 0 || t == 65535 )
    {
        // No hit (sphere is behind the origin)
        return;
    }
    
    // Attributes to pass to the rest of the system
    
    float3 p = rayOrigin + t * rayDirection;
    
    SphereAttributes attr;
    attr.t = t;
    attr.normal = normalize(p - s.center);
    
    ReportHit(t, 0, attr);
}

[shader("anyhit")]
void SphereAnyHit(inout RayPayload payload, SphereAttributes attr)
{
    float d = dot(attr.normal, float3(0.0, 0.0, 1.0)) * 0.5 + 0.5;
    if (uint(d * 360) % (36/2) < 36/2/2)
    {
        IgnoreHit();
    }
}

[shader("closesthit")]
void SphereClosestHit(inout RayPayload payload, SphereAttributes attr)
{
    payload.color = float4(abs(attr.normal), 1.0);
}
