//
//  WAPaintCanvas.metal
//  WhatsApp
//
//  Created by Kuan Yong on 7/1/16.
//  Copyright © 2016 WhatsApp. All rights reserved.
//

#include <metal_stdlib>
#include <metal_texture>

using namespace metal;

struct VertexInOut {
    float4 position [[ position ]];
    float2 texCoord [[ user(texturecoord) ]];
};

struct VertexInPositionTextureColor {
    float2 position; // alignment: 8
    float2 texCoord; // alignment: 8
    half4 color; // alignment: 8
};

struct VertexOutPositionTextureColor {
    float4 position [[ position ]]; // alignment: 16
    float2 texCoord; // alignment: 8
    half4 color; // alignment: 8
};

struct VertexInPositionTexture {
    float2 position; // alignment: 8
    float2 texCoord; // alignment: 8
};

struct VertexOutPositionTexture {
    float4 position [[ position ]]; // alignment: 16
    float2 texCoord; // alignment: 8
};

constexpr sampler texSampler(address::clamp_to_zero, filter::linear, mip_filter::linear);
constexpr sampler layerSampler(filter::linear, mip_filter::none);
constexpr sampler cleanPlateSampler(filter::nearest, mip_filter::none);

vertex VertexInOut layerQuadVertex(const device float2 *position [[ buffer(0) ]],
                                   const device float2 *texCoord [[ buffer(1) ]],
                                   uint v_id [[ vertex_id ]] ) {
    VertexInOut v_out;
    v_out.position = float4(position[v_id], 0.0f, 1.0f);
    v_out.texCoord = texCoord[v_id];
    return v_out;
}

fragment half4 layerQuadFragment(VertexInOut inFrag [[ stage_in ]],
                                 texture2d<half> tex2D [[ texture(0) ]]) {
    half4 color = tex2D.sample(layerSampler, inFrag.texCoord);
    return color;
}

vertex void
convertPointToVertexPositionSizeColor(const device float2 *pos_in [[ buffer(0) ]],
                                      const device float2 *size_in [[ buffer(1) ]],
                                      const device float4 *color_in [[ buffer(2) ]],
                                      constant float2 &textureCoord0 [[ buffer(7) ]],
                                      constant float2 &textureCoord1 [[ buffer(8) ]],
                                      device VertexInPositionTextureColor *v_out [[ buffer(4) ]],
                                      uint v_id [[ vertex_id ]] ) {
    float2 pos = pos_in[v_id];
    float4 color = color_in[v_id];
    float width2 = size_in[v_id].x * 0.5;
    float height2 = size_in[v_id].y * 0.5;
    uint outID = v_id * 6;

    VertexInPositionTextureColor v0, v1, v2, v3;
    v0.position = float2(pos.x - width2, pos.y - height2);
    v0.color = half4(color);
    v0.texCoord = float2(textureCoord0.x, textureCoord0.y);
    v1.position = float2(pos.x - width2, pos.y + height2);
    v1.color = half4(color);
    v1.texCoord = float2(textureCoord0.x, textureCoord1.y);
    v2.position = float2(pos.x + width2, pos.y - height2);
    v2.color = half4(color);
    v2.texCoord = float2(textureCoord1.x, textureCoord0.y);
    v3.position = float2(pos.x + width2, pos.y + height2);
    v3.color = half4(color);
    v3.texCoord = float2(textureCoord1.x, textureCoord1.y);

    v_out[outID] = v0;
    v_out[outID+1] = v0;
    v_out[outID+2] = v1;
    v_out[outID+3] = v2;
    v_out[outID+4] = v3;
    v_out[outID+5] = v3;
}

vertex void
convertPointToVertexPositionSize(const device float2 *pos_in [[ buffer(0) ]],
                                 const device float2 *size_in [[ buffer(1) ]],
                                 device VertexInPositionTexture *v_out [[ buffer(4) ]],
                                 constant float2 &textureCoord0 [[ buffer(7) ]],
                                 constant float2 &textureCoord1 [[ buffer(8) ]],
                                 uint v_id [[ vertex_id ]] ) {
    float2 pos = pos_in[v_id];
    float width2 = size_in[v_id].x * 0.5;
    float height2 = size_in[v_id].y * 0.5;
    uint outID = v_id * 6;

    VertexInPositionTexture v0, v1, v2, v3;
    v0.position = float2(pos.x - width2, pos.y - height2);
    v0.texCoord = float2(textureCoord0.x, textureCoord0.y);
    v1.position = float2(pos.x - width2, pos.y + height2);
    v1.texCoord = float2(textureCoord0.x, textureCoord1.y);
    v2.position = float2(pos.x + width2, pos.y - height2);
    v2.texCoord = float2(textureCoord1.x, textureCoord0.y);
    v3.position = float2(pos.x + width2, pos.y + height2);
    v3.texCoord = float2(textureCoord1.x, textureCoord1.y);

    v_out[outID] = v0;
    v_out[outID+1] = v0;
    v_out[outID+2] = v1;
    v_out[outID+3] = v2;
    v_out[outID+4] = v3;
    v_out[outID+5] = v3;
}

