Resolve WES-90 "Integrate signpredictor in courses"

This commit is contained in:
Louis Adriaens
2023-03-18 19:53:17 +00:00
committed by Jerome Coudron
parent 1a75791d62
commit 746906294b
463 changed files with 99422 additions and 1187 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: f89931ec4ed9542308f3425d051750b9
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,34 @@
// Most often used kernels
#pragma kernel Relu_Flat
#pragma kernel Relu_FlatStrict
#pragma kernel Relu_Loop
#pragma kernel Relu6_Flat
#pragma kernel Relu6_FlatStrict
#pragma kernel Relu6_Loop
#pragma kernel Tanh_Flat
#pragma kernel Tanh_FlatStrict
#pragma kernel Tanh_Loop
#pragma kernel Swish_Flat
#pragma kernel Swish_FlatStrict
#pragma kernel Swish_Loop
#pragma kernel Sigmoid_Flat
#pragma kernel Sigmoid_FlatStrict
#pragma kernel Sigmoid_Loop
#pragma kernel LeakyRelu_Flat
#pragma kernel LeakyRelu_FlatStrict
#pragma kernel LeakyRelu_Loop
#pragma kernel Clip_Flat
#pragma kernel Clip_FlatStrict
#pragma kernel Clip_Loop
#pragma kernel PRelu_Flat
#pragma kernel PRelu_Loop
#include "Activation.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 517d235ce3daa4bcd88fd5494d4b99ed
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,13 @@
#pragma kernel Reciprocal_Flat
#pragma kernel Reciprocal_FlatStrict
#pragma kernel Reciprocal_Loop
#pragma kernel Sqrt_Flat
#pragma kernel Sqrt_FlatStrict
#pragma kernel Sqrt_Loop
#pragma kernel HardSigmoid_Flat
#pragma kernel HardSigmoid_FlatStrict
#pragma kernel HardSigmoid_Loop
#include "Activation.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 110f1fc1578364452982dd20f246f765
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,249 @@
#pragma kernel Abs_Flat
#pragma kernel Abs_FlatStrict
#pragma kernel Abs_Loop
#pragma kernel Neg_Flat
#pragma kernel Neg_FlatStrict
#pragma kernel Neg_Loop
#pragma kernel Ceil_Flat
#pragma kernel Ceil_FlatStrict
#pragma kernel Ceil_Loop
#pragma kernel Floor_Flat
#pragma kernel Floor_FlatStrict
#pragma kernel Floor_Loop
#pragma kernel Round_Flat
#pragma kernel Round_FlatStrict
#pragma kernel Round_Loop
#pragma kernel Selu_Flat
#pragma kernel Selu_FlatStrict
#pragma kernel Selu_Loop
#pragma kernel Softplus_Flat
#pragma kernel Softplus_FlatStrict
#pragma kernel Softplus_Loop
#pragma kernel Elu_Flat
#pragma kernel Elu_FlatStrict
#pragma kernel Elu_Loop
#pragma kernel Exp_Flat
#pragma kernel Exp_FlatStrict
#pragma kernel Exp_Loop
#pragma kernel Log_Flat
#pragma kernel Log_FlatStrict
#pragma kernel Log_Loop
#pragma kernel Pow_Flat
#pragma kernel Pow_FlatStrict
#pragma kernel Pow_Loop
#pragma kernel LogicalNot_Flat
#pragma kernel LogicalNot_FlatStrict
#pragma kernel Sign_Loop
#pragma kernel Sign_Flat
#pragma kernel Sign_FlatStrict
#pragma kernel Sign_Loop
#pragma kernel Acos_Flat
#pragma kernel Acos_FlatStrict
#pragma kernel Acos_Loop
#pragma kernel Acosh_Flat
#pragma kernel Acosh_FlatStrict
#pragma kernel Acosh_Loop
#pragma kernel Asin_Flat
#pragma kernel Asin_FlatStrict
#pragma kernel Asin_Loop
#pragma kernel Asinh_Flat
#pragma kernel Asinh_FlatStrict
#pragma kernel Asinh_Loop
#pragma kernel Atan_Flat
#pragma kernel Atan_FlatStrict
#pragma kernel Atan_Loop
#pragma kernel Atanh_Flat
#pragma kernel Atanh_FlatStrict
#pragma kernel Atanh_Loop
#pragma kernel Cos_Flat
#pragma kernel Cos_FlatStrict
#pragma kernel Cos_Loop
#pragma kernel Cosh_Flat
#pragma kernel Cosh_FlatStrict
#pragma kernel Cosh_Loop
#pragma kernel Sin_Flat
#pragma kernel Sin_FlatStrict
#pragma kernel Sin_Loop
#pragma kernel Sinh_Flat
#pragma kernel Sinh_FlatStrict
#pragma kernel Sinh_Loop
#pragma kernel Tan_Flat
#pragma kernel Tan_FlatStrict
#pragma kernel Tan_Loop
#pragma kernel Erf_Flat
#pragma kernel Erf_FlatStrict
#pragma kernel Erf_Loop
#pragma kernel Relu_NHWC CHANNELS_FIRST=0
#pragma kernel Relu_NCHW CHANNELS_FIRST=1
#pragma kernel Relu_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Relu_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Relu_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Relu_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Relu6_NHWC CHANNELS_FIRST=0
#pragma kernel Relu6_NCHW CHANNELS_FIRST=1
#pragma kernel Relu6_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Relu6_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Relu6_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Relu6_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel PRelu_NHWC CHANNELS_FIRST=0
#pragma kernel PRelu_NCHW CHANNELS_FIRST=1
#pragma kernel PRelu_CNyx2_NHWC CHANNELS_FIRST=0
//#pragma kernel PRelu_CNyx2_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Selu_NHWC CHANNELS_FIRST=0
#pragma kernel Selu_NCHW CHANNELS_FIRST=1
#pragma kernel Selu_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Selu_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Selu_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Selu_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Tanh_NHWC CHANNELS_FIRST=0
#pragma kernel Tanh_NCHW CHANNELS_FIRST=1
#pragma kernel Tanh_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Tanh_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Tanh_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Tanh_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Swish_NHWC CHANNELS_FIRST=0
#pragma kernel Swish_NCHW CHANNELS_FIRST=1
#pragma kernel Swish_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Swish_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Swish_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Swish_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Softplus_NHWC CHANNELS_FIRST=0
#pragma kernel Softplus_NCHW CHANNELS_FIRST=1
#pragma kernel Softplus_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Softplus_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Softplus_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Softplus_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Sigmoid_NHWC CHANNELS_FIRST=0
#pragma kernel Sigmoid_NCHW CHANNELS_FIRST=1
#pragma kernel Sigmoid_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Sigmoid_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Sigmoid_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Sigmoid_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel HardSigmoid_NHWC CHANNELS_FIRST=0
#pragma kernel HardSigmoid_NCHW CHANNELS_FIRST=1
#pragma kernel HardSigmoid_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel HardSigmoid_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel HardSigmoid_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel HardSigmoid_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Elu_NHWC CHANNELS_FIRST=0
#pragma kernel Elu_NCHW CHANNELS_FIRST=1
#pragma kernel Elu_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Elu_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Elu_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Elu_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel LeakyRelu_NHWC CHANNELS_FIRST=0
#pragma kernel LeakyRelu_NCHW CHANNELS_FIRST=1
#pragma kernel LeakyRelu_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel LeakyRelu_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel LeakyRelu_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel LeakyRelu_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Exp_NHWC CHANNELS_FIRST=0
#pragma kernel Exp_NCHW CHANNELS_FIRST=1
#pragma kernel Exp_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Exp_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Exp_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Exp_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Log_NHWC CHANNELS_FIRST=0
#pragma kernel Log_NCHW CHANNELS_FIRST=1
#pragma kernel Log_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Log_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Log_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Log_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Sqrt_NHWC CHANNELS_FIRST=0
#pragma kernel Sqrt_NCHW CHANNELS_FIRST=1
#pragma kernel Sqrt_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Sqrt_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Sqrt_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Sqrt_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Pow_NHWC CHANNELS_FIRST=0
#pragma kernel Pow_NCHW CHANNELS_FIRST=1
#pragma kernel Pow_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Pow_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Pow_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Pow_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Clip_NHWC CHANNELS_FIRST=0
#pragma kernel Clip_NCHW CHANNELS_FIRST=1
#pragma kernel Clip_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Clip_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Clip_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Clip_Nyxc_NCHW CHANNELS_FIRST=1
#pragma kernel Acos_NHWC CHANNELS_FIRST=0
#pragma kernel Acos_NCHW CHANNELS_FIRST=1
#pragma kernel Acos_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Acos_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Acos_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Acos_Nyxc_NHWCCHANNELS_FIRST=1
#pragma kernel Acosh_NHWC CHANNELS_FIRST=0
#pragma kernel Acosh_NCHW CHANNELS_FIRST=1
#pragma kernel Acosh_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Acosh_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Acosh_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Acosh_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Asin_NHWC CHANNELS_FIRST=0
#pragma kernel Asin_NCHW CHANNELS_FIRST=1
#pragma kernel Asin_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Asin_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Asin_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Asin_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Asinh_NHWC CHANNELS_FIRST=0
#pragma kernel Asinh_NCHW CHANNELS_FIRST=1
#pragma kernel Asinh_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Asinh_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Asinh_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Asin_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Atan_NHWC CHANNELS_FIRST=0
#pragma kernel Atan_NCHW CHANNELS_FIRST=1
#pragma kernel Atan_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Atan_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Atan_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Atan_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Atanh_NHWC CHANNELS_FIRST=0
#pragma kernel Atanh_NCHW CHANNELS_FIRST=1
#pragma kernel Atanh_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Atanh_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Atanh_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Atanh_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Cos_NHWC CHANNELS_FIRST=0
#pragma kernel Cos_NCHW CHANNELS_FIRST=1
#pragma kernel Cos_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Cos_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Cos_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Cos_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Cosh_NHWC CHANNELS_FIRST=0
#pragma kernel Cosh_NCHW CHANNELS_FIRST=1
#pragma kernel Cosh_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Cosh_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Cosh_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Cosh_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Sin_NHWC CHANNELS_FIRST=0
#pragma kernel Sin_NCHW CHANNELS_FIRST=1
#pragma kernel Sin_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Sin_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Sin_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Sin_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Sinh_NHWC CHANNELS_FIRST=0
#pragma kernel Sinh_NCHW CHANNELS_FIRST=1
#pragma kernel Sinh_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Sinh_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Sinh_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Sinh_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Tan_NHWC CHANNELS_FIRST=0
#pragma kernel Tan_NCHW CHANNELS_FIRST=1
#pragma kernel Tan_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Tan_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Tan_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Tan_Nyxc_NHWC CHANNELS_FIRST=1
#pragma kernel Erf_NHWC CHANNELS_FIRST=0
#pragma kernel Erf_NCHW CHANNELS_FIRST=1
#pragma kernel Erf_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel Erf_CNyx_NHWC CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel Erf_Nyxc_NHWC CHANNELS_FIRST=0
//#pragma kernel Erf_Nyxc_NHWC CHANNELS_FIRST=1
#include "Activation.cginc"

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: fdc94044b2f234c0fa80ada3771a2ae7
timeCreated: 1495527718
licenseType: Pro
ComputeShaderImporter:
currentAPIMask: 196608
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: b4b1b304aae6c404cb0cdab46b8fa084
timeCreated: 1495527718
licenseType: Pro
ComputeShaderImporter:
currentAPIMask: 196608
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,410 @@
#include "Tensor.cginc"
float _Alpha;
int _IsFirstDispatch;
uint4 _XStrides;
uint4 _SStrides;
uint4 _BStrides;
TENSOR_DECL(X)
TENSOR_DECL(S)
TENSOR_DECL(B)
TENSOR_DECL_RW(O)
void DispatchThreadIdToTensorIndices(uint3 dispatchThreadID, out uint c, out uint x, out uint y)
{
#if CHANNELS_FIRST
//DISPATCH ARGS(O.width, O.height, O.channels);
x = dispatchThreadID.x;
y = dispatchThreadID.y;
c = dispatchThreadID.z;
#else
//DISPATCH ARGS(O.channels, O.width, O.height);
c = dispatchThreadID.x;
x = dispatchThreadID.y;
y = dispatchThreadID.z;
#endif
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastAdd)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v =
X.FastGet(dot(uint4(n, y, x, c), _XStrides)) +
B.FastGet(dot(uint4(n, y, x, c), _BStrides));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastSub)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v =
X.FastGet(dot(uint4(n, y, x, c), _XStrides)) -
B.FastGet(dot(uint4(n, y, x, c), _BStrides));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastMul)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v =
X.FastGet(dot(uint4(n, y, x, c), _XStrides)) *
B.FastGet(dot(uint4(n, y, x, c), _BStrides));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastDiv)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v =
X.FastGet(dot(uint4(n, y, x, c), _XStrides)) /
B.FastGet(dot(uint4(n, y, x, c), _BStrides));
O.Set(n, y, x, c, v);
}
}
float signed_pow(float f, float e)
{
// handle negative f
float v = pow(abs(f), e);
float s = (e % 2 == 1) ?
sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
1; // exponent is even => pow(abs(f), e)
return v * s;
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastPow)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v = signed_pow(
X.FastGet(dot(uint4(n, y, x, c), _XStrides)),
B.FastGet(dot(uint4(n, y, x, c), _BStrides)));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastMin)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v = min(
X.FastGet(dot(uint4(n, y, x, c), _XStrides)),
B.FastGet(dot(uint4(n, y, x, c), _BStrides)));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastMax)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v = max(
X.FastGet(dot(uint4(n, y, x, c), _XStrides)),
B.FastGet(dot(uint4(n, y, x, c), _BStrides)));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(BroadcastMean)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
a *= _IsFirstDispatch ? _Alpha : 1.0f;
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides)) * _Alpha;
float v = a + b;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastGreater)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = (a > b) ? 1.0f : 0.0f;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastGreaterEqual)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = (a >= b) ? 1.0f : 0.0f;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastLess)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = (a < b) ? 1.0f : 0.0f;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastLessEqual)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = (a <= b) ? 1.0f : 0.0f;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastEqual)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = (a == b) ? 1.0f : 0.0f;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastLogicalOr)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = (X.FastGet(dot(uint4(n, y, x, c), _XStrides)) == 0.0f) ? 0.0f : 1.0f;
float b = (B.FastGet(dot(uint4(n, y, x, c), _BStrides)) == 0.0f) ? 0.0f : 1.0f;
float v = a * (1 - b) + b;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastLogicalAnd)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = a * b != 0.0 ? 1.0f : 0.0f;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(BroadcastLogicalXor)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_TWOINPUTS(X, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
float a = X.FastGet(dot(uint4(n, y, x, c), _XStrides)) != 0.0f ? 1.0f : 0.0f;
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides)) != 0.0f ? 1.0f : 0.0f;
float v = a * (1 - 2 * b) + b;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(BroadcastWhere)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_THREEINPUTS(X, S, B, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels)
return;
if (x >= O.width)
return;
if (y >= O.height)
return;
for (uint n = 0; n < O.batch; ++n)
{
bool cond = (X.FastGet(dot(uint4(n, y, x, c), _XStrides)) != 0.0f);
float a = S.FastGet(dot(uint4(n, y, x, c), _SStrides));
float b = B.FastGet(dot(uint4(n, y, x, c), _BStrides));
float v = cond ? a : b;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(BroadcastDivExpSub)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_THREEINPUTS(X, B, S, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v =
X.FastGet(dot(uint4(n, y, x, c), _XStrides)) -
B.FastGet(dot(uint4(n, y, x, c), _BStrides));
v = exp(v) / S.FastGet(dot(uint4(n, y, x, c), _SStrides));
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(LogSoftmaxEnd)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_THREEINPUTS(X, B, S, O);
uint c, x, y;
DispatchThreadIdToTensorIndices(dispatchThreadID, c, x, y);
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v =
X.FastGet(dot(uint4(n, y, x, c), _XStrides)) -
B.FastGet(dot(uint4(n, y, x, c), _BStrides));
v = v - log(S.FastGet(dot(uint4(n, y, x, c), _SStrides)));
O.Set(n, y, x, c, v);
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: fc624dd44959d4dfcad99aed0abc2a8d
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,22 @@
#pragma kernel BroadcastAdd_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastSub_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastMul_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastDiv_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastPow_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastMin_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastMax_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastMean_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastGreater_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastGreaterEqual_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastLess_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastLessEqual_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastEqual_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastLogicalOr_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastLogicalAnd_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastLogicalXor_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastWhere_NCHW CHANNELS_FIRST=1
#pragma kernel BroadcastDivExpSub_NCHW CHANNELS_FIRST=1
#pragma kernel LogSoftmaxEnd_NCHW CHANNELS_FIRST=1
#include "Broadcast.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 5d7fa6770eadc4ef38d7b12a5dedf404
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,22 @@
#pragma kernel BroadcastAdd_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastSub_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastMul_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastDiv_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastPow_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastMin_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastMax_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastMean_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastGreater_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastGreaterEqual_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastLess_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastLessEqual_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastEqual_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastLogicalOr_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastLogicalAnd_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastLogicalXor_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastWhere_NHWC CHANNELS_FIRST=0
#pragma kernel BroadcastDivExpSub_NHWC CHANNELS_FIRST=0
#pragma kernel LogSoftmaxEnd_NHWC CHANNELS_FIRST=0
#include "Broadcast.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: e08c989f90a0240cdac731efb621231e
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 8211ebc2a8cd04e49a086347aebe8ee6
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,17 @@
// Most often used kernels
#pragma kernel Conv2D_NCHW CHANNELS_FIRST=1
#pragma kernel Conv2D_RegisterBlock4x2_NCHW CHANNELS_FIRST=1
#pragma kernel DepthwiseConv2D_NCHW CHANNELS_FIRST=1
//R4x4_64k
#pragma kernel Conv2DKernelKxK_StrictC16K64_T16x16_R4x4_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=4 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC16K64_T16x16_R
#pragma kernel Conv2DKernelKxK_T16x16_R4x4_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=4 SUFFIX=KernelKxK_T16x16_R
#pragma kernel Conv2DKernel1x1_StrictC16K64_T16x16_R4x4_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=4 KERNEL_1x1=1 STRICT_CHANNELS=1 SUFFIX=Kernel1x1_StrictC16K64_T16x16_R
#include "Conv2d.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 9d6406345bbd8482bab46e622092abcb
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,16 @@
// Most often used kernels
#pragma kernel Conv2D_NHWC CHANNELS_FIRST=0
#pragma kernel Conv2D_RegisterBlock4x2_NHWC CHANNELS_FIRST=0
#pragma kernel DepthwiseConv2D_NHWC CHANNELS_FIRST=0
//R4x4_64k
#pragma kernel Conv2DKernelKxK_StrictC16K64_T16x16_R4x4_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=4 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC16K64_T16x16_R
#pragma kernel Conv2DKernelKxK_T16x16_R4x4_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=4 SUFFIX=KernelKxK_T16x16_R
#pragma kernel Conv2DKernel1x1_StrictC16K64_T16x16_R4x4_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=4 KERNEL_1x1=1 STRICT_CHANNELS=1 SUFFIX=Kernel1x1_StrictC16K64_T16x16_R
#include "Conv2d.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 60d69d385fb8141349f401ede7d4d5c7
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,34 @@
//R8x8_64k
#pragma kernel Conv2DKernelKxK_StrictC16StrictK64_T8x8_R8x8_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=8 KERNEL_PER_TG=64 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC16StrictK64_T8x8_R
#pragma kernel Conv2DKernelKxK_StrictC16StrictK64_T8x8_R8x8_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=8 KERNEL_PER_TG=64 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC16StrictK64_T8x8_R
#pragma kernel Conv2DKernelKxK_StrictC16LaxK64_T8x8_R8x8_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=8 KERNEL_PER_TG=64 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=KernelKxK_StrictC16LaxK64_T8x8_R
#pragma kernel Conv2DKernelKxK_StrictC16LaxK64_T8x8_R8x8_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=8 KERNEL_PER_TG=64 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=KernelKxK_StrictC16LaxK64_T8x8_R
//R8x8_16k
#pragma kernel Conv2DKernelKxK_StrictC4StrictK16_T2x32_R8x8_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=8 KERNEL_PER_TG=16 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC4StrictK16_T2x32_R
#pragma kernel Conv2DKernelKxK_StrictC4StrictK16_T2x32_R8x8_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=8 KERNEL_PER_TG=16 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC4StrictK16_T2x32_R
#pragma kernel Conv2DKernelKxK_LaxC4StrictK16_T2x32_R8x8_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=8 KERNEL_PER_TG=16 SUFFIX=KernelKxK_LaxC4StrictK16_T2x32_R
#pragma kernel Conv2DKernelKxK_LaxC4StrictK16_T2x32_R8x8_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=8 KERNEL_PER_TG=16 SUFFIX=KernelKxK_LaxC4StrictK16_T2x32_R
#pragma kernel Conv2DKernelKxK_StrictC4LaxK16_T2x32_R8x8_NHWC CHANNELS_FIRST=0 BLOCK_SIZE=8 KERNEL_PER_TG=16 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=KernelKxK_StrictC4LaxK16_T2x32_R
#pragma kernel Conv2DKernelKxK_StrictC4LaxK16_T2x32_R8x8_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=8 KERNEL_PER_TG=16 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=KernelKxK_StrictC4LaxK16_T2x32_R
#pragma kernel Conv2DTrans_NHWC CHANNELS_FIRST=0
#pragma kernel Conv2DTrans_NCHW CHANNELS_FIRST=1
//Tested 2x2, 3x3 and 5x5 kernels with groupsize [8,8], [8,16], [16,16] and [16,32] (this one not in 5x5 as it does not fit in 32k)
//k=5x5 t=[16,16] fast consistently faster or equal to other configuration both on AMDVega and RTX2080 (tested with kernel size 2x2x32x32, input size 128x128x32)
//however this configuration is quite LDS bound performance profile might be very different on hardware without on chip LDS. This is especially true for smaller kernel
//as a lot of LDS will be reserved but not used, reducing the amount of cache used.
#pragma kernel Conv2DTrans_KernelCached_K5x5_T16x16_NHWC CHANNELS_FIRST=0 MAX_KERNEL_SIZE=5 GROUP_SIZE_X=16 GROUP_SIZE_Y=16
#pragma kernel Conv2DTrans_KernelCached_K5x5_T16x16_NCHW CHANNELS_FIRST=1 MAX_KERNEL_SIZE=5 GROUP_SIZE_X=16 GROUP_SIZE_Y=16
#pragma kernel Conv2DTransFlipKernel
#pragma kernel Conv2DTransPadFill_NHWC CHANNELS_FIRST=0
#pragma kernel Conv2DTransPadFill_NCHW CHANNELS_FIRST=1
#pragma kernel KernelWinograd_3x3
#pragma kernel Conv2DWinograd_2x2_Kernel3x3_StrictC8StrictK16_T16x16_R4x4_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=4 KERNEL_PER_TG=16 STRICT_CHANNELS=1 SUFFIX=Kernel3x3_StrictC8StrictK16_T16x16_R
#pragma kernel Conv2DWinograd_2x2_Kernel3x3_StrictC8LaxK16_T16x16_R4x4_NCHW CHANNELS_FIRST=1 BLOCK_SIZE=4 KERNEL_PER_TG=16 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=Kernel3x3_StrictC8LaxK16_T16x16_R
#include "Conv2d.cginc"

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 1279e283ef61d47309a96431ea81d6bb
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2164736
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 37f7d6dfde4c7c141ae5b12a1bf7b18d
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 2097156
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,263 @@
#pragma kernel Conv3D_NHWC CHANNELS_FIRST=0
#pragma kernel Conv3D_NCHW CHANNELS_FIRST=1
#pragma kernel Conv3DKernelKxK_LaxC8LaxK32_T8x16_R4x4_NHWC CHANNELS_FIRST=0 LAX_KERNEL=1 SUFFIX=KernelKxK_LaxC8LaxK32_T8x16_R
#pragma kernel Conv3DKernelKxK_LaxC8LaxK32_T8x16_R4x4_NCHW CHANNELS_FIRST=1 LAX_KERNEL=1 SUFFIX=KernelKxK_LaxC8LaxK32_T8x16_R
#pragma kernel Conv3DKernelKxK_StrictC8LaxK32_T8x16_R4x4_NHWC CHANNELS_FIRST=0 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=KernelKxK_StrictC8LaxK32_T8x16_R
#pragma kernel Conv3DKernelKxK_StrictC8LaxK32_T8x16_R4x4_NCHW CHANNELS_FIRST=1 STRICT_CHANNELS=1 LAX_KERNEL=1 SUFFIX=KernelKxK_StrictC8LaxK32_T8x16_R
#pragma kernel Conv3DKernelKxK_StrictC8StrictK32_T8x16_R4x4_NHWC CHANNELS_FIRST=0 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC8StrictK32_T8x16_R
#pragma kernel Conv3DKernelKxK_StrictC8StrictK32_T8x16_R4x4_NCHW CHANNELS_FIRST=1 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC8StrictK32_T8x16_R
#include "Tensor.cginc"
TENSOR_DECL(X)
TENSOR_DECL(K)
TENSOR_DECL(B)
TENSOR_DECL(WBK)
TENSOR_DECL_RW(O)
uint4 _Pad;
uint4 _Stride;
float ffma(float a, float b, float c) { return dot(float2(a,c), float2(b,1)); }
#if CHANNELS_FIRST
#define FUNC_NAME_CALL(KERNEL, SUFFIX, SIZE) KERNEL##SUFFIX##SIZE##x##SIZE##_NCHW
#define CACHE_NAME_CALL(KERNEL, SUFFIX, SIZE, TENSOR) KERNEL##SUFFIX##SIZE##x##SIZE##_Cache_##TENSOR##_NCHW
#else
#define FUNC_NAME_CALL(KERNEL, SUFFIX, SIZE) KERNEL##SUFFIX##SIZE##x##SIZE##_NHWC
#define CACHE_NAME_CALL(KERNEL, SUFFIX, SIZE, TENSOR) KERNEL##SUFFIX##SIZE##x##SIZE##_Cache_##TENSOR##_NHWC
#endif
#define FUNC_NAME(KERNEL, SUFFIX, SIZE) FUNC_NAME_CALL(KERNEL, SUFFIX, SIZE)
#define CACHE_NAME(KERNEL, SUFFIX, SIZE, TENSOR) CACHE_NAME_CALL(KERNEL, SUFFIX, SIZE, TENSOR)
#define KERNEL_NAME Conv3D
NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
void KERNEL_FUNC(Conv3D)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(K.kernelCount, O.width, O.height);
TENSOR_SHARED2_ARGS4_8D(X, K, B, WBK, O);
uint k = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (k >= K.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
uint3 leftCorner = _Pad.xyz;
uint3 rightCorner = uint3(X.width, X.height, X.depth) + _Pad.xyz;
for (uint n = 0; n < O.batch; ++n)
for (uint d = 0; d < O.depth; ++d)
{
float acc = B.FastGet(k);
for (uint dd = 0; dd < K.GetKernelSpatialDepth(); ++dd)
{
for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
{
for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
{
uint3 pos3d = uint3(x, y, d) * _Stride.xyz + uint3(dx, dy, dd);
for (uint c = 0; c < X.channels; ++c)
{
float v = 0;
// WARNING: Mali-G71 performance drops 4x if this branching includes storing accumulator (comment copied from Conv2D kernel)
if (!any(pos3d < leftCorner) && !any(pos3d >= rightCorner))
v = X.Get5D(n, pos3d.z - leftCorner.z, pos3d.y - leftCorner.y, pos3d.x - leftCorner.x, c);
//acc = fastfma(v, K.Get(dy, dx, c, k), acc);
acc += v * K.GetKernel5D(dd, dy, dx, c, k);
}
}
}
}
O.Set5DWithActivation(n, d, y, x, k, acc);
}
}
#define PIXEL_PER_TG 64 //only supported value
#define KERNEL_PER_TG 32 //only supported value
#define BLOCK_SIZE 4 //only supported value
#define CACHE_DEPTH 8 //only support modulo of 4 values.
//Each thread handle = 4 kernels * 4 pixels (in registers) and all channels
//A threadgroup (8,16,1) handle = 32 kernels x 64 pixels and all channels (looping on CACHE_DEPTH channel at a time)
groupshared float CACHE_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE, LDS) [(32+64) * CACHE_DEPTH]; //(32+64)*CACHE_DEPTH == 96*CACHE_DEPTH floats (CACHE_DEPTH == 8 --> 768 floats)
[numthreads(8,16,1)]
void FUNC_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE)(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID, uint threadIndex : SV_GroupIndex)
{
//This kernel assume the following:
//Input:
// C % CACHE_DEPTH==0 <-- only if STRICT_CHANNELS==1
//Kernel:
// K%32==0 <-- only if LAX_KERNEL=0
//DISPATCH ARGS(K.kernelCount, O.width * O.height * O.depth, O.batch);
TENSOR_SHARED2_ARGS4_8D(X, K, B, WBK, O);
#define LDS_ CACHE_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE, LDS)
#define X_OFFSET 0
#define W_OFFSET CACHE_DEPTH*PIXEL_PER_TG
//Per thread group (scalar registers)
uint tg_NumChannels = X.channels;
uint tg_DepthX = X.depth;
uint tg_WidthX = X.width;
uint tg_HeightX = X.height;
uint tg_DepthO = O.depth;
uint tg_WidthO = O.width;
uint tg_HeightO = O.height;
uint tg_NumKernels = K.channels;
uint tg_NumInputPixels = tg_DepthX*tg_WidthX*tg_HeightX;
uint tg_NumOuputPixels = tg_DepthO*tg_WidthO*tg_HeightO;
uint tg_KernelSpatialStride = tg_NumKernels*tg_NumChannels;
uint tg_KernelBaseId = groupID.x * KERNEL_PER_TG;
uint tg_OutputPixelBaseId = groupID.y * PIXEL_PER_TG;
uint tg_BatchReadOffset = groupID.z * tg_NumChannels * tg_NumInputPixels;
uint tg_BatchWriteOffset = groupID.z * tg_NumKernels * tg_NumOuputPixels;
uint tg_kernelSpatialOffset = 0;
//4x4 block, 4 kernels by 4 pixels
float dstA[BLOCK_SIZE*BLOCK_SIZE];
//Load Bias [K] into dstA [Kernels, Pixels]
uint tg_kId;
uint tg_pId;
uint maxBiasIndex = O.channels - 1;
[unroll] for (tg_pId = 0; tg_pId < BLOCK_SIZE; ++tg_pId)
[unroll] for (tg_kId = 0; tg_kId < BLOCK_SIZE; ++tg_kId)
dstA[tg_pId*BLOCK_SIZE+tg_kId] = B.FastGet(min(maxBiasIndex,tg_KernelBaseId + groupThreadID.x * BLOCK_SIZE + tg_kId));
//Looping over kernel spatially
for (uint tg_Dd = 0; tg_Dd < K.GetKernelSpatialDepth(); ++tg_Dd)
for (uint tg_Dy = 0; tg_Dy < K.GetKernelHeight(); ++tg_Dy)
for (uint tg_Dx = 0; tg_Dx < K.GetKernelWidth(); ++tg_Dx)
{
//Looping over channels, convolving CACHE_DEPTH of them at a time.
for (uint tg_ChannelOffset = 0; tg_ChannelOffset < tg_NumChannels; tg_ChannelOffset += CACHE_DEPTH)
{
//Load from DDR to LDS: Threadgroup need 32 weight + 64 pixels per CACHE_DEPTH = 96 float, but we have 128 threads.
//--> Load 4 channels at a time (3 loads per threads, 1 kernel and 2 pixels) consequence is CHANNEL_DEPTH must be a modulo of 4.
//A threadgroup (128 Threads) contains 4 half-warps of 32 threads.
// half-warps 0 - threadId [00-31] --> load Kernels [00-31] channel 0 + Pixels [00,31] channel 0 and 2
// half-warps 1 - threadId [32-63] --> load Kernels [00-31] channel 1 + Pixels [32,64] channel 1 and 3
// half-warps 2 - threadId [65-95] --> load Kernels [00-31] channel 2 + Pixels [00,31] channel 0 and 2
// half-warps 3 - threadId [96-127] --> load Kernels [00-31] channel 3 + Pixels [32,64] channel 1 and 3
uint warpThreadId = threadIndex % 64;
uint warpId = threadIndex / 64;
uint halfWarpThreadId = threadIndex % 32;
uint halfWarpId = threadIndex / 32;
[unroll] for (uint tg_CacheLoadIdx = 0; tg_CacheLoadIdx < CACHE_DEPTH; tg_CacheLoadIdx+=4)//TODO verify unrolling actually happens
{
//Kernels (1 per thread)
//K stored as DHWCK, threadgroup is loading 4*32 kernels at a time to LDS.
//DHW from tg_kernelSpatialOffset,
//C from tg_ChannelOffset+tg_CacheLoadIdx+halfWarpId([0,3])
//K from tg_KernelBaseId+halfWarpThreadId([0,31])
uint kernelReadOffset = tg_kernelSpatialOffset + tg_NumKernels*(tg_ChannelOffset+tg_CacheLoadIdx+halfWarpId) + tg_KernelBaseId + halfWarpThreadId;
#if !STRICT_CHANNELS || LAX_KERNEL
kernelReadOffset = min(kernelReadOffset, K.GetLength5D()-1);
#endif
LDS_[W_OFFSET+tg_CacheLoadIdx*KERNEL_PER_TG+threadIndex] = K.FastGet(kernelReadOffset);
//Pixels (two of them per thread)
//threadgroup is loading 4*64 kernels at a time to LDS.
int outputPixelBaseId = tg_OutputPixelBaseId + warpThreadId;
int3 outputPixelCoords;
outputPixelCoords.x = outputPixelBaseId % tg_WidthO;//width
outputPixelCoords.y = (outputPixelBaseId / tg_WidthO) % tg_HeightO;//height
outputPixelCoords.z = outputPixelBaseId / (tg_WidthO * tg_HeightO);//depth
int3 inputPixelCoords = outputPixelCoords * _Stride.xyz - _Pad.xyz + int3(tg_Dx, tg_Dy, tg_Dd);
bool inputPixelMask = all( (inputPixelCoords >= 0) && (inputPixelCoords < float3(tg_WidthX, tg_HeightX, tg_DepthX)) );
int inputPixelId = inputPixelCoords.z * (tg_WidthX*tg_HeightX) + inputPixelCoords.y * tg_WidthX + inputPixelCoords.x;
uint inputChannelId1 = tg_ChannelOffset + tg_CacheLoadIdx + warpId;
uint inputChannelId2 = inputChannelId1 + 2;
bool inputChannelMask1 = inputChannelId1 < tg_NumChannels;
bool inputChannelMask2 = inputChannelId2 < tg_NumChannels;
#if STRICT_CHANNELS
inputChannelMask1 = true;
inputChannelMask2 = true;
#endif
#if CHANNELS_FIRST
uint pixelReadOffset1 = tg_NumInputPixels * inputChannelId1 + inputPixelId + tg_BatchReadOffset;
uint pixelReadOffset2 = tg_NumInputPixels * inputChannelId2 + inputPixelId + tg_BatchReadOffset;
#else
uint pixelReadOffset1 = tg_NumChannels * inputPixelId + inputChannelId1 + tg_BatchReadOffset;
uint pixelReadOffset2 = tg_NumChannels * inputPixelId + inputChannelId2 + tg_BatchReadOffset;
#endif
LDS_[X_OFFSET+tg_CacheLoadIdx*PIXEL_PER_TG+threadIndex] = X.MaskedGet(inputPixelMask && inputChannelMask1, pixelReadOffset1);
LDS_[X_OFFSET+tg_CacheLoadIdx*PIXEL_PER_TG+128+threadIndex] = X.MaskedGet(inputPixelMask && inputChannelMask2, pixelReadOffset2);
}
GroupMemoryBarrierWithGroupSync();
//Inner loop
//TODO get rid of bank conflicts.
uint ptrX = groupThreadID.y*BLOCK_SIZE + X_OFFSET;
uint ptrW = groupThreadID.x*BLOCK_SIZE + W_OFFSET;
for (uint tg_CacheExecuteIdx = 0; tg_CacheExecuteIdx < CACHE_DEPTH; ++tg_CacheExecuteIdx)
{
//Load LDS -> registers
float colOfX[BLOCK_SIZE];
float rowOfW[BLOCK_SIZE];
uint tg_q;
[unroll] for (tg_q = 0; tg_q < BLOCK_SIZE; ++tg_q)
colOfX[tg_q] = LDS_[ptrX + tg_q];
[unroll] for (tg_q = 0; tg_q < BLOCK_SIZE; ++tg_q)
rowOfW[tg_q] = LDS_[ptrW + tg_q];
ptrX += PIXEL_PER_TG;
ptrW += KERNEL_PER_TG;
//Mads 4 pixels by 4 kernels matmul style --> 16 mads
[unroll] for (uint tg_X = 0; tg_X < BLOCK_SIZE; ++tg_X)
[unroll] for (uint tg_W = 0; tg_W < BLOCK_SIZE; ++tg_W)
dstA[tg_X*BLOCK_SIZE+tg_W] = ffma(colOfX[tg_X], rowOfW[tg_W], dstA[tg_X*BLOCK_SIZE+tg_W]);
}
GroupMemoryBarrierWithGroupSync();
}
tg_kernelSpatialOffset += tg_KernelSpatialStride;
}
//-------------------------------
//store registers to DDR
//-------------------------------
//B does not require an offset as size == 1
//C from tg_KernelBaseId, groupThreadID.x and tg_kId
//HW from tg_OutputPixelBaseId, groupThreadID.y and tg_pId
[unroll] for (tg_kId = 0; tg_kId < BLOCK_SIZE; ++tg_kId)
[unroll] for (tg_pId = 0; tg_pId < BLOCK_SIZE; ++tg_pId)
{
uint writeChannelId = tg_KernelBaseId + groupThreadID.x * BLOCK_SIZE + tg_kId;
uint writePixelId = tg_OutputPixelBaseId + groupThreadID.y * BLOCK_SIZE + tg_pId;
float writeValue = dstA[tg_pId*BLOCK_SIZE+tg_kId];
#if CHANNELS_FIRST
uint writeIndex = O.depth * O.width * O.height * writeChannelId + writePixelId + tg_BatchWriteOffset;
#else
uint writeIndex = tg_NumKernels * writePixelId + writeChannelId + tg_BatchWriteOffset;
#endif
#if LAX_KERNEL
bool canWriteChannel = (writeChannelId < tg_NumKernels);
#else
bool canWriteChannel = true;
#endif
if ((writePixelId < tg_NumOuputPixels) && canWriteChannel)
O.FastSetWithActivation(writeIndex, writeValue);
}
#undef X_OFFSET
#undef W_OFFSET
#undef LDS_
}
#undef CACHE_DEPTH
#undef BLOCK_SIZE
#undef KERNEL_PER_TG
#undef PIXEL_PER_TG

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 5da0dcf3215520c41bdb8342e88aa56e
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 4
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,99 @@
/// DEBUG ONLY:
/// `KERNEL_ASSERTS` allow to track out of bound read/write and assertion
/// in all kernels with the exception of those where FORCE_NO_DEBUG is defined.
/// to debug only a few kernel one can also define FORCE_NO_DEBUG per kernel rather.
/// To debug kernel be sure to set ComputeDebugUtils.debugKernels to true BarracudaComputeDebugUtils.cs also.
/// Production code should not define this as this will significantly degrade performances.
/// Defining those require Shader model 5.0 and not Metal (Metal does not support GetDimensions on buffer)
/// aka `#pragma target 5.0` see https://docs.unity3d.com/Manual/SL-ShaderCompileTargets.html.
#include "KernelDebug.cginc"
#if !defined(KERNEL_ASSERTS)
// KernelDebug.cginc allow to enable kernel debugging on yamato. Uncomment the line below to force it at dev time.
// #define KERNEL_ASSERTS
#endif
//Keep in sync with BarracudaComputeDebugUtils.cs enum ComputeDebugUtils.KernelAssertContext
#define KERNEL_ASSERT_CONTEXT_READONLY_READ 0
#define KERNEL_ASSERT_CONTEXT_READWRITE_READ 1
#define KERNEL_ASSERT_CONTEXT_READWRITE_WRITE 2
#define KERNEL_ASSERT_CONTEXT_SHARED_READ 3
#define KERNEL_ASSERT_CONTEXT_ASSERTION 4
#define KERNEL_ASSERT_CONTEXT_ASSERTION_WITH_VALUE 5
//Keep in sync with BarracudaComputeDebugUtils.cs enum ComputeDebugUtils.KernelAssertInfo
struct KernelAssertInfo
{
uint lockValue;
//context
uint lineNumber;
uint context;
//specific to read/write OOB detection
uint index;
uint bufferSize;
//specific to assertion with value
uint debugValue;
//padding
uint padding0;
uint padding1;
};
#if (defined(KERNEL_ASSERTS) && !defined(FORCE_NO_DEBUG)) || defined(FORCE_DEBUG)
RWStructuredBuffer<KernelAssertInfo> KernelAssertInfoBuffer;
void LogAssertion(uint index, uint bufferSize, uint debugValue, uint lineNumber, uint context)
{
uint anAssertionIsAlreadyLogged;
InterlockedAdd(KernelAssertInfoBuffer[0].lockValue, 1, anAssertionIsAlreadyLogged);
if (!anAssertionIsAlreadyLogged)
{
KernelAssertInfoBuffer[0].lineNumber = lineNumber;
KernelAssertInfoBuffer[0].context = context;
KernelAssertInfoBuffer[0].index = index;
KernelAssertInfoBuffer[0].bufferSize = bufferSize;
KernelAssertInfoBuffer[0].debugValue = debugValue;
}
}
uint GetSafeTensorIndex(uint index, uint bufferSize, uint lineNumber, uint context)
{
bool isIndexValid = (index >= 0 && index < bufferSize);
if (isIndexValid)
return index;
LogAssertion(index, bufferSize, 0, lineNumber, context);
//Always return a valid index to avoid GPU crashs so CPU get a chance to catch the error.
return 0;
}
void KernelAssert(bool isOk, uint lineNumber)
{
if (isOk)
return;
LogAssertion(0, 0, 0, lineNumber, KERNEL_ASSERT_CONTEXT_ASSERTION);
}
void KernelAssertWithDebugValue(bool isOk, uint lineNumber, uint value)
{
if (isOk)
return;
LogAssertion(0, 0, value, lineNumber, KERNEL_ASSERT_CONTEXT_ASSERTION_WITH_VALUE);
}
#define ASSERT_TENSOR_INDEX(index, context) \
uint dataNumStructs, dataStride; \
data.GetDimensions(dataNumStructs, dataStride); \
uint safeIndex = GetSafeTensorIndex(index, dataNumStructs, __LINE__, context);
#define TENSOR_READ(varName, index, context) ASSERT_TENSOR_INDEX(index, context); varName = data[safeIndex]
#define TENSOR_WRITE(varName, index, context) ASSERT_TENSOR_INDEX(index, context); data[safeIndex] = varName
#define KERNEL_ASSERT(condition) KernelAssert(condition, __LINE__)
#define KERNEL_ASSERT_WITH_VALUE(condition, value) KernelAssertWithDebugValue(condition, __LINE__, value)
#else
#define TENSOR_READ(varName, index, context) varName = data[index]
#define TENSOR_WRITE(varName, index, context) data[index] = varName
#define KERNEL_ASSERT(condition)
#define KERNEL_ASSERT_WITH_VALUE(condition, value)
#endif

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: a236e93868e2f6349b7a40e7552915fd
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 6b08c0ac202ad41deb8881132b21894c
timeCreated: 1507457322
licenseType: Pro
ComputeShaderImporter:
currentAPIMask: 196608
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,429 @@
#pragma kernel Dense3_T8x8_R8x8_NHWC BLOCK_SIZE=8 KERNEL_PER_TG=64 CHANNELS_FIRST=0
#pragma kernel Dense3_T8x8_R8x8_NCHW BLOCK_SIZE=8 KERNEL_PER_TG=64 CHANNELS_FIRST=1
#pragma kernel Dense3_T8x16_R4x4_NHWC BLOCK_SIZE=4 KERNEL_PER_TG=32 CHANNELS_FIRST=0
#pragma kernel Dense3_T8x16_R4x4_NCHW BLOCK_SIZE=4 KERNEL_PER_TG=32 CHANNELS_FIRST=1
#pragma kernel Dense3_L1Cached64_NHWC CHANNELS_FIRST=0
#pragma kernel Dense3_L1Cached64_NCHW CHANNELS_FIRST=1
#include "Tensor.cginc"
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(B)
TENSOR_DECL(WBK)
TENSOR_DECL_RW(O)
float ffma(float a, float b, float c) { return dot(float2(a, c), float2(b, 1)); } //return a*b+c;} //fastfma(a,b,c); }
#if CHANNELS_FIRST
#define FUNC_NAME_CALL(KERNEL, SIZE) KERNEL##SIZE##x##SIZE##_NCHW
#define CACHE_NAME_CALL(KERNEL, SIZE, TENSOR) KERNEL##SIZE##x##SIZE##_Cache_##TENSOR##_NCHW
#else
#define FUNC_NAME_CALL(KERNEL, SIZE) KERNEL##SIZE##x##SIZE##_NHWC
#define CACHE_NAME_CALL(KERNEL, SIZE, TENSOR) KERNEL##SIZE##x##SIZE##_Cache_##TENSOR##_NHWC
#endif
#define FUNC_NAME(KERNEL, SIZE) FUNC_NAME_CALL(KERNEL, SIZE)
#define CACHE_NAME(KERNEL, SIZE, TENSOR) CACHE_NAME_CALL(KERNEL, SIZE, TENSOR)
#if BLOCK_SIZE == 8
#if KERNEL_PER_TG == 64
#define KERNEL_NAME Dense3_T8x8_R
#define CACHE_WIDTH_W_PAD 1
#define CACHE_WIDTH_X 64
#define CACHE_WIDTH_W (64+CACHE_WIDTH_W_PAD)
#define CACHE_DEPTH 8
groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)[1039]; // [(8*9)*(3*8+7)+(7)*8+7+1] // [(CACHE_WIDTH_A + CACHE_WIDTH_B)* BLOCK_SIZE];
[numthreads(8, 8, 1)]
void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 groupID : SV_GroupID, uint threadIndex : SV_GroupIndex, uint3 dispatchThreadID : SV_DispatchThreadID)
{
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint ti = threadIndex;
uint bx = groupID.x * 8 * BLOCK_SIZE;
uint by = groupID.y * 8 * BLOCK_SIZE;
uint n = X.width;
uint strideX = X.channels;
uint strideW = W.GetFlatWidth();
uint lengthW = W.GetLength() - 1;
uint dzX = groupID.z * n * strideX;
uint dzO = groupID.z * strideW * strideX;
#define LDS_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)
#define X_OFFSET 0
#define W_OFFSET CACHE_DEPTH*8*BLOCK_SIZE
float dstO[BLOCK_SIZE*BLOCK_SIZE];
uint tg_X = 0;
uint tg_W = 0;
[unroll] for (tg_W = 0; tg_W < BLOCK_SIZE; ++tg_W)
dstO[0*BLOCK_SIZE + tg_W] = B.FastGet(min(B.GetLength()-1, bx + ((ti & 7) << 3) + tg_W));
[unroll] for (tg_X = 1; tg_X < BLOCK_SIZE; ++tg_X)
[unroll] for (tg_W = 0; tg_W < BLOCK_SIZE; ++tg_W)
dstO[tg_X*BLOCK_SIZE + tg_W] = dstO[0*BLOCK_SIZE + tg_W];
for (uint i = 0; i < n; i += CACHE_DEPTH)
{
#if CHANNELS_FIRST
//LDS_[X_OFFSET + ti + 8 * 8 * [0..7]] = X.FastGet((i + [0..7]) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 0] = X.MaskedGet(((by + ti) < strideX) && ((i + 0) < X.width), dzX + (i + 0) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 1] = X.MaskedGet(((by + ti) < strideX) && ((i + 1) < X.width), dzX + (i + 1) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 2] = X.MaskedGet(((by + ti) < strideX) && ((i + 2) < X.width), dzX + (i + 2) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 3] = X.MaskedGet(((by + ti) < strideX) && ((i + 3) < X.width), dzX + (i + 3) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 4] = X.MaskedGet(((by + ti) < strideX) && ((i + 4) < X.width), dzX + (i + 4) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 5] = X.MaskedGet(((by + ti) < strideX) && ((i + 5) < X.width), dzX + (i + 5) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 6] = X.MaskedGet(((by + ti) < strideX) && ((i + 6) < X.width), dzX + (i + 6) + X.width * (by + ti));
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 7] = X.MaskedGet(((by + ti) < strideX) && ((i + 7) < X.width), dzX + (i + 7) + X.width * (by + ti));
#else
//LDS_[X_OFFSET + ti + 8 * 8 * [0..7]] = X.FastGet(X.channels * (i + [0..7]) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 0] = X.MaskedGet(((by + ti) < strideX) && ((i + 0) < X.width), dzX + X.channels * (i + 0) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 1] = X.MaskedGet(((by + ti) < strideX) && ((i + 1) < X.width), dzX + X.channels * (i + 1) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 2] = X.MaskedGet(((by + ti) < strideX) && ((i + 2) < X.width), dzX + X.channels * (i + 2) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 3] = X.MaskedGet(((by + ti) < strideX) && ((i + 3) < X.width), dzX + X.channels * (i + 3) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 4] = X.MaskedGet(((by + ti) < strideX) && ((i + 4) < X.width), dzX + X.channels * (i + 4) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 5] = X.MaskedGet(((by + ti) < strideX) && ((i + 5) < X.width), dzX + X.channels * (i + 5) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 6] = X.MaskedGet(((by + ti) < strideX) && ((i + 6) < X.width), dzX + X.channels * (i + 6) + by + ti);
LDS_[X_OFFSET + ti + CACHE_WIDTH_X * 7] = X.MaskedGet(((by + ti) < strideX) && ((i + 7) < X.width), dzX + X.channels * (i + 7) + by + ti);
#endif
//LDS_[W_OFFSET + ti + writeIndex + (8 * 8 + 1) * [0..7]] = W.FastGet(strideB * (i + [0..7]) + bx + ti);
uint WWriteIndex = (ti & 0x20) >> 5;// (ti > 31) ? CACHE_WIDTH_B_PAD : 0;
LDS_[W_OFFSET + (ti + WWriteIndex) + 0 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 0) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 1 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 1) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 2 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 2) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 3 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 3) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 4 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 4) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 5 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 5) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 6 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 6) + bx + ti, lengthW));
LDS_[W_OFFSET + (ti + WWriteIndex) + 7 * CACHE_WIDTH_W] = W.FastGet(min(strideW * (i + 7) + bx + ti, lengthW));
GroupMemoryBarrierWithGroupSync();
//uint ptrX = X_OFFSET + (ti/8) * 8;
//uint ptrW = W_OFFSET + (ti%8) * 8 + readIndex;
uint ptrX = X_OFFSET + (ti & 0x78);
uint ptrW = ((ti & 7) << 3);
ptrW += (ti & 0x4) >> 2; // ptrW += (ptrW > 31) ? CACHE_WIDTH_W_PAD : 0;
ptrW += W_OFFSET;
float srcX[BLOCK_SIZE];
float srcW[BLOCK_SIZE];
[unroll] for (uint tg_CacheExecuteIdx = 0; tg_CacheExecuteIdx < CACHE_DEPTH; tg_CacheExecuteIdx++)
{
srcX[0] = LDS_[ptrX | 0];
srcX[1] = LDS_[ptrX | 1];
srcX[2] = LDS_[ptrX | 2];
srcX[3] = LDS_[ptrX | 3];
srcX[4] = LDS_[ptrX | 4];
srcX[5] = LDS_[ptrX | 5];
srcX[6] = LDS_[ptrX | 6];
srcX[7] = LDS_[ptrX | 7];
srcW[0] = LDS_[ptrW + 0];
srcW[1] = LDS_[ptrW + 1];
srcW[2] = LDS_[ptrW + 2];
srcW[3] = LDS_[ptrW + 3];
srcW[4] = LDS_[ptrW + 4];
srcW[5] = LDS_[ptrW + 5];
srcW[6] = LDS_[ptrW + 6];
srcW[7] = LDS_[ptrW + 7];
ptrX += CACHE_WIDTH_X;
ptrW += CACHE_WIDTH_W;
[unroll] for (tg_X = 0; tg_X < BLOCK_SIZE; ++tg_X)
[unroll] for (tg_W = 0; tg_W < BLOCK_SIZE; ++tg_W)
dstO[tg_X*BLOCK_SIZE + tg_W] = ffma(srcX[tg_X], srcW[tg_W], dstO[tg_X*BLOCK_SIZE + tg_W]);
}
GroupMemoryBarrierWithGroupSync();
}
#if CHANNELS_FIRST
[unroll] for (tg_X = 0; tg_X < BLOCK_SIZE; ++tg_X)
[unroll] for (tg_W = 0; tg_W < BLOCK_SIZE; ++tg_W)
{
uint writeXId = ((bx + 8 * (ti % 8)) + tg_X);
uint writeWId = ((by + 8 * (ti / 8)) + tg_W);
if (writeWId < O.channels && writeXId < O.width)
O.FastSet(dzO + writeXId + O.width * writeWId, dstO[BLOCK_SIZE * tg_W + tg_X]);
}
#else
[unroll] for (uint tg_XOffset = 0; tg_XOffset < BLOCK_SIZE; tg_XOffset += 2)
{
[unroll] for (tg_X = 0; tg_X < 2; ++tg_X)
[unroll] for (tg_W = 0; tg_W < BLOCK_SIZE; ++tg_W)
{
//To avoid bank conflict store in 32 groups [8pixelsGroups,4channelsGroups] each group contain 64 values [8pixels,8kernels] for a total of 2048 values [64pixels,32channels]
uint ldsOffsetOfGroup = 65 * (tg_X*BLOCK_SIZE + tg_W);//64 * ([0,3]*8+[0,7]) = [0,1984]
LDS_[ldsOffsetOfGroup + ti] = dstO[BLOCK_SIZE * tg_W + (tg_XOffset + tg_X)];
}
GroupMemoryBarrierWithGroupSync();
[unroll] for (tg_X = 0; tg_X < 16; ++tg_X)
{
// (((tg_A % 4) * 8) + (ti % 8)) * CACHE_WIDTH_A
uint ldsOffsetOfGroup = 65 * (((tg_X & 1) << 3) + (ti & 7));//CACHE_WIDTH_A * ([0,3]*8+[0,7]) = [0,1984]
// (ti / 8) * 8 + (tg_A / 4)
uint ldsOffsetInGroup = (ti & 0x78) + (tg_X >> 1);//[0,7]*8+[0,7] = [0,63]
//load from LDS and store to DDR
uint readIndex = ldsOffsetOfGroup + ldsOffsetInGroup;//[0,2047]
// bx + tg_!%4 + (tgA/4)*8 + tg_AOffset
uint writeXId = bx + (tg_X & 1) + ((tg_X >> 1) << 3) + tg_XOffset;
uint writeIndex = dzO + O.channels * writeXId + (by + ti);
if ((by + ti) < O.channels && writeXId < O.width)
O.FastSet(writeIndex, LDS_[readIndex]);
}
}
#endif
}
#endif
#undef CACHE_DEPTH
#undef KERNEL_NAME
#elif BLOCK_SIZE == 4
#if KERNEL_PER_TG == 32
//TODO optimize
#define KERNEL_NAME Dense3_T8x16_R
#define CACHE_DEPTH 8
groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)[16*8*4 + 8*8*4]; // [(8*9)*(3*8+7)+(7)*8+7+1] // [(CACHE_WIDTH_A + CACHE_WIDTH_B)* BLOCK_SIZE];
[numthreads(8, 16, 1)]
void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 groupID : SV_GroupID, uint threadIndex : SV_GroupIndex)
{
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
#define LDS_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)
uint x = 8 * groupID.x + (threadIndex % 8);
uint y = 16 * groupID.y + (threadIndex / 8);
uint n = X.width;
uint strideX = X.channels;
uint strideW = W.GetFlatWidth();
uint dzX = groupID.z * n * strideX;
uint dzO = groupID.z * strideW * strideX;
float dstO[BLOCK_SIZE*BLOCK_SIZE];
dstO[0 ] = B.FastGet(min(4 * x + 0, strideW - 1));
dstO[1 ] = B.FastGet(min(4 * x + 0, strideW - 1));
dstO[2 ] = B.FastGet(min(4 * x + 0, strideW - 1));
dstO[3 ] = B.FastGet(min(4 * x + 0, strideW - 1));
dstO[4 ] = B.FastGet(min(4 * x + 1, strideW - 1));
dstO[5 ] = B.FastGet(min(4 * x + 1, strideW - 1));
dstO[6 ] = B.FastGet(min(4 * x + 1, strideW - 1));
dstO[7 ] = B.FastGet(min(4 * x + 1, strideW - 1));
dstO[8 ] = B.FastGet(min(4 * x + 2, strideW - 1));
dstO[9 ] = B.FastGet(min(4 * x + 2, strideW - 1));
dstO[10] = B.FastGet(min(4 * x + 2, strideW - 1));
dstO[11] = B.FastGet(min(4 * x + 2, strideW - 1));
dstO[12] = B.FastGet(min(4 * x + 3, strideW - 1));
dstO[13] = B.FastGet(min(4 * x + 3, strideW - 1));
dstO[14] = B.FastGet(min(4 * x + 3, strideW - 1));
dstO[15] = B.FastGet(min(4 * x + 3, strideW - 1));
//float acc = B.FastGet(min(x, strideW - 1));
// loop over X columns (flatWidth) and W rows (height) in CACHESIZE steps
for (uint i = 0; i < n; i += CACHE_DEPTH)
{
// Cache X
// coalescent reads
#if CHANNELS_FIRST
LDS_[(threadIndex / 8) * 8 + (threadIndex % 8) + 16 * 8 * 0] = X.MaskedGet((4 * y + 0 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + (i + (threadIndex % 8)) + X.width * (4 * y + 0));
LDS_[(threadIndex / 8) * 8 + (threadIndex % 8) + 16 * 8 * 1] = X.MaskedGet((4 * y + 1 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + (i + (threadIndex % 8)) + X.width * (4 * y + 1));
LDS_[(threadIndex / 8) * 8 + (threadIndex % 8) + 16 * 8 * 2] = X.MaskedGet((4 * y + 2 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + (i + (threadIndex % 8)) + X.width * (4 * y + 2));
LDS_[(threadIndex / 8) * 8 + (threadIndex % 8) + 16 * 8 * 3] = X.MaskedGet((4 * y + 3 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + (i + (threadIndex % 8)) + X.width * (4 * y + 3));
#else
LDS_[(threadIndex / 8)*8 + (threadIndex % 8) + 16*8 * 0] = X.MaskedGet((4 * y + 0 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + X.channels*(i + (threadIndex % 8)) + 4 * y + 0);
LDS_[(threadIndex / 8)*8 + (threadIndex % 8) + 16*8 * 1] = X.MaskedGet((4 * y + 1 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + X.channels*(i + (threadIndex % 8)) + 4 * y + 1);
LDS_[(threadIndex / 8)*8 + (threadIndex % 8) + 16*8 * 2] = X.MaskedGet((4 * y + 2 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + X.channels*(i + (threadIndex % 8)) + 4 * y + 2);
LDS_[(threadIndex / 8)*8 + (threadIndex % 8) + 16*8 * 3] = X.MaskedGet((4 * y + 3 < X.channels) && (i + (threadIndex % 8)) < X.width, dzX + X.channels*(i + (threadIndex % 8)) + 4 * y + 3);
#endif
LDS_[8 * 16 * 4 + ((threadIndex / 8)%8) * 8 + (threadIndex % 8) + 8 * 8 * (2*((threadIndex/8)/8)+0)] = W.MaskedGet((4 * x + 0 < strideW) && (i + ((threadIndex / 8)%8)) < W.GetFlatHeight(), 4 * x + (2*((threadIndex/8)/8)+0) + (i + ((threadIndex / 8)%8))*strideW);
LDS_[8 * 16 * 4 + ((threadIndex / 8)%8) * 8 + (threadIndex % 8) + 8 * 8 * (2*((threadIndex/8)/8)+1)] = W.MaskedGet((4 * x + 1 < strideW) && (i + ((threadIndex / 8)%8)) < W.GetFlatHeight(), 4 * x + (2*((threadIndex/8)/8)+1) + (i + ((threadIndex / 8)%8))*strideW);
GroupMemoryBarrierWithGroupSync();
float srcX[4];
float srcW[4];
// X * W
[unroll]
for (uint di = 0; di < CACHE_DEPTH; ++di)
{
srcX[0] = LDS_[di + (threadIndex / 8) * 8 + 8 * 16 * 0];
srcX[1] = LDS_[di + (threadIndex / 8) * 8 + 8 * 16 * 1];
srcX[2] = LDS_[di + (threadIndex / 8) * 8 + 8 * 16 * 2];
srcX[3] = LDS_[di + (threadIndex / 8) * 8 + 8 * 16 * 3];
srcW[0] = LDS_[4 * 8 * 16 + 8 * di + (threadIndex % 8) + 8 * 8 * 0];
srcW[1] = LDS_[4 * 8 * 16 + 8 * di + (threadIndex % 8) + 8 * 8 * 1];
srcW[2] = LDS_[4 * 8 * 16 + 8 * di + (threadIndex % 8) + 8 * 8 * 2];
srcW[3] = LDS_[4 * 8 * 16 + 8 * di + (threadIndex % 8) + 8 * 8 * 3];
dstO[0] = fastfma(srcX[0], srcW[0], dstO[0]);
dstO[1] = fastfma(srcX[1], srcW[0], dstO[1]);
dstO[2] = fastfma(srcX[2], srcW[0], dstO[2]);
dstO[3] = fastfma(srcX[3], srcW[0], dstO[3]);
dstO[4] = fastfma(srcX[0], srcW[1], dstO[4]);
dstO[5] = fastfma(srcX[1], srcW[1], dstO[5]);
dstO[6] = fastfma(srcX[2], srcW[1], dstO[6]);
dstO[7] = fastfma(srcX[3], srcW[1], dstO[7]);
dstO[8] = fastfma(srcX[0], srcW[2], dstO[8]);
dstO[9 ] = fastfma(srcX[1], srcW[2], dstO[9]);
dstO[10] = fastfma(srcX[2], srcW[2], dstO[10]);
dstO[11] = fastfma(srcX[3], srcW[2], dstO[11]);
dstO[12] = fastfma(srcX[0], srcW[3], dstO[12]);
dstO[13] = fastfma(srcX[1], srcW[3], dstO[13]);
dstO[14] = fastfma(srcX[2], srcW[3], dstO[14]);
dstO[15] = fastfma(srcX[3], srcW[3], dstO[15]);
}
GroupMemoryBarrierWithGroupSync();
}
#if CHANNELS_FIRST
O.FastSet(dzO + (4 * x + 0) + O.width * (4 * y + 0), dstO[0]);
O.FastSet(dzO + (4 * x + 0) + O.width * (4 * y + 1), dstO[1]);
O.FastSet(dzO + (4 * x + 0) + O.width * (4 * y + 2), dstO[2]);
O.FastSet(dzO + (4 * x + 0) + O.width * (4 * y + 3), dstO[3]);
O.FastSet(dzO + (4 * x + 1) + O.width * (4 * y + 0), dstO[4]);
O.FastSet(dzO + (4 * x + 1) + O.width * (4 * y + 1), dstO[5]);
O.FastSet(dzO + (4 * x + 1) + O.width * (4 * y + 2), dstO[6]);
O.FastSet(dzO + (4 * x + 1) + O.width * (4 * y + 3), dstO[7]);
O.FastSet(dzO + (4 * x + 2) + O.width * (4 * y + 0), dstO[8]);
O.FastSet(dzO + (4 * x + 2) + O.width * (4 * y + 1), dstO[9]);
O.FastSet(dzO + (4 * x + 2) + O.width * (4 * y + 2), dstO[10]);
O.FastSet(dzO + (4 * x + 2) + O.width * (4 * y + 3), dstO[11]);
O.FastSet(dzO + (4 * x + 3) + O.width * (4 * y + 0), dstO[12]);
O.FastSet(dzO + (4 * x + 3) + O.width * (4 * y + 1), dstO[13]);
O.FastSet(dzO + (4 * x + 3) + O.width * (4 * y + 2), dstO[14]);
O.FastSet(dzO + (4 * x + 3) + O.width * (4 * y + 3), dstO[15]);
#else
O.FastSet(dzO + (4 * x + 0)*O.channels + 4 * y + 0, dstO[0]);
O.FastSet(dzO + (4 * x + 0)*O.channels + 4 * y + 1, dstO[1]);
O.FastSet(dzO + (4 * x + 0)*O.channels + 4 * y + 2, dstO[2]);
O.FastSet(dzO + (4 * x + 0)*O.channels + 4 * y + 3, dstO[3]);
O.FastSet(dzO + (4 * x + 1)*O.channels + 4 * y + 0, dstO[4]);
O.FastSet(dzO + (4 * x + 1)*O.channels + 4 * y + 1, dstO[5]);
O.FastSet(dzO + (4 * x + 1)*O.channels + 4 * y + 2, dstO[6]);
O.FastSet(dzO + (4 * x + 1)*O.channels + 4 * y + 3, dstO[7]);
O.FastSet(dzO + (4 * x + 2)*O.channels + 4 * y + 0, dstO[8]);
O.FastSet(dzO + (4 * x + 2)*O.channels + 4 * y + 1, dstO[9]);
O.FastSet(dzO + (4 * x + 2)*O.channels + 4 * y + 2, dstO[10]);
O.FastSet(dzO + (4 * x + 2)*O.channels + 4 * y + 3, dstO[11]);
O.FastSet(dzO + (4 * x + 3)*O.channels + 4 * y + 0, dstO[12]);
O.FastSet(dzO + (4 * x + 3)*O.channels + 4 * y + 1, dstO[13]);
O.FastSet(dzO + (4 * x + 3)*O.channels + 4 * y + 2, dstO[14]);
O.FastSet(dzO + (4 * x + 3)*O.channels + 4 * y + 3, dstO[15]);
#endif
}
#endif
#undef CACHE_DEPTH
#undef KERNEL_NAME
#endif
#undef FUNC_NAME
#undef CACHE_NAME
#undef FUNC_NAME_CALL
#undef CACHE_NAME_CALL
#if CHANNELS_FIRST
#define FUNC_NAME_CALL(KERNEL) KERNEL##_NCHW
#define CACHE_NAME_CALL(KERNEL, TENSOR) KERNEL##_Cache_##TENSOR##_NCHW
#else
#define FUNC_NAME_CALL(KERNEL) KERNEL##_NHWC
#define CACHE_NAME_CALL(KERNEL, TENSOR) KERNEL##_Cache_##TENSOR##_NHWC
#endif
#define FUNC_NAME(KERNEL) FUNC_NAME_CALL(KERNEL)
#define CACHE_NAME(KERNEL, TENSOR) CACHE_NAME_CALL(KERNEL, TENSOR)
// NOTE: usually this path is used for <16 batches
#undef CACHESIZE
#undef LDS_
#define KERNEL_NAME Dense3_L1Cached64
#define CACHESIZE 64
groupshared float CACHE_NAME(KERNEL_NAME, LDS)[CACHESIZE];
[numthreads(64, 1, 1)]
void FUNC_NAME(KERNEL_NAME)(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID, uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.flatWidth, O.flatHeight, 1);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
#define LDS_ CACHE_NAME(KERNEL_NAME, LDS)
uint x = CACHESIZE * groupID.x + groupThreadID.x;
uint y = groupID.y;
uint n = X.width;
uint strideX = X.channels;
uint strideW = W.GetFlatWidth();
uint dzX = groupID.z * n * strideX;
uint dzO = groupID.z * strideW * strideX;
float acc = B.FastGet(min(x, strideW - 1));
// loop over X columns (flatWidth) and W rows (height) in CACHESIZE steps
for (uint i = 0; i < n; i += CACHESIZE)
{
// Cache X
// coalescent reads
bool maskX = (y < strideX) && (i + groupThreadID.x) < X.width;
#if CHANNELS_FIRST
LDS_[groupThreadID.x] = X.MaskedGet(maskX, dzX + y * X.width + (i + groupThreadID.x));
#else
LDS_[groupThreadID.x] = X.MaskedGet(maskX, dzX + (i + groupThreadID.x) * X.channels + y);
#endif
GroupMemoryBarrierWithGroupSync();
// X * W
[unroll]
for (uint di = 0; di < CACHESIZE; ++di)
{
acc = fastfma(LDS_[di], W.MaskedGet(x < strideW && (i+di) < W.GetFlatHeight(), x + (i + di)*strideW), acc);
}
GroupMemoryBarrierWithGroupSync();
}
if ((x < O.width) && (y < O.channels))
{
#if CHANNELS_FIRST
O.FastSet(dzO + y * O.width + x, acc);
#else
O.FastSet(dzO + x * O.channels + y, acc);
#endif
}
#undef LDS_
}
#undef KERNEL_NAME
#undef CACHESIZE

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: b2365b5a091a4ed4aa09dd10bd46f7e1
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 4
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,73 @@
#pragma kernel DenseFP16Div2_NHWC CHANNELS_FIRST=0
#pragma kernel DenseFP16Div2_NCHW CHANNELS_FIRST=1
#include "Tensor.cginc"
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(B)
TENSOR_DECL(WBK)
TENSOR_DECL_RW(O)
float f16tof32_(uint src)
{
// Based on Fabian Giesen's public domain half_to_float_fast3
const uint magic = 113 << 23;
const uint shiftedExp = 0x7c00 << 13; // exponent mask after shift
// Mask out sign bit
uint o = src & 0x7fff;
if (o)
{
// Move exponent + mantissa to correct bits
o <<= 13;
uint exponent = o & shiftedExp;
if (exponent == 0)
{
// Handle denormal
o = asuint(asfloat(o + magic) - asfloat(magic));
}
else if (exponent == shiftedExp) // Inf/NaN
o += (255 - 31) << 23;
else
o += (127 - 15) << 23;
}
// Copy sign bit
o |= (src & 0x8000) << 16;
return asfloat(o);
}
float2 Unpack(SharedTensor t, uint y, uint x)
{
uint v = asuint(t.data[t.Index(y, x) >> 1]);
// TEMPORARY: f16tof32 is broken in GLSL/Metal compiler
// using custom conversion function for now
//return float2(f16tof32(v), f16tof32(v>>16));
return float2(f16tof32_(v), f16tof32_(v>>16));
}
// NOTE: usually this path is used for <16 batches
NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
void KERNEL_FUNC(DenseFP16Div2)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.flatWidth/2, O.flatHeight, 1);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint x = dispatchThreadID.x;
uint y = dispatchThreadID.y;
if (x*2 >= O.GetFlatWidth()) return;
if (y >= O.GetFlatHeight()) return;
float2 acc = Unpack(B, 0, x*2);
for (uint i = 0; i < X.width; ++i)
{
float2 w = Unpack(W, i, x*2);
acc += X.Get(y, i) * w;
}
O.Set(y, x*2+0, acc[0]);
O.Set(y, x*2+1, acc[1]);
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: cff3cb66e54744fa4888ef91a11ec90c
timeCreated: 1508334838
licenseType: Pro
ComputeShaderImporter:
currentAPIMask: 196608
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,438 @@
#pragma kernel ScaleBias_NHWC CHANNELS_FIRST=0
#pragma kernel ScaleBias_NCHW CHANNELS_FIRST=1
#pragma kernel ScaleBias_CNyx_NHWC CHANNELS_FIRST=0
//#pragma kernel ScaleBias_CNyx_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel ScaleBias_CNyx2_NHWC CHANNELS_FIRST=0
//#pragma kernel ScaleBias_CNyx2_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel ScaleBias_Flat_NHWC CHANNELS_FIRST=0
#pragma kernel ScaleBias_Flat_NCHW CHANNELS_FIRST=1
#pragma kernel ScaleBias_Loop_NHWC CHANNELS_FIRST=0
#pragma kernel ScaleBias_Loop_NCHW CHANNELS_FIRST=1
#pragma kernel InstanceNormTail_CNyx2_NHWC CHANNELS_FIRST=0
//#pragma kernel InstanceNormTail_CNyx2_NCHW CHANNELS_FIRST=1 //This kernel require NHWC by design
#pragma kernel InstanceNormTail_Flat_NHWC CHANNELS_FIRST=0
#pragma kernel InstanceNormTail_Flat_NCHW CHANNELS_FIRST=1
#pragma kernel InstanceNormTail_Loop_NHWC CHANNELS_FIRST=0
#pragma kernel InstanceNormTail_Loop_NCHW CHANNELS_FIRST=1
#pragma kernel Upsample2D_NHWC CHANNELS_FIRST=0
#pragma kernel Upsample2D_NCHW CHANNELS_FIRST=1
#pragma kernel UpsampleBilinear2D_NHWC CHANNELS_FIRST=0
#pragma kernel UpsampleBilinear2D_NCHW CHANNELS_FIRST=1
#pragma kernel UpsampleBilinear2D_2x2_NHWC CHANNELS_FIRST=0
#pragma kernel UpsampleBilinear2D_2x2_NCHW CHANNELS_FIRST=1
#pragma kernel Copy_NHWC CHANNELS_FIRST=0
#pragma kernel Copy_NCHW CHANNELS_FIRST=1
#pragma kernel ReshapeFromNHWCModel_Flat_NCHW CHANNELS_FIRST=1
#pragma kernel ReshapeFromNHWCModel_Loop_NCHW CHANNELS_FIRST=1
#pragma kernel TransposeToChannelFirst
#include "Tensor.cginc"
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(S)
TENSOR_DECL(B)
TENSOR_DECL(WBK)
TENSOR_DECL_RW(O)
uint4 _Pool;
uint4 _Pad;
float _Epsilon;
uint _LoopStride;
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(ScaleBias)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= O.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
float bias = B.Get(0, 0, 0, c);
float scale = W.Get(0, 0, 0, c);
for (uint n = 0; n < X.batch; ++n)
{
float v = X.Get(n, y, x, c);
v = v * scale + bias;
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
void KERNEL_FUNC(ScaleBias_CNyx)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.batch * O.height * O.width, 1);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint c = dispatchThreadID.x;
uint nyx = dispatchThreadID.y;
uint x = nyx % X.width;
uint ny = nyx / X.width;
uint y = ny % X.height;
uint n = ny / X.height;
if (c >= X.channels) return;
if (n >= X.batch) return;
float bias = B.Get(0, 0, 0, c);
float scale = W.Get(0, 0, 0, c);
float v = X.Get(n, y, x, c);
v = v * scale + bias;
O.Set(n, y, x, c, v);
}
NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
void KERNEL_FUNC(ScaleBias_Flat)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.length, 1, 1);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint i = dispatchThreadID.x;
if (i >= O.GetLength()) return;
uint c = X.GetChannelFromIndex(i);
float bias = B.FastGet(c);
float scale = W.FastGet(c);
float v = X.FastGet(i);
v = v * scale + bias;
O.FastSet(i, v);
}
NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
void KERNEL_FUNC(ScaleBias_Loop)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.length, 1, 1);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint i = dispatchThreadID.x;
uint len = O.GetLength();
while (i < len)
{
uint c = X.GetChannelFromIndex(i);
float bias = B.FastGet(c);
float scale = W.FastGet(c);
float v = X.FastGet(i);
v = v * scale + bias;
O.FastSet(i, v);
i += _LoopStride;
}
}
NUMTHREADS((32,4,1), (32,2,1), (16,2,1))
void KERNEL_FUNC(ScaleBias_CNyx2)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.batch * O.height * O.width, 1);
TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
uint c = dispatchThreadID.x;
uint i = dispatchThreadID.y * X.channels + c;
if (c >= X.channels) return;
if (i >= X.GetLength()) return;
float bias = B.FastGet(c);
float scale = W.FastGet(c);
float v = X.FastGet(i);
v = v * scale + bias;
O.FastSet(i, v);
}
NUMTHREADS((256, 1, 1), (128, 1, 1), (64, 1, 1))
void KERNEL_FUNC(InstanceNormTail_Flat)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.length, 1, 1);
TENSOR_ARG(W);
TENSOR_SHARED2_ARGS4(X, S, B, WBK, O);
uint i = dispatchThreadID.x;
if (i >= O.GetLength()) return;
uint c = X.GetChannelFromIndex(i);
uint b = i / (X.height * X.width * X.channels);
float mean = W.Get(b, 0, 0, c);
float variance = W.Get(b, 1, 0, c);
float scale = S.FastGet(c);
float bias = B.FastGet(c);
// normalization factor
float invNormFactor = 1 / sqrt(variance + _Epsilon);
float v = X.FastGet(i);
v = v * invNormFactor - mean * invNormFactor;
v = v * scale + bias;
O.FastSetWithActivation(i, v);
}
NUMTHREADS((256, 1, 1), (128, 1, 1), (64, 1, 1))
void KERNEL_FUNC(InstanceNormTail_Loop)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.length, 1, 1);
TENSOR_ARG(W);
TENSOR_SHARED2_ARGS4(X, S, B, WBK, O);
uint i = dispatchThreadID.x;
uint len = O.GetLength();
while (i < len)
{
uint c = X.GetChannelFromIndex(i);
uint b = i / (X.height * X.width * X.channels);
float mean = W.Get(b, 0, 0, c);
float variance = W.Get(b, 1, 0, c);
float scale = S.FastGet(c);
float bias = B.FastGet(c);
// normalization factor
float invNormFactor = 1 / sqrt(variance + _Epsilon);
float v = X.FastGet(i);
v = v * invNormFactor - mean * invNormFactor;
v = v * scale + bias;
O.FastSetWithActivation(i, v);
i += _LoopStride;
}
}
NUMTHREADS((32, 4, 1), (32, 2, 1), (16, 2, 1))
void KERNEL_FUNC(InstanceNormTail_CNyx2)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.batch * O.height * O.width, 1);
TENSOR_ARG(W);
TENSOR_SHARED2_ARGS4(X, S, B, WBK, O);
uint c = dispatchThreadID.x;
uint i = dispatchThreadID.y * X.channels + c;
uint b = i / (X.height * X.width * X.channels);
if (c >= X.channels) return;
if (i >= X.GetLength()) return;
float mean = W.Get(b, 0, 0, c);
float variance = W.Get(b, 1, 0, c);
float scale = S.FastGet(c);
float bias = B.FastGet(c);
// normalization factor
float invNormFactor = 1 / sqrt(variance + _Epsilon);
float v = X.FastGet(i);
v = v * invNormFactor - mean * invNormFactor;
v = v * scale + bias;
O.FastSetWithActivation(i, v);
}
[numthreads(4,4,4)]
void KERNEL_FUNC(UpsampleBilinear2D)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= O.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
float2 dstPos = float2(x, y);
float2 srcPos = (dstPos + 0.5) / _Pool.xy - 0.5;
for (uint n = 0; n < O.batch; ++n)
{
float p00 = X.ClampGet(n, floor(srcPos) + float2(0, 0), c);
float p01 = X.ClampGet(n, floor(srcPos) + float2(0, 1), c);
float p10 = X.ClampGet(n, floor(srcPos) + float2(1, 0), c);
float p11 = X.ClampGet(n, floor(srcPos) + float2(1, 1), c);
float v =
p00 * (1-frac(srcPos.x)) * (1-frac(srcPos.y)) +
p01 * (1-frac(srcPos.x)) * frac(srcPos.y) +
p10 * frac(srcPos.x) * (1-frac(srcPos.y)) +
p11 * frac(srcPos.x) * frac(srcPos.y);
O.Set(n, y, x, c, v);
}
}
//Only a part of LDS will be used. Size match NUMTHREADS to simplify shader code when storing to LDS.
groupshared float UpsampleBilinear2D_2x2_Cache[8][8];
NUMTHREADS((8,8,1), (8,8,1), (8,8,1))
void KERNEL_FUNC(UpsampleBilinear2D_2x2)(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
{
//DISPATCH ARGS(O.width, O.height, O.channels);
TENSOR_ARGS2(X, O);
int2 tg_SrcBasePos = groupID.xy * 4 - 1;
uint c = dispatchThreadID.z;
uint x = dispatchThreadID.x;
uint y = dispatchThreadID.y;
float2 srcLDSPos = (groupThreadID.xy + 0.5f) / 2.0f - 0.5f;
uint2 srcLDSBasePos = floor(srcLDSPos) + uint2(1,1);
for (uint n = 0; n < O.batch; ++n)
{
//store inputs to LDS
UpsampleBilinear2D_2x2_Cache[groupThreadID.x][groupThreadID.y] = X.ClampGet(n, tg_SrcBasePos + groupThreadID.xy, c);
GroupMemoryBarrierWithGroupSync();
//read inputs from LDS
float p00 = UpsampleBilinear2D_2x2_Cache[srcLDSBasePos.x+0][srcLDSBasePos.y+0];
float p01 = UpsampleBilinear2D_2x2_Cache[srcLDSBasePos.x+0][srcLDSBasePos.y+1];
float p10 = UpsampleBilinear2D_2x2_Cache[srcLDSBasePos.x+1][srcLDSBasePos.y+0];
float p11 = UpsampleBilinear2D_2x2_Cache[srcLDSBasePos.x+1][srcLDSBasePos.y+1];
float v =
p00 * (1-frac(srcLDSPos.x)) * (1-frac(srcLDSPos.y)) +
p01 * (1-frac(srcLDSPos.x)) * frac(srcLDSPos.y) +
p10 * frac(srcLDSPos.x) * (1-frac(srcLDSPos.y)) +
p11 * frac(srcLDSPos.x) * frac(srcLDSPos.y);
if ((c < O.channels) && (x < O.width) && (y < O.height))
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(Upsample2D)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
// NOTE: dispatched over X (not O)
//DISPATCH ARGS(X.channels, X.width, X.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= X.channels) return;
if (x >= X.width) return;
if (y >= X.height) return;
for (uint n = 0; n < O.batch; ++n)
{
float v = X.Get(n, y, x, c);
for (uint dy = 0; dy < _Pool.y; ++dy)
for (uint dx = 0; dx < _Pool.x; ++dx)
{
uint oy = y * _Pool.y + dy;
uint ox = x * _Pool.x + dx;
O.Set(n, oy, ox, c, v);
}
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void KERNEL_FUNC(Copy)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
// NOTE: dispatched over X (not O)
//DISPATCH ARGS(X.channels, X.width, X.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
if (c >= X.channels) return; if (x >= X.width) return; if (y >= X.height) return;
for (uint n = 0; n < X.batch; ++n)
{
float v = X.Get(n, y, x, c);
O.Set(n + _Pad[0], y + _Pad[1], x + _Pad[2], c + _Pad[3], v);
}
}
NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
void ReshapeFromNHWCModel_Flat_NCHW(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
for (uint n = 0; n < O.batch; ++n)
{
//find the memory offset of target item in HWC format (aka on O)
uint index_NHWC = O.IndexHWC(n,y,x,c);
//from this offset find indices of item in HWC format before the reshape (aka on X)
uint c_NHWC, y_NHWC, x_NHWC, b_NHWC;
X.GetPositionFromIndexNHWC(index_NHWC, b_NHWC, y_NHWC, x_NHWC, c_NHWC);
//finally copy item
float v = X.Get(b_NHWC, y_NHWC, x_NHWC, c_NHWC);
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((64,1,1), (64,1,1), (64,1,1))
void ReshapeFromNHWCModel_Loop_NCHW(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.length, 1, 1);
TENSOR_ARGS2(X, O);
uint i = dispatchThreadID.x;
uint len = O.GetLength();
while (i < len)
{
uint c, y, x, n;
O.GetPositionFromIndexNCHW(i, n, y, x, c);
//find the memory offset of target item in HWC format (aka on O)
uint index_NHWC = O.IndexHWC(n,y,x,c);
//from this offset find indices of item in HWC format before the reshape (aka on X)
uint c_NHWC, y_NHWC, x_NHWC, b_NHWC;
X.GetPositionFromIndexNHWC(index_NHWC, b_NHWC, y_NHWC, x_NHWC, c_NHWC);
//finally copy item
float v = X.Get(b_NHWC, y_NHWC, x_NHWC, c_NHWC);
O.Set(n, y, x, c, v);
i += _LoopStride;
}
}
NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
void TransposeToChannelFirst(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH_ARGS(X.channels, X.width, X.height);
TENSOR_ARGS2_8D(X, O);
uint c = dispatchThreadID.x; uint w = dispatchThreadID.y; uint h = dispatchThreadID.z;
if (c >= O.channels) return; if (w >= O.width) return; if (h >= O.height) return;
for (uint s = 0; s < O.sequenceLength; ++s)
for (uint r = 0; r < O.numberOfDirections; ++r)
for (uint n = 0; n < O.batch; ++n)
for (uint t = 0; t < O.extraDimension; ++t)
for (uint d = 0; d < O.depth; ++d)
{
float v = X.Get8D(s,r,n,t,d,h,w,c);
uint index = X.IndexSRNCTDHW(s,r,n,t,d,h,w,c);
O.FastSet(index, v);
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 62f5efacd43b24dd38ead3ce0d80cc34
timeCreated: 1495527718
licenseType: Pro
ComputeShaderImporter:
currentAPIMask: 196608
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,2 @@
//See DebugUtils.cginc
//#define KERNEL_ASSERTS

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 658d58a262863454e8daacc86138ba3f
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,2 @@
//See DebugUtils.cginc
//#define KERNEL_ASSERTS

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: ae661a10fea2b40fcbe9ef81c40653cc
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,2 @@
//See DebugUtils.cginc
#define KERNEL_ASSERTS

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 801f6bbcb80e44fab8b21ca2a87367a8
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,406 @@
#pragma kernel MultidimMatMul_T16x16_R4x4_AR3_BR2_NHWC RANKA=3 RANKB=2 BLOCK_SIZE=4 CHANNELS_FIRST=0
#pragma kernel MultidimMatMul_T16x16_R4x4_AR3_BR2_NCHW RANKA=3 RANKB=2 BLOCK_SIZE=4 CHANNELS_FIRST=1
#pragma kernel MultidimMatMul_T8x8_R8x8_AR3_BR2_NHWC RANKA=3 RANKB=2 BLOCK_SIZE=8 KERNEL_PER_TG=64 CHANNELS_FIRST=0
#pragma kernel MultidimMatMul_T8x8_R8x8_AR3_BR2_NCHW RANKA=3 RANKB=2 BLOCK_SIZE=8 KERNEL_PER_TG=64 CHANNELS_FIRST=1
#pragma kernel MultidimMatMul_L1Cached64_AR3_BR2_NHWC RANKA=3 RANKB=2 CHANNELS_FIRST=0
#pragma kernel MultidimMatMul_L1Cached64_AR3_BR2_NCHW RANKA=3 RANKB=2 CHANNELS_FIRST=1
#include "Tensor.cginc"
TENSOR_DECL(A)
TENSOR_DECL(B)
//TENSOR_DECL(C)
TENSOR_DECL_RW(O)
float ffma(float a, float b, float c) { return dot(float2(a, c), float2(b, 1)); } //return a*b+c;} //fastfma(a,b,c); }
#if CHANNELS_FIRST
#define FUNC_NAME_CALL(KERNEL, SIZE, RANK1, RANK2) KERNEL##SIZE##x##SIZE##_AR##RANK1##_BR##RANK2##_NCHW
#define CACHE_NAME_CALL(KERNEL, SIZE, TENSOR) KERNEL##SIZE##x##SIZE##_Cache_##TENSOR##_NCHW
#else
#define FUNC_NAME_CALL(KERNEL, SIZE, RANK1, RANK2) KERNEL##SIZE##x##SIZE##_AR##RANK1##_BR##RANK2##_NHWC
#define CACHE_NAME_CALL(KERNEL, SIZE, TENSOR) KERNEL##SIZE##x##SIZE##_Cache_##TENSOR##_NHWC
#endif
#define FUNC_NAME(KERNEL, SIZE, RANK1, RANK2) FUNC_NAME_CALL(KERNEL, SIZE, RANK1, RANK2)
#define CACHE_NAME(KERNEL, SIZE, TENSOR) CACHE_NAME_CALL(KERNEL, SIZE, TENSOR)
#if BLOCK_SIZE == 8
#if KERNEL_PER_TG == 64
#define KERNEL_NAME MultidimMatMul_T8x8_R
#define CACHE_WIDTH_B_PAD 2
#define CACHE_WIDTH_A 64
#define CACHE_WIDTH_B (64+CACHE_WIDTH_B_PAD)
#define CACHE_DEPTH 8
groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)[1039]; // [(8*9)*(3*8+7)+(7)*8+7+1] // [(CACHE_WIDTH_A + CACHE_WIDTH_B)* BLOCK_SIZE];
[numthreads(8, 8, 1)]
void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE, RANKA, RANKB)(uint3 groupID : SV_GroupID, uint threadIndex : SV_GroupIndex)
{
TENSOR_ARGS3(A, B, O);
uint ti = threadIndex;
uint bx = groupID.x * 8 * BLOCK_SIZE;
uint by = groupID.y * 8 * BLOCK_SIZE;
uint n = A.width;
uint strideA = A.channels;
uint strideB = B.GetFlatWidth();
uint lengthB = B.GetLength() - 1;
uint dzA = groupID.z * n * strideA;
uint dzO = groupID.z * strideB * strideA;
#define LDS_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)
#define A_OFFSET 0
#define B_OFFSET CACHE_DEPTH*8*BLOCK_SIZE
float dstO[BLOCK_SIZE*BLOCK_SIZE];
uint tg_A = 0;
uint tg_B = 0;
[unroll] for (tg_A = 0; tg_A < BLOCK_SIZE; ++tg_A)
[unroll] for (tg_B = 0; tg_B < BLOCK_SIZE; ++tg_B)
dstO[tg_A*BLOCK_SIZE + tg_B] = 0.0f;
for (uint i = 0; i < n; i += CACHE_DEPTH)
{
#if CHANNELS_FIRST
//LDS_[A_OFFSET + ti + 8 * 8 * [0..7]] = A.FastGet((i + [0..7]) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 0] = A.MaskedGet(((by + ti) < strideA) && ((i + 0) < A.width), dzA + (i + 0) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 1] = A.MaskedGet(((by + ti) < strideA) && ((i + 1) < A.width), dzA + (i + 1) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 2] = A.MaskedGet(((by + ti) < strideA) && ((i + 2) < A.width), dzA + (i + 2) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 3] = A.MaskedGet(((by + ti) < strideA) && ((i + 3) < A.width), dzA + (i + 3) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 4] = A.MaskedGet(((by + ti) < strideA) && ((i + 4) < A.width), dzA + (i + 4) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 5] = A.MaskedGet(((by + ti) < strideA) && ((i + 5) < A.width), dzA + (i + 5) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 6] = A.MaskedGet(((by + ti) < strideA) && ((i + 6) < A.width), dzA + (i + 6) + A.width * (by + ti));
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 7] = A.MaskedGet(((by + ti) < strideA) && ((i + 7) < A.width), dzA + (i + 7) + A.width * (by + ti));
#else
//LDS_[A_OFFSET + ti + 8 * 8 * [0..7]] = A.FastGet(A.channels * (i + [0..7]) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 0] = A.MaskedGet(((by + ti) < A.channels) && (i + 0) < A.width, dzA + A.channels * (i + 0) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 1] = A.MaskedGet(((by + ti) < A.channels) && (i + 1) < A.width, dzA + A.channels * (i + 1) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 2] = A.MaskedGet(((by + ti) < A.channels) && (i + 2) < A.width, dzA + A.channels * (i + 2) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 3] = A.MaskedGet(((by + ti) < A.channels) && (i + 3) < A.width, dzA + A.channels * (i + 3) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 4] = A.MaskedGet(((by + ti) < A.channels) && (i + 4) < A.width, dzA + A.channels * (i + 4) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 5] = A.MaskedGet(((by + ti) < A.channels) && (i + 5) < A.width, dzA + A.channels * (i + 5) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 6] = A.MaskedGet(((by + ti) < A.channels) && (i + 6) < A.width, dzA + A.channels * (i + 6) + by + ti);
LDS_[A_OFFSET + ti + CACHE_WIDTH_A * 7] = A.MaskedGet(((by + ti) < A.channels) && (i + 7) < A.width, dzA + A.channels * (i + 7) + by + ti);
#endif
//LDS_[B_OFFSET + ti + writeIndex + (8 * 8 + 1) * [0..7]] = B.FastGet(strideB * (i + [0..7]) + bx + ti);
uint BWriteIndex = (ti & 0x20) >> 4;// (ti > 31) ? CACHE_WIDTH_B_PAD : 0;
LDS_[B_OFFSET + (ti + BWriteIndex) + 0 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 0) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 1 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 1) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 2 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 2) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 3 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 3) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 4 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 4) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 5 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 5) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 6 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 6) + bx + ti, lengthB));
LDS_[B_OFFSET + (ti + BWriteIndex) + 7 * CACHE_WIDTH_B] = B.FastGet(min(strideB * (i + 7) + bx + ti, lengthB));
GroupMemoryBarrierWithGroupSync();
//uint ptrA = A_OFFSET + (ti/8) * 8;
//uint ptrB = B_OFFSET + (ti%8) * 8 + readIndex;
uint ptrA = A_OFFSET + (ti & 0x78);
uint ptrB = ((ti & 7) << 3);
ptrB += (ti & 0x4) >> 1; // ptrB += (ptrB > 31) ? CACHE_WIDTH_B_PAD : 0;
ptrB += B_OFFSET;
float srcA[BLOCK_SIZE];
float srcB[BLOCK_SIZE];
[unroll] for (uint tg_CacheExecuteIdx = 0; tg_CacheExecuteIdx < CACHE_DEPTH; tg_CacheExecuteIdx++)
{
srcA[0] = LDS_[ptrA | 0];
srcA[1] = LDS_[ptrA | 1];
srcA[2] = LDS_[ptrA | 2];
srcA[3] = LDS_[ptrA | 3];
srcA[4] = LDS_[ptrA | 4];
srcA[5] = LDS_[ptrA | 5];
srcA[6] = LDS_[ptrA | 6];
srcA[7] = LDS_[ptrA | 7];
srcB[0] = LDS_[ptrB + 0];
srcB[1] = LDS_[ptrB + 1];
srcB[2] = LDS_[ptrB + 2];
srcB[3] = LDS_[ptrB + 3];
srcB[4] = LDS_[ptrB + 4];
srcB[5] = LDS_[ptrB + 5];
srcB[6] = LDS_[ptrB + 6];
srcB[7] = LDS_[ptrB + 7];
ptrA += CACHE_WIDTH_A;
ptrB += CACHE_WIDTH_B;
[unroll] for (tg_A = 0; tg_A < BLOCK_SIZE; ++tg_A)
[unroll] for (tg_B = 0; tg_B < BLOCK_SIZE; ++tg_B)
dstO[tg_A*BLOCK_SIZE + tg_B] = ffma(srcA[tg_A], srcB[tg_B], dstO[tg_A*BLOCK_SIZE + tg_B]);
}
GroupMemoryBarrierWithGroupSync();
}
#if CHANNELS_FIRST
[unroll] for (tg_A = 0; tg_A < BLOCK_SIZE; ++tg_A)
[unroll] for (tg_B = 0; tg_B < BLOCK_SIZE; ++tg_B)
{
uint writeAId = ((bx + 8 * (ti % 8)) + tg_A);
uint writeBId = ((by + 8 * (ti / 8)) + tg_B);
if (writeBId < O.channels && writeAId < O.width)
O.FastSet(dzO + writeAId + O.width * writeBId, dstO[BLOCK_SIZE * tg_B + tg_A]);
}
#else
[unroll] for (uint tg_AOffset = 0; tg_AOffset < BLOCK_SIZE; tg_AOffset += 2)
{
[unroll] for (tg_A = 0; tg_A < 2; ++tg_A)
[unroll] for (tg_B = 0; tg_B < BLOCK_SIZE; ++tg_B)
{
//To avoid bank conflict store in 32 groups [8pixelsGroups,4channelsGroups] each group contain 64 values [8pixels,8kernels] for a total of 2048 values [64pixels,32channels]
uint ldsOffsetOfGroup = 65 * (tg_A*BLOCK_SIZE + tg_B);//64 * ([0,3]*8+[0,7]) = [0,1984]
LDS_[ldsOffsetOfGroup + ti] = dstO[BLOCK_SIZE * tg_B + (tg_AOffset + tg_A)];
}
GroupMemoryBarrierWithGroupSync();
[unroll] for (tg_A = 0; tg_A < 16; ++tg_A)
{
// (((tg_A % 4) * 8) + (ti % 8)) * CACHE_WIDTH_A
uint ldsOffsetOfGroup = 65 * (((tg_A & 1) << 3) + (ti & 7));//CACHE_WIDTH_A * ([0,3]*8+[0,7]) = [0,1984]
// (ti / 8) * 8 + (tg_A / 4)
uint ldsOffsetInGroup = (ti & 0x78) + (tg_A >> 1);//[0,7]*8+[0,7] = [0,63]
//load from LDS and store to DDR
uint readIndex = ldsOffsetOfGroup + ldsOffsetInGroup;//[0,2047]
// bx + tg_!%4 + (tgA/4)*8 + tg_AOffset
uint writeXId = bx + (tg_A & 1) + ((tg_A >> 1) << 3) + tg_AOffset;
uint writeIndex = dzO + O.channels * writeXId + (by + ti);
if ((by+ti) < O.channels && writeXId < O.width)
O.FastSet(writeIndex, LDS_[readIndex]);
}
}
#endif
}
#endif
#undef CACHE_DEPTH
#undef KERNEL_NAME
#endif
#if BLOCK_SIZE == 4
#define KERNEL_NAME MultidimMatMul_T16x16_R
#define CACHE_DEPTH 16
groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)[2*CACHE_DEPTH*16*BLOCK_SIZE];
[numthreads(16, 16, 1)]
void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE, RANKA, RANKB)(uint3 groupID : SV_GroupID, uint threadIndex : SV_GroupIndex)
{
TENSOR_ARGS3(A, B, O);
uint ti = threadIndex;
uint bx = groupID.x * 16 * BLOCK_SIZE;
uint by = groupID.y * 16 * BLOCK_SIZE;
uint n = A.width;
uint strideA = A.channels;
uint strideB = B.GetFlatWidth();
#define LDS_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, LDS)
#define A_OFFSET 0
#define B_OFFSET CACHE_DEPTH*16*BLOCK_SIZE
float dstO[BLOCK_SIZE*BLOCK_SIZE];
dstO[0 * BLOCK_SIZE + 0] = 0;
dstO[0 * BLOCK_SIZE + 1] = 0;
dstO[0 * BLOCK_SIZE + 2] = 0;
dstO[0 * BLOCK_SIZE + 3] = 0;
dstO[1 * BLOCK_SIZE + 0] = 0;
dstO[1 * BLOCK_SIZE + 1] = 0;
dstO[1 * BLOCK_SIZE + 2] = 0;
dstO[1 * BLOCK_SIZE + 3] = 0;
dstO[2 * BLOCK_SIZE + 0] = 0;
dstO[2 * BLOCK_SIZE + 1] = 0;
dstO[2 * BLOCK_SIZE + 2] = 0;
dstO[2 * BLOCK_SIZE + 3] = 0;
dstO[3 * BLOCK_SIZE + 0] = 0;
dstO[3 * BLOCK_SIZE + 1] = 0;
dstO[3 * BLOCK_SIZE + 2] = 0;
dstO[3 * BLOCK_SIZE + 3] = 0;
uint tiD64M64 = (ti & 0x3c0);
uint tiMod4M16 = ((ti & 3) << 4);
uint tiMod64 = (ti & 63);
uint tiMod64D4 = (tiMod64 >> 2);
uint tiD64 = (ti >> 6);
for (uint i = 0; i < n; i += CACHE_DEPTH)
{
//LDS_[B_OFFSET + ((ti/64)*64) + ((ti%4)*16) + ((ti%64)/4) + 16*16*[0..3]] = B.FastGet(strideB * (i + (ti / 64) + 4*[0..3]) + bx + (ti % 64));
LDS_[B_OFFSET + tiD64M64 + tiMod4M16 + tiMod64D4 + 16 * 16 * 0] = B.FastGet(strideB * (i + tiD64 + 4 * 0) + bx + tiMod64);
LDS_[B_OFFSET + tiD64M64 + tiMod4M16 + tiMod64D4 + 16 * 16 * 1] = B.FastGet(strideB * (i + tiD64 + 4 * 1) + bx + tiMod64);
LDS_[B_OFFSET + tiD64M64 + tiMod4M16 + tiMod64D4 + 16 * 16 * 2] = B.FastGet(strideB * (i + tiD64 + 4 * 2) + bx + tiMod64);
LDS_[B_OFFSET + tiD64M64 + tiMod4M16 + tiMod64D4 + 16 * 16 * 3] = B.FastGet(strideB * (i + tiD64 + 4 * 3) + bx + tiMod64);
//LDS_[A_OFFSET + ti + 16 * 16 * [0..3]] = A.FastGet((by + (ti % 64)) + strideA * (i + (ti / 64) + 4 * [0..3]));
LDS_[A_OFFSET + ti + 16*16*0] = A.FastGet((by + tiMod64) + strideA * (i + tiD64 + 4*0));
LDS_[A_OFFSET + ti + 16*16*1] = A.FastGet((by + tiMod64) + strideA * (i + tiD64 + 4*1));
LDS_[A_OFFSET + ti + 16*16*2] = A.FastGet((by + tiMod64) + strideA * (i + tiD64 + 4*2));
LDS_[A_OFFSET + ti + 16*16*3] = A.FastGet((by + tiMod64) + strideA * (i + tiD64 + 4*3));
GroupMemoryBarrierWithGroupSync();
uint ptrA = (ti >> 4) << 2;
uint ptrB = B_OFFSET + (ti&15);
float srcA[BLOCK_SIZE];
float srcB[BLOCK_SIZE];
for (uint tg_CacheExecuteIdx = 0; tg_CacheExecuteIdx < CACHE_DEPTH; tg_CacheExecuteIdx++)
{
srcA[0] = LDS_[ptrA | 0];
srcA[1] = LDS_[ptrA | 1];
srcA[2] = LDS_[ptrA | 2];
srcA[3] = LDS_[ptrA | 3];
srcB[0] = LDS_[ptrB | 0*16];
srcB[1] = LDS_[ptrB | 1*16];
srcB[2] = LDS_[ptrB | 2*16];
srcB[3] = LDS_[ptrB | 3*16];
ptrA += 64;
ptrB += 64;
dstO[0 * BLOCK_SIZE + 0] = ffma(srcA[0], srcB[0], dstO[0 * BLOCK_SIZE + 0]);
dstO[0 * BLOCK_SIZE + 1] = ffma(srcA[0], srcB[1], dstO[0 * BLOCK_SIZE + 1]);
dstO[0 * BLOCK_SIZE + 2] = ffma(srcA[0], srcB[2], dstO[0 * BLOCK_SIZE + 2]);
dstO[0 * BLOCK_SIZE + 3] = ffma(srcA[0], srcB[3], dstO[0 * BLOCK_SIZE + 3]);
dstO[1 * BLOCK_SIZE + 0] = ffma(srcA[1], srcB[0], dstO[1 * BLOCK_SIZE + 0]);
dstO[1 * BLOCK_SIZE + 1] = ffma(srcA[1], srcB[1], dstO[1 * BLOCK_SIZE + 1]);
dstO[1 * BLOCK_SIZE + 2] = ffma(srcA[1], srcB[2], dstO[1 * BLOCK_SIZE + 2]);
dstO[1 * BLOCK_SIZE + 3] = ffma(srcA[1], srcB[3], dstO[1 * BLOCK_SIZE + 3]);
dstO[2 * BLOCK_SIZE + 0] = ffma(srcA[2], srcB[0], dstO[2 * BLOCK_SIZE + 0]);
dstO[2 * BLOCK_SIZE + 1] = ffma(srcA[2], srcB[1], dstO[2 * BLOCK_SIZE + 1]);
dstO[2 * BLOCK_SIZE + 2] = ffma(srcA[2], srcB[2], dstO[2 * BLOCK_SIZE + 2]);
dstO[2 * BLOCK_SIZE + 3] = ffma(srcA[2], srcB[3], dstO[2 * BLOCK_SIZE + 3]);
dstO[3 * BLOCK_SIZE + 0] = ffma(srcA[3], srcB[0], dstO[3 * BLOCK_SIZE + 0]);
dstO[3 * BLOCK_SIZE + 1] = ffma(srcA[3], srcB[1], dstO[3 * BLOCK_SIZE + 1]);
dstO[3 * BLOCK_SIZE + 2] = ffma(srcA[3], srcB[2], dstO[3 * BLOCK_SIZE + 2]);
dstO[3 * BLOCK_SIZE + 3] = ffma(srcA[3], srcB[3], dstO[3 * BLOCK_SIZE + 3]);
}
GroupMemoryBarrierWithGroupSync();
}
for (uint tg_registerChannelOffset = 0; tg_registerChannelOffset < BLOCK_SIZE; tg_registerChannelOffset += 2)
{
uint tg_kId;
uint tg_pId;
//Store 4 pixels x 2 channels per threads to LDS.
[unroll] for (tg_kId = 0; tg_kId < 2; ++tg_kId)
[unroll] for (tg_pId = 0; tg_pId < BLOCK_SIZE; ++tg_pId)
{
LDS_[64 * ((threadIndex % 16) * 2 + tg_kId) + (threadIndex / 16) * BLOCK_SIZE + tg_pId] = dstO[tg_pId * BLOCK_SIZE + (tg_registerChannelOffset + tg_kId)];
}
GroupMemoryBarrierWithGroupSync();
//We have a buffers of [64pixels,32channels] floats, each thread will store [1pixels,8channels] so a threadgroup is storing 64 pixels and 4 channels at a time to DDR in a linear fashion.
uint writePixelId = by + (threadIndex % 64);
[unroll] for (tg_kId = 0; tg_kId < 32; tg_kId += 4)
{
uint readChannelId = tg_kId + (threadIndex / 64);
uint readIndex = 64 * readChannelId + (threadIndex % 64);
uint writeChannelId = bx + (readChannelId % 2) + (readChannelId / 2)*BLOCK_SIZE + tg_registerChannelOffset;
O.FastSet(writeChannelId * strideA + writePixelId, LDS_[readIndex]);
}
GroupMemoryBarrierWithGroupSync();
}
#undef A_
#undef B_
}
#undef CACHE_DEPTH
#undef KERNEL_NAME
#endif
#undef FUNC_NAME
#undef CACHE_NAME
#undef FUNC_NAME_CALL
#undef CACHE_NAME_CALL
#if CHANNELS_FIRST
#define FUNC_NAME_CALL(KERNEL, RANK1, RANK2) KERNEL##_AR##RANK1##_BR##RANK2##_NCHW
#define CACHE_NAME_CALL(KERNEL, TENSOR) KERNEL##_Cache_##TENSOR##_NCHW
#else
#define FUNC_NAME_CALL(KERNEL, RANK1, RANK2) KERNEL##_AR##RANK1##_BR##RANK2##_NHWC
#define CACHE_NAME_CALL(KERNEL, TENSOR) KERNEL##_Cache_##TENSOR##_NHWC
#endif
#define FUNC_NAME(KERNEL, RANK1, RANK2) FUNC_NAME_CALL(KERNEL, RANK1, RANK2)
#define CACHE_NAME(KERNEL, TENSOR) CACHE_NAME_CALL(KERNEL, TENSOR)
// NOTE: usually this path is used for <16 batches
#undef CACHESIZE
#undef LDS_
#undef X_OFFSET
#undef W_OFFSET
#define KERNEL_NAME MultidimMatMul_L1Cached64
#define CACHESIZE 64
groupshared float CACHE_NAME(KERNEL_NAME, LDS)[CACHESIZE];
[numthreads(64, 1, 1)]
void FUNC_NAME(KERNEL_NAME, RANKA, RANKB)(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
{
//DISPATCH ARGS(O.flatWidth, O.flatHeight, 1);
TENSOR_ARGS3(A, B, O);
#define LDS_ CACHE_NAME(KERNEL_NAME, LDS)
uint x = CACHESIZE * groupID.x + groupThreadID.x;
uint y = groupID.y;
uint n = A.width;
uint strideA = A.channels;
uint strideB = B.GetFlatWidth();
uint dzA = groupID.z * n * strideA;
uint dzO = groupID.z * strideB * strideA;
float acc = 0.0;
// loop over X columns (flatWidth) and W rows (height) in CACHESIZE steps
for (uint i = 0; i < n; i += CACHESIZE)
{
// Cache X
// coalescent reads
bool maskA = (y < strideA) && (i + groupThreadID.x) < A.width;
#if CHANNELS_FIRST
LDS_[groupThreadID.x] = A.MaskedGet(maskA, dzA + y * A.width + (i + groupThreadID.x));
#else
LDS_[groupThreadID.x] = A.MaskedGet(maskA, dzA + (i + groupThreadID.x) * A.channels + y);
#endif
GroupMemoryBarrierWithGroupSync();
// X * W
[unroll]
for (uint di = 0; di < CACHESIZE; ++di)
{
acc = fastfma(LDS_[di], B.MaskedGet(x < strideB && (i + di) < B.GetFlatHeight(), x + (i + di)*strideB), acc);
}
GroupMemoryBarrierWithGroupSync();
}
if ((x < O.width) && (y < O.channels))
{
#if CHANNELS_FIRST
O.FastSet(dzO + y * O.width + x, acc);
#else
O.FastSet(dzO + x * O.channels + y, acc);
#endif
}
#undef LDS_
}

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 1892719d60b907b4eb8befb172f72544
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 4
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,166 @@
#pragma kernel Border2D_NHWC CHANNELS_FIRST=0
#pragma kernel Border2D_NCHW CHANNELS_FIRST=1
#pragma kernel Pad2DEdge_NHWC CHANNELS_FIRST=0
#pragma kernel Pad2DEdge_NCHW CHANNELS_FIRST=1
#pragma kernel Pad2DReflect_NHWC CHANNELS_FIRST=0
#pragma kernel Pad2DReflect_NCHW CHANNELS_FIRST=1
#pragma kernel Pad2DSymmetric_NHWC CHANNELS_FIRST=0
#pragma kernel Pad2DSymmetric_NCHW CHANNELS_FIRST=1
#include "Tensor.cginc"
TENSOR_DECL(X)
TENSOR_DECL(B)
TENSOR_DECL_RW(O)
uint4 _Pool;
uint4 _Stride;
uint4 _Pad;
float _Beta;
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(Border2D)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= O.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
// NOTE: negative "pad" variable crop X tensor
int croppedWidth = _Pool.x;
int croppedHeight = _Pool.y;
int croppedChannels = _Pool.z;
int readX = x - _Pad.x;
int readY = y - _Pad.y;
int readC = c - _Pad.z;
bool paddedTexel = (readX < 0 || readX >= croppedWidth || readY < 0 || readY >= croppedHeight || readC < 0 || readC >= croppedChannels);
for (uint n = 0; n < O.batch; ++n)
{
float v = _Beta;
if (!paddedTexel)
v = X.Get(n, readY, readX, readC);
O.Set(n, y, x, c, v);
}
}
void ClampHWToTensorShape(uint2 shape, inout int height, inout int width)
{
width = clamp(width, 0, (int)shape.x - 1);
height = clamp(height, 0, (int)shape.y - 1);
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(Pad2DEdge)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= O.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
int readX = x - _Pad.x;
int readY = y - _Pad.y;
//clamp read indices to source
ClampHWToTensorShape(uint2(X.width, X.height), readY, readX);
for (uint n = 0; n < O.batch; ++n)
{
float v = X.Get(n, readY, readX, c);
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(Pad2DReflect)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= O.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
int readX = x - _Pad.x;
int readY = y - _Pad.y;
uint2 Xshape = uint2(X.width, X.height);
int lastXIndex = Xshape.x - 1;
int lastYIndex = Xshape.y - 1;
//x reflect indexing
readX = (readX < 0) ? -readX : readX;
readX = (readX > lastXIndex) ? lastXIndex - (readX - lastXIndex) : readX;
//y reflect indexing
readY = (readY < 0) ? -readY : readY;
readY = (readY > lastYIndex) ? lastYIndex - (readY - lastYIndex) : readY;
//clamp read indices to source
ClampHWToTensorShape(Xshape, readY, readX);
for (uint n = 0; n < O.batch; ++n)
{
float v = X.Get(n, readY, readX, c);
O.Set(n, y, x, c, v);
}
}
NUMTHREADS((4, 8, 8), (4, 8, 4), (4, 4, 4))
void KERNEL_FUNC(Pad2DSymmetric)(uint3 dispatchThreadID : SV_DispatchThreadID)
{
//DISPATCH ARGS(O.channels, O.width, O.height);
TENSOR_ARGS2(X, O);
uint c = dispatchThreadID.x;
uint x = dispatchThreadID.y;
uint y = dispatchThreadID.z;
if (c >= O.channels) return;
if (x >= O.width) return;
if (y >= O.height) return;
int readX = x - _Pad.x;
int readY = y - _Pad.y;
uint2 Xshape = uint2(X.width, X.height);
int lastXIndex = Xshape.x - 1;
int lastYIndex = Xshape.y - 1;
//x reflect indexing
readX = (readX < 0) ? -readX - 1: readX;
readX = (readX > lastXIndex) ? lastXIndex - (readX - lastXIndex) + 1: readX;
//y reflect indexing
readY = (readY < 0) ? -readY - 1: readY;
readY = (readY > lastYIndex) ? lastYIndex - (readY - lastYIndex) + 1: readY;
//clamp read indices to source
ClampHWToTensorShape(Xshape, readY, readX);
for (uint n = 0; n < O.batch; ++n)
{
float v = X.Get(n, readY, readX, c);
O.Set(n, y, x, c, v);
}
}

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: cf52068b3397856488e3ec8c94fa02ef
ComputeShaderImporter:
externalObjects: {}
currentAPIMask: 4
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: c23201977ed5ef64885111460f407afb
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,199 @@
Shader "Barracuda/Activation"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma multi_compile None Relu Selu Abs Neg Ceil Floor Round Reciprocal Swish Tanh Softplus Sigmoid HardSigmoid Relu6 Elu LeakyRelu Exp Log Sqrt Acos Acosh Asin Asinh Atan Atanh Cos Cosh Sin Sinh Tan Pow Clip Erf Sign LogicalNot
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
float signed_pow(float f, float e)
{
// handle negative f
float v = pow(abs(f), e);
float s = (e % 2 == 1) ?
sign(f) : // exponent is odd => sign(f) * pow(abs(f), e)
1; // exponent is even => pow(abs(f), e)
return v * s;
}
float erf(float v)
{
// Abramowitz/Stegun approximations
// erf(x) = -erf(-x)
float x = abs(v);
float p = 0.3275911f;
float a1 = 0.254829592f; float a2 = -0.284496736f; float a3 = 1.421413741f;
float a4 = -1.453152027f; float a5 = 1.061405429f;
float t = 1.0f / (1.0f + p * x);
float t2 = t * t;
float t3 = t2 * t;
float t4 = t3 * t;
float t5 = t4 * t;
return sign(v)*(1 - (a1*t + a2 * t2 + a3 * t3 + a4 * t4 + a5 * t5)*exp(-x * x));
}
float _Alpha;
float _Beta;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 v = X.Get4(n, h, w, c4);
#ifdef Relu
v = 0.5f * (v + abs(v));
#endif
#ifdef Selu
v = _Beta * (max(v, 0.0f) + min(_Alpha * (exp(v) - 1.0f), 0.0f));
#endif
#ifdef Abs
v = abs(v);
#endif
#ifdef Neg
v = -v;
#endif
#ifdef Ceil
v = ceil(v);
#endif
#ifdef Floor
v = floor(v);
#endif
#ifdef Round
v = round(v);
#endif
#ifdef Reciprocal
v = 1.0f / v;
#endif
#ifdef Swish
v = v / (1 + exp(-v));
#endif
#ifdef Tanh
v = tanh(clamp(v,-16.0f,16.0f));//clamp to avoid NaNs for large values.
#endif
#ifdef Softplus
v = log(exp(v) + 1);
#endif
#ifdef Sigmoid
v = 1 / (1 + exp(-v));
#endif
#ifdef HardSigmoid
v = max(0.0f, min(1.0f, _Alpha * v + _Beta));
#endif
#ifdef Relu6
v = min(max(0, v), 6);
#endif
#ifdef Elu
if (v.x <= 0)
v.x = _Alpha * (exp(v.x) - 1);
if (v.y <= 0)
v.y = _Alpha * (exp(v.y) - 1);
if (v.z <= 0)
v.z = _Alpha * (exp(v.z) - 1);
if (v.w <= 0)
v.w = _Alpha * (exp(v.w) - 1);
#endif
#ifdef LeakyRelu
v = max(v, _Alpha * v);
#endif
#ifdef Exp
v = exp(v);
#endif
#ifdef Log
v = log(v);
#endif
#ifdef Sqrt
v = sqrt(v);
#endif
#ifdef Acos
v = acos(v);
#endif
#ifdef Acosh
v = log(v + sqrt(v * v - 1.0f));
#endif
#ifdef Asin
v = asin(v);
#endif
#ifdef Asinh
v = log(v + sqrt(v*v + 1.0f));
#endif
#ifdef Atan
v = atan(v);
#endif
#ifdef Atanh
v = 0.5f * log((1.0f + v) / (1.0f - v));
#endif
#ifdef Cos
v = cos(v);
#endif
#ifdef Cosh
v = 0.5f * (exp(v) + exp(-v));
#endif
#ifdef Sin
v = sin(v);
#endif
#ifdef Sinh
v = 0.5f * (exp(v) - exp(-v));
#endif
#ifdef Tan
v = tan(v);
#endif
#ifdef Pow
v.x = signed_pow(v.x, _Alpha);
v.y = signed_pow(v.y, _Alpha);
v.z = signed_pow(v.z, _Alpha);
v.w = signed_pow(v.w, _Alpha);
#endif
#ifdef Clip
v = clamp(v, _Alpha, _Beta);
#endif
#ifdef Erf
v.x = erf(v.x);
v.y = erf(v.y);
v.z = erf(v.z);
v.w = erf(v.w);
#endif
#ifdef Sign
v = sign(v);
#endif
#ifdef LogicalNot
v = (v == 0.0f) ? 1.0f : 0.0f;
#endif
if (4 * c4 >= X.channels)
v.x = 0.0f;
if (4 * c4 + 1 >= X.channels)
v.y = 0.0f;
if (4 * c4 + 2 >= X.channels)
v.z = 0.0f;
if (4 * c4 + 3 >= X.channels)
v.w = 0.0f;
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 9626ea9ab0b94e94a95ddbd110d29e78
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,60 @@
Shader "Barracuda/AvgPool2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
uint4 _Pool;
uint4 _Pad;
uint4 _Stride;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
uint2 leftCorner = _Pad.xy;
uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
float4 acc4 = 0;
float counter = 0;
for (uint dy = 0; dy < _Pool.y; ++dy)
for (uint dx = 0; dx < _Pool.x; ++dx)
{
uint oy = h * _Stride.y + dy;
uint ox = w * _Stride.x + dx;
bool mask = (oy >= leftCorner.y) && (ox >= leftCorner.x) && (oy < rightCorner.y) && (ox < rightCorner.x);
acc4 += (mask) ? X.Get4(n, min(oy - leftCorner.y, X.height - 1), min(ox - leftCorner.x, X.width - 1), c4) : 0;
counter += (mask) ? 1 : 0;
}
acc4 /= counter;
return acc4;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 0c6f0ed2e703bff4dae9ffaf72e4d67f
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,58 @@
Shader "Barracuda/Border2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
int4 _Pad;
int4 _Pool;
float _Beta;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
int croppedWidth = _Pool.x;
int croppedHeight = _Pool.y;
int readX = (int)(w) - _Pad.x;
int readY = (int)(h) - _Pad.y;
float4 v = 0.0f;
if (readX < 0 || readX >= croppedWidth ||
readY < 0 || readY >= croppedHeight)
{
v = _Beta;
}
else
{
v = X.Get4(n, readY, readX, c4);
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: e7bb264f71a76b64ca7a26148d7c18fd
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,191 @@
Shader "Barracuda/Broadcast"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma multi_compile Sub Pow Mul Min Mean Max LogicalXor LogicalOr LogicalAnd LessEqual Less GreaterEqual Greater Equal Div Add
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(B)
int _IsFirstDispatch;
float _Alpha;
float signed_pow(float f, float e)
{
// handle negative f
float v = pow(abs(f), e);
float s = (e % 2 == 1) ?
sign(f) : // exponent is odd => sign(f) * pow(abs(f), e)
1; // exponent is even => pow(abs(f), e)
return v * s;
}
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS3(X, B, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 v = 0.0;
#ifdef Sub
v = X.BroadcastGet4(n, h, w, c4) - B.BroadcastGet4(n, h, w, c4);
#endif
#ifdef Pow
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
v.x = signed_pow(a.x, b.x);
v.y = signed_pow(a.y, b.y);
v.z = signed_pow(a.z, b.z);
v.w = signed_pow(a.w, b.w);
#endif
#ifdef Mul
v = X.BroadcastGet4(n, h, w, c4) * B.BroadcastGet4(n, h, w, c4);
#endif
#ifdef Min
v = min(X.BroadcastGet4(n, h, w, c4), B.BroadcastGet4(n, h, w, c4));
#endif
#ifdef Mean
float4 a = X.BroadcastGet4(n, h, w, c4);
a *= _IsFirstDispatch ? _Alpha : 1.0f;
float4 b = B.BroadcastGet4(n, h, w, c4) * _Alpha;
v = a + b;
#endif
#ifdef Max
v = max(X.BroadcastGet4(n, h, w, c4), B.BroadcastGet4(n, h, w, c4));
#endif
#ifdef LogicalXor
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
a.x = (a.x == 0.0f) ? 0.0f : 1.0f;
a.y = (a.y == 0.0f) ? 0.0f : 1.0f;
a.z = (a.z == 0.0f) ? 0.0f : 1.0f;
a.w = (a.w == 0.0f) ? 0.0f : 1.0f;
b.x = (b.x == 0.0f) ? 0.0f : 1.0f;
b.y = (b.y == 0.0f) ? 0.0f : 1.0f;
b.z = (b.z == 0.0f) ? 0.0f : 1.0f;
b.w = (b.w == 0.0f) ? 0.0f : 1.0f;
v = a * (1 - 2 * b) + b;
#endif
#ifdef LogicalOr
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
a.x = (a.x == 0.0f) ? 0.0f : 1.0f;
a.y = (a.y == 0.0f) ? 0.0f : 1.0f;
a.z = (a.z == 0.0f) ? 0.0f : 1.0f;
a.w = (a.w == 0.0f) ? 0.0f : 1.0f;
b.x = (b.x == 0.0f) ? 0.0f : 1.0f;
b.y = (b.y == 0.0f) ? 0.0f : 1.0f;
b.z = (b.z == 0.0f) ? 0.0f : 1.0f;
b.w = (b.w == 0.0f) ? 0.0f : 1.0f;
v = a * (1 - b) + b;
#endif
#ifdef LogicalAnd
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
a.x = (a.x == 0.0f) ? 0.0f : 1.0f;
a.y = (a.y == 0.0f) ? 0.0f : 1.0f;
a.z = (a.z == 0.0f) ? 0.0f : 1.0f;
a.w = (a.w == 0.0f) ? 0.0f : 1.0f;
b.x = (b.x == 0.0f) ? 0.0f : 1.0f;
b.y = (b.y == 0.0f) ? 0.0f : 1.0f;
b.z = (b.z == 0.0f) ? 0.0f : 1.0f;
b.w = (b.w == 0.0f) ? 0.0f : 1.0f;
v.x = a.x * b.x != 0.0 ? 1.0f : 0.0f;
v.y = a.y * b.y != 0.0 ? 1.0f : 0.0f;
v.z = a.z * b.z != 0.0 ? 1.0f : 0.0f;
v.w = a.w * b.w != 0.0 ? 1.0f : 0.0f;
#endif
#ifdef LessEqual
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
v.x = (a.x <= b.x) ? 1.0f : 0.0f;
v.y = (a.y <= b.y) ? 1.0f : 0.0f;
v.z = (a.z <= b.z) ? 1.0f : 0.0f;
v.w = (a.w <= b.w) ? 1.0f : 0.0f;
#endif
#ifdef Less
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
v.x = (a.x < b.x) ? 1.0f : 0.0f;
v.y = (a.y < b.y) ? 1.0f : 0.0f;
v.z = (a.z < b.z) ? 1.0f : 0.0f;
v.w = (a.w < b.w) ? 1.0f : 0.0f;
#endif
#ifdef GreaterEqual
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
v.x = (a.x >= b.x) ? 1.0f : 0.0f;
v.y = (a.y >= b.y) ? 1.0f : 0.0f;
v.z = (a.z >= b.z) ? 1.0f : 0.0f;
v.w = (a.w >= b.w) ? 1.0f : 0.0f;
#endif
#ifdef Greater
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
v.x = (a.x > b.x) ? 1.0f : 0.0f;
v.y = (a.y > b.y) ? 1.0f : 0.0f;
v.z = (a.z > b.z) ? 1.0f : 0.0f;
v.w = (a.w > b.w) ? 1.0f : 0.0f;
#endif
#ifdef Equal
float4 a = X.BroadcastGet4(n, h, w, c4);
float4 b = B.BroadcastGet4(n, h, w, c4);
v.x = (a.x == b.x) ? 1.0f : 0.0f;
v.y = (a.y == b.y) ? 1.0f : 0.0f;
v.z = (a.z == b.z) ? 1.0f : 0.0f;
v.w = (a.w == b.w) ? 1.0f : 0.0f;
#endif
#ifdef Div
v = X.BroadcastGet4(n, h, w, c4) / B.BroadcastGet4(n, h, w, c4);
#endif
#ifdef Add
v = X.BroadcastGet4(n, h, w, c4) + B.BroadcastGet4(n, h, w, c4);
#endif
if (4 * c4 >= O.channels)
v.x = 0.0f;
if (4 * c4 + 1 >= O.channels)
v.y = 0.0f;
if (4 * c4 + 2 >= O.channels)
v.z = 0.0f;
if (4 * c4 + 3 >= O.channels)
v.w = 0.0f;
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: f868a56d815cb174a9054230194069c9
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,57 @@
Shader "Barracuda/BroadcastWhere"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(K)
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS4(X, W, K, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 cond = (X.BroadcastGet4(n, h, w, c4) != 0.0f);
float4 a = W.BroadcastGet4(n, h, w, c4);
float4 b = K.BroadcastGet4(n, h, w, c4);
float4 v = 0.0f;
v.x = cond.x ? a.x : b.x;
v.y = cond.y ? a.y : b.y;
v.z = cond.z ? a.z : b.z;
v.w = cond.w ? a.w : b.w;
if (4 * c4 >= O.channels)
v.x = 0.0f;
if (4 * c4 + 1 >= O.channels)
v.y = 0.0f;
if (4 * c4 + 2 >= O.channels)
v.z = 0.0f;
if (4 * c4 + 3 >= O.channels)
v.w = 0.0f;
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 8eefed1a026d30840a504a3df988f403
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,56 @@
Shader "Barracuda/BufferToTensor"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
uint _InputHeight;
uint _InputWidth;
Texture2D<float> Xtex2D;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_O(O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 v = 0.0f;
[unroll]
for (uint cc = 0; cc < 4; cc++)
{
if (c4*4+cc >= O.channels)
break;
uint index = n * O.height * O.width * O.channels + h * O.width * O.channels + w * O.channels + 4 * c4 + cc;
uint x = (index) % _InputWidth;
uint y = (index) / _InputWidth;
v[cc] = Xtex2D.Load(uint3(x, y, 0)).r;
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 97a746f9b7f26334c840552c379b55e0
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,22 @@
#include "UnityCG.cginc"
struct appdata
{
float4 vertex : POSITION;
float2 uv : TEXCOORD0;
};
struct v2f
{
float2 uv : TEXCOORD0;
float4 vertex : SV_POSITION;
};
v2f vert(appdata v)
{
v2f o;
o.vertex = UnityObjectToClipPos(v.vertex);
o.uv = v.uv;
return o;
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: ccbdf3223f3727b49b4a9b9b1f13b205
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,71 @@
Shader "Barracuda/Concat"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(OPred)
uint4 _Pad;
uint _IsFirstPass;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS3(X, OPred, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
uint c;
float4 v = 0;
if (_IsFirstPass == 1)
v = 0;
else
v = OPred.Get4(n, h, w, c4);
if ((n >= _Pad.x && n - _Pad.x < X.batch) &&
(h >= _Pad.y && h - _Pad.y < X.height) &&
(w >= _Pad.z && w - _Pad.z < X.width))
{
c = 4 * c4 + 0;
if (c >= _Pad.w && c - _Pad.w < X.channels)
v.x = X.Get(n - _Pad.x, h - _Pad.y, w - _Pad.z, c - _Pad.w);
c = 4 * c4 + 1;
if (c >= _Pad.w && c - _Pad.w < X.channels)
v.y = X.Get(n - _Pad.x, h - _Pad.y, w - _Pad.z, c - _Pad.w);
c = 4 * c4 + 2;
if (c >= _Pad.w && c - _Pad.w < X.channels)
v.z = X.Get(n - _Pad.x, h - _Pad.y, w - _Pad.z, c - _Pad.w);
c = 4 * c4 + 3;
if (c >= _Pad.w && c - _Pad.w < X.channels)
v.w = X.Get(n - _Pad.x, h - _Pad.y, w - _Pad.z, c - _Pad.w);
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: d7d0dd0d75980854698fda3b64064f15
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,68 @@
Shader "Barracuda/Conv2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(K)
TENSOR_DECL(B)
uint4 _Pad;
uint4 _Stride;
fixed4 frag(v2f i) : SV_Target
{
TENSOR_O(O);
TENSOR_ARG(X);
TENSOR_ARG(K);
TENSOR_ARG(B);
uint n, h, w, k4;
O.GetPositionFromUV(i.uv, n, h, w, k4);
float4 acc4 = B.Get4(0, 0, 0, k4);
for (uint c4 = 0; c4 < X.channels4; c4++)
{
for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
{
for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
{
uint2 pos = uint2(w, h) * _Stride.xy + uint2(dx, dy);
float4 v = X.SafeGet4(n, pos, c4, _Pad.xy);
float4 w0 = K.Get4(dy, dx, 4 * c4 + 0, k4);
float4 w1 = K.Get4(dy, dx, 4 * c4 + 1, k4);
float4 w2 = K.Get4(dy, dx, 4 * c4 + 2, k4);
float4 w3 = K.Get4(dy, dx, 4 * c4 + 3, k4);
acc4.x += dot(v, float4(w0.x, w1.x, w2.x, w3.x));
acc4.y += dot(v, float4(w0.y, w1.y, w2.y, w3.y));
acc4.z += dot(v, float4(w0.z, w1.z, w2.z, w3.z));
acc4.w += dot(v, float4(w0.w, w1.w, w2.w, w3.w));
}
}
}
return ApplyFusedActivation(acc4);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: d0c49b9d8f87a034b9a2ccc84df087ef
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,86 @@
Shader "Barracuda/Conv2DTrans"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(K)
TENSOR_DECL(B)
uint4 _Pad;
uint4 _Stride;
fixed4 frag(v2f i) : SV_Target
{
TENSOR_O(O);
TENSOR_ARG(X);
TENSOR_ARG(K);
TENSOR_ARG(B);
uint n, h, w, k4;
O.GetPositionFromUV(i.uv, n, h, w, k4);
uint2 strideMask = _Stride.xy - 1;
float4 acc4 = B.Get4(0, 0, 0, k4);
uint strideH = 1;
uint strideW = 1;
for (uint c4 = 0; c4 < X.channels4; c4++)
{
for (uint dy = 0; dy < K.GetKernelHeight(); dy += strideH)
{
for (uint dx = 0; dx < K.GetKernelWidth(); dx += strideW)
{
uint readX = (w + dx - _Pad.x) / _Stride.x;
uint readY = (h + dy - _Pad.y) / _Stride.y;
// early out if read input index fall upon leftmost outer zero padding
if ((w + dx) < _Pad.x) continue;
if ((h + dy) < _Pad.y) continue;
// early out if read input index fall upon rightmost outer zero padding
if (readX >= X.width) continue;
if (readY >= X.height) continue;
if ((w + dx - _Pad.x) % _Stride.x != 0) continue;
if ((h + dy - _Pad.y) % _Stride.y != 0) continue;
float4 v = X.Get4(n, readY, readX, c4);
float4 w0 = K.Get4(K.GetKernelHeight() - 1 - dy, K.GetKernelWidth() - 1 - dx, 4 * c4 + 0, k4);
float4 w1 = K.Get4(K.GetKernelHeight() - 1 - dy, K.GetKernelWidth() - 1 - dx, 4 * c4 + 1, k4);
float4 w2 = K.Get4(K.GetKernelHeight() - 1 - dy, K.GetKernelWidth() - 1 - dx, 4 * c4 + 2, k4);
float4 w3 = K.Get4(K.GetKernelHeight() - 1 - dy, K.GetKernelWidth() - 1 - dx, 4 * c4 + 3, k4);
acc4.x += dot(v, float4(w0.x, w1.x, w2.x, w3.x));
acc4.y += dot(v, float4(w0.y, w1.y, w2.y, w3.y));
acc4.z += dot(v, float4(w0.z, w1.z, w2.z, w3.z));
acc4.w += dot(v, float4(w0.w, w1.w, w2.w, w3.w));
}
}
}
return ApplyFusedActivation(acc4);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 390e5d6f68d6cea4187c73311334cce8
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,54 @@
Shader "Barracuda/Copy"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 v = 0.0f;
[unroll]
for (uint cc = 0; cc < 4; cc++)
{
if (c4 * 4 + cc >= O.channels)
break;
uint index = n * O.height * O.width * O.channels + h * O.width * O.channels + w * O.channels + (4 * c4 + cc);
uint cX = index % X.channels;
uint wX = (index / X.channels) % X.width;
uint hX = (index / X.channels / X.width) % X.height;
uint nX = (index / X.channels / X.width / X.height);
v[cc] = X.Get(nX, hX, wX, cX);
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: c72b5e72f9c141943ab0c51ddbc37622
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,56 @@
Shader "Barracuda/Dense"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(B)
fixed4 frag(v2f i) : SV_Target
{
TENSOR_O(O);
TENSOR_ARG(X);
TENSOR_ARG(W);
TENSOR_ARG(B);
uint n, h, w, k4;
O.GetPositionFromUV(i.uv, n, h, w, k4);
float4 acc4 = B.Get4(0, 0, 0, k4);
for (uint c4 = 0; c4 < X.channels4; c4++)
{
float4 v = X.Get4(n, 0, 0, c4);
float4 w0 = W.Get4(4 * c4 + 0, 0, 0, k4);
float4 w1 = W.Get4(4 * c4 + 1, 0, 0, k4);
float4 w2 = W.Get4(4 * c4 + 2, 0, 0, k4);
float4 w3 = W.Get4(4 * c4 + 3, 0, 0, k4);
acc4.x += dot(v, float4(w0.x, w1.x, w2.x, w3.x));
acc4.y += dot(v, float4(w0.y, w1.y, w2.y, w3.y));
acc4.z += dot(v, float4(w0.z, w1.z, w2.z, w3.z));
acc4.w += dot(v, float4(w0.w, w1.w, w2.w, w3.w));
}
return ApplyFusedActivation(acc4);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 5c7b7cbbc9eafbe419d22e9485aacb45
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,47 @@
Shader "Barracuda/Dense3"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(B)
fixed4 frag(v2f i) : SV_Target
{
TENSOR_O(O);
TENSOR_ARG(X);
TENSOR_ARG(W);
TENSOR_ARG(B);
uint n, h, w, k4;
O.GetPositionFromUV(i.uv, n, h, w, k4);
float4 acc4 = B.Get(0, 0, 0, w);
for (uint j = 0; j < X.width; ++j)
{
acc4 += X.Get4(n, 0, j, k4) * W.Get(j, 0, 0, w);
}
return ApplyFusedActivation(acc4);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: d15d796a4efe4b0429e2d4c08f53e10a
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,56 @@
Shader "Barracuda/DepthToSpace_CRD"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
uint4 _Pool;
fixed4 frag(v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, y, x, c4;
O.GetPositionFromUV(i.uv, n, y, x, c4);
uint bsX = _Pool.x;
uint bsY = _Pool.y;
float4 v = 0;
[unroll]
for (uint cc = 0; cc < 4; cc++)
{
uint iy = y / bsY;
uint by = y % bsY;
uint ix = x / bsX;
uint bx = x % bsX;
uint cRead = ((4 * c4 + cc) * bsX * bsY) + (by * bsX) + bx;
if(cRead < X.channels)
v[cc] = X.Get(n, iy, ix, cRead);
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 23341d4f86653834a9d49a4bd2eed862
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,56 @@
Shader "Barracuda/DepthToSpace_DCR"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
uint4 _Pool;
fixed4 frag(v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, y, x, c4;
O.GetPositionFromUV(i.uv, n, y, x, c4);
uint bsX = _Pool.x;
uint bsY = _Pool.y;
float4 v = 0;
[unroll]
for (uint cc = 0; cc < 4; cc++)
{
uint iy = y / bsY;
uint by = y % bsY;
uint ix = x / bsX;
uint bx = x % bsX;
uint cRead = (by * bsX * O.channels) + (bx * O.channels) + (4 * c4 + cc);
if (cRead < X.channels)
v[cc] = X.Get(n, iy, ix, cRead);
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 443f2de71dcda184581a1e90c9bb9ea2
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,56 @@
Shader "Barracuda/DepthwiseConv2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(K)
TENSOR_DECL(B)
uint4 _Pad;
uint4 _Stride;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS4(X, K, B, O);
uint n, h, w, k4;
O.GetPositionFromUV(i.uv, n, h, w, k4);
float4 acc4 = B.Get4(0, 0, 0, k4);
for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
{
for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
{
uint2 pos = uint2(w, h) * _Stride.xy + uint2(dx, dy);
float4 v = X.SafeGet4(n, pos, k4, _Pad.xy);
float4 w0 = K.Get4(dy, dx, 0, k4);
acc4 += v * w0;
}
}
return ApplyFusedActivation(acc4);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 9bb27abd97d768e4a89ca0ed9f2bd88a
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,66 @@
Shader "Barracuda/Gather"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma multi_compile Input1D Input2D
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(K)
uint _Axis;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS3(X, K, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 v = 0.0f;
if (_Axis == 0)
v = X.Get4((uint)K.Get(n,0,0,0), h, w, c4);
else if (_Axis == 1)
v = X.Get4(n, (uint)K.Get(h,0,0,0), w, c4);
else if (_Axis == 2)
v = X.Get4(n, h, (uint)K.Get(w,0,0,0), c4);
else if (_Axis == 3)
{
v.x = X.Get(n, h, w, (uint)K.Get(4 * c4 + 0, 0, 0, 0));
v.y = X.Get(n, h, w, (uint)K.Get(4 * c4 + 1, 0, 0, 0));
v.z = X.Get(n, h, w, (uint)K.Get(4 * c4 + 2, 0, 0, 0));
v.w = X.Get(n, h, w, (uint)K.Get(4 * c4 + 3, 0, 0, 0));
}
if (4 * c4 >= O.channels)
v.x = 0.0f;
if (4 * c4 + 1 >= O.channels)
v.y = 0.0f;
if (4 * c4 + 2 >= O.channels)
v.z = 0.0f;
if (4 * c4 + 3 >= O.channels)
v.w = 0.0f;
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: c27dc5cba8ccd9d408583df574b77160
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,43 @@
Shader "Barracuda/GlobalAvgPool2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 acc4 = 0;
for (uint y = 0; y < X.height; ++y)
for (uint x = 0; x < X.width; ++x)
acc4 += X.Get4(n, y, x, c4);
acc4 /= (X.height * X.width);
return acc4;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: ec56be125cfe2de4e8664b6e4fd7c00b
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,57 @@
Shader "Barracuda/GlobalAvgVariancePool2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float mean = 0;
float mean2 = 0;
for (uint y = 0; y < X.height; ++y)
{
for (uint x = 0; x < X.width; ++x)
{
float4 v = X.Get4(n, y, x, c4);
mean += v;
mean2 += v * v;
}
}
mean /= (X.height * X.width);
mean2 /= (X.height * X.width);
if (h == 0)
return mean;
else if (h == 1)
return mean2;
else
return 0;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 750dc44b5188a0047915538013c7fafa
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,47 @@
Shader "Barracuda/GlobalMaxPool2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 maxV4 = -FLT_MAX;
for (uint y = 0; y < X.height; ++y)
for (uint x = 0; x < X.width; ++x)
{
float4 v = X.Get4(n, y, x, c4);
maxV4.x = max(v.x, maxV4.x);
maxV4.y = max(v.y, maxV4.y);
maxV4.z = max(v.z, maxV4.z);
maxV4.w = max(v.w, maxV4.w);
}
return maxV4;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: ef5a86e12013d444fb3b1abdd0f52de4
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,54 @@
Shader "Barracuda/Sigmoid"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
float _Alpha;
float _Beta;
TENSOR_DECL_O(O)
TENSOR_DECL(X)
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 v = X.Get4(n, h, w, c4);
v.x = max(0.0f, min(1.0f, _Alpha * v.x + _Beta));
v.y = max(0.0f, min(1.0f, _Alpha * v.y + _Beta));
v.z = max(0.0f, min(1.0f, _Alpha * v.z + _Beta));
v.w = max(0.0f, min(1.0f, _Alpha * v.w + _Beta));
if (4 * c4 >= X.channels)
v.x = 0.0f;
if (4 * c4 + 1 >= X.channels)
v.y = 0.0f;
if (4 * c4 + 2 >= X.channels)
v.z = 0.0f;
if (4 * c4 + 3 >= X.channels)
v.w = 0.0f;
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 2791212921327144d8248dfe5f9b79da
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,61 @@
Shader "Barracuda/InstanceNorm"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(W)
TENSOR_DECL(B)
float _Epsilon;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS4(X, W, B, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 gamma = W.Get4(0, 0, 0, c4);
float4 beta = B.Get4(0, 0, 0, c4);
float4 alpha = X.Get4(n, 0, 0, c4);
uint y, x;
float4 sum = 0, sumSq = 0;
for (y = 0; y < X.height; ++y)
for (x = 0; x < X.width; ++x)
{
float4 delta = X.Get4(n, y, x, c4) - alpha;
sum += delta;
sumSq += delta * delta;
}
float4 mean = alpha + sum / (X.width * X.height);
float4 var = (sumSq - (sum * sum) / (X.width * X.height)) / (X.width * X.height);
float4 v = X.Get4(n, h, w, c4);
v = gamma * (v - mean) / sqrt(var + _Epsilon) + beta;
return ApplyFusedActivation(v);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 250538bd780e0484c82352bcefb68f4d
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,74 @@
Shader "Barracuda/LRN"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
float _Alpha;
float _Beta;
float _Epsilon;
uint _Axis;
TENSOR_DECL_O(O)
TENSOR_DECL(X)
float signed_pow(float f, float e)
{
// handle negative f
float v = pow(abs(f), e);
float s = (e % 2 == 1) ?
sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
1; // exponent is even => pow(abs(f), e)
return v * s;
}
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float bias = _Epsilon;
float sizef = (float)_Axis;
float regionCenter = (sizef - 1.0f) / 2.0f;
float4 v = X.Get4(n, h, w, c4);
[unroll]
for (uint cc = 0; cc < 4; cc++)
{
uint c = 4 * c4 + cc;
uint regionStart = max(0, c - (uint)floor(regionCenter));
uint regionEnd = min(X.channels, c + (uint)ceil(regionCenter) + 1);
float sumOfSquared = 0.0f;
for (uint ci = regionStart; ci < regionEnd; ++ci)
{
float regionValue = X.Get(n, h, w, ci);
sumOfSquared += regionValue * regionValue;
}
v[cc] /= signed_pow(bias + _Alpha * sumOfSquared / sizef, _Beta);
}
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 5038389ba3277cf43b4844f5520eb231
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,124 @@
Shader "Barracuda/LogSoftmax"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma multi_compile ReduceN ReduceH ReduceW ReduceC
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
uint _Axis;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 maxV = -FLT_MAX;
uint j = 0;
#ifdef ReduceN
for (j = 0; j < X.batch; j++)
#endif
#ifdef ReduceH
for (j = 0; j < X.height; j++)
#endif
#ifdef ReduceW
for (j = 0; j < X.width; j++)
#endif
#ifdef ReduceC
for (j = 0; j < X.channels4; j++)
#endif
{
float4 v = 0.0f;
#ifdef ReduceN
v = X.SafeGet4(j, uint2(w, h), c4, uint2(0, 0), -FLT_MAX);
#endif
#ifdef ReduceH
v = X.SafeGet4(n, uint2(w, j), c4, uint2(0, 0), -FLT_MAX);
#endif
#ifdef ReduceW
v = X.SafeGet4(n, uint2(j, h), c4, uint2(0, 0), -FLT_MAX);
#endif
#ifdef ReduceC
v = X.SafeGet4(n, uint2(w, h), j, uint2(0, 0), -FLT_MAX);
#endif
maxV = max(maxV, v);
}
#ifdef ReduceC
maxV = max(maxV.x, max(maxV.y, max(maxV.z, maxV.w)));
#endif
float4 acc = 0.0f;
#ifdef ReduceN
for (j = 0; j < X.batch; j++)
#endif
#ifdef ReduceH
for (j = 0; j < X.height; j++)
#endif
#ifdef ReduceW
for (j = 0; j < X.width; j++)
#endif
#ifdef ReduceC
for (j = 0; j < X.channels4; j++)
#endif
{
float4 v = 0.0f;
#ifdef ReduceN
v = X.Get4(j, h, w, c4);
#endif
#ifdef ReduceH
v = X.Get4(n, j, w, c4);
#endif
#ifdef ReduceW
v = X.Get4(n, h, j, c4);
#endif
#ifdef ReduceC
v = X.Get4(n, h, w, j);
#endif
#ifdef ReduceC
if (4 * j + 0 < X.channels)
acc.x += exp(v.x - maxV.x);
if (4 * j + 1 < X.channels)
acc.y += exp(v.y - maxV.y);
if (4 * j + 2 < X.channels)
acc.z += exp(v.z - maxV.z);
if (4 * j + 3 < X.channels)
acc.w += exp(v.w - maxV.w);
#else
acc += exp(v - maxV);
#endif
}
#ifdef ReduceC
acc = acc.x + acc.y + acc.z + acc.w;
#endif
float4 v = X.Get4(n, h, w, c4);
v = (v - maxV) - log(acc);
return v;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 3a0431d0fc2c43c468a2b6c7e67de5d0
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,86 @@
Shader "Barracuda/MatMul"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma multi_compile xTranspose_OFF xTranspose_ON
#pragma multi_compile yTranspose_OFF yTranspose_ON
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
TENSOR_DECL(Y)
fixed4 frag(v2f i) : SV_Target
{
TENSOR_O(O);
TENSOR_ARG(X);
TENSOR_ARG(Y);
uint n, h, w, k4;
O.GetPositionFromUV(i.uv, n, h, w, k4);
float4 acc4 = 0.0f;
for (uint c4 = 0; c4 < X.channels4; c4++)
{
float4 a = X.Get4(n, 0, 0, c4);
#ifdef xTranspose_ON
a.x = X.Get(4 * c4 + 0, 0, 0, n);
a.y = X.Get(4 * c4 + 1, 0, 0, n);
a.z = X.Get(4 * c4 + 2, 0, 0, n);
a.w = X.Get(4 * c4 + 3, 0, 0, n);
#endif
float4 b0 = Y.Get4(4 * c4 + 0, 0, 0, k4);
float4 b1 = Y.Get4(4 * c4 + 1, 0, 0, k4);
float4 b2 = Y.Get4(4 * c4 + 2, 0, 0, k4);
float4 b3 = Y.Get4(4 * c4 + 3, 0, 0, k4);
#ifdef yTranspose_ON
b0.x = Y.Get(4 * k4 + 0, 0, 0, 4 * c4 + 0);
b0.y = Y.Get(4 * k4 + 1, 0, 0, 4 * c4 + 0);
b0.z = Y.Get(4 * k4 + 2, 0, 0, 4 * c4 + 0);
b0.w = Y.Get(4 * k4 + 3, 0, 0, 4 * c4 + 0);
b1.x = Y.Get(4 * k4 + 0, 0, 0, 4 * c4 + 1);
b1.y = Y.Get(4 * k4 + 1, 0, 0, 4 * c4 + 1);
b1.z = Y.Get(4 * k4 + 2, 0, 0, 4 * c4 + 1);
b1.w = Y.Get(4 * k4 + 3, 0, 0, 4 * c4 + 1);
b2.x = Y.Get(4 * k4 + 0, 0, 0, 4 * c4 + 2);
b2.y = Y.Get(4 * k4 + 1, 0, 0, 4 * c4 + 2);
b2.z = Y.Get(4 * k4 + 2, 0, 0, 4 * c4 + 2);
b2.w = Y.Get(4 * k4 + 3, 0, 0, 4 * c4 + 2);
b3.x = Y.Get(4 * k4 + 0, 0, 0, 4 * c4 + 3);
b3.y = Y.Get(4 * k4 + 1, 0, 0, 4 * c4 + 3);
b3.z = Y.Get(4 * k4 + 2, 0, 0, 4 * c4 + 3);
b3.w = Y.Get(4 * k4 + 3, 0, 0, 4 * c4 + 3);
#endif
acc4.x += dot(a, float4(b0.x, b1.x, b2.x, b3.x));
acc4.y += dot(a, float4(b0.y, b1.y, b2.y, b3.y));
acc4.z += dot(a, float4(b0.z, b1.z, b2.z, b3.z));
acc4.w += dot(a, float4(b0.w, b1.w, b2.w, b3.w));
}
return acc4;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: a85eb55355defae4dbd68f698857c3c7
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,49 @@
Shader "Barracuda/MaxPool2D"
{
Properties
{
}
SubShader
{
// No culling or depth
Cull Off ZWrite Off ZTest Always
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "CommonVertexShader.cginc"
#include "TensorTexture.cginc"
TENSOR_DECL_O(O)
TENSOR_DECL(X)
uint4 _Pool;
uint4 _Pad;
uint4 _Stride;
fixed4 frag (v2f i) : SV_Target
{
TENSOR_ARGS2(X, O);
uint n, h, w, c4;
O.GetPositionFromUV(i.uv, n, h, w, c4);
float4 maxV = -FLT_MAX;
for (uint dy = 0; dy < _Pool.y; ++dy)
for (uint dx = 0; dx < _Pool.x; ++dx)
{
uint2 pos = uint2(w, h) * _Stride.xy + uint2(dx, dy);
float4 v = X.SafeGet4(n, pos, c4, _Pad.xy, -FLT_MAX);
maxV = max(v, maxV);
}
return maxV;
}
ENDCG
}
}
}

Some files were not shown because too many files have changed in this diff Show More