/*
 * BSD 3-Clause License
 *
 * Copyright (c) 2016, Nicolas Weber, Sandra C. Amend / GCC / TU-Darmstadt
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
    * Redistributions of source code must retain the above copyright
      notice, this list of conditions and the following disclaimer.
    * Redistributions in binary form must reproduce the above copyright
      notice, this list of conditions and the following disclaimer in the
      documentation and/or other materials provided with the distribution.
    * Neither the name of the <organization> nor the
      names of its contributors may be used to endorse or promote products
      derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * HLSL/Unity version Copyright (c) 2024, Chelsea "Feilen" Jaggi and VRChat Inc
 * The following is also licensed under the BSD 3-Clause License.
 *
 * This is a HLSL/Unity Compute shader implementation of the original cuda code 
 * available at https://github.com/mergian/dpid, adjusted to be used as an 
 * AssetPostprocessor (in-SDK) and a fully GPU-side texture updater (in-game) 
 * for generating mipmaps in Unity. The algorithm is intended to intentionally 
 * over-emphasize 'perceptually relevant' details, avoiding the need for 
 * post-sharpening typical in traditional downscaling algorithms.
 */
// TODO: 'quality' mode which does a gaussian filter of the original image, rather than box.
// Original paper says this makes little difference

#pragma kernel KernelGuidance GUIDANCE_SAMPLING
#pragma kernel KernelDownsampling DPID_SAMPLING

// Constants
#define THREADS 64
#define USE_WARP_REDUCTION 1
#define FLT_EPS 1.192092896e-07
//#define DEBUG_GUIDANCE 1
//#define DEBUG_INVERT 1
//#define DEBUG_BOXBLUR 1

// Buffers
Texture2D<unorm float4> _Input;

// OpenGL (ES) and Vulkan don't like RWTexture2D<> to be read/write, they bind them as write-only
#if defined(DPID_SAMPLING)
Texture2D<unorm float4> _Guidance;
#endif
RWTexture2D<unorm float4> _Output;
// For summation
groupshared float4 colorResultBuffer[THREADS]; // Output buffer where the results will be written
groupshared float factorResultBuffer[THREADS]; // Output buffer where the results will be written
groupshared float factorAlphaResultBuffer[THREADS]; // Output buffer where the results will be written

uint oWidth;
uint oHeight;
uint iWidth;
uint iHeight;
float pWidth;
float pHeight;
float lambda;
// Desired behavior for alpha handling is that, for example, a pure-white-on-black linear RGB star will end up the same as a pure alpha-on-transparent star
bool premultiplyAlpha;
bool sRGB;
bool normalMap;

struct ColorFactor
{
    float4 color;
    float factor;
    float factorAlpha;
};

// Helper functions
inline void normalize(inout ColorFactor var, float4 fallbackColor)
{
    if (!premultiplyAlpha)
    {
        if (var.factor < FLT_EPS)
        {
            var.color = fallbackColor;
        }
        else
        {
            var.color /= var.factor;
        }
    }
    else
    {
        // Edge case where we have a single transparent pixel for 4 inputs, 
        // results in the non-transparent color becoming dominant.
        // Then when we try and get contributions, it ends up re-emphasizing the hidden ones
        if (var.factor < FLT_EPS)
        {
            var.color.rgb = fallbackColor.rgb;
        }
        else
        {
            var.color.rgb /= var.factor;
        }

        if (var.factorAlpha < FLT_EPS)
        {
            var.color.a = fallbackColor.a;
        }
        else
        {
            var.color.a /= var.factorAlpha;
        }
    }
    var.factor = var.factorAlpha = 1.0f;
}

inline float3 LinearToSRGB(float3 color)
{
    float3 sRGBLo = color * 12.92;
    float3 sRGBHi = 1.055 * pow(abs(color), 1.0 / 2.4) - 0.055;
    float3 isLow = step(color, 0.0031308);

    return lerp(sRGBHi, sRGBLo, isLow);
}

inline float3 SRGBToLinear(float3 color)
{
    float3 linearLo = color / 12.92;
    float3 linearHi = pow((color + 0.055) / 1.055, 2.4);
    float3 isLow = step(color, 0.04045);

    return lerp(linearHi, linearLo, isLow);
}