vertex void
convertPointToVertexPositionSizeColorAngle(const device float2 *pos_in [[ buffer(0) ]],
                                           const device float2 *size_in [[ buffer(1) ]],
                                           const device float4 *color_in [[ buffer(2) ]],
                                           const device float *angle_in [[ buffer(3) ]],
                                           constant float &ratio [[ buffer(6) ]],
                                           constant float2 &textureCoord0 [[ buffer(7) ]],
                                           constant float2 &textureCoord1 [[ buffer(8) ]],
                                           device VertexInPositionTextureColor *v_out [[ buffer(4) ]],
                                           uint v_id [[ vertex_id ]] ) {
    float2 pos = pos_in[v_id];
    float4 color = color_in[v_id];
    float width2 = size_in[v_id].x * 0.5;
    float height2 = size_in[v_id].y * 0.5;
    float angle = angle_in[v_id];
    float2 x = float2(cos(angle) * width2, sin(angle) * ratio * width2);
    float2 y = float2(-sin(angle) / ratio * height2, cos(angle) * height2);
    uint outID = v_id * 6;

    VertexInPositionTextureColor v0, v1, v2, v3;
    v0.position = pos - x - y;
    v0.color = half4(color);
    v0.texCoord = float2(textureCoord0.x, textureCoord0.y);
    v1.position = pos - x + y;
    v1.color = half4(color);
    v1.texCoord = float2(textureCoord0.x, textureCoord1.y);
    v2.position = pos + x - y;
    v2.color = half4(color);
    v2.texCoord = float2(textureCoord1.x, textureCoord0.y);
    v3.position = pos + x + y;
    v3.color = half4(color);
    v3.texCoord = float2(textureCoord1.x, textureCoord1.y);

    v_out[outID] = v0;
    v_out[outID+1] = v0;
    v_out[outID+2] = v1;
    v_out[outID+3] = v2;
    v_out[outID+4] = v3;
    v_out[outID+5] = v3;
}

vertex VertexOutPositionTextureColor
brushPickingVertex(const device VertexInPositionTextureColor *v_in [[ buffer(4) ]],
                   constant float3x3 &mat [[ buffer(5) ]],
                   uint v_id [[ vertex_id ]]) {
    VertexInPositionTextureColor v = v_in[v_id];
    VertexOutPositionTextureColor v_out;
    float4 position;
    position.xyw = mat * float3(v.position, 1.0f);
    position.z = 0.0f;
    v_out.position = position;
    v_out.texCoord = v.texCoord;
    v_out.color = v.color;
    return v_out;
}

fragment half4
brushPickingFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
                     texture2d<float> tex2D [[ texture(0) ]]) {
    // Force texture alpha to go to 1.0 to make transparent areas pickable.
    float alpha = min(1.0f, 1000.0f * tex2D.sample(texSampler, inFrag.texCoord).a);
    half4 src_color = alpha * inFrag.color;
    return saturate(src_color);
}

vertex VertexOutPositionTextureColor
basicBrushVertex(const device VertexInPositionTextureColor *v_in [[ buffer(4) ]],
                 constant float3x3 &mat [[ buffer(5) ]],
                 uint v_id [[ vertex_id ]]) {
    VertexInPositionTextureColor v = v_in[v_id];
    VertexOutPositionTextureColor v_out;
    float4 position;
    position.xyw = mat * float3(v.position, 1.0f);
    position.z = 0.0f;
    v_out.position = position;
    v_out.texCoord = v.texCoord;
    v_out.color = v.color;
    return v_out;
}

fragment half4
basicBrushFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
                   texture2d<half> tex2D [[ texture(0) ]]) {
    return saturate(tex2D.sample(texSampler, inFrag.texCoord) * inFrag.color);
}

vertex VertexOutPositionTexture
pixelateBrushVertex(const device VertexInPositionTexture *v_in [[ buffer(4) ]],
                    constant float3x3 &mat [[ buffer(5) ]],
                    uint v_id [[ vertex_id ]]) {
    VertexInPositionTexture v = v_in[v_id];
    VertexOutPositionTexture v_out;
    float4 position;
    position.xyw = mat * float3(v.position, 1.0f);
    position.z = 0.0f;
    v_out.position = position;
    v_out.texCoord = v.texCoord;
    return v_out;
}

fragment half4
pixelateBrushFragment(VertexOutPositionTexture inFrag [[ stage_in ]],
                      constant float2 &windowSize [[ buffer(0) ]],
                      constant float &width [[ buffer(1) ]],
                      texture2d<half> brushTex2D [[ texture(0) ]],
                      texture2d<half> cleanPlateTex2D [[ texture(1) ]]) {
    float2 position = float2(inFrag.position.x / windowSize.x, 1.0h - inFrag.position.y / windowSize.y);
    float aspectRatio = windowSize.x / windowSize.y;
    float2 stepSize = float2(width, aspectRatio * width);
    position = round(position / stepSize) * stepSize;
    half4 src_color = cleanPlateTex2D.sample(cleanPlateSampler, position);
    src_color = src_color * brushTex2D.sample(texSampler, inFrag.texCoord);
    return saturate(src_color);
}

vertex VertexOutPositionTextureColor
stampVertex(const device VertexInPositionTextureColor *v_in [[ buffer(4) ]],
            constant float3x3 &mat [[ buffer(5) ]],
            uint v_id [[ vertex_id ]]) {
    VertexInPositionTextureColor v = v_in[v_id];
    VertexOutPositionTextureColor v_out;
    float4 position;
    position.xyw = mat * float3(v.position, 1.0f);
    position.z = 0.0f;
    v_out.position = position;
    v_out.texCoord = v.texCoord;
    v_out.color = v.color;
    return v_out;
}

fragment half4
stampFragment(VertexOutPositionTextureColor inFrag [[ stage_in ]],
              texture2d<half> tex2D [[ texture(0) ]]) {
    // Alpha is <1.0 only when we are dimming the stamp when it's eligible to be deleted.
    return saturate(tex2D.sample(texSampler, inFrag.texCoord) * inFrag.color.a);
}