// Sum up everything within our one group
void WarpReduce(inout float4 value, inout float factor, inout float factorAlpha, uint indexInGroup)
{
    // Store the value in global buffer
    colorResultBuffer[indexInGroup] = value;
    factorResultBuffer[indexInGroup] = factor;
    factorAlphaResultBuffer[indexInGroup] = factorAlpha;

    // Synchronize to ensure all values are written
    GroupMemoryBarrierWithGroupSync();
    
    // Perform wave active sum reduction within the group
    for (uint stride = THREADS / 2; stride > 0; stride /= 2)
    {
        // A coyote informed me that lanes can get out-of-order reads on some hardware
        float4 c = colorResultBuffer[indexInGroup + stride];
        float f = factorResultBuffer[indexInGroup + stride];
        float fa = factorAlphaResultBuffer[indexInGroup + stride];

        GroupMemoryBarrierWithGroupSync();

        if (indexInGroup < stride)
        {
            colorResultBuffer[indexInGroup] += c;
            factorResultBuffer[indexInGroup] += f;
            factorAlphaResultBuffer[indexInGroup] += fa;
        }
        
        GroupMemoryBarrierWithGroupSync();
    }
    
    // Return the reduced sum for this warp
    value = colorResultBuffer[0];
    factor = factorResultBuffer[0];
    factorAlpha = factorAlphaResultBuffer[0];
}

inline void add(inout ColorFactor var, float4 color, float factorColor, float factorAlpha)
{
    // Without the 'max' term here, normalizing produces invalid results
    if (!premultiplyAlpha)
    {
        var.color += color * factorColor;
        var.factor += factorColor;
    }
    else
    {
        float premultFactor = factorColor * color.a;
        var.color.rgb += color.rgb * premultFactor;
        var.color.a += factorAlpha * color.a;
        var.factor += premultFactor;
    }
    var.factorAlpha += factorAlpha;
}

inline float applyLambda(float dist, float lambdaValue)
{
    if (lambdaValue == 0.0f)
        return 1.0f;
    else if (lambdaValue == 1.0f)
        return dist;

    return pow(dist, lambdaValue);
}

inline float contribution(float2 max_f, float2 min_f, uint x, uint y)
{
    float f = 1.0f;
    float fx = (float)x, fy = (float)y;  
    if (fx < max_f.x) f *= 1.0f - (max_f.x - fx);
    if ((fx + 1.0f) > min_f.x) f *= 1.0f - ((fx + 1.0f) - min_f.x);
    if (fy < max_f.y) f *= 1.0f - (max_f.y - fy);
    if ((fy + 1.0f) > min_f.y) f *= 1.0f - ((fy + 1.0f) - min_f.y);
    return f;
}

inline float distance(float4 avg, float4 color)
{
    float4 diff = avg - color;
    return length(diff) / 2.0f; // sqrt(2)
}

inline float distance(float3 avg, float3 color)
{
    float3 diff = avg - color;
    return length(diff) / sqrt(3.0f); // sqrt(2)
}

inline float distance(float avg, float color)
{
    return abs(avg - color);
}

// Same as add but does optional srgb conversion
inline void addBoxBlur(inout ColorFactor var, float4 color, float factor)
{
    if (sRGB)
        color.rgb = SRGBToLinear(color.rgb);
    add(var, color, factor, factor);
}

#if defined(DPID_SAMPLING)
float4 BoxBlurSample(uint PX, uint PY) {
	const float corner	= 1.0;
	const float edge	= 2.0;
	const float center	= 4.0;

	// calculate average color
        ColorFactor avg = { float4(0, 0, 0, 0), 0.0f, 0.0f };

	// TOP
	if(PY > 0) {
		if(PX > 0) 
			addBoxBlur(avg, _Guidance[int2(PX - 1, PY - 1)], corner);

		addBoxBlur(avg, _Guidance[int2(PX, PY - 1)], edge);
	
		if((PX+1) < oWidth)
			addBoxBlur(avg, _Guidance[int2(PX + 1, PY - 1)], corner);
	}

	// LEFT
	if(PX > 0) 
		addBoxBlur(avg, _Guidance[int2(PX - 1, PY)], edge);

	// CENTER
	addBoxBlur(avg, _Guidance[int2(PX, PY)], center);
	
	// RIGHT
	if((PX+1) < oWidth)
		addBoxBlur(avg, _Guidance[int2(PX + 1, PY)], edge);

	// BOTTOM
	if((PY+1) < oHeight) {
		if(PX > 0) 
			addBoxBlur(avg, _Guidance[int2(PX - 1, PY + 1)], corner);

		addBoxBlur(avg, _Guidance[int2(PX, PY + 1)], edge);
	
		if((PX+1) < oWidth)
			addBoxBlur(avg, _Guidance[int2(PX + 1, PY + 1)], corner);
	}

        normalize(avg, avg.color);
	return avg.color;
}
#endif

// Dispatch is sent as Dispatch(Math.Max(param.oWidth / THREADS, 1), Math.Max(param.oHeight, 1), 1) (ignore numthreads)
// Each warp will handle a patch of patchSize pixels: inputTex.width / outputTex.width, inputTex.height / outputTex.height for PX, PY
// Kernels
void Downsampling(uint3 id, uint3 gid, uint gtid)
{
    uint indexInGroup = gtid;
    uint PX = id.x / THREADS;
    uint PY = id.y;

    float2 max_f = float2(max(PX * pWidth, 0.0f), max(PY * pHeight, 0.0f));
    float2 min_f = float2(min((PX + 1) * pWidth, iWidth), min((PY + 1) * pHeight, iHeight));

    uint2 patchMinCorner = uint2(floor(max_f.x), floor(max_f.y));
    uint2 patchMaxCorner = uint2(ceil(min_f.x), ceil(min_f.y));
    uint2 patchSize = patchMaxCorner - patchMinCorner;
    uint pixelCount = patchSize.x * patchSize.y;

#if defined(DPID_SAMPLING)
    float4 avg = BoxBlurSample(PX, PY); 
#else
    float4 avg = float4(1, 0, 1, 1);
#endif

    ColorFactor color = { float4(0, 0, 0, 0), 0.0f, 0.0f };
#if USE_WARP_REDUCTION
    for (uint i = indexInGroup; i < pixelCount; i += THREADS) 
#else
    for(uint i = 0; i < pixelCount; i++)
#endif
    {
        uint x = patchMinCorner.x + (i % patchSize.x);
        uint y = patchMinCorner.y + (i / patchSize.x);
        float4 pixel = _Input[int2(x, y)];
        if (normalMap)
            pixel.ba = 1.0f;

#if defined(DPID_SAMPLING)
        float f_c = 0.0f, f_a = 0.0f; 
        if (premultiplyAlpha)
        {
            f_c = distance(avg.rgb, pixel.rgb);
            f_a = distance(avg.a, pixel.a);
        }
        else
        {
            f_c = f_a = distance(avg, pixel);
        }
        f_c = applyLambda(f_c, lambda);
        f_a = applyLambda(f_a, lambda);
        float contrib = contribution(max_f, min_f, x, y);
        f_c *= contrib;
        f_a *= contrib;
#elif defined(GUIDANCE_SAMPLING)
        float contrib = contribution(max_f, min_f, x, y);
        float f_c = contrib, f_a = contrib;
#endif

        bool b = (PX >= oWidth || PY >= oHeight);
        add(color, pixel, b ? 0.0f : f_c, b ? 0.0f : f_a);
    }
#if USE_WARP_REDUCTION
    WarpReduce(color.color, color.factor, color.factorAlpha, indexInGroup);
#endif
    if (PX >= oWidth || PY >= oHeight)
        return;

    if (indexInGroup == 0)
    {
        normalize(color, avg);
#ifdef DEBUG_INVERT
        color.color.xyz = 1 - color.color.xyz;
#endif
        if (sRGB)
            color.color.xyz = LinearToSRGB(color.color.xyz); 
#if defined(DEBUG_GUIDANCE) && defined(DPID_SAMPLING)
        _Output[int2(PX, PY)] = _Guidance[int2(PX, PY)];
#else
        _Output[int2(PX, PY)] = color.color;
#endif
    } 
}

#if defined(GUIDANCE_SAMPLING)
[numthreads(THREADS, 1, 1)]
void KernelGuidance(uint3 id : SV_DispatchThreadID, uint3 gid: SV_GroupID, uint gtid: SV_GroupIndex)
{
    Downsampling(id, gid, gtid);
}
void KernelDownsampling(uint3 id : SV_DispatchThreadID, uint3 gid: SV_GroupID, uint gtid: SV_GroupIndex) { }
#else
[numthreads(THREADS, 1, 1)]
void KernelDownsampling(uint3 id : SV_DispatchThreadID, uint3 gid: SV_GroupID, uint gtid: SV_GroupIndex)
{
    Downsampling(id, gid, gtid);
}
[numthreads(THREADS, 1, 1)]
void KernelGuidance(uint3 id : SV_DispatchThreadID, uint3 gid: SV_GroupID, uint gtid: SV_GroupIndex) { }
#endif

