Files
unity-application/Packages/com.unity.barracuda/Runtime/ONNX/ONNXModelConverter.cs
Jelle De Geest c6faf3bf6f Sprint 3
2023-03-26 21:23:17 +00:00

3462 lines
171 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using Google.Protobuf;
using Google.Protobuf.Collections;
using Onnx;
using Unity.Barracuda.Compiler.Passes;
using UnityEngine;
using UnityEngine.Assertions;
using UnityEngine.Profiling;
[assembly: InternalsVisibleTo("Unity.Barracuda.Tests")]
namespace Unity.Barracuda.ONNX
{
/// <summary>
/// ONNX model converter to Barracuda format.
/// </summary>
public class ONNXModelConverter
{
[Flags]
internal enum ImportMode
{
Legacy = 0, // No flags == legacy
Standard = 1 << 0,
// Additional options
KeepAsNCHW = 1 << 16,
SkipMetadataImport = 1 << 17,
}
[Flags]
internal enum DataTypeMode
{
Default = 0,
ForceHalf = 1,
ForceFloat = 2
}
// Configuration
bool m_TreatErrorsAsWarnings;
bool m_OptimizeModel = true;
bool m_ForceArbitraryBatchSize;
ImportMode m_ImportMode;
// TF2ONNX known issue: (as of 1.5.4)
// - Conv are framed with Transposes as long as the NCHW flag is not set
// (note this seems that it's going to be fixed https://github.com/onnx/tensorflow-onnx/pull/796)
// - Tensorflow appends :0 to all node names
bool m_FixTf2OnnxExportIssues;
/// <summary>
/// Model imported event
/// </summary>
public static event Action<object, Model> ModelImported;
private readonly Dictionary<string, ONNXTensor> m_OverrideGlobalInputs = new Dictionary<string, ONNXTensor>()
{
{ "sequence_length:0", new ONNXTensor(new Tensor(1, 1, new[] { 1f }), new [] { 1 }) },
{ "sequence_length", new ONNXTensor(new Tensor(1, 1, new[] { 1f }), new [] { 1 }) }
};
private readonly HashSet<string> m_ShouldNotBeBaked = new HashSet<string>()
{
// the following nodes handle constant inputs in a custom manner and should not be baked:
"Constant", "Reshape", "Shape", "Slice", "Gather", "Transpose", "Squeeze", "Unsqueeze", "NonZero", "ConstantOfShape",
// the following nodes are dynamic in nature and can not be baked even when all inputs are constant:
"RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike"
};
private readonly HashSet<string> m_AllInputsChannelFirst = new HashSet<string>()
{
// the following onnx nodes have all of there inputs as channel first layout
"Concat", "Add", "Sum", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Mean", "Greater", "Less", "Equal", "Or", "And", "Xor", "Where"
};
// Shortcuts
private Dictionary<string, ONNXTensor> constantTensors { get { return m_ModelTensors.constants; } }
private Dictionary<string, VariableTensor> variableTensors { get { return m_ModelTensors.variables; } }
private Dictionary<string, string> lstmInputs = new Dictionary<string, string>();
private Dictionary<string, string> lstmOutputs = new Dictionary<string, string>();
private List<string> layerRequiringUpstreamPatch = new List<string>();
private void Add(string opType, Action<ModelBuilder, ONNXNodeWrapper> opImportAction)
{
m_NodeImporters.Add(opType, opImportAction);
}
/// <summary>
/// Convert ONNX model and return Barracuda Model object.
/// </summary>
/// <param name="filePath">Location of the input ONNX model.</param>
/// <returns>Barracuda Model object.</returns>
public Model Convert(string filePath)
{
using (var readStream = new FileStream(filePath, FileMode.Open, FileAccess.Read))
using (var inputStream = new CodedInputStream(readStream))
return Convert(inputStream);
}
/// <summary>
/// Convert ONNX model and return Barracuda Model object.
/// </summary>
/// <param name="buffer">Memory buffer containing ONNX model.</param>
/// <returns>Barracuda Model object.</returns>
public Model Convert(byte[] buffer)
{
using (var inputStream = new CodedInputStream(buffer))
return Convert(inputStream);
}
// Legacy LSTM importer automagically split input nodes and added output nodes when they didn't exist in the
// network, which is no longer supported
bool IsLegacyMLAgentsLSTMNetwork(ModelProto onnxModel)
{
GraphProto graph = onnxModel.Graph;
// Hallway-lstm.onnx - legacy importer splits recurrent_in to recurrent_in_c and recurrent_in_h
// adds output node recurrent_out_c and recurrent_out_h
if (onnxModel.ProducerName == "tf2onnx"
&& graph.Input.Any(i => i.Name.Contains("recurrent_in"))
&& graph.Output.Any(o => o.Name.Contains("recurrent_out")))
return true;
// Hallway.onnx / Hallway-no-workaround.onnx - legacy importer splits memories to memories_c and memories_h;
// adds output node recurrent_out_<nn>_c and recurrent_out_<nn>_h
NodeProto lstmNode = graph.Node.FirstOrDefault(n => n.OpType == "LSTM");
if (onnxModel.ProducerName == "pytorch"
&& graph.Input.Any(i => i.Name.Contains("memories"))
&& lstmNode != null
&& lstmNode.Output.Count == 3
&& !graph.Node.Any(n => n.Name == lstmNode.Output[1]) // missing output cell and hidden nodes
&& !graph.Node.Any(n => n.Name == lstmNode.Output[2]))
return true;
// Hallway_1_9.onnx - This was supposed to be the candidate for ML-Agents 2.0, but did not have transposes
// in the network, so we will have to import using legacy importer and support during the 1.x ML-Agents
// lifecycle since this already shipped.
lstmNode = graph.Node.FirstOrDefault(n => n.OpType == "LSTM");
if (onnxModel.ProducerName == "pytorch"
&& graph.Input.Any(i => i.Name.Contains("recurrent_in"))
&& graph.Output.Any(i => i.Name.Contains("recurrent_out"))
// Input to LSTM node is incorrectly coming directly from a Slice w/o a Transpose
&& lstmNode != null
&& lstmNode.Input.Any(i =>
{
var inputNode = graph.Node.FirstOrDefault(n => n.Output.FirstOrDefault() == i);
return inputNode != null && inputNode.Input.Contains("recurrent_in") && inputNode.OpType == "Slice";
}))
return true;
return false;
}
internal Model Convert(CodedInputStream inputStream)
{
var onnxModel = new ModelProto();
onnxModel.MergeFrom(inputStream);
m_FixTf2OnnxExportIssues = onnxModel.ProducerName == "tf2onnx";
bool legacyMLAgentsLSTMNetwork = IsLegacyMLAgentsLSTMNetwork(onnxModel);
if (legacyMLAgentsLSTMNetwork)
m_ImportMode = ImportMode.Legacy;
if (m_ImportMode.HasFlag(ImportMode.Standard))
UseStandardImporter();
else
UseLegacyImporter();
var model = ConvertOnnxModel(onnxModel);
if (m_ImportMode.HasFlag(ImportMode.Standard))
{
var preserveLayersPass = new PreserveLayersPass();
preserveLayersPass.Run(ref model);
if (m_ImportMode.HasFlag(ImportMode.KeepAsNCHW))
{
// Since our model is non-runnable due to NHWC-native ops this pass is always required
var runnableNCHWPass = new IntermediateToRunnableNCHWPass();
runnableNCHWPass.Run(ref model);
}
else
{
var runnableNHWCPass = new IntermediateToRunnableNHWCPass()
{
Optimize = m_OptimizeModel
};
runnableNHWCPass.Run(ref model);
}
}
if (legacyMLAgentsLSTMNetwork)
model.Warnings.Add(new Model.ImporterWarning("model", "Using legacy importer since legacy LSTM network was detected; Support will be removed in Barracuda v2.0"));
ModelImported?.Invoke(onnxModel, model);
return model;
}
/// <summary>
/// Constructs ONNX model converter
/// </summary>
/// <param name="optimizeModel">Enable/disable various model optimizations while importing model from ONNX format.</param>
/// <param name="treatErrorsAsWarnings">Treat import errors as warnings.</param>
/// <param name="forceArbitraryBatchSize">Repair model input batch size. Sometimes needed for ONNX models coming from PyTorch.</param>
public ONNXModelConverter(bool optimizeModel, bool treatErrorsAsWarnings = false, bool forceArbitraryBatchSize = true)
: this(optimizeModel, treatErrorsAsWarnings, forceArbitraryBatchSize, ImportMode.Standard)
{
}
// Internal constructor to allow setting import mode
internal ONNXModelConverter(bool optimizeModel, bool treatErrorsAsWarnings, bool forceArbitraryBatchSize, ImportMode importMode)
{
m_OptimizeModel = optimizeModel;
m_TreatErrorsAsWarnings = treatErrorsAsWarnings;
m_ForceArbitraryBatchSize = forceArbitraryBatchSize;
m_ImportMode = importMode;
}
void UseStandardImporter()
{
m_NodeImporters.Clear();
var defaultZeroTensor = new ONNXTensor(new Tensor(1, 1, new[] { 0f }), new[] { 1 });
Add("Constant", (net, node) => {
node.UnsupportedAttribute("sparse_value");
Const(node, node.ValueAsTensor);
});
Add("ConstantOfShape", (net, node) => {
UnityEngine.Debug.Assert(node.InputCount > 0);
ONNXTensor valueTensor = node.GetOptionalTensor("value", defaultZeroTensor);
var value = valueTensor.ToBarracuda("ONNX").AsFloats()[0];
if (node.IsInput0Const)
{
var onnxShape = node.Input0Constant("ONNX").AsInts();
int onnxRank = onnxShape.Length;
onnxShape = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, "ONNX");
var tensor = new Tensor(onnxShape);
tensor.Fill(value);
net.Const(node.Name, tensor, -1, onnxRank);
}
else
{
net.ConstantOfShape(node.Name, node.Input0, value);
}
});
Add("Reshape", (net, node) => {
int[] onnxShape;
if (node.InputCount == 1)
{
onnxShape = node.Shape;
if (node.IsInput0Const)
{
// reshape constant source tensor and store it as the new constant
var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape);
Const(node, reshapedTensor);
}
else
{
net.Reshape(node.Name, node.Input0, onnxShape);
Output(node, rank:onnxShape.Length);
}
}
else
{
if (node.IsInput1Const)
{
onnxShape = node.Input1Constant(onnxLayout: "ONNX", name: "shape").AsInts();
if (node.IsInput0Const)
{
// reshape constant source tensor and store it as the new constant
var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape);
Const(node, reshapedTensor);
}
else
{
net.Reshape(node.Name, node.Input0, onnxShape);
Output(node, rank:onnxShape.Length);
}
}
else
{
net.Reshape(node.Name, node.Input0, node.Input1);
}
}
});
Add("Expand", (net, node) => {
if (node.IsInput1Const)
{
var onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsInts();
net.Expand(node.Name, node.Input0, onnxShape);
Output(node, rank: onnxShape.Length);
}
else
{
net.Expand(node.Name, node.Input0, node.Input1);
}
});
Add("Shape", (net, node) =>
{
float[] shapeValuesAsFloats;
if (node.IsInput0Const)
{
shapeValuesAsFloats = constantTensors[node.Input0].shape.Select(x => (float)x).ToArray();
}
else
{
net.Shape(node.Name, node.Input0);
}
});
Add("Unsqueeze", (net, node) =>
{
int[] constAxes = null;
if (node.InputCount >= 2 && node.IsInput1Const)
constAxes = node.Input1Constant(onnxLayout: "ONNX", name: "axes").AsInts();
else
constAxes = node.Axes;
if (node.IsInput0Const && constAxes != null)
{
var unsqueezed = constantTensors[node.Input0].Unsqueeze(constAxes);
Const(node, unsqueezed);
}
else if (node.InputCount == 1)
{
net.Unsqueeze(node.Name, node.Input0, node.Axes);
}
else
{
net.Unsqueeze(node.Name, node.Input0, node.Input1);
}
});
Add("Squeeze", (net, node) =>
{
int[] constAxes = null;
if (node.InputCount >= 2 && node.IsInput1Const)
constAxes = node.Input1Constant(onnxLayout: "ONNX", name: "axes").AsInts();
else
constAxes = node.Axes;
if (node.IsInput0Const && constAxes != null)
{
var squeezed = constantTensors[node.Input0].Squeeze(constAxes);
Const(node, squeezed);
}
else if (node.InputCount == 1)
{
net.Squeeze(node.Name, node.Input0, node.Axes);
}
else
{
net.Squeeze(node.Name, node.Input0, node.Input1);
}
});
Add("Tile", (net, node) =>
{
// only 4D Tile support for now
net.Tile(node.Name, node.Input0, node.Input1);
});
Add("Flatten", (net, node) => {
node.UnsupportedAttribute("axis", 1); // TODO we can support it, insert transposes or if dimensions are ok, == reshape
net.Flatten(node.Name, node.Input0);
Output(node, rank:2);
});
Add("Concat", (net, node) => {
int axis = node.AxisOptional(0);
if (node.Inputs.Length == 1)
net.Identity(node.Name, node.Input0);
else
{
net.Concat(node.Name, node.Inputs, axis, true);
}
});
Add("Split", (net, node) => {
int axis = node.AxisOptional(0);
int[] splits;
try
{
splits = node.GetRequiredIntArray("split");
}
catch (OnnxLayerImportException)
{
throw new OnnxLayerImportException($"Unsupported default attribute `split` for node {node.Name} of type Split. Value is required.");
}
Assert.IsTrue(splits.Length == node.Outputs.Length);
int currentSliceStartIndex = 0;
// Convert `Split` into multiple `StridedSlice` operations.
for (int i = 0; i < splits.Length; ++i)
{
var starts = currentSliceStartIndex;
var ends = starts + splits[i];
var strides = 1;
net.StridedSlice(node.Outputs[i], node.Input0, new[] { starts }, new[] { ends }, new[] { strides }, new[] { axis });
currentSliceStartIndex += splits[i];
}
});
Add("Slice", (net, node) => {
int[] starts, ends, axes, steps;
if (node.InputCount > 1) // Slice-10
{
if (!node.IsInput1Const || !node.IsInput2Const)
{
if(node.InputCount == 5)
net.StridedSlice(node.Name, node.Input0, starts: node.Input1, ends: node.Input2, strides: node.Input4, axes: node.Input3);
else if (node.InputCount == 3)
net.StridedSlice(node.Name, node.Input0, starts: node.Input1, ends: node.Input2, strides: null, axes: null);
}
else
{
var constStarts = node.Input1Constant(onnxLayout: "ONNX", name: "starts");
var constEnds = node.Input2Constant(onnxLayout: "ONNX", name: "ends");
var defaultAxes = new Tensor(constStarts.shape, Enumerable.Range(0, constStarts.length).Select(v => (float)v).ToArray());
var constAxes = node.Input3ConstantOptional(defaultAxes, onnxLayout: "ONNX", name: "axes");
var constSteps = node.Input4ConstantOptional(constStarts.shape, 1.0f, onnxLayout: "ONNX", name: "steps");
starts = constStarts.AsInts();
ends = constEnds.AsInts();
axes = constAxes.AsInts();
steps = constSteps.AsInts();
net.StridedSlice(node.Name, node.Input0, starts: starts, ends: ends, strides: steps, axes: axes);
}
}
else // Slice-1
{
starts = node.Starts;
ends = node.Ends;
axes = node.AxesOptional(Enumerable.Range(0, starts.Length).ToArray());
steps = Enumerable.Repeat(1, starts.Length).ToArray();
net.StridedSlice(node.Name, node.Input0, starts: starts, ends: ends, strides: steps, axes: axes);
}
});
Add("Gather", (net, node) =>
{
int axis = node.AxisOptional(0);
if (node.IsInput0Const && node.IsInput1Const)
{
var indices = node.Input1Constant(onnxLayout:"ONNX", name:"indices").AsInts();
ONNXTensor gatheredTensor = constantTensors[node.Input0].Gather(axis, indices);
Const(node, gatheredTensor);
}
else
{
int input1Rank = node.Input1Rank;
if (node.IsInput1Const)
{
bool isIndicesIntValue = !node.IsInput1Array("indices");
// The original rank was cached above since our constant tensor requires a shape of rank 1 and original may have been a scalar
var indices = node.Input1Constant(onnxLayout: "ONNX", name: "indices").AsFloats();
var shape = isIndicesIntValue ? new int[] { } : new[] { indices.Length };
var constTensor = new ONNXTensor(new Tensor(new [] { indices.Length, 1, 1, 1, 1, 1, 1, 1 }, indices), shape);
Const(node.Input1, constTensor);
}
// for import conveintcy, gather with single int values and not int[] implemented with int[] followed by squeeze
if (node.Input1Rank == 0)
{
var gatherLayer = net.Gather(node.Name + "_Squeezed", node.Input0, node.Input1, axis, true);
net.Squeeze(node.Name, gatherLayer, new[] { axis });
}
else
{
net.Gather(node.Name, node.Input0, node.Input1, axis, true);
}
Output(node.Name, rank: input1Rank + node.Input0Rank - 1);
}
});
Add("ScatterND", (net, node) =>
{
string reduction = node.GetOptionalString("reduction", "none");
Layer.ScatterNDReductionMode reductionType = Layer.ScatterNDReductionMode.None;
if (reduction == "add")
reductionType = Layer.ScatterNDReductionMode.Add;
else if (reduction == "mul")
reductionType = Layer.ScatterNDReductionMode.Mul;
net.ScatterND(node.Name, node.Input0, node.Input1, node.Input2, reductionType);
});
Add("NonMaxSuppression", (net, node) =>
{
int centerPointBox = node.GetOptionalInt("center_point_box", 0);
var boxes = node.GetRequiredInput(0);
var scores = node.GetRequiredInput(1);
object maxOutputBoxesPerClass = 0f;
object iouThreshold = 0f;
object scoreThreshold = 0f;
if (node.InputCount > 4 && node.IsInput2Const && node.IsInput3Const && node.IsInput4Const
|| node.InputCount > 3 && node.IsInput2Const && node.IsInput3Const
|| node.InputCount > 2 && node.IsInput2Const)
{
// Use constant version (possibly with defaults)
maxOutputBoxesPerClass = node.Input2ConstantOptional((float)maxOutputBoxesPerClass, "ONNX", nameof(maxOutputBoxesPerClass))[0];
iouThreshold = node.Input3ConstantOptional((float)iouThreshold, "ONNX", nameof(iouThreshold))[0];
scoreThreshold = node.Input4ConstantOptional((float)scoreThreshold, "ONNX", nameof(scoreThreshold))[0];
}
else
{
// Use dynamic tensor version
maxOutputBoxesPerClass = node.Input2Optional;
iouThreshold = node.Input3Optional;
scoreThreshold = node.Input4Optional;
}
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
net.NonMaxSuppression(node.Name, boxes, scores, maxOutputBoxesPerClass, iouThreshold, scoreThreshold, centerPointBox);
Output(node, rank: 2);
});
Add("OneHot", (net, node) => {
node.UnsupportedAttribute("axis", -1);
var defaultOffOn = new Tensor(2, 0, new float[] {0, 1});
var depth = (int)node.Input1Constant(onnxLayout:"C", name:"depth")[0];
var offon = node.Input2ConstantOptional(defaultOffOn, onnxLayout:"C", name:"values");
net.OneHot(node.Name, node.Input0, depth, (int)offon[1], (int)offon[0]);
Output(node, features:depth, rank: node.Input0Rank + 1);
});
Add("RoiAlign", (net, node) =>
{
node.UnsupportedAttribute("mode"); // TODO support
int output_height = node.GetOptionalInt("output_height", 1);
int output_width = node.GetOptionalInt("output_width", 1);
int sampling_ratio = node.GetOptionalInt("sampling_ratio", 0);
float spatial_scale = node.GetOptionalFloat("spatial_scale", 1.0f);
net.RoiAlign(node.Name, node.Input0, node.Input1, node.Input2, output_height, output_width, sampling_ratio, spatial_scale);
});
Add("TopK", (net, node) => {
int axis = node.AxisOptional(-1);
// TopK-11 introduced these options
bool largest = node.GetOptionalInt("largest", 1) == 1;
// If sorted = false, then the output is undefined
bool sorted = node.GetOptionalInt("sorted", 1) == 1;
string kName;
if (node.InputCount > 1) // TopK-10 introduced K as an input tensor
{
kName = node.Input1;
}
else
{
// TopK-1
int k = node.GetRequiredInt("k");
kName = "Const_TopK";
var kTensor = new ONNXTensor(
data:new Tensor(new[] { 1, 1, 1, 1 }, new[] { (float)k }, kName),
onnxShape:new [] { 1 });
Const(node, kTensor);
}
Layer indices = net.TopKIndices(node.Outputs[1], node.Input0, kName, axis, largest, sorted);
Output(node.Outputs[1], rank: node.Input0Rank);
net.TopKValues(node.Outputs[0], node.Input0, indices, axis);
Output(node.Outputs[0], rank: node.Input0Rank);
});
Add("NonZero", (net, node) => {
if (node.IsInput0Const)
{
var nonZeroTensor = constantTensors[node.Input0].NonZero();
Const(node, nonZeroTensor);
}
else
{
net.NonZero(node.Name, node.Input0);
Output(node.Outputs[0], rank: 2);
}
});
Add("LSTM", (net, node) =>
{
node.UnsupportedAttribute("activation_alpha");
node.UnsupportedAttribute("activation_beta");
node.UnsupportedAttribute("activations", new[] { "Sigmoid", "Tanh", "Tanh" }); // Only Sigmoid is supported for now
node.UnsupportedAttribute("clip");
node.UnsupportedAttribute("direction", "forward"); // Only forward direction supported
node.UnsupportedAttribute("input_forget");
node.UnsupportedAttribute("layout"); // alternate layout not supported
int hiddenSize = node.GetRequiredInt("hidden_size");
string[] nodeInputs = node.Inputs;
int inputCount = nodeInputs.Length;
object W = node.Input1;
if (node.IsInput1Const)
W = node.Input1Constant(onnxLayout: "RKC", name: "W");
object R = node.Input2;
if (node.IsInput2Const)
R = node.Input2Constant(onnxLayout: "RKC", name: "R");
object B = node.Input3Optional;
if (inputCount > 3 && node.IsInput3Const)
{
B = node.Input3Constant(onnxLayout: "RC", name: "B");
}
else if (string.IsNullOrEmpty((string)B))
{
var tensor = new Tensor(new TensorShape(1, 8 * hiddenSize));
tensor.Fill(0);
B = net.Const($"Const_{node.Name}_B", tensor, rank: 2);
}
int outputCount = node.Outputs.Length;
string[] outputs = { node.Outputs[0],
outputCount > 1 ? node.Outputs[1] : null,
outputCount > 2 ? node.Outputs[2] : null };
string initialHidden = inputCount > 5 && !string.IsNullOrEmpty(nodeInputs[5]) ? node.Input5Optional : null;
string initialCell = inputCount > 6 && !string.IsNullOrEmpty(nodeInputs[6]) ? node.Input6Optional : null;
net.LSTM(node.Name, node.Input0, outputs, W, R, B, hiddenSize, initialHidden, initialCell);
Output(node.Outputs[0], rank:2); // Actually rank 4, but needs to be 2 for how we handle this layer (re-evaluate?)
if (outputCount > 1)
Output(node.Outputs[1], rank:2); // Actually rank 3, but needs to be 2 for how we handle this layer (re-evaluate?)
if (outputCount > 2)
Output(node.Outputs[2], rank:2); // Actually rank 3, but needs to be 2 for how we handle this layer (re-evaluate?)
});
// Activation ops
Add("Relu", (net, node) => { net.Relu(node.Name, node.Input0); });
Add("Softmax", (net, node) =>
{
const int defaultAxis = 1;
int axis = node.AxisOptional(defaultAxis);
net.Softmax(node.Name, node.Input0, axis, axisIs8D: true); // keep axis as is
});
Add("Tanh", (net, node) => { net.Tanh(node.Name, node.Input0); });
Add("Sqrt", (net, node) => { net.Sqrt(node.Name, node.Input0); });
Add("Sigmoid", (net, node) => { net.Sigmoid(node.Name, node.Input0); });
Add("Elu", (net, node) => { net.Elu(node.Name, node.Input0, node.AlphaOptional(1f)); });
Add("LeakyRelu",(net, node) => { net.LeakyRelu(node.Name, node.Input0, node.AlphaOptional(0.01f)); });
Add("Selu", (net, node) => { net.Selu(node.Name, node.Input0, node.AlphaOptional(1.67326f), node.GammaOptional(1.0507f)); });
Add("Swish", (net, node) => { net.Swish(node.Name, node.Input0); });
Add("PRelu", (net, node) => { net.PRelu(node.Name, node.Input0, node.Input1); });
Add("LogSoftmax", (net, node) =>
{
const int defaultAxis = 1;
int axis = node.AxisOptional(defaultAxis);
net.LogSoftmax(node.Name, node.Input0, axis, axisIs8D: true); // keep axis as is
});
// TODO: Add("Hardmax", (net, node) => { net.Hardmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); });
Add("Softplus", (net, node) => { net.Softplus(node.Name, node.Input0); });
// TODO: Add("Softsign", (net, node) => { net.Softsign(node.Name, node.Input0); });
Add("HardSigmoid", (net, node) => { net.HardSigmoid(node.Name, node.Input0, node.AlphaOptional(0.2f), node.BetaOptional(0.5f)); });
Add("Exp", (net, node) => { net.Exp(node.Name, node.Input0); });
Add("Log", (net, node) => { net.Log(node.Name, node.Input0); });
Add("Reciprocal", (net, node) => { net.Reciprocal(node.Name, node.Input0); });
Add("Abs", (net, node) => { net.Abs(node.Name, node.Input0); });
Add("Neg", (net, node) => { net.Neg(node.Name, node.Input0); });
Add("Ceil", (net, node) => { net.Ceil(node.Name, node.Input0); });
Add("Floor", (net, node) => { net.Floor(node.Name, node.Input0); });
Add("Round", (net, node) => { net.Round(node.Name, node.Input0); });
Add("Clip", (net, node) => {
float minValue = float.MinValue;
float maxValue = float.MaxValue;
if (node.InputCount > 1) // Clip-11
{
minValue = node.Input1ConstantOptional(minValue, onnxLayout:"C", name:"min")[0];
maxValue = node.Input2ConstantOptional(maxValue, onnxLayout:"C", name:"max")[0];
}
else
{
minValue = node.MinOptional(minValue);
maxValue = node.MaxOptional(maxValue);
}
net.Clip(node.Name, node.Input0, minValue, maxValue);
});
Add("Acos", (net, node) => { net.Acos(node.Name, node.Input0); });
Add("Acosh", (net, node) => { net.Acosh(node.Name, node.Input0); });
Add("Asin", (net, node) => { net.Asin(node.Name, node.Input0); });
Add("Asinh", (net, node) => { net.Asinh(node.Name, node.Input0); });
Add("Atan", (net, node) => { net.Atan(node.Name, node.Input0); });
Add("Atanh", (net, node) => { net.Atanh(node.Name, node.Input0); });
Add("Cos", (net, node) => { net.Cos(node.Name, node.Input0); });
Add("Cosh", (net, node) => { net.Cosh(node.Name, node.Input0); });
Add("Sin", (net, node) => { net.Sin(node.Name, node.Input0); });
Add("Sinh", (net, node) => { net.Sinh(node.Name, node.Input0); });
Add("Tan", (net, node) => { net.Tan(node.Name, node.Input0); });
Add("Erf", (net, node) => { net.Erf(node.Name, node.Input0); });
string[] GetArithmeticOpInputs(ONNXNodeWrapper node, ModelBuilder net)
{
string[] inputs = new string[node.Inputs.Length];
Array.Copy(node.Inputs, inputs, inputs.Length);
if (node.IsInput1Const)
{
string onnxLayout = "ONNX";
string constName = $"Const_{node.Input1}";
if (!constantTensors.ContainsKey(constName))
{
Tensor tensorData = node.Input1Constant(onnxLayout, node.Input1);
Layer layer = net.Const(constName, tensorData, rank: node.Input1Rank);
inputs[1] = layer.name;
Const(constName, new ONNXTensor(tensorData, tensorData.shape.ToArray()));
}
}
return inputs;
}
// Broadcast ops
Add("Add", (net, node) => { net.Add(node.Name, GetArithmeticOpInputs(node, net)); });
Add("Sum", (net, node) => { net.Add(node.Name, GetArithmeticOpInputs(node, net)); }); // Sum is implemented via Add
Add("Sub", (net, node) => { net.Sub(node.Name, GetArithmeticOpInputs(node, net)); });
Add("Mul", (net, node) => { net.Mul(node.Name, GetArithmeticOpInputs(node, net)); });
Add("Div", (net, node) => { net.Div(node.Name, GetArithmeticOpInputs(node, net)); });
Add("Pow", (net, node) => { net.Pow(node.Name, node.Inputs); });
Add("Min", (net, node) => { net.Min(node.Name, node.Inputs); });
Add("Max", (net, node) => { net.Max(node.Name, node.Inputs); });
Add("Mean", (net, node) => { net.Mean(node.Name, node.Inputs); });
// Logical ops
Add("Greater", (net, node) => { net.Greater(node.Name, node.Input0, node.Input1); });
Add("Less", (net, node) => { net.Less(node.Name, node.Input0, node.Input1); });
Add("LessOrEqual", (net, node) => { net.LessEqual(node.Name, node.Input0, node.Input1); });
Add("Equal", (net, node) => { net.Equal(node.Name, node.Input0, node.Input1); });
Add("Or", (net, node) => { net.LogicalOr(node.Name, node.Input0, node.Input1); });
Add("And", (net, node) => { net.LogicalAnd(node.Name, node.Input0, node.Input1); });
Add("Not", (net, node) => { net.LogicalNot(node.Name, node.Input0); });
Add("Sign", (net, node) => { net.Sign(node.Name, node.Input0); });
Add("Xor", (net, node) => { net.LogicalXor(node.Name, node.Input0, node.Input1); });
Add("Where", (net, node) => { net.Where(node.Name, node.Input0, node.Input1, node.Input2); });
// Padding ops
Add("MirrorPad", (net, node) =>
{
//Note: MirrorPad is not in onnx spec, it is a custom op from tensorflow implementing there own padding (aka symmetric).
node.UnsupportedAttribute("mode", "symmetric");
var value = node.GetOptionalFloat("value", 0.0f);
var autoPad = node.AutoPadMode();
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
if (node.InputCount == 1)
{
var pads = node.GetRequiredIntArray("pads");
net.Pad(node.Name, node.Input0, pads, value, Layer.PadMode.Symetric, Layer.AutoPad.NotSet);
}
else
net.Pad(node.Name, node.Input0, node.Input1, node.Input2Optional, Layer.PadMode.Symetric, Layer.AutoPad.NotSet);
});
Add("Pad", (net, node) =>
{
var value = node.GetOptionalFloat("value", 0.0f);
var modeType = node.PadMode();
var autoPadType = node.AutoPadMode();
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
if (node.InputCount == 1)
{
var pads = node.GetRequiredIntArray("pads");
net.Pad(node.Name, node.Input0, pads, value, modeType, autoPadType);
}
else
net.Pad(node.Name, node.Input0, node.Input1, node.Input2Optional, modeType, autoPadType);
});
// Pooling ops
Add("AveragePool", (net, node) => {
node.UnsupportedAttribute("ceil_mode", 0);
node.UnsupportedAttribute("count_include_pad", 0);
net.AvgPool2D(node.Name, node.Input0, node.KernelShape, node.Strides, node.Pads);
});
Add("MaxPool", (net, node) => {
node.UnsupportedAttribute("ceil_mode", 0);
node.UnsupportedAttribute("dilations", new[] {1, 1});
node.UnsupportedAttribute("storage_order", 0);
int[] strides = node.Strides;
int[] pads = node.Pads;
if (strides.Length == 1)
strides = new[] { 1, strides[0] };
UnityEngine.Debug.Assert(strides.Length == 2);
int[] kernelShape = node.KernelShape;
if (kernelShape.Length == 1)
kernelShape = new[] { kernelShape[0], 1 };
net.MaxPool2D(node.Name, node.Input0, kernelShape, strides, pads);
});
Add("GlobalAveragePool", (net, node) =>
{
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
net.GlobalAvgPool2D(node.Name, node.Input0);
});
Add("GlobalMaxPool", (net, node) =>
{
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
net.GlobalMaxPool2D(node.Name, node.Input0);
});
Add("Upsample", (net, node) =>
{
UpsampleNCHW(net, node, 1);
});
Add("Resize", (net, node) => {
var mode = node.ModeOptional("nearest");
var bilinear = IsModeBilinear(net, node, mode);
if (node.InputCount > 2) // Resize-11/13
{
node.UnsupportedAttribute("coordinate_transformation_mode", "half_pixel");
node.UnsupportedAttribute("cubic_coeff_a", -0.75f);
node.UnsupportedAttribute("exclude_outside", 0);
node.UnsupportedAttribute("extrapolation_value", 0f);
node.UnsupportedAttribute("nearest_mode", "round_prefer_floor");
// Inputs (3 - 4)
// X : T1
// roi : T2, It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize"
// scales : tensor(float)
// sizes (optional) : tensor(int64)
// TODO: cropping via roi input
}
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default and size as constants, so this is non-runnable as-is
if (node.InputCount == 4)
{
//Resize-11/13 using target size
net.Resample2D(node.Name, node.Input0, node.Input3, bilinear);
}
else
{
//Resize using scales
UpsampleNCHW(net, node, node.InputCount-1);
}
});
Add("Transpose", (net, node) =>
{
// From https://github.com/onnx/onnx/blob/master/docs/Operators.md#transpose
// By default, reverse the dimensions, otherwise permute the axes according to the values given.
if (node.IsInput0Const)
{
int inputTensorRank = constantTensors[node.Input0].rank;
var defaultPermutations = new int[inputTensorRank];
for (int i = 0; i < inputTensorRank; ++i)
defaultPermutations[i] = inputTensorRank - 1 - i;
var permutations = node.GetOptionalIntArray("perm", defaultPermutations);
var transposedTensor = constantTensors[node.Input0].Permute(permutations);
Const(node, transposedTensor);
}
else
{
var defaultPermutations = new[] { 0, 1, 2, 3, 4, 5 };
var permutations = node.GetOptionalIntArray("perm", defaultPermutations);
if (permutations.Length > 6)
throw new OnnxLayerImportException($"Transpose support up to 6 dimensions but got a permutations of rank {permutations}.");
net.Transpose(node.Name, node.Input0, permutations);
}
});
Add("DepthToSpace", (net, node) => {
net.DepthToSpace(node.Name, node.Input0, node.BlockSize, node.ModeOptional("DCR"));
});
Add("SpaceToDepth", (net, node) => {
net.SpaceToDepth(node.Name, node.Input0, node.BlockSize);
});
// Tensor ops
Add("Gemm", (net, node) => {
node.UnsupportedAttribute("alpha", 1.0f);
node.UnsupportedAttribute("beta", 1.0f);
if (node.IsInput1Const && node.IsInput2Const)
{
var weights = node.Input1Constant(node.TransBOptional() ? "KC" : "CK", name: "B");
var biases = node.Input2ConstantOptional(Bias(weights.shape), 0.0f, "C", name: "C");
var input0 = node.Input0;
int transposeA = node.GetOptionalInt("transA", 0);
if (transposeA == 1)
{
input0 = input0 + "_transpose";
net.Transpose(input0, node.Input0, new[] { 1, 0 });
}
net.Dense(node.Name, input0, weights, biases);
Output(node, features: weights.channels, rank: 2); // Gemm forces flatten of the input to rank 2
}
else
{
int transposeA = node.GetOptionalInt("transA", 0);
int transposeB = node.GetOptionalInt("transB", 0);
var input0 = node.Input0;
var input1 = node.Input1;
if (transposeA == 1)
{
input0 = input0 + "_transpose";
net.Transpose(input0, node.Input0, new[] { 1, 0 });
}
if (transposeB == 1)
{
input1 = input1 + "_transpose";
net.Transpose(input1, node.Input1, new[] { 1, 0 });
}
net.MatMul(node.Name, input0, input1);
if (node.InputCount == 3)
{
net.Add(node.Name + "_bias", new[] { node.Name, node.Input2 });
}
}
});
Add("MatMul", (net, node) => {
net.MatMul(node.Name, node.Input0, node.Input1);
Output(node, features: node.Input0Features, rank: Math.Max(node.Input0Rank, node.Input1Rank));
});
Add("Conv", (net, node) => {
int[] dilationsDHW = new[] { 1, 1, 1 }; // @TODO trap on wrong values
int[] strides = node.Strides;
int[] pads = node.Pads;
node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead");
// Ideally, we'd import kernels/biases in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW
var kernels = node.Input1Constant(onnxLayout: "KCHW", name: "W");
var kernelRank = node.Input1Rank;
if (kernelRank == 3) // Conv1D
{
dilationsDHW = node.DilatationsOptional(new[] { 1 }); // @TODO trap on wrong values
UnityEngine.Debug.Assert(dilationsDHW.Length == 1);
dilationsDHW = new[] { 1, 1, dilationsDHW[0] };
if (strides.Length == 1)
strides = new[] { strides[0], 1 };
if (pads.Length == 2)
pads = new[] { pads[0], 0, pads[1], 0 };
}
else if (kernelRank == 4) // Conv2D
{
dilationsDHW = node.DilatationsOptional(new[] { 1, 1 });
UnityEngine.Debug.Assert(dilationsDHW.Length == 2);
dilationsDHW = new[] { 1, dilationsDHW[0], dilationsDHW[1] };
}
else if (kernelRank == 5) // Conv3D
{
//TODO specific error message for DepthwiseConv3D (or support it).
node.UnsupportedAttribute("group", 1);
dilationsDHW = node.DilatationsOptional(new[] { 1, 1, 1 });
UnityEngine.Debug.Assert(dilationsDHW.Length == 3);
pads = node.Pads3D;
strides = node.Strides3D;
}
else
{
Warn(net, node, $"Unsuported Conv kernel rank. Conv1D/2D/3 assumes rank 3/4/5 respectively, but got {kernelRank}.");
}
UnityEngine.Debug.Assert(dilationsDHW.Length == 3);
if (dilationsDHW[0] != 1 || dilationsDHW[1] != 1 || dilationsDHW[2] != 1)
kernels = DilateKernel(kernels, dilationsDHW); // @TODO inefficient method. Support dilatation in kernel code properly
var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout: "C", name: "B");
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
// TODO assert correctly: with group == 2 or group != in_channel we don't support it
if (node.GroupOptional() > 1)
net.DepthwiseConv2D(node.Name, node.Input0, strides, pads, kernels, biases);
else
{
if (kernelRank < 5)
net.Conv2D(node.Name, node.Input0, strides, pads, kernels, biases);
else
net.Conv3D(node.Name, node.Input0, strides, pads, kernels, biases);
}
Output(node, features: kernels.channels);
});
Add("ConvTranspose", (net, node) => {
node.UnsupportedAttribute("group", 1);
node.UnsupportedAttribute("output_shape", new int[0]);
node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead");
int[] strides = node.Strides;
int[] pads = node.Pads;
int[] outputPadding = node.OutputPadding;
var kernelRank = node.Input1Rank;
if (kernelRank == 3) // ConvTranspose1D
{
node.UnsupportedAttribute("dilations", new[] {1});
if (strides.Length == 1)
strides = new[] { strides[0], 1 };
if (pads.Length == 2)
pads = new[] { pads[0], 0, pads[1], 0 };
if (outputPadding.Length == 1)
outputPadding = new[] { outputPadding[0], 0 };
}
else if (kernelRank == 4)// ConvTranspose2D
{
node.UnsupportedAttribute("dilations", new[] {1, 1});
}
else
{
Warn(net, node, $"Unsuported ConvTranspose kernel rank. ConvTranspose1D/2D assumes rank 3/4 respectively, but got {kernelRank}.");
}
// Ideally, we'd import kernels/biases in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW
var kernels = node.Input1Constant(onnxLayout:"CKHW", name:"W");
var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout:"C", name:"B");
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
net.Conv2DTrans(node.Name, node.Input0, strides, pads, outputPadding, kernels, biases);
Output(node, features:kernels.channels);
});
Add("BatchNormalization", (net, node) => {
// Ideally, we'd import variances/scales/biases/means in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW
var variance = node.Input4Constant(onnxLayout:"C", name:"var");
var scale = node.Input1ConstantOptional(variance.shape, 1.0f, onnxLayout:"C", name:"scale");
var bias = node.Input2ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"B");
var mean = node.Input3ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"mean");
if (variance.length != scale.length || scale.length != bias.length || bias.length != mean.length)
Warn(net, node, $"Number of elements in all parameters for BatchNorm must be the same." +
$"Parameter shapes are: {scale.shape}, {bias.shape}, {mean.shape}, {variance.shape}");
// TODO: Jeremy has one non valid onnx model with #channels > than input_channels, see if we want to support his model?
var fusedData = FuseBatchNormWeights(scale, bias, mean, variance, node.EpsilonOptional(), variance.shape.channels);
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
net.ScaleBias(node.Name, node.Input0, fusedData.Item1, fusedData.Item2);
});
Add("ImageScaler", (net, node) =>
{
var attrBias = node.Bias;
var attrScale = node.ScaleOptional();
int maxElements = attrBias.Length;
Tensor scale = new Tensor(1, maxElements);
Tensor bias = new Tensor(1, maxElements);
for (int i = 0; i < maxElements; ++i)
{
scale[i] = attrScale;
bias[i] = attrBias[i];
}
net.ScaleBias(node.Name, node.Input0, scale, bias);
});
Add("InstanceNormalization", (net, node) => {
// Ideally, we'd import scales/biases in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW
var scale = node.Input1Constant(onnxLayout:"C", name:"scale");
var bias = node.Input2ConstantOptional(scale.shape, 0.0f, onnxLayout:"C", name:"B");
if (scale.length != bias.length)
Warn(net, node, $"Number of elements in all parameters for InstanceNorm must be the same." +
$"Parameter shapes are: {scale.shape}, {bias.shape}");
if (scale.channels != node.Input0Features && node.Input0Features > 0)
{
Warn(net, node, $"Number of elements in InstanceNorm must match features from the previous layer. Was expecting {node.Input0Features}, but got {scale.channels}.");
var scaleArray = scale.ToReadOnlyArray();
Array.Resize(ref scaleArray, node.Input0Features);
var biasArray = bias.ToReadOnlyArray();
Array.Resize(ref biasArray, node.Input0Features);
scale = new Tensor(1, node.Input0Features, scaleArray);
bias = new Tensor(1, node.Input0Features, biasArray);
}
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
net.Normalization(node.Name, node.Input0, scale, bias, node.EpsilonOptional());
});
Add("LRN", (net, node) => {
float bias = node.GetOptionalFloat("bias", 1.0f);
int size = node.GetRequiredInt("size");
net.LRN(node.Name, node.Input0, node.AlphaOptional(0.0001f), node.BetaOptional(0.75f), bias, size);
});
// random ops
Add("RandomNormal", (net, node) => {
var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"ONNX");
net.RandomNormal(node.Name, shape, node.MeanOptional(), node.ScaleOptional(), node.Seed);
Output(node, rank:node.Shape.Length);
});
Add("RandomNormalLike", (net, node) => {
net.RandomNormal(node.Name, node.Input0, node.MeanOptional(), node.ScaleOptional(), node.Seed);
});
Add("RandomUniform", (net, node) => {
float high = node.GetOptionalFloat("high", 1.0f);
float low = node.GetOptionalFloat("low", 0.0f);
var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"ONNX");
net.RandomUniform(node.Name, shape, low, high, node.Seed);
Output(node, rank:node.Shape.Length);
});
Add("RandomUniformLike", (net, node) => {
float high = node.GetOptionalFloat("high", 1.0f);
float low = node.GetOptionalFloat("low", 0.0f);
net.RandomUniform(node.Name, node.Input0, low, high, node.Seed);
});
Add("Multinomial", (net, node) => {
int samples = node.GetOptionalInt("sample_size", 1);
net.Multinomial(node.Name, node.Input0, samples, node.Seed);
});
Add("Range", (net, node) =>
{
if (node.IsInput0Const && node.IsInput1Const && node.IsInput2Const)
{
var startTensor = node.GetRequiredInputAsConstant(node.Input0, "N", "start");
var limitTensor = node.GetRequiredInputAsConstant(node.Input1, "N", "start");
var deltaTensor = node.GetRequiredInputAsConstant(node.Input2, "N", "start");
Assert.AreEqual(startTensor.length, 1);
Assert.AreEqual(limitTensor.length, 1);
Assert.AreEqual(deltaTensor.length, 1);
float start = startTensor[0];
float limit = limitTensor[0];
float delta = deltaTensor[0];
var range = ONNXTensor.Range(start, limit, delta);
Const(node, range);
}
else
{
net.Range(node.Name, node.Input0, node.Input1, node.Input2);
}
});
// Reduce ops
Add("ReduceMax", (net, node) => {
ReduceNCHW(net, node, Layer.Type.ReduceMax);
});
Add("ReduceMean", (net, node) => {
ReduceNCHW(net, node, Layer.Type.ReduceMean);
});
Add("ReduceMin", (net, node) => {
ReduceNCHW(net, node, Layer.Type.ReduceMin);
});
Add("ReduceProd", (net, node) => {
ReduceNCHW(net, node, Layer.Type.ReduceProd);
});
Add("ReduceSum", (net, node) => {
ReduceNCHW(net, node, Layer.Type.ReduceSum);
});
Add("ArgMax", (net, node) => {
node.UnsupportedAttribute("select_last_index");
ReduceNCHW(net, node, Layer.Type.ArgMax);
});
Add("ArgMin", (net, node) => {
node.UnsupportedAttribute("select_last_index");
ReduceNCHW(net, node, Layer.Type.ArgMin);
});
// Ignore, noop during inference
Add("Identity", (net, node) => { net.Identity(node.Name, node.Input0); });
Add("Cast", (net, node) => { net.Identity(node.Name, node.Input0); });
Add("Dropout", (net, node) => { net.Identity(node.Name, node.Input0); });
}
void UseLegacyImporter()
{
m_NodeImporters.Clear();
var defaultZeroTensor = new ONNXTensor(new Tensor(1, 1, new[] { 0f }), new[] { 1 });
var defaultOneTensor = new ONNXTensor(new Tensor(1, 1, new[] { 1f }), new[] { 1 });
var toNCHW = new [] { 0, 3, 1, 2 };
var toNHWC = new [] { 0, 2, 3, 1 };
var fromN1WCtoNCH = new [] { 0, 3, 2, 1 };
var fromNCHtoN1WC = new [] { 0, 3, 2, 1 };
// TODO: setup m_NodeImporters via initializer list
// TODO: simplify code to avoid passing node.Name over and over again
Add("Constant", (net, node) => {
node.UnsupportedAttribute("sparse_value");
Const(node, node.ValueAsTensor);
});
Add("ConstantOfShape", (net, node) => {
Assert.IsTrue(node.InputCount > 0);
var valueTensor = node.GetOptionalTensor("value", defaultZeroTensor);
var onnxShape = node.Input0ConstantONNXShape(name: "input");
var dataShape = ONNXLayout.ConvertShapeToBarracuda(onnxShape, onnxLayout:"?");
var tensorData = new Tensor(dataShape);
tensorData.Fill(valueTensor[0]);
var constantOfShape = new ONNXTensor(tensorData, onnxShape);
Const(node, constantOfShape);
});
Add("Reshape", (net, node) => {
int[] onnxShape;
if (node.InputCount > 1) // Reshape-5
{
if (node.IsInput1Const)
{
onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsInts();
}
else
{
int input0Rank = node.Input0Rank;
if (input0Rank <= 4 && variableTensors.TryGetValue(node.Input0, out VariableTensor previousOutput)
&& previousOutput.layout != VariableTensor.Layout.ChannelsLast)
{
int outputRank = 4;
Model.Input input1 = net.model.inputs.Where(i => i.name == node.Input1).FirstOrDefault();
if (!input1.Equals(default))
{
if (input1.rank == 1) // shape is in the tensor
outputRank = input1.shape[TensorShape.DataBatch];
}
// For handling all reshapes correctly with dynamic shapes (e.g. rank 3) perform in NCHW layout
Layer nchwTranspose = net.Transpose($"Transpose_{node.Input0}_For_{node.Name}", node.Input0, input0Rank == 3 ? fromN1WCtoNCH : toNCHW);
Layer reshape = net.Reshape($"{node.Name}_NCHW", nchwTranspose, node.Input1);
net.Transpose(node.Name, reshape, outputRank == 3 ? fromNCHtoN1WC : toNHWC);
Output(node, rank:4);
}
else
{
net.Reshape(node.Name, node.Input0, node.Input1);
}
return;
}
}
else // Reshape-1
onnxShape = node.Shape;
if (node.IsInput0Const)
{
// reshape constant source tensor and store it as the new constant
var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape);
Const(node, reshapedTensor);
}
else
{
Layer reshapeLayer = null;
int numDimensionContainingChannelsInformationAfterReshape = 1;
var symbolicShape = ONNXLayout.ConvertReshapeToBarracuda(onnxShape, node.Input0Rank, out numDimensionContainingChannelsInformationAfterReshape);
int variableDimension = Array.IndexOf(symbolicShape, -1);
bool containsNoVariableDimensions = variableDimension == -1;
// special case handling with inferable reshapes
// TODO: remove this once we have full shape inference
// onnx: NCW -> N1CW
// N: is unknown and H,W are inferable
if (node.Input0Rank == 3 && onnxShape.Length == 4 &&
onnxShape[0] == 0 && onnxShape[1] == 1 && onnxShape[2] == 0 && onnxShape[3] == 0)
{
// onnx: NCW -> N1CW
// barracuda: N_WC -> NCW1
net.Transpose(node.Name, node.Input0, new[] { 0, 3, 2, 1 });
Output(node, features: 1, rank: onnxShape.Length);
return;
}
if (containsNoVariableDimensions)
{
if (m_ForceArbitraryBatchSize)
symbolicShape[0] = -1; // force arbitrary batch size
// Creating any of the spatial dimensions requires to run reshape in NCHW and transpose to NHWC after it to match NCHW behavior.
if (onnxShape.Length > 2 && node.Input0Rank <= 2)
{
int[] onnxPaddedShape = onnxShape;
if (onnxShape.Length == 3) // correct NCH to NCW
onnxPaddedShape = new[] {onnxShape[0], onnxShape[1], 1, onnxShape[2]};
reshapeLayer = net.Reshape($"{node.Name}_NCHW", node.Input0, onnxPaddedShape);
reshapeLayer = net.Transpose(node.Name, reshapeLayer, toNHWC);
}
}
else if (onnxShape.Length <= 4 && node.Input0Rank <= 4
&& (onnxShape.Length == 2 || variableDimension != TensorShape.C)
&& variableTensors.TryGetValue(node.Input0, out VariableTensor previousOutput)
&& previousOutput.layout != VariableTensor.Layout.ChannelsLast)
{
// Collapsing any of the spatial dimensions requires a reshape in NCHW layout
int[] onnxPaddedShape;
if (onnxShape.Length == 3) // correct NCH to NCW
onnxPaddedShape = new[] { onnxShape[0], onnxShape[1], 1, onnxShape[2] };
else
onnxPaddedShape = onnxShape.Concat(Enumerable.Repeat(1, 4 - onnxShape.Length)).ToArray();
Layer nchwTranspose = net.Transpose($"Transpose_{node.Input0}_For_{node.Name}", node.Input0, toNCHW);
reshapeLayer = net.Reshape($"{node.Name}_NCHW", nchwTranspose, onnxPaddedShape);
reshapeLayer = net.Transpose(node.Name, reshapeLayer, toNHWC);
}
if (reshapeLayer == null)
reshapeLayer = net.Reshape(node.Name, node.Input0, symbolicShape);
reshapeLayer.axis = numDimensionContainingChannelsInformationAfterReshape;
var features = onnxShape.Length > 1 ? onnxShape[1] : -1;
Output(node, features: features, rank:onnxShape.Length);
}
});
Add("Expand", (net, node) => {
var onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsInts();
var symbolicShape = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, "NCHW");
bool containsNoVariableDimensions = Array.IndexOf(symbolicShape, -1) == -1;
if (containsNoVariableDimensions && m_ForceArbitraryBatchSize)
symbolicShape[0] = -1; // force arbitrary batch size
net.Expand(node.Name, node.Input0, symbolicShape);
Output(node, rank:symbolicShape.Length);
});
Add("Shape", (net, node) =>
{
float[] shapeValuesAsFloats;
if (node.IsInput0Const)
{
shapeValuesAsFloats = constantTensors[node.Input0].shape.Select(x => (float)x).ToArray();
}
else
{
switch (node.Input0Rank)
{
default:
case 4: // NCHW
case 3: // NCW
case 2: // NC
// @TODO: dynamic implementation that would return real shape during execution of the model
//
// meanwhile at import time we assume 0 (taken from input tensor) for the spatial dimensions
// NOTE: this assumption works for common Upsample opset=9 case:
// Upsample.scales = (shape.hw * constant) / shape.hw
// however this would not work for potential (opset=10) cases like:
// Resize.size = shape.hw + constant
// stored in ONNX layout
var shapeWithChannelsFirst = new[] { 0f, node.Input0Features }; // NC
var fillSpatialDimensionsWithUnknown = 0f;
var numberOfSpatialDimensions = node.Input0Rank - 2;
var shapeFollowedWithSpatialDimensions = Enumerable.Repeat(fillSpatialDimensionsWithUnknown, numberOfSpatialDimensions);
shapeValuesAsFloats = shapeWithChannelsFirst.Concat(shapeFollowedWithSpatialDimensions).ToArray();
break;
case 1: // C
shapeValuesAsFloats = new[] {(float)node.Input0Features};
break;
case 0: // scalar
shapeValuesAsFloats = new[] {0f};
break;
}
}
var shapeLength = shapeValuesAsFloats.Length;
Assert.IsTrue(shapeLength == node.Input0Rank);
var shape = new int[8];
shape[0] = shapeLength;
var shapeTensor = new ONNXTensor(
// NOTE: stored in single rank ONNX layout
// with data in the 1st dimension
// thus `shapeLength` specifies the length of the 1st dimension
data:new Tensor(shape, shapeValuesAsFloats),
onnxShape:new [] { shapeLength });
Const(node, shapeTensor);
Output(node, features:shapeLength, productOfShape:node.Input0);
});
Add("Unsqueeze", (net, node) => {
if (node.IsInput0Const)
{
var unsqueezed = constantTensors[node.Input0].Unsqueeze(node.Axes);
Const(node, unsqueezed);
}
else
{
// NOTE: axis=0 or 1 will require Transpose between channels and other spatial dimensions when converted to Barracuda layout.
// As we have different layouts between ONNX and Barracuda, Unsqueeze might require actual Transpose not just Reshape!
var features = node.Input0Features;
var inputRank = node.Input0Rank;
var outputRank = inputRank + 1;
Output(node.Name, features: features, rank: outputRank);
// ONNX pseudocode here:
// a = Tensor [2, 10] # NC -> barracuda N11C
// b = Unsqueeze(a, axis=0)
// # b is now Tensor [1, 2, 10] # NCHW -> barrada NHWC
// Because ONNX is NCHW, but generally hell knows what goes where and Barracuda is strict NHWC. We end up with:
// `a` would be [2, 1, 1, 10], but `b` would have to be [1, 10, 1, 2]. Note the actual data swap in channels!
int axis = node.Axes[0];
if (axis < 0)
axis = node.Input0Rank+1 - axis;
var transpose = ONNXLayout.UnSqueezeAxisPermutationForMappingONNXLayoutToBarracuda(inputRank, axis, "NCHW");
net.Transpose(node.Name, node.Input0, transpose);
}
});
Add("Squeeze", (net, node) => {
if (node.IsInput0Const)
{
var squeezed = constantTensors[node.Input0].Squeeze(node.Axes);
Const(node, squeezed);
}
else
{
var features = node.Input0Features;
var inputRank = node.Input0Rank;
var outputRank = inputRank - 1;
Output(node.Name, features: features, rank: outputRank);
// See Unsqueeze above for explanation
int axis = node.Axes[0];
if (axis < 0)
axis = node.Input0Rank + 1 - axis;
var transpose = ONNXLayout.SqueezeAxisPermutationForMappingONNXLayoutToBarracuda(inputRank, axis, "NCHW");
net.Transpose(node.Name, node.Input0, transpose);
}
});
Add("Flatten", (net, node) => {
node.UnsupportedAttribute("axis", 1);
if (variableTensors.TryGetValue(node.Input0, out var inputTensor) && inputTensor.layout == VariableTensor.Layout.ChannelsLast)
net.Flatten(node.Name, node.Input0);
else
{
Layer nchwTranspose = net.Transpose($"Transpose_{node.Input0}_For_{node.Name}", node.Input0, node.Input0Rank == 3 ? fromN1WCtoNCH : toNCHW);
net.Flatten(node.Name, nchwTranspose);
// No need to transpose back b/c final shape is always NC (rank 2)
}
Output(node, rank:2);
});
Add("Concat", (net, node) => {
int axis = node.AxisOptional(0);
if (node.Inputs.Length == 1)
net.Identity(node.Name, node.Input0);
else
{
// TODO: write dedicated ONNXTensor.Concat() so that output shape is exact to ONNX
// if (node.AreAllInputsConst)
// Const(node, ONNXTensor.Concat(node.Inputs.Select(i => constantTensors[i]).ToArray(), axis));
axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank: node.Input0Rank, onnxLayout: "NCHW");
net.Concat(node.Name, node.Inputs, axis, true);
bool lastAxis = (axis == -1 || axis == TensorShape.C || axis == node.Input0Rank - 1); // last axis in Barracuda is feature axis
if (lastAxis)
{
var featuresConcatenated = node.Inputs.Sum(i => variableTensors[i].features);
Output(node, features: featuresConcatenated);
}
}
});
Add("Split", (net, node) => {
int axis = node.AxisOptional(0);
int[] splits;
try {
splits = node.GetRequiredIntArray("split");
} catch (OnnxLayerImportException) {
throw new OnnxLayerImportException($"Unsupported default attribute `split` for node {node.Name} of type Split. Value is required.");
}
Assert.IsTrue(splits.Length == node.Outputs.Length);
axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank:node.Input0Rank, onnxLayout:"NCHW");
int currentSliceStartIndex = 0;
//Convert `Split` into multiple `StridedSlice` operations.
for (int i = 0; i < splits.Length; ++i)
{
var starts = new int[] {0, 0, 0, 0, 0, 0, 0, 0};
var ends = new int[] {0, 0, 0, 0, 0, 0, 0, 0};
var strides = new int[] {1, 1, 1, 1, 1, 1, 1, 1};
starts[axis] = currentSliceStartIndex;
ends[axis] = starts[axis] + splits[i];
net.StridedSlice(node.Outputs[i], node.Input0,starts,ends,strides);
currentSliceStartIndex += splits[i];
}
});
Add("Slice", (net, node) => {
int[] starts, ends, axes, steps;
if (node.InputCount > 1) // Slice-10
{
var constStarts = node.Input1Constant(onnxLayout:"C", name:"starts");
var constEnds = node.Input2Constant(onnxLayout:"C", name:"ends");
var defaultAxes = new Tensor(constStarts.shape, Enumerable.Range(0, constStarts.length).Select(v => (float)v).ToArray());
var constAxes = node.Input3ConstantOptional(defaultAxes, onnxLayout:"C", name:"axes");
var constSteps = node.Input4ConstantOptional(constStarts.shape, 1.0f, onnxLayout:"C", name:"steps");
starts = constStarts.AsInts();
ends = constEnds.AsInts();
axes = constAxes.AsInts();
steps = constSteps.AsInts();
}
else // Slice-1
{
starts = node.Starts;
ends = node.Ends;
axes = node.AxesOptional(Enumerable.Range(0, starts.Length).ToArray());
steps = Enumerable.Repeat(1, starts.Length).ToArray();
}
Assert.IsTrue(starts.Length == ends.Length);
var onnxRank = node.Input0Rank;
var onnxStarts = Enumerable.Repeat(0, onnxRank).ToArray();
var onnxEnds = Enumerable.Repeat(int.MaxValue, onnxRank).ToArray(); // by default copy the whole axis till the end
var onnxSteps = Enumerable.Repeat(1, onnxRank).ToArray();
// NOTE: begin=0, end=0, stride=1 <= full range from existing axis
// begin=0, end=inf,stride=1 <= full range from existing axis
// begin=0, end=X, stride=1 <= full range from existing axis, if X==last element on this axis
// begin=0, end=0, stride=0 <= new axis OR shrink axis to single 1st element
// begin=N, end=N, stride=0 <= shrink axis to single Nth element
// These notes are copied from TensorExtensions.ApplyStridedSlice(...)
for (int i = 0; i < axes.Length; ++i)
{
var axis = axes[i];
if (axis < 0)
axis += onnxRank;
axis = Math.Min(Math.Max(axis, 0), onnxRank);
onnxStarts[axis] = starts[i];
onnxEnds[axis] = ends[i];
onnxSteps[axis] = steps[i];
}
if (node.IsInput0Const)
{
var slicedTensor = constantTensors[node.Input0].Slice(starts:onnxStarts, ends:onnxEnds, steps:onnxSteps);
Const(node, slicedTensor);
}
else
{
net.StridedSlice(node.Name, node.Input0,
starts:ONNXLayout.PermuteToBarracuda(onnxStarts, onnxLayout:"NCHW",0),
ends:ONNXLayout.PermuteToBarracuda(onnxEnds, onnxLayout:"NCHW",int.MaxValue),
strides:ONNXLayout.PermuteToBarracuda(onnxSteps, onnxLayout:"NCHW",1));
}
});
Add("Tile", (net, node) =>
{
// TODO: Implement Tile in ONNXTensor for const
var onnxRepeats = node.Input1Constant(onnxLayout: "C", name: "repeats").AsInts();
var repeats = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxRepeats, onnxLayout: "NCHW");
var features = node.Input0Features;
features *= repeats[1];
Output(node.Name, rank: node.Input0Rank, features: features);
// only 4D Tile support for now
net.Tile(node.Name, node.Input0, new[] { repeats[2], repeats[5], repeats[6], repeats[7] });
});
Add("Gather", (net, node) =>
{
int axis = node.AxisOptional(0);
if (node.IsInput0Const && node.IsInput1Const)
{
var indices = node.Input1Constant(onnxLayout:"C", name:"indices").AsInts();
// If the previous node was a shape and we're gathering an inferred value, then don't treat the shape as a constant
if (node.Input0.IndexOf("shape", StringComparison.OrdinalIgnoreCase) >= 0
&& indices.Length == 1 && indices[0] > 0
&& constantTensors[node.Input0].ToBarracuda("C")[indices[0]] == 0 // Must resolve at runtime
&& variableTensors.TryGetValue(node.Input0, out VariableTensor input0Tensor)
&& variableTensors.TryGetValue(input0Tensor.productOfShape, out VariableTensor shapeInputTensor))
{
axis = ONNXLayout.ConvertAxisToBarracuda(indices[0], shapeInputTensor.rank, "NCHW");
net.Shape(node.Name, input0Tensor.productOfShape, axis);
D.Log($"Re-writing {node.Name} to a Shape+Axis layer (results in a scalar)");
}
else
{
ONNXTensor gatheredTensor = constantTensors[node.Input0].Gather(axis, indices);
Const(node, gatheredTensor);
}
}
else
{
int input1Rank = node.Input1Rank;
if (node.IsInput1Const)
{
// The original rank was cached above since our constant tensor requires a shape of rank 1 and original may have been a scalar
var indices = node.Input1Constant(onnxLayout: "C", name: "indices").AsFloats();
var constTensor = new ONNXTensor(new Tensor(new [] { indices.Length, 1, 1, 1, 1, 1, 1, 1 }, indices), new [] { indices.Length });
Const(node.Input1, constTensor);
}
axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank:node.Input0Rank, onnxLayout:"NCHW");
net.Gather(node.Name, node.Input0, node.Input1, axis, true);
Output(node.Name, rank: input1Rank + node.Input0Rank - 1);
}
});
Add("NonMaxSuppression", (net, node) =>
{
int centerPointBox = node.GetOptionalInt("center_point_box", 0);
var boxes = node.GetRequiredInput(0);
var scores = node.GetRequiredInput(1);
object maxOutputBoxesPerClass = 0f;
object iouThreshold = 0f;
object scoreThreshold = 0f;
if (node.InputCount > 4 && node.IsInput2Const && node.IsInput3Const && node.IsInput4Const
|| node.InputCount > 3 && node.IsInput2Const && node.IsInput3Const
|| node.InputCount > 2 && node.IsInput2Const)
{
// Use constant version (possibly with defaults)
maxOutputBoxesPerClass = node.Input2ConstantOptional((float)maxOutputBoxesPerClass, "C", nameof(maxOutputBoxesPerClass))[0];
iouThreshold = node.Input3ConstantOptional((float)iouThreshold, "C", nameof(iouThreshold))[0];
scoreThreshold = node.Input4ConstantOptional((float)scoreThreshold, "C", nameof(scoreThreshold))[0];
}
else
{
// Use dynamic tensor version
maxOutputBoxesPerClass = node.Input2Optional;
iouThreshold = node.Input3Optional;
scoreThreshold = node.Input4Optional;
}
net.NonMaxSuppression(node.Name, boxes, scores, maxOutputBoxesPerClass, iouThreshold, scoreThreshold, centerPointBox);
Output(node, rank: 2);
});
Add("OneHot", (net, node) => {
node.UnsupportedAttribute("axis", -1);
var defaultOffOn = new Tensor(2, 0, new float[] {0, 1});
var depth = (int)node.Input1Constant(onnxLayout:"C", name:"depth")[0];
var offon = node.Input2ConstantOptional(defaultOffOn, onnxLayout:"C", name:"values");
net.OneHot(node.Name, node.Input0, depth, (int)offon[1], (int)offon[0]);
Output(node, features: depth, rank: node.Input0Rank + 1);
});
Add("TopK", (net, node) => {
int axis = node.AxisOptional(-1);
axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank:node.Input0Rank, onnxLayout:"NCHW");
// TopK-11 introduced these options
bool largest = node.GetOptionalInt("largest", 1) == 1;
// If sorted = false, then the output is undefined
bool sorted = node.GetOptionalInt("sorted", 1) == 1;
string kName;
if (node.InputCount > 1) // TopK-10 introduced K as an input tensor
{
kName = node.Input1;
}
else
{
// TopK-1
int k = node.GetRequiredInt("k");
kName = "Const_TopK";
var kTensor = new ONNXTensor(
data:new Tensor(new[] { 1, 1, 1, 1 }, new[] { (float)k }, kName),
onnxShape:new [] { 1 });
Const(node, kTensor);
}
Layer indices = net.TopKIndices(node.Outputs[1], node.Input0, kName, axis, largest, sorted);
Output(node.Outputs[1], rank: node.Input0Rank);
net.TopKValues(node.Outputs[0], node.Input0, indices, axis);
Output(node.Outputs[0], rank: node.Input0Rank);
});
Add("NonZero", (net, node) => {
if (node.IsInput0Const)
{
var nonZeroTensor = constantTensors[node.Input0].NonZero();
Const(node, nonZeroTensor);
}
else
{
net.NonZero(node.Name, node.Input0);
Output(node.Outputs[0], rank: 2);
}
});
// LSTM
// - it = f(Xt*Wi + Ht_1*Ri + Wbi + Rbi)
// - ft = f(Xt*Wf + Ht_1*Rf + Wbf + Rbf)
// - ct = g(Xt*Wc + Ht_1*Rc + Wbc + Rbc), c means j in our formula
// - Ct = ft . Ct_ + it . ct
// - ot = f(Xt*Wo + Ht_1*Ro + Wbo + Rbo)
// - Ht = ot . h(Ct)
Add("LSTM", (net, node) =>
{
var W = node.Input1Constant(onnxLayout: "RKC", name: "W");
var R = node.Input2Constant(onnxLayout: "RKC", name: "R");
var B = node.Input3Constant(onnxLayout: "RC", name: "B");
// gate order [iofj]
var ops = new ReferenceCPUOps();
var w_i = ops.StridedSlice(W, new[] {0,0,0,0}, new[] {W.batch,1,1,W.channels/4 }, new[] {1, 1, 1, 1});
var w_o = ops.StridedSlice(W, new[] {0,0,0,W.channels/4}, new[] {W.batch,1,1,2*W.channels/4 }, new[] {1, 1, 1, 1});
var w_f = ops.StridedSlice(W, new[] {0,0,0,2*W.channels/4}, new[] {W.batch,1,1,3*W.channels/4 }, new[] {1, 1, 1, 1});
var w_j = ops.StridedSlice(W, new[] {0,0,0,3*W.channels/4}, new[] {W.batch,1,1,4*W.channels/4 }, new[] {1, 1, 1, 1});
var r_i = ops.StridedSlice(R, new[] {0,0,0,0}, new[] {R.batch,1,1,R.channels/4 }, new[] {1, 1, 1, 1});
var r_o = ops.StridedSlice(R, new[] {0,0,0,R.channels/4}, new[] {R.batch,1,1,2*R.channels/4 }, new[] {1, 1, 1, 1});
var r_f = ops.StridedSlice(R, new[] {0,0,0,2*R.channels/4}, new[] {R.batch,1,1,3*R.channels/4 }, new[] {1, 1, 1, 1});
var r_j = ops.StridedSlice(R, new[] {0,0,0,3*R.channels/4}, new[] {R.batch,1,1,4*R.channels/4 }, new[] {1, 1, 1, 1});
var wb_i = ops.StridedSlice(B, new[] {0,0,0,0}, new[] {1,1,1,B.channels/8 }, new[] {1, 1, 1, 1});
var wb_o = ops.StridedSlice(B, new[] {0,0,0,B.channels/8}, new[] {1,1,1,2*B.channels/8 }, new[] {1, 1, 1, 1});
var wb_f = ops.StridedSlice(B, new[] {0,0,0,2*B.channels/8}, new[] {1,1,1,3*B.channels/8 }, new[] {1, 1, 1, 1});
var wb_j = ops.StridedSlice(B, new[] {0,0,0,3*B.channels/8}, new[] {1,1,1,4*B.channels/8 }, new[] {1, 1, 1, 1});
var rb_i = ops.StridedSlice(B, new[] {0,0,0,4*B.channels/8}, new[] {1,1,1,5*B.channels/8 }, new[] {1, 1, 1, 1});
var rb_o = ops.StridedSlice(B, new[] {0,0,0,5*B.channels/8}, new[] {1,1,1,6*B.channels/8 }, new[] {1, 1, 1, 1});
var rb_f = ops.StridedSlice(B, new[] {0,0,0,6*B.channels/8}, new[] {1,1,1,7*B.channels/8 }, new[] {1, 1, 1, 1});
var rb_j = ops.StridedSlice(B, new[] {0,0,0,7*B.channels/8}, new[] {1,1,1,8*B.channels/8 }, new[] {1, 1, 1, 1});
var memSize = r_i.flatHeight;
var baseLSTMName = ResolveLstmInputName(node);
var initial_h = $"{baseLSTMName}_h";
var initial_c = $"{baseLSTMName}_c";
var baseLSTMOutputName = ResolveLstmOutputName(node);
var output_h = $"{baseLSTMOutputName}_h";
var output_c = $"{baseLSTMOutputName}_c";
var i_mad_w = net.Dense($"{node.Name}_bc_i_mad_w", node.Input0, w_i, wb_i);
var i_mad_r = net.Dense($"{node.Name}_bc_i_mad_r", initial_h, r_i, rb_i);
var i_mad = net.Add($"{node.Name}_bc_i_mad", new [] {i_mad_w, i_mad_r});
var j_mad_w = net.Dense($"{node.Name}_bc_j_mad_w", node.Input0, w_j, wb_j);
var j_mad_r = net.Dense($"{node.Name}_bc_j_mad_r", initial_h, r_j, rb_j);
var j_mad = net.Add($"{node.Name}_bc_j_mad", new [] {j_mad_w, j_mad_r});
var f_mad_w = net.Dense($"{node.Name}_bc_f_mad_w", node.Input0, w_f, wb_f);
var f_mad_r = net.Dense($"{node.Name}_bc_f_mad_r", initial_h, r_f, rb_f);
var f_mad = net.Add($"{node.Name}_bc_f_mad", new [] {f_mad_w, f_mad_r});
var o_mad_w = net.Dense($"{node.Name}_bc_o_mad_w", node.Input0, w_o, wb_o);
var o_mad_r = net.Dense($"{node.Name}_bc_o_mad_r", initial_h, r_o, rb_o);
var o_mad = net.Add($"{node.Name}_bc_o_mad", new [] {o_mad_w, o_mad_r});
var i = net.Sigmoid($"{node.Name}_bc_i_sigmoid", i_mad);
var j = net.Tanh($"{node.Name}_bc_j_tanh", j_mad);
var f = net.Sigmoid($"{node.Name}_bc_f_sigmoid", f_mad);
var o = net.Sigmoid($"{node.Name}_bc_o_sigmoid", o_mad);
var state_c_mul = net.Mul($"{node.Name}_bc_state_c_mul", new[] {initial_c, f.name});
var i_j_mul = net.Mul($"{node.Name}_bc_i_j_mul", new[] {i, j});
var state_c = net.Add(output_c, new[] {state_c_mul, i_j_mul});
var state_c_tanh = net.Tanh($"{node.Name}_bc_state_c_tanh", state_c);
var state_h = net.Mul(output_h, new[] {o, state_c_tanh});
net.Identity(node.Outputs[0], state_h);
net.Identity(node.Outputs[1], state_h);
net.Identity(node.Outputs[2], state_c);
net.Memory(initial_c, state_c, new TensorShape(-1,1,1,memSize));
net.Memory(initial_h, state_h, new TensorShape(-1,1,1,memSize));
Output(node.Outputs[0], features:wb_o.channels, rank:2);
Output(node.Outputs[1], features:wb_o.channels, rank:2);
Output(node.Outputs[2], features:wb_o.channels, rank:2);
});
// Activation ops
Add("Relu", (net, node) => { net.Relu(node.Name, node.Input0); });
Add("Softmax", (net, node) =>
{
const int defaultAxis = 1;
int axis = node.AxisOptional(defaultAxis); // Leave in NCHW form and transpose instead
if (axis < 0)
axis = node.Input0Rank + axis;
string input = node.Input0;
string output = node.Name;
int rank = node.Input0Rank;
if(rank == 2)
{
axis = axis == 0 ? 0 : 3; // NC => N__C
}
else if (rank == 3)
{
axis = axis == 0 ? 0 : (axis == 1 ? 3 : axis); // NCW => N_WC
}
else
{
axis = axis == 0 ? 0 : (axis == 1 ? 3 : axis-1); // NCHW => NHWC
}
Layer layer = net.Softmax(output, input, axis);
});
Add("Tanh", (net, node) => { net.Tanh(node.Name, node.Input0); });
Add("Sqrt", (net, node) => { net.Sqrt(node.Name, node.Input0); });
Add("Sigmoid", (net, node) => { net.Sigmoid(node.Name, node.Input0); });
Add("Elu", (net, node) => { net.Elu(node.Name, node.Input0, node.AlphaOptional(1f)); });
Add("LeakyRelu",(net, node) => { net.LeakyRelu(node.Name, node.Input0, node.AlphaOptional(0.01f)); });
Add("Selu", (net, node) => { net.Selu(node.Name, node.Input0, node.AlphaOptional(1.67326f), node.GammaOptional(1.0507f)); });
Add("Swish", (net, node) => { net.Swish(node.Name, node.Input0); });
Add("PRelu", (net, node) => { net.PRelu(node.Name, node.Input0, node.Input1); });
Add("LogSoftmax", (net, node) => { net.LogSoftmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); });
// TODO: Add("Hardmax", (net, node) => { net.Hardmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); });
Add("Softplus", (net, node) => { net.Softplus(node.Name, node.Input0); });
// TODO: Add("Softsign", (net, node) => { net.Softsign(node.Name, node.Input0); });
// TODO: Add("HardSigmoid", (net, node) => { net.HardSigmoid(node.Name, node.Input0, node.AlphaOptional(0.2f), node.BetaOptional(0.5f)); });
Add("Exp", (net, node) => { net.Exp(node.Name, node.Input0); });
Add("Log", (net, node) => { net.Log(node.Name, node.Input0); });
Add("Reciprocal", (net, node) => { net.Reciprocal(node.Name, node.Input0); });
Add("Abs", (net, node) => { net.Abs(node.Name, node.Input0); });
Add("Neg", (net, node) => { net.Neg(node.Name, node.Input0); });
Add("Ceil", (net, node) => { net.Ceil(node.Name, node.Input0); });
Add("Floor", (net, node) => { net.Floor(node.Name, node.Input0); });
Add("Round", (net, node) => { net.Round(node.Name, node.Input0); });
Add("Clip", (net, node) => {
float minValue = float.MinValue;
float maxValue = float.MaxValue;
if (node.InputCount > 1) // Clip-11
{
minValue = node.Input1ConstantOptional(minValue, onnxLayout:"C", name:"min")[0];
maxValue = node.Input2ConstantOptional(maxValue, onnxLayout:"C", name:"max")[0];
}
else
{
minValue = node.MinOptional(minValue);
maxValue = node.MaxOptional(maxValue);
}
net.Clip(node.Name, node.Input0, minValue, maxValue);
});
Add("Acos", (net, node) => { net.Acos(node.Name, node.Input0); });
Add("Acosh", (net, node) => { net.Acosh(node.Name, node.Input0); });
Add("Asin", (net, node) => { net.Asin(node.Name, node.Input0); });
Add("Asinh", (net, node) => { net.Asinh(node.Name, node.Input0); });
Add("Atan", (net, node) => { net.Atan(node.Name, node.Input0); });
Add("Atanh", (net, node) => { net.Atanh(node.Name, node.Input0); });
Add("Cos", (net, node) => { net.Cos(node.Name, node.Input0); });
Add("Cosh", (net, node) => { net.Cosh(node.Name, node.Input0); });
Add("Sin", (net, node) => { net.Sin(node.Name, node.Input0); });
Add("Sinh", (net, node) => { net.Sinh(node.Name, node.Input0); });
Add("Tan", (net, node) => { net.Tan(node.Name, node.Input0); });
string[] GetCorrectedConstants(ONNXNodeWrapper node, ModelBuilder net)
{
string[] inputs = new string[node.Inputs.Length];
Array.Copy(node.Inputs, inputs, inputs.Length);
if (node.IsInput1Const)
{
string onnxLayout;
switch (node.Input1Rank)
{
case 1:
onnxLayout = "C";
break;
default:
onnxLayout = "NCHW";
break;
}
string constName = $"Const_{node.Input1}";
if (!constantTensors.ContainsKey(constName))
{
Tensor tensorData = node.Input1Constant(onnxLayout, node.Input1);
if(node.Input0Rank == 3 && node.Input1Rank == 1)
{
// 1,1,1,C -> 1,1,C,1
tensorData = tensorData.Reshape(new int[] { 1, 1, tensorData.channels, 1 });
}
Layer layer = net.Const(constName, tensorData);
inputs[1] = layer.name;
Const(constName, new ONNXTensor(tensorData, tensorData.shape.ToArray()));
}
}
return inputs;
}
// Broadcast ops
Add("Add", (net, node) => { net.Add(node.Name, GetCorrectedConstants(node, net)); });
Add("Sum", (net, node) => { net.Add(node.Name, GetCorrectedConstants(node, net)); }); // Sum is implemented via Add
Add("Sub", (net, node) => { net.Sub(node.Name, GetCorrectedConstants(node, net)); });
Add("Mul", (net, node) => { net.Mul(node.Name, GetCorrectedConstants(node, net)); });
Add("Div", (net, node) => { net.Div(node.Name, GetCorrectedConstants(node, net)); });
Add("Pow", (net, node) => { net.Pow(node.Name, node.Inputs); });
Add("Min", (net, node) => { net.Min(node.Name, node.Inputs); });
Add("Max", (net, node) => { net.Max(node.Name, node.Inputs); });
Add("Mean", (net, node) => { net.Mean(node.Name, node.Inputs); });
// Logical ops
Add("Greater", (net, node) => { net.Greater(node.Name, node.Input0, node.Input1); });
Add("Less", (net, node) => { net.Less(node.Name, node.Input0, node.Input1); });
Add("LessOrEqual", (net, node) => { net.LessEqual(node.Name, node.Input0, node.Input1); });
Add("Equal", (net, node) => { net.Equal(node.Name, node.Input0, node.Input1); });
Add("Or", (net, node) => { net.LogicalOr(node.Name, node.Input0, node.Input1); });
Add("And", (net, node) => { net.LogicalAnd(node.Name, node.Input0, node.Input1); });
Add("Not", (net, node) => { net.LogicalNot(node.Name, node.Input0); });
Add("Xor", (net, node) => { net.LogicalXor(node.Name, node.Input0, node.Input1); });
Add("Where", (net, node) => { net.Where(node.Name, node.Input0, node.Input1, node.Input2); });
// Padding ops
Add("Pad", (net, node) =>
{
// TODO refactor pad handling to truncate only in NCHWToNHWCPass
var mode = node.ModeOptional("constant");
var pads = node.Pads;
switch (mode)
{
case "constant":
var value = node.GetOptionalFloat("value", 0.0f);
if (pads.Length > 4)
net.Border3D(node.Name, node.Input0, pads, value);
else
net.Border2D(node.Name, node.Input0, pads, value);
break;
case "reflect": net.Pad2DReflect(node.Name, node.Input0, pads); break;
case "edge": net.Pad2DEdge(node.Name, node.Input0, pads); break;
}
});
// Pooling ops
Add("AveragePool", (net, node) => {
node.UnsupportedAttribute("ceil_mode", 0);
node.UnsupportedAttribute("count_include_pad", 0);
net.AvgPool2D(node.Name, node.Input0, node.KernelShape, node.Strides, node.Pads);
});
Add("MaxPool", (net, node) => {
node.UnsupportedAttribute("ceil_mode", 0);
node.UnsupportedAttribute("dilations", new[] {1, 1});
node.UnsupportedAttribute("storage_order", 0);
int[] strides = node.Strides;
int[] pads = node.Pads;
if (strides.Length == 1)
strides = new[] { 1, strides[0] };
Assert.IsTrue(strides.Length == 2);
int[] kernenShape = node.KernelShape;
if (kernenShape.Length == 1)
kernenShape = new[] { kernenShape[0], 1 };
net.MaxPool2D(node.Name, node.Input0, kernenShape, strides, pads);
});
Add("GlobalAveragePool", (net, node) => { net.GlobalAvgPool2D(node.Name, node.Input0); });
Add("GlobalMaxPool", (net, node) => { net.GlobalMaxPool2D(node.Name, node.Input0); });
Add("Upsample", (net, node) => {
// @TODO: the same for Resize node
string mode = node.ModeOptional("nearest");
if (node.InputCount == 2 && !node.IsInput1Const)
if (node.Input0Rank <= 4)
net.Upsample2D(node.Name, node.Input0, node.Input1, IsModeBilinear(net, node, mode));
else
net.Upsample3D(node.Name, node.Input0, node.Input1, IsModeBilinear(net, node, mode));
else
Resample(net, node, node.Name, node.Input0, node.Scales, mode);
});
Add("Resize", (net, node) => {
if (node.InputCount > 2) // Resize-11
{
node.UnsupportedAttribute("coordinate_transformation_mode", "half_pixel");
node.UnsupportedAttribute("cubic_coeff_a", -0.75f);
node.UnsupportedAttribute("exclude_outside", 0);
node.UnsupportedAttribute("extrapolation_value", 0f);
node.UnsupportedAttribute("nearest_mode", "round_prefer_floor");
// Inputs (3 - 4)
// X : T1
// roi : T2, It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize"
// scales : tensor(float)
// sizes (optional) : tensor(int64)
// TODO: cropping via roi input
// TODO: support sizes
}
if (node.InputCount > 3)
{
var mode = node.ModeOptional("nearest");
var bilinear = IsModeBilinear(net, node, mode);
net.Resample2D(node.Name, node.Input0, node.Sizes, bilinear);
}
else
{
Resample(net, node, node.Name, node.Input0, node.Scales, node.ModeOptional("nearest"));
}
});
Add("Transpose", (net, node) =>
{
// From https://github.com/onnx/onnx/blob/master/docs/Operators.md#transpose
// By default, reverse the dimensions, otherwise permute the axes according to the values given.
if (node.IsInput0Const)
{
int inputTensorRank = constantTensors[node.Input0].rank;
var defaultPermutations = new int[inputTensorRank];
for (int i = 0; i < inputTensorRank; ++i)
defaultPermutations[i] = inputTensorRank - 1 - i;
var permutations = node.GetOptionalIntArray("perm", defaultPermutations);
var transposedTensor = constantTensors[node.Input0].Permute(permutations);
Const(node, transposedTensor);
}
else
{
var defaultPermutations = new[] {5, 4, 3, 2, 1, 0};
var permutations = node.GetOptionalIntArray("perm", defaultPermutations);
if (permutations.Length > 6)
throw new OnnxLayerImportException($"Transpose support up to 6 dimensions but got a permutations of rank {permutations}.");
if (Enumerable.SequenceEqual(permutations, new[] { 0, 3, 1, 2 }) || // NHWC -> NCHW
Enumerable.SequenceEqual(permutations, new[] { 0, 2, 3, 1 })) // NCHW -> NHWC
{
// @TODO: reorder uptream nodes and global input dimensions accordingly from NHWC -> NCHW
net.Identity(node.Name, node.Input0);
if (permutations[1] == 3) // NHWC -> NCHW
Output(node, layout: VariableTensor.Layout.ChannelsFirst);
else if (permutations[1] == 2) // NCHW -> NHWC
{
Output(node, layout: VariableTensor.Layout.ChannelsLast);
layerRequiringUpstreamPatch.Add(node.Name);
}
else
Assert.IsTrue("Reached unexpected branch" == "");
}
else if (Enumerable.SequenceEqual(permutations, new[] { 0, 2, 1 })) // NWC <-> NCW
{
// @TODO: reorder uptream nodes and global input dimensions accordingly from NHWC -> NCHW
if (m_FixTf2OnnxExportIssues)
{
Warn(net, node, $"Use '--inputs-as-nchw' flag when exporting model from Tensorflow with tf2onnx");
net.Identity(node.Name, node.Input0);
// flip layout
if (node.Input0Layout == VariableTensor.Layout.ChannelsLast)
Output(node, layout: VariableTensor.Layout.ChannelsFirst);
else
{
Output(node, layout: VariableTensor.Layout.ChannelsLast);
layerRequiringUpstreamPatch.Add(node.Name);
}
}
else
{
int[] barracudaPermutation = { 0, 1, 3, 2 };
net.Transpose(node.Name, node.Input0, barracudaPermutation);
}
}
else if (Enumerable.SequenceEqual(permutations, new[] { 1, 0, 2 })) // batch <-> seq_length
{
// LSTM layout is problematic as it's usually flanked by a few Transposed if exported from TF
// @TODO investigate if better solution
net.Identity(node.Name, node.Input0);
}
else
{
//Here we assume `Channels` are represented by only one dimensions and it that it is the 2nd one.
//however in some case (example: shufflenet, sub-pixel-cnn) reshape-transpose-reshape pattern lead
//to channels being represented by two dimenssion this is handled in
//FixReshapeTransposePatternWhenChannelsAreSplitIntoMultipleDimensions()
//Expand received permutation to 6D adding padding between Channels and other feature dimensions.
int numDimensionDimensionsThatWerePaddedAtCenterOfTranspose = 0;
var permutationsNCTDHW = ONNXLayout.ExpandONNXPermutationToNCTDHW(permutations, out numDimensionDimensionsThatWerePaddedAtCenterOfTranspose);
//From channel first to channel last.
var permutationsNTDHWC = ONNXLayout.ConvertPermutationToLayout(permutationsNCTDHW, "NCTDHW", "NTDHWC");
//6d to 8d
int[] permuteSRNTDHWC = new int[TensorShape.MaxRank];
permuteSRNTDHWC[0] = 0;
permuteSRNTDHWC[1] = 1;
for (int i = 0; i < 6; ++i)
permuteSRNTDHWC[i+2] = 2+permutationsNTDHWC[i];
var layer = net.Transpose(node.Name, node.Input0, permuteSRNTDHWC);
layer.axis = numDimensionDimensionsThatWerePaddedAtCenterOfTranspose;
}
}
});
Add("DepthToSpace", (net, node) => {
net.DepthToSpace(node.Name, node.Input0, node.BlockSize, node.ModeOptional("DCR"));
});
Add("SpaceToDepth", (net, node) => {
net.SpaceToDepth(node.Name, node.Input0, node.BlockSize);
});
// Tensor ops
Add("Gemm", (net, node) => {
node.UnsupportedAttribute("alpha", 1.0f);
node.UnsupportedAttribute("beta", 1.0f);
node.UnsupportedAttribute("transA", 0);
var onnxLayout = node.TransBOptional() ? "KC" : "CK";
var weights = node.Input1Constant(onnxLayout, name:"B");
var biases = node.Input2ConstantOptional(Bias(weights.shape), 0.0f, onnxLayout:"C", name:"C");
// Change data layout from "channels first" to "channels last"
weights = SwapSpatialDimensionsAndFeaturesInMatMulWeights(weights, weights.flatHeight, node.Input0Layout);
net.Dense(node.Name, node.Input0, weights, biases);
Output(node, features:weights.channels, rank:2); // Gemm forces flatten of the input to rank 2
});
Add("MatMul", (net, node) => {
if (node.InputCount == 2 && !node.IsInput1Const || node.Input0Rank != 2 || node.Input1Rank != 2)
{
// if inputs are const, need to transpose them
if(node.IsInput1Const)
{
var Y = constantTensors[node.Input1].ToBarracuda("NCTDHW");
net.Const(node.Input1, Y);
}
net.MatMul(node.Name, node.Input0, node.Input1);
Output(node, features: node.Input0Features, rank: Math.Max(node.Input0Rank, node.Input1Rank));
}
else
{
var weights = node.Input1Constant(onnxLayout: "CK", name: "B");
var biases = node.DefaultTensor(Bias(weights.shape), 0.0f);
// Change data layout from "channels first" to "channels last"
weights = SwapSpatialDimensionsAndFeaturesInMatMulWeights(weights, node.Input0Features, node.Input0Layout);
net.Dense(node.Name, node.Input0, weights, biases);
Output(node, features: weights.channels, rank: 2); // MatMul forces flatten of the input to rank 2
}
});
Add("Conv", (net, node) => {
int[] dilationsDHW = new[] { 1, 1, 1 }; // @TODO trap on wrong values
int[] strides = node.Strides;
int[] pads = node.Pads;
node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead");
var kernels = node.Input1Constant(onnxLayout: "KCHW", name: "W");
var kernelRank = node.Input1Rank;
if (kernelRank == 3) // Conv1D
{
dilationsDHW = node.DilatationsOptional(new[] { 1 }); // @TODO trap on wrong values
Assert.IsTrue(dilationsDHW.Length == 1);
dilationsDHW = new[] { 1, 1, dilationsDHW[0] };
if (strides.Length == 1)
strides = new[] { strides[0], 1 };
if (pads.Length == 2)
pads = new[] { pads[0], 0, pads[1], 0 };
}
else if (kernelRank == 4) // Conv2D
{
dilationsDHW = node.DilatationsOptional(new[] { 1, 1 });
Assert.IsTrue(dilationsDHW.Length == 2);
dilationsDHW = new[] { 1, dilationsDHW[0], dilationsDHW[1] };
}
else if (kernelRank == 5) // Conv3D
{
//TODO specific error message for DepthwiseConv3D (or support it).
node.UnsupportedAttribute("group", 1);
dilationsDHW = node.DilatationsOptional(new[] { 1, 1, 1 });
Assert.IsTrue(dilationsDHW.Length == 3);
pads = node.Pads3D;
strides = node.Strides3D;
}
else
{
Warn(net, node, $"Unsuported Conv kernel rank. Conv1D/2D/3 assumes rank 3/4/5 respectively, but got {kernelRank}.");
}
Assert.IsTrue(dilationsDHW.Length == 3);
if (dilationsDHW[0] != 1 || dilationsDHW[1] != 1 || dilationsDHW[2] != 1)
kernels = DilateKernel(kernels, dilationsDHW); // @TODO inefficient method. Support dilatation in kernel code properly
var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout: "C", name: "B");
if (node.GroupOptional() > 1)
net.DepthwiseConv2D(node.Name, node.Input0, strides, pads, kernels, biases);
else
{
if (kernelRank < 5)
net.Conv2D(node.Name, node.Input0, strides, pads, kernels, biases);
else
net.Conv3D(node.Name, node.Input0, strides, pads, kernels, biases);
}
Output(node, features: kernels.channels);
});
Add("ConvTranspose", (net, node) => {
node.UnsupportedAttribute("dilations", new[] {1, 1});
node.UnsupportedAttribute("group", 1);
node.UnsupportedAttribute("output_shape", new int[0]);
node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead");
var kernels = node.Input1Constant(onnxLayout:"CKHW", name:"W");
var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout:"C", name:"B");
net.Conv2DTrans(node.Name, node.Input0, node.Strides, node.Pads, node.OutputPadding, kernels, biases);
Output(node, features:kernels.channels);
});
Add("BatchNormalization", (net, node) => {
var variance = node.Input4Constant(onnxLayout:"C", name:"var");
var scale = node.Input1ConstantOptional(variance.shape, 1.0f, onnxLayout:"C", name:"scale");
var bias = node.Input2ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"B");
var mean = node.Input3ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"mean");
if (variance.length != scale.length || scale.length != bias.length || bias.length != mean.length)
Warn(net, node, $"Number of elements in all parameters for BatchNorm must be the same." +
$"Parameter shapes are: {scale.shape}, {bias.shape}, {mean.shape}, {variance.shape}");
if (variance.channels != node.Input0Features && node.Input0Features > 0)
Warn(net, node, $"Number of elements in BatchNorm must match features from the previous layer. Was expecting {node.Input0Features}, but got {variance.channels}.");
var fusedData = FuseBatchNormWeights(scale, bias, mean, variance, node.EpsilonOptional(), node.Input0Features);
net.ScaleBias(node.Name, node.Input0, fusedData.Item1, fusedData.Item2);
});
Add("ImageScaler", (net, node) =>
{
var attrBias = node.Bias;
var attrScale = node.ScaleOptional();
int maxElements = attrBias.Length;
Tensor scale = new Tensor(1, maxElements);
Tensor bias = new Tensor(1, maxElements);
for (int i = 0; i < maxElements; ++i)
{
scale[i] = attrScale;
bias[i] = attrBias[i];
}
net.ScaleBias(node.Name, node.Input0, scale, bias);
});
Add("InstanceNormalization", (net, node) => {
var scale = node.Input1Constant(onnxLayout:"C", name:"scale");
var bias = node.Input2ConstantOptional(scale.shape, 0.0f, onnxLayout:"C", name:"B");
if (scale.length != bias.length)
Warn(net, node, $"Number of elements in all parameters for InstanceNorm must be the same." +
$"Parameter shapes are: {scale.shape}, {bias.shape}");
if (scale.channels != node.Input0Features && node.Input0Features > 0)
{
Warn(net, node, $"Number of elements in InstanceNorm must match features from the previous layer. Was expecting {node.Input0Features}, but got {scale.channels}.");
var scaleArray = scale.ToReadOnlyArray();
Array.Resize(ref scaleArray, node.Input0Features);
var biasArray = bias.ToReadOnlyArray();
Array.Resize(ref biasArray, node.Input0Features);
scale = new Tensor(1, node.Input0Features, scaleArray);
bias = new Tensor(1, node.Input0Features, biasArray);
}
net.Normalization(node.Name, node.Input0, scale, bias, node.EpsilonOptional());
});
Add("LRN", (net, node) => {
float bias = node.GetOptionalFloat("bias", 1.0f);
int size = node.GetRequiredInt("size");
net.LRN(node.Name, node.Input0, node.AlphaOptional(0.0001f), node.BetaOptional(0.75f), bias, size);
});
// random ops
Add("RandomNormal", (net, node) => {
var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"NCHW");
net.RandomNormal(node.Name, shape, node.MeanOptional(), node.ScaleOptional(), node.Seed);
Output(node, rank:node.Shape.Length);
});
Add("RandomNormalLike", (net, node) => {
net.RandomNormal(node.Name, node.Input0, node.MeanOptional(), node.ScaleOptional(), node.Seed);
});
Add("RandomUniform", (net, node) => {
float high = node.GetOptionalFloat("high", 1.0f);
float low = node.GetOptionalFloat("low", 0.0f);
var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"NCHW");
net.RandomUniform(node.Name, shape, low, high, node.Seed);
Output(node, rank:node.Shape.Length);
});
Add("RandomUniformLike", (net, node) => {
float high = node.GetOptionalFloat("high", 1.0f);
float low = node.GetOptionalFloat("low", 0.0f);
net.RandomUniform(node.Name, node.Input0, low, high, node.Seed);
});
Add("Multinomial", (net, node) => {
int samples = node.GetOptionalInt("sample_size", 1);
net.Multinomial(node.Name, node.Input0, samples, node.Seed);
});
// Reduce ops
Add("ReduceMax", (net, node) => {
Reduce(net, node, Layer.Type.ReduceMax);
});
Add("ReduceMean", (net, node) => {
Reduce(net, node, Layer.Type.ReduceMean);
});
Add("ReduceMin", (net, node) => {
Reduce(net, node, Layer.Type.ReduceMin);
});
Add("ReduceProd", (net, node) => {
Reduce(net, node, Layer.Type.ReduceProd);
});
Add("ReduceSum", (net, node) => {
Reduce(net, node, Layer.Type.ReduceSum);
});
Add("ArgMax", (net, node) => {
node.UnsupportedAttribute("select_last_index");
Reduce(net, node, Layer.Type.ArgMax);
});
Add("ArgMin", (net, node) => {
node.UnsupportedAttribute("select_last_index");
Reduce(net, node, Layer.Type.ArgMin);
});
// Ignore, noop during inference
Add("Identity", (net, node) => { net.Identity(node.Name, node.Input0); });
Add("Cast", (net, node) => { net.Identity(node.Name, node.Input0); });
Add("Dropout", (net, node) => { net.Identity(node.Name, node.Input0); });
}
private string ResolveLstmOutputName(ONNXNodeWrapper node)
{
var baseLSTMOutputName = $"recurrent_out_{node.Name}";
if (lstmOutputs.ContainsKey(node.Name))
{
var actualName = lstmOutputs[node.Name];
if (actualName.EndsWith(":0"))
actualName = actualName.Substring(0, actualName.Length - 2);
if (actualName.EndsWith("_h") || actualName.EndsWith("_c"))
baseLSTMOutputName = actualName.Substring(0, actualName.Length - 2);
else
baseLSTMOutputName = actualName;
}
return baseLSTMOutputName;
}
private string ResolveLstmInputName(ONNXNodeWrapper node)
{
var baseLSTMName = $"recurrent_in_{node.Name}";
if (lstmInputs.ContainsKey(node.Name))
{
var actualName = lstmInputs[node.Name];
if (actualName.EndsWith(":0"))
actualName = actualName.Substring(0, actualName.Length - 2);
if (actualName.EndsWith("_h") || actualName.EndsWith("_c"))
baseLSTMName = actualName.Substring(0, actualName.Length - 2);
else
baseLSTMName = actualName;
}
return baseLSTMName;
}
// Fuse training time BatchNorm tensors into Scale & Bias
internal static Tuple<Tensor, Tensor> FuseBatchNormWeights(Tensor gamma, Tensor beta, Tensor mean, Tensor variance, float epsilon, int maxElements = -1)
{
// https://github.com/Tencent/ncnn/blob/master/src/layer/batchnorm.cpp
// float sqrt_var = sqrt(var_data[i]);
// a_data[i] = bias_data[i] - slope_data[i] * mean_data[i] / sqrt_var;
// b_data[i] = slope_data[i] / sqrt_var;
// ...
// ptr[i] = b * ptr[i] + a;
Assert.IsTrue(gamma.channels == gamma.length); // assert 1d tensor
Assert.IsTrue(gamma.shape == beta.shape);
Assert.IsTrue(gamma.shape == mean.shape);
Assert.IsTrue(gamma.shape == variance.shape);
if (maxElements <= 0 || gamma.length < maxElements) // clip to the smallest valid number of channels
maxElements = gamma.length;
Tensor scale = new Tensor(1, maxElements);
Tensor bias = new Tensor(1, maxElements);
for (int i = 0; i < maxElements; ++i)
{
scale[i] = gamma[i] / Mathf.Sqrt(variance[i] + epsilon);
bias[i] = beta[i] - gamma[i] * mean[i] / Mathf.Sqrt(variance[i] + epsilon);
}
return Tuple.Create(scale, bias);
}
// TODO move that in custom pass if need be
// Transpose channels first to channels last data in MatMul/GEMM weight tensor
internal static Tensor SwapSpatialDimensionsAndFeaturesInMatMulWeights(Tensor weights, int featureCount, VariableTensor.Layout layout)
{
if (featureCount == 0) // wild card feature: after Reduce, runtime correct weights. TODO: remove when full dims are known
return weights;
Assert.IsTrue(featureCount <= weights.flatHeight);
var weightsAssumeChannelsFirstLayout = (layout != VariableTensor.Layout.ChannelsLast);
if (featureCount != weights.flatHeight && weightsAssumeChannelsFirstLayout)
{
var shape = weights.shape;
var implicitSpatialDimensionsInWeights = shape.flatHeight / featureCount;
Assert.IsTrue(shape.flatHeight % featureCount == 0);
// reshape: __C____K -> __C__HWK
weights = weights.Reshape(
new TensorShape(featureCount, implicitSpatialDimensionsInWeights, 1, shape.channels));
// permute: __C__HWK -> __H__WCK
var permutations =
TensorExtensions.Get8DPermutationsForNHWCPermutationsAndShape(weights.shape, new int[] {1, 0, 2, 3});
weights = ONNXTensor.Permute(weights, permutations);
// reshape: __H__WCK -> __C____K
weights = weights.Reshape(shape);
}
return weights;
}
internal static Model PatchFromIncorrectlyAssumedChannelsFirstToChannelsLastLayoutUpstream(Model model, List<string> layerRequiringUpstreamPatch)
{
HashSet<int> patchedInputIndices = new HashSet<int>();
HashSet<string> patchedLayerAxis = new HashSet<string>();
var inputIndexByName = new Dictionary<string, int>();
for (var i = 0; i < model.inputs.Count; ++i)
inputIndexByName.Add(model.inputs[i].name, i);
// NOTE: although original input had NHWC layout
// (most probably exported from Tensorflow without '--inputs-as-nchw' flag)
// earlier when parsing input and axis we made incorrect assumption that they were NCHW
// now we need to revert that assumption!
foreach (var rootNodeForPatch in layerRequiringUpstreamPatch)
{
int inputIndex = -1;
var upstream = ModelAnalyzer.FindUpstreamLayers(model, new[] {rootNodeForPatch});
foreach (var layer in upstream)
{
// patch axis
if (!patchedLayerAxis.Contains(layer.name) && (
layer.type == Layer.Type.Concat ||
layer.type == Layer.Type.Gather ||
layer.type == Layer.Type.TopKValues))//TODO handle ReduceXX and StridedSlice
{
patchedLayerAxis.Add(layer.name);
if (layer.axis == 6) layer.axis = TensorShape.C;
else if (layer.axis == TensorShape.C) layer.axis = 6;
}
//patch inputs
foreach (var inputName in layer.inputs)
{
if (inputIndexByName.TryGetValue(inputName, out inputIndex) &&
!patchedInputIndices.Contains(inputIndex))
{
// example (NCHW): -1,2,2,16 -> (incorrect) -1,2,16,2 -> (fix) -1,2,2,16
// example (NCW): -1,2,16 -> (incorrect) -1,1,16,2 -> (fix) -1,1,2,16
patchedInputIndices.Add(inputIndex);
var inputDesc = model.inputs[inputIndex];
inputDesc.shape = ONNXLayout.Permute(inputDesc.shape, new[] {-1, -1, 2, -1, -1, 7, 5, 6});
model.inputs[inputIndex] = inputDesc;
}
}
// @TODO: figure out, if there is any case where we would have to propagate fixed layout assumption downstream?
}
}
return model;
}
// TODO: use Burst for this
internal static Tensor DilateKernel(Tensor kernel, int[] dilationsDHW)
{
//TODO: slow path in C# consider refactoring in Burst
Assert.IsTrue(dilationsDHW.Length == 3);
Assert.IsTrue(dilationsDHW[0] > 0);
Assert.IsTrue(dilationsDHW[1] > 0);
Assert.IsTrue(dilationsDHW[2] > 0);
// https://arxiv.org/pdf/1603.07285.pdf
Tensor dilatedKernel = new Tensor(new TensorShape(1,
kernel.shape.kernelSpatialDepth + (kernel.shape.kernelSpatialDepth - 1) * (dilationsDHW[0] - 1),
kernel.shape.kernelHeight + (kernel.shape.kernelHeight - 1) * (dilationsDHW[1] - 1),
1,
1,
kernel.shape.kernelWidth + (kernel.shape.kernelWidth - 1) * (dilationsDHW[2] - 1),
kernel.shape.kernelDepth,
kernel.shape.kernelCount));
for (int c = 0; c < dilatedKernel.kernelDepth; ++c)
for (int k = 0; k < dilatedKernel.kernelCount; ++k)
{
for (int d = 0; d < kernel.shape.kernelSpatialDepth; ++d)
for (int y = 0; y < kernel.shape.kernelHeight; ++y)
for (int x = 0; x < kernel.shape.kernelWidth; ++x)
{
int od = d * dilationsDHW[0];
int oy = y * dilationsDHW[1];
int ox = x * dilationsDHW[2];
int strideD = d == (kernel.shape.kernelSpatialDepth - 1) ? 1 : dilationsDHW[0];
int strideY = y == (kernel.shape.kernelHeight - 1) ? 1 : dilationsDHW[1];
int strideX = x == (kernel.shape.kernelWidth - 1) ? 1 : dilationsDHW[2];
for (int dx = 0; dx < strideX; dx++)
for (int dy = 0; dy < strideY; dy++)
for (int dd = 0; dd < strideD; dd++)
{
dilatedKernel[ 0, od +dd, oy + dy, 0, 0, ox + dx, c, k] = 0.0f;
}
float v = kernel[ 0, d, y, 0, 0, x, c, k];
dilatedKernel[0, od, oy, 0, 0, ox, c, k] = v;
}
}
return dilatedKernel;
}
internal static TensorShape Bias(TensorShape shape)
{
return new TensorShape(1, 1, 1, shape.channels);
}
internal static bool IsModeBilinear(ModelBuilder net, ONNXNodeWrapper node, string mode)
{
bool bilinear = false;
if (mode == "linear" || mode == "bilinear")
bilinear = true;
else if (mode != "nearest")
Warn(net, node, $"Mode `{mode}` is not supported for type {node.OperatorType}.");
return bilinear;
}
internal static Layer UpsampleNCHW(ModelBuilder net, ONNXNodeWrapper node, int scaleInputIndex)
{
string mode = node.ModeOptional("nearest");
var bilinear = IsModeBilinear(net, node, mode);
// NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is
if (scaleInputIndex != 0 && node.InputCount > scaleInputIndex && !node.IsInputConst(scaleInputIndex))
{
// TODO: Input1 may be rank 1, which means that this would require a swizzle in the actual data
if (node.Input0Rank <= 4)
return net.Upsample2D(node.Name, node.Input0, node.GetRequiredInput(scaleInputIndex), bilinear);
else
return net.Upsample3D(node.Name, node.Input0, node.GetRequiredInput(scaleInputIndex), bilinear);
}
else
return UpsampleFromConstNCHW(net, node, node.Name, node.Input0, node.ConvertScales(), mode);
}
internal static Layer UpsampleFromConstNCHW(ModelBuilder net, ONNXNodeWrapper node, string name, object input, float[] scales, string mode)
{
if (!scales.All(x => x > 0.0f))
Warn(net, node, $"Only positive scale values are supported.");
if (scales.Length == 4 &&
scales[0] == 1.0f &&
scales[1] == 1.0f &&
scales[2] < 1.0f &&
scales[3] < 1.0f &&
IsModeBilinear(net, node, mode))
{
var scales2D = scales.Skip(2);
if (!scales2D.All(x => Mathf.Approximately(1f / x, Mathf.Round(1f / x))))
Warn(net, node, $"Only inverse of scale values which produce integer are currently supported. Inverse of scale value will be rounded to closest integer.");
var noPad = new[] { 0, 0, 0, 0 };
var inverseScalesRoundedToInt = scales2D.Select(x => (int)Mathf.Round(1f / x)).ToArray();
return net.AvgPool2D(name, input, inverseScalesRoundedToInt, inverseScalesRoundedToInt, noPad);
}
else
{
if (!scales.All(x => Mathf.Approximately(x, Mathf.Round(x))))
Warn(net, node, $"Only integer scale values are currently supported. Scale value will be rounded to closest integer value.");
var scalesRoundedToInt = scales.Select(x => (int)Mathf.Round(x)).ToArray();
if (scales.Length > 5)
Warn(net, node, ">3D upsampling are not supported yet!");
if (scales.Length == 5)
return net.Upsample3D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode));
else
return net.Upsample2D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode));
}
}
internal static Layer Resample(ModelBuilder net, ONNXNodeWrapper node, string name, object input, float[] scales, string mode)
{
if (!scales.All(x => x > 0.0f))
Warn(net, node, $"Only positive scale values are supported.");
if (scales.All(x => x < 1.0f))
{
if (!scales.All(x => Mathf.Approximately(1f/x, Mathf.Round(1f/x))))
Warn(net, node, $"Only inverse of scale values which produce integer are currently supported. Inverse of scale value will be rounded to closest integer.");
var noPad = new[] {0, 0, 0, 0};
var inverseScalesRoundedToInt = scales.Select(x => (int)Mathf.Round(1f/x)).ToArray();
// @TODO: nearest, actually this is bilinear downsampling
if (scales.Length > 2)
Warn(net, node, ">2D downsampling are not supported yet!");
return net.AvgPool2D(name, input, inverseScalesRoundedToInt, inverseScalesRoundedToInt, noPad);
}
else
{
if (!scales.All(x => Mathf.Approximately(x, Mathf.Round(x))))
Warn(net, node, $"Only integer scale values are currently supported. Scale value will be rounded to closest integer value.");
var scalesRoundedToInt = scales.Select(x => (int)Mathf.Round(x)).ToArray();
if (scales.Length > 3)
Warn(net, node, ">3D upsampling are not supported yet!");
if (scales.Length > 2)
return net.Upsample3D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode));
else
return net.Upsample2D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode));
}
}
private static int[] GetPermutationToMatchReduceWithDroppedDimensionsFromONNX(int[] droppedONNXAxis, int rank)
{
Assert.IsTrue(droppedONNXAxis.Length>0);
//Barracuda always have all dimensions, however in ONNX it is not the case one can drop dimensions,
//Here we handle the case of ReduceXXX ops when they do so.
//An example:
//ONNX -> NCHW
//Reduce on C with keepDims=False.
//ONNX -> NHW
//However ONNX tensor semantic are deducted by position to be mapped to Barracuda in the following way:
//ONNX 1D -> N -> Barracuda N,1,1,1
//ONNX 2D -> NC -> Barracuda N,1,1,C
//ONNX 3D -> NCW -> Barracuda N,1,W,C
//ONNX 4D -> NCHW -> Barracuda N,H,W,C
//Thus the output tensor above (NHW) will be mapped to N,1,W,C in Barracuda
//while Reduce in Barracuda would rather output N,H,W,1 if keepDim would be true.
//Here we find the transpose needed in Barracuda to match the ONNX behavior as seen by Barracuda.
//ie the transpose from N,H,W,1 to N,1,W,C in this case aka 0,3,2,1.
//ONNX input Layout from rank
string onnxLayout;
switch (rank)
{
case 1: onnxLayout = "N";
break;
case 2: onnxLayout = "NC";
break;
case 3: onnxLayout = "NCW";
break;
case 4: onnxLayout = "NCHW";
break;
default:
//TODO support 8D
throw new OnnxLayerImportException($"Reduce ops support up to 4D at the moment, however received an input of rank {rank}.");
}
//ONNX Layout once dimensions are dropped (example: NHW if C was dropped)
string onnxLayoutDimensionsDropped = onnxLayout;
foreach (var axis in droppedONNXAxis)
{
var onnxAxis = axis;
if (onnxAxis < 0)
onnxAxis = rank + axis;
string semanticToRemove = onnxLayout[onnxAxis].ToString();
onnxLayoutDimensionsDropped = onnxLayoutDimensionsDropped.Replace(semanticToRemove, string.Empty);
}
Assert.IsTrue(onnxLayoutDimensionsDropped.Length>0);
//Find all missing dimensions that will be unitary in Barracuda
var missingDimensions = new List<char>();
foreach (var dim in "NHWC")
{
if (!onnxLayoutDimensionsDropped.Contains(dim))
missingDimensions.Add(dim);
}
//Find semantic of onnx layout with dropped dimension in Barracuda
var barracudaSemanticLayoutFromONNXReduce = new char[4];
switch (onnxLayoutDimensionsDropped.Length)
{
case 1:
//ONNX 1D -> N -> Barracuda N,1,1,1
barracudaSemanticLayoutFromONNXReduce[0] = onnxLayoutDimensionsDropped[0];
barracudaSemanticLayoutFromONNXReduce[1] = missingDimensions[0];
barracudaSemanticLayoutFromONNXReduce[2] = missingDimensions[1];
barracudaSemanticLayoutFromONNXReduce[3] = missingDimensions[2];
break;
case 2:
//ONNX 2D -> NC -> Barracuda N,1,1,C
barracudaSemanticLayoutFromONNXReduce[0] = onnxLayoutDimensionsDropped[0];
barracudaSemanticLayoutFromONNXReduce[1] = missingDimensions[0];
barracudaSemanticLayoutFromONNXReduce[2] = missingDimensions[1];
barracudaSemanticLayoutFromONNXReduce[3] = onnxLayoutDimensionsDropped[1];
break;
case 3:
//3D -> NCW -> Barracuda N,1,W,C
barracudaSemanticLayoutFromONNXReduce[0] = onnxLayoutDimensionsDropped[0];
barracudaSemanticLayoutFromONNXReduce[1] = missingDimensions[0];
barracudaSemanticLayoutFromONNXReduce[2] = onnxLayoutDimensionsDropped[2];
barracudaSemanticLayoutFromONNXReduce[3] = onnxLayoutDimensionsDropped[1];
break;
}
//Find permutation from NHWC Barracuda layout when mapped from ONNX with dropped dimensions.
var permutation = new int[4];
for(int idTarget = 0; idTarget<permutation.Length; ++idTarget)
{
char semantic = barracudaSemanticLayoutFromONNXReduce[idTarget];
permutation[idTarget] = "NHWC".IndexOf(semantic);;
}
return permutation;
}
internal void ReduceNCHW(ModelBuilder net, ONNXNodeWrapper node, Layer.Type reduceType)
{
var keepdims = node.GetOptionalInt("keepdims", 1);
var features = node.Input0Features;
var rank = node.Input0Rank;
object input = node.Input0;
var axes = node.HasAttribute("axes") ? node.AxesOptional(new[] { 0 }) : new[] {node.AxisOptional(0)};
if (node.InputCount >= 2)
axes = node.Input1Constant(onnxLayout: "ONNX", name: "axes").AsInts();
// Sort high to low since we are reducing rank in each iteration
// var axes = node.AxesOptional(new[] { 0 }).OrderByDescending(a => a).ToArray();
int reducedDim = 0;
foreach (var onnxAxis in axes)
{
//TODO support 8D inputs
//var axis = ONNXLayout.ConvertAxisToBarracuda(onnxAxis, onnxRank: rank, onnxLayout: "ONNX");
var axis = onnxAxis;
if (reducedDim != 0)
axis--;
var nameR = $"{node.Name}__axis{onnxAxis}";
input = net.Reduce(reduceType, nameR, input, axis, true, keepdims);
//if (axis == TensorShape.C) // This is actually W
// features = 1; // this operation collapse all features to 1
Output(nameR, features: features, rank: rank);
// Without keepdims, we will be reducing rank every axis iteration
if((keepdims == 0))
{
rank--;
reducedDim++;
}
}
net.Identity(node.Name, input);
}
internal void Reduce(ModelBuilder net, ONNXNodeWrapper node, Layer.Type reduceType)
{
var keepdims = node.GetOptionalInt("keepdims", 1);
var features = node.Input0Features;
var rank = node.Input0Rank;
object input = node.Input0;
var axes = node.HasAttribute("axes") ? node.AxesOptional(new[] { 0 }) : new[] {node.AxisOptional(0)};
foreach (var onnxAxis in axes)
{
//TODO support 8D inputs
var axis = ONNXLayout.ConvertAxisToBarracuda(onnxAxis, onnxRank: rank, onnxLayout: "NCHW");
if (node.Input0Layout == VariableTensor.Layout.ChannelsLast && node.Input0Rank == 4)
axis = TensorExtensions.Convert4DTo8DAxis(onnxAxis);
var nameR = $"{node.Name}__axis{axis}";
input = net.Reduce(reduceType, nameR, input, axis, true, keepdims);
if (axis == TensorShape.C)
features = 1; // this operation collapse all features to 1
Output(nameR, features: features, rank: rank);
}
if (keepdims != 1 && rank > 1 && (node.Input0Layout != VariableTensor.Layout.ChannelsLast)) // keepdims removes dimensions in the context of onnx thus we need to repack/transpose to match behavior.
{
var nameT = $"{node.Name}__transpose";
var transpose = GetPermutationToMatchReduceWithDroppedDimensionsFromONNX(axes, rank);
input = net.Transpose(nameT, input, transpose);
rank = rank - axes.Length;
//TODO: features count is wrong and should potentially be deduced from input + transpose
Output(nameT, features: 0, rank: rank);
}
net.Identity(node.Name, input);
//TODO: features count is wrong and should potentially be deduced from input
Output(node.Name, features: 0, rank: rank);
}
private ONNXModelTensors m_ModelTensors = new ONNXModelTensors();
private readonly Dictionary<string, Action<ModelBuilder, ONNXNodeWrapper>> m_NodeImporters =
new Dictionary<string, Action<ModelBuilder, ONNXNodeWrapper>>();
// NOTE: It's questionable whether we should be doing this since the ONNX specification requires the graph to be
// topologically sorted, but at least one network encountered that was exported from keras2onnx v1.7.0 produced
// an incorrectly sorted graph. related example: https://github.com/onnx/keras-onnx/issues/184
void SortTopologically(ModelProto onnxModel, List<NodeProto> sortedGraph)
{
var nodesToSort = new Queue<NodeProto>();
GraphProto onnxGraph = onnxModel.Graph;
foreach (NodeProto node in onnxGraph.Node)
{
nodesToSort.Enqueue(node);
}
var requeueNodes = new Queue<NodeProto>();
while (nodesToSort.Count > 0)
{
NodeProto node = nodesToSort.Dequeue();
var allInputsExist = true;
foreach (string input in node.Input)
{
if (string.IsNullOrEmpty(input))
continue;
if (!sortedGraph.Exists(n => n.Output.Any(o => o == input))
&& !onnxGraph.Input.Any(i => i.Name == input)
&& !onnxGraph.Initializer.Any(i => i.Name == input))
{
allInputsExist = false;
break;
}
}
if (!allInputsExist)
{
if (nodesToSort.Count != 0)
{
// Mark for re-processing again when (potentially) all inputs have been processed
// We use a separate list, so we don't continually spin on nodes that are missing inputs
if (!requeueNodes.Contains(node))
requeueNodes.Enqueue(node);
continue;
}
// Something must've gone wrong
throw new OnnxImportException($"Missing inputs to node {node.Name}, but there are no nodes to process.");
}
if (!sortedGraph.Contains(node))
sortedGraph.Add(node);
// Now that we have at least processed a single new node, let's requeue
while (requeueNodes.Count > 0)
nodesToSort.Enqueue(requeueNodes.Dequeue());
}
}
private Model ConvertOnnxModel(ModelProto onnxModel)
{
var model = new Model();
bool standardImport = m_ImportMode.HasFlag(ImportMode.Standard);
model.layout = standardImport ? "iNCHW" : "NHWC";
var modelBuilder = new ModelBuilder(model);
// Builds list of nodes that should not be included into the final Barracuda Model, mostly for LSTMs
var nodesToSkip = standardImport ? new HashSet<string>() : BuildNodeSkipList(onnxModel.Graph);
// Import any (optional) metadata properties
if (!m_ImportMode.HasFlag(ImportMode.SkipMetadataImport))
{
RepeatedField<StringStringEntryProto> metadataProps = onnxModel.MetadataProps;
Dictionary<string, string> metadata = model.Metadata;
for (int p = 0; p < metadataProps.Count; p++)
{
StringStringEntryProto prop = metadataProps[p];
metadata.Add(prop.Key, prop.Value);
}
}
// Convert graph inputs & outputs
var initializersByName = onnxModel.Graph.Initializer.ToDictionary(i => i.Name, i => true);
foreach (ValueInfoProto i in onnxModel.Graph.Input)
{
// skip input tensors that have initializer data, they are constant tensors not global inputs
// also skip nodes that should be trimmed
if (initializersByName.ContainsKey(i.Name) || (!standardImport && nodesToSkip.Contains(i.Name)))
continue;
if (!standardImport && m_OverrideGlobalInputs.ContainsKey(i.Name))
{
Const(i.Name, m_OverrideGlobalInputs[i.Name]);
continue;
}
int[] onnxShape = i.Type.TensorType.Shape.AsInts();
modelBuilder.Input(i.Name, ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, onnxLayout:standardImport ? "ONNX" : "NCHW"), onnxShape.Length);
var shapeValues = i.Type.TensorType.Shape.Dim.Select(d => d.DimValue).ToArray();
Output(i.Name, onnxShape: shapeValues, onnxLayout:"NCHW");
}
foreach (ValueInfoProto o in onnxModel.Graph.Output)
modelBuilder.Output(o.Name);
// Read constants from initializer list
foreach (TensorProto initializer in onnxModel.Graph.Initializer)
Const(initializer.Name, new ONNXTensor(initializer));
// Nodes are supposed to be sorted, but this isn't always the case
var sortedGraph = new List<NodeProto>();
if (standardImport)
{
SortTopologically(onnxModel, sortedGraph);
}
else
{
// for the legacy import pipeline, let's keep it as it was
sortedGraph.AddRange(onnxModel.Graph.Node);
}
// Convert graph nodes
foreach (NodeProto onnxNode in sortedGraph)
{
if (!standardImport && nodesToSkip.Contains(ONNXNodeWrapper.GetName(onnxNode)))
continue;
var node = new ONNXNodeWrapper(onnxNode, m_ModelTensors, model.Warnings);
var nodeId = node.Name;
var opType = node.OperatorType;
Output(node);
bool injectDummy = false;
if (m_NodeImporters.ContainsKey(opType))
{
try
{
if (!standardImport && node.AreAllInputsConst && !m_ShouldNotBeBaked.Contains(opType))
{
Profiler.BeginSample($"Bake {opType} {node.Name}");
var bakedTensor = BakeNodeIntoConstant(opType, node);
Const(node.Name, bakedTensor);
var printTensor = bakedTensor.ToBarracuda("NCHW");
D.Log($"Baked node {nodeId} into constant of shape {printTensor.shape} and values: {printTensor.DataToString()}");
Profiler.EndSample();
}
else
{
Profiler.BeginSample($"Import {opType} {node.Name}");
m_NodeImporters[opType](modelBuilder, node);
Profiler.EndSample();
}
}
catch (Exception e)
{
// We support the layer but something went wrong while importing it
// We log the problem and insert an identity layer
string message = $"Unexpected error while parsing layer {nodeId} of type {opType}.";
Err(model, nodeId, message,
extendedMessage:"Will replace it by an Identity layer.",
debugMessage:$"{e.Message}\n\nJson: {onnxNode}\n{e.StackTrace}\n");
injectDummy = true;
}
}
else
{
// We don't support this type of layer
// We log the problem and insert an identity layer
string message = $"Unknown type {opType} encountered while parsing layer {nodeId}.";
Err(model, nodeId, message, extendedMessage:"Will replace it by an Identity layer.");
injectDummy = true;
}
if (injectDummy)
{
var originalLayerHadInputs = (node.InputCount > 0);
if (originalLayerHadInputs)
{
var originalLayerHadConstantInput = node.IsInput0Const;
if (originalLayerHadConstantInput)
Const(nodeId, constantTensors[node.Input0]); // copy constant
else
modelBuilder.Identity(nodeId, node.Input0);
}
else // if errorneous layer had no inputs, inject dummy constant which does not require any inputs
modelBuilder.Const(nodeId, new Tensor());
}
m_ModelTensors.CompleteUninitializedFields(node);
}
// Convert constant tensors
var requiredConstants = new HashSet<string>(ModelAnalyzer.FindBrokenLinks(model));
// ML-Agents metadata is stored in otherwise unreferenced constants
var unreferencedConstantsContainMLAgentsMetadata = UnreferencedNodes(onnxModel.Graph);
requiredConstants.UnionWith(unreferencedConstantsContainMLAgentsMetadata); // keep ML-Agents metadata
int insertionIndex = 0; // insert constants at the beginning of the model
foreach(var entry in constantTensors)
{
if (requiredConstants.Contains(entry.Key)) // skip if constant is unused
{
modelBuilder.Const(entry.Key, entry.Value.ToBarracuda(standardImport ? "ONNX" :
GetONNXLayoutForConstant(model, entry.Key)),
insertionIndex++, rank: entry.Value.rank);
}
}
if (m_ImportMode == ImportMode.Legacy)
{
foreach (Layer l in model.layers)
{
if (requiredConstants.Contains(l.name))
l.flags |= Layer.Flags.Preserve;
}
model = ModelOptimizer.Optimize(model, allowFusing: m_OptimizeModel, keepLayers:requiredConstants); // keep ML-Agents metadata
model = FixReshapeTransposePatternWhenChannelsAreSplitIntoMultipleDimensions(model);
if (!m_FixTf2OnnxExportIssues)
model = PatchFromIncorrectlyAssumedChannelsFirstToChannelsLastLayoutUpstream(model, layerRequiringUpstreamPatch);
}
// strip :0 at the end of string name for TF import
if (m_FixTf2OnnxExportIssues)
model = TrimTensorflowNames(model);
if (m_ImportMode == ImportMode.Legacy)
Validate(model);
// Parse meta data
var irVersion = onnxModel.IrVersion; // legacy
if (onnxModel.OpsetImport?.Count > 0)
irVersion = onnxModel.OpsetImport[0].Version;
model.ProducerName = $"{onnxModel.ProducerName} v{onnxModel.ProducerVersion}";
model.IrSource = "ONNX";
model.IrVersion = $"{irVersion}";
return model;
}
private bool IsLayerInputChannelDependant(Layer.Type opType, int index)
{
return index == 0 || //First input is usually channel order dependants
opType == Layer.Type.Add || //however some operator have all input channel dependants
opType == Layer.Type.Sub ||
opType == Layer.Type.Mul ||
opType == Layer.Type.Div ||
opType == Layer.Type.Pow ||
opType == Layer.Type.Min ||
opType == Layer.Type.Max ||
opType == Layer.Type.Mean ||
opType == Layer.Type.Greater ||
opType == Layer.Type.GreaterEqual ||
opType == Layer.Type.Less ||
opType == Layer.Type.LessEqual ||
opType == Layer.Type.Equal ||
opType == Layer.Type.LogicalOr ||
opType == Layer.Type.LogicalAnd ||
opType == Layer.Type.LogicalXor ||
opType == Layer.Type.Where ||
opType == Layer.Type.Concat;
}
private string GetONNXLayoutForConstant(Model model, string nodeName)
{
int constLayoutRequestCount = 0;
int nctdhwRequestCount = 0;
//find all layer using that constant as an input.
foreach (var l in model.layers)
{
for (int i = 0; i < l.inputs.Length; ++i)
{
if (l.inputs[i] == nodeName)
{
if (IsLayerInputChannelDependant(l.type, i))
++nctdhwRequestCount;
else
++constLayoutRequestCount;
}
}
}
if (nctdhwRequestCount != 0 && constLayoutRequestCount != 0)
{
Err(model, nodeName, $"{nodeName} is both used as channel order dependant constant and a plain constant, this is not supported at the moment.");
}
return nctdhwRequestCount>constLayoutRequestCount?"NCTDHW":"CONST";
}
private ONNXTensor BakeNodeIntoConstant(string opType, ONNXNodeWrapper node)
{
var model = new Model();
var net = new ModelBuilder(model);
// add all inputs as constants
Assert.IsTrue(node.AreAllInputsConst);
for (var i = 0; i < node.InputCount; ++i)
{
var assumeOnnxLayout = (m_AllInputsChannelFirst.Contains(opType) || i == 0) ? "NCTDHW" : "CONST";
var input = node.Inputs[i];
net.Const(input,
constantTensors[input].ToBarracuda(assumeOnnxLayout));
}
// add node that we are going to bake into the constant
m_NodeImporters[opType](net, node);
// bake
var useCPUforBaking = WorkerFactory.Device.CPU;
using (var worker = WorkerFactory.CreateWorker(model, useCPUforBaking))
{
var bakedConstant = worker.Execute().PeekOutput();
// convert from Barracuda back into ONNX layout
Tensor onnxData = bakedConstant;
onnxData = ONNXTensor.Permute(bakedConstant, new int[] {0,1,2,7,3,4,5,6}); // S,R,N,T,D,H,W,C (channelLast)-> S,R,N,C,H,W (channelFirst)
var onnxShape = onnxData.shape.ToArray();
return new ONNXTensor(onnxData, onnxShape).SqueezeAll();
}
}
static private void Validate(Model model)
{
// Model should not contain any broken links in the end
var unconnectedInputs = ModelAnalyzer.FindBrokenLinks(model);
Assert.IsTrue(unconnectedInputs.Length == 0);
if (unconnectedInputs.Length > 0)
{
var message = $"Broken links: {string.Join(", ", unconnectedInputs)}";
Warn(model, "", message);
}
}
private HashSet<string> UnreferencedNodes(GraphProto graph)
{
var allNodes = new HashSet<string>();
var allInputs = new HashSet<string>();
foreach (var node in graph.Node)
{
allNodes.Add(ONNXNodeWrapper.GetName(node));
foreach (var input in node.Input)
allInputs.Add(input);
}
// Remove all global output nodes
foreach (ValueInfoProto o in graph.Output)
allNodes.Remove(o.Name);
// Remove all nodes that are referenced by Inputs to get the set of unreferenced ones
var unreferencedNodes = allNodes;
unreferencedNodes.ExceptWith(allInputs);
return unreferencedNodes;
}
private void BacktraceNodeInputs(Dictionary<string, NodeProto> nameToNode,
NodeProto[] startingNodes,
Action<NodeProto> regularNodeCallback,
Action<NodeProto> inputNodeCallback)
{
HashSet<NodeProto> nodesToCheck = new HashSet<NodeProto>(startingNodes);
while (nodesToCheck.Count > 0)
{
var el = nodesToCheck.First();
regularNodeCallback(el);
nodesToCheck.Remove(el);
if (el.Input.Count > 0)
{
if (nameToNode.ContainsKey(el.Input[0]))
nodesToCheck.Add(nameToNode[el.Input[0]]); // regular node
else
inputNodeCallback(el);
}
}
}
// TODO: Remove along with legacy importer in Barracuda 2.0
private HashSet<string> BuildNodeSkipList(GraphProto graph)
{
var res = new HashSet<string>();
var nameToNode = graph.Node.ToDictionary(i => ONNXNodeWrapper.GetName(i), i => i);
var outputToLSTMNode = new Dictionary<string, string>();
// Skip all LSTM _h & _c inputs as they will be accessible directly via Model.memories
foreach (NodeProto onnxNode in graph.Node)
{
if (onnxNode.OpType == "LSTM")
{
var lstmNodeName = ONNXNodeWrapper.GetName(onnxNode);
var initial_h = onnxNode.Input[5];
var initial_c = onnxNode.Input[6];
List<NodeProto> startingNodes = new List<NodeProto>();
if (nameToNode.ContainsKey(initial_h))
startingNodes.Add(nameToNode[initial_h]);
if (nameToNode.ContainsKey(initial_c))
startingNodes.Add(nameToNode[initial_c]);
BacktraceNodeInputs(
nameToNode,
startingNodes.ToArray(),
el => { res.Add(ONNXNodeWrapper.GetName(el)); },
el => { lstmInputs[lstmNodeName] = el.Input[0]; res.Add(el.Input[0]);}
);
outputToLSTMNode[onnxNode.Output[1]] = lstmNodeName; // _h
outputToLSTMNode[onnxNode.Output[2]] = lstmNodeName; // _c
}
}
// Also trace from outputs to LSTM nodes to figure out names of the output _h and _c nodes
foreach (var output in graph.Output)
{
if (!nameToNode.ContainsKey(output.Name))
continue;
// As LSTM has 3 outputs and backtracing is done only via output[0]
// then output[1] and output[2] will be treated as leaf input nodes
BacktraceNodeInputs(
nameToNode,
new[] {nameToNode[output.Name]},
el => { },
el =>
{
var inputName = el.Input[0];
if (outputToLSTMNode.ContainsKey(inputName))
{
lstmOutputs[outputToLSTMNode[inputName]] = output.Name;
}
}
);
}
return res;
}
static private string ApplyPermutationToLayout(string layout, int[] permutation)
{
Assert.IsTrue(layout.Length == permutation.Length);
char[] permutedLayout = new char[layout.Length];
for (int i = 0; i < layout.Length; ++i)
{
permutedLayout[i] = layout[permutation[i]];
}
return new string(permutedLayout);
}
static private int[] FindPermutationFromLayouts(string layout, string permutedLayout)
{
Assert.IsTrue(layout.Length == permutedLayout.Length);
int[] permutation = new int[layout.Length];
for (int i = 0; i < layout.Length; ++i)
{
permutation[i] = layout.IndexOf(permutedLayout[i]);
}
return permutation;
}
static private Model FixReshapeTransposePatternWhenChannelsAreSplitIntoMultipleDimensions(Model model)
{
var transposes = model.layers.Where(l => l.type == Layer.Type.Transpose).ToList();
foreach (var transposeLayer in transposes)
{
var previousLayer = model.layers.Find(l => l.name == transposeLayer.inputs[0]);
if (previousLayer == null)
continue;
if (previousLayer.type != Layer.Type.Reshape)
continue;
var numChannelDimensionBeforeTranspose = previousLayer.axis;
if (numChannelDimensionBeforeTranspose <= 1)
continue;
int centerPaddingThatWasAddedInPermutation = transposeLayer.axis;
Assert.IsTrue(centerPaddingThatWasAddedInPermutation <= 1);
Assert.IsTrue(centerPaddingThatWasAddedInPermutation >= 0);
//NOTE: See also ConvertReshapeToBarracuda() for mode detail on the problem.
//In some network like shufflenet, superresolution_cnn and yolov3 a reshape is used
//before a transpose to split the channels resulting in a tensor with
//multiple dimension used for channels, this is a problem when importing to
//barracuda as the semantic of the dimensions are changed and this change the
//way channel first to channel last conversion should happen. The code below
//is a limited to support for that.
Assert.IsTrue(numChannelDimensionBeforeTranspose == 2 || numChannelDimensionBeforeTranspose == 3);
var permutationSRNTDHWC = transposeLayer.pool;
if (permutationSRNTDHWC.Length != 8)
{
Warn(model, transposeLayer.name,
$"Expecting a permutation of rank 8 after Reshape '{previousLayer.name}' itself outputting more than one channel dimension. Permutation can't be patched to account for the extra channel dimensions.");
continue;
}
//Find layouts before transpose in both channel order
string layoutBeforeTranspose_ChannelFirst = (numChannelDimensionBeforeTranspose == 3) ? "SRN123HW" : "SRN1T2HW";
string layoutBeforeTranspose_ChannelLast = (numChannelDimensionBeforeTranspose == 3) ? "SRNHW123" : "SRNTHW12";
//Find layout after transpose in channel first
int[] permutation_ChannelFirst = ONNXLayout.ConvertPermutationToLayout(permutationSRNTDHWC, "SRNTDHWC","SRNCTDHW");
string layoutAfterTranspose_ChannelFirst = ApplyPermutationToLayout(layoutBeforeTranspose_ChannelFirst, permutation_ChannelFirst);
//Find layout after transpose in channel last
//TODO/HEURISTIC: We differentiate the various case by knowing if channels and features are interleaved during permutations.
//This is a work around to create the right permutation for the shufflenet/super-resolution and yolov3, it does not generalise well however.
//In next version of the importer we might need to introduce transposes in channel last mode to generalise fully.
int[] channelFirstToLastPermutation = null;
if (numChannelDimensionBeforeTranspose == 3)
{
//super resolution -> final reshape will pick only 1 dimension as channel -> regular channel first to last transposition.
channelFirstToLastPermutation = FindPermutationFromLayouts("SRN1TDHW", "SRNTDHW1");
}
else if (IsPermutationMixingChannelsAndOtherFeatures(layoutBeforeTranspose_ChannelFirst, permutation_ChannelFirst))
{
//yolov3 -> final reshape does not pick any dimension as channel -> no transposition.
channelFirstToLastPermutation = FindPermutationFromLayouts("SRNTUDHW", "SRNTUDHW");
}
else
{
//shufflenet -> final reshape take 2 dimension and merge them so both need to be affected by channel first to last transposition
channelFirstToLastPermutation = FindPermutationFromLayouts("SRN1T2HW", "SRNTHW12");
}
string layoutAfterTranspose_ChannelLast = ApplyPermutationToLayout(layoutAfterTranspose_ChannelFirst, channelFirstToLastPermutation);
//Finally compute and return permutation in channel last
int[] permutation_ChannelLast = FindPermutationFromLayouts(layoutBeforeTranspose_ChannelLast, layoutAfterTranspose_ChannelLast);
transposeLayer.pool = permutation_ChannelLast;
}
return model;
}
static private bool IsPermutationMixingChannelsAndOtherFeatures(string layout, int[] permutation)
{
//Convention here is that channels are described as numbers, while other features by letters.
Assert.IsTrue(layout.Length == permutation.Length);
for (int i = 0; i < permutation.Length; ++i)
{
bool sourceIsAChannel = Char.IsNumber(layout[i]);
bool targetIsAChannel = Char.IsNumber(layout[permutation[i]]);
if (sourceIsAChannel != targetIsAChannel)
return true;
}
return false;
}
static private Model TrimTensorflowNames(Model model)
{
model.inputs = model.inputs.Select(i => {
i.name = TrimTensorflowName(i.name);
return i;
}).ToList();
model.outputs = model.outputs.Select(o => {
return TrimTensorflowName(o);
}).ToList();
model.memories = model.memories.Select(m => {
m.input = TrimTensorflowName(m.input);
m.output = TrimTensorflowName(m.output);
return m;
}).ToList();
model.layers = model.layers.Select(l => {
l.name = TrimTensorflowName(l.name);
for(int i = 0; i < l.datasets.Length; i++)
l.datasets[i].name = TrimTensorflowName(l.datasets[i].name);
for(int i = 0; i < l.inputs.Length; i++)
l.inputs[i] = TrimTensorflowName(l.inputs[i]);
if (l.outputs != null)
{
for (int i = 0; i < l.outputs.Length; i++)
l.outputs[i] = TrimTensorflowName(l.outputs[i]);
}
return l;
}).ToList();
return model;
}
static private string TrimTensorflowName(string name)
{
if (name.EndsWith(":0"))
return name.Remove(name.Length-2);
return name;
}
// Helpers to keep track of model tensors
private void Const(ONNXNodeWrapper node, ONNXTensor onnxTensor)
{
m_ModelTensors.AddConstant(node.Name, onnxTensor);
}
private void Const(string name, ONNXTensor onnxTensor)
{
m_ModelTensors.AddConstant(name, onnxTensor);
}
private void Output(ONNXNodeWrapper node, int features = -1, int rank = -1,
VariableTensor.Layout layout = VariableTensor.Layout.Unknown)
{
Output(node.Name, features, rank, layout);
}
private void Output(string name, int features = -1, int rank = -1,
VariableTensor.Layout layout = VariableTensor.Layout.Unknown)
{
m_ModelTensors.AddVariable(name, features, rank, layout);
}
private void Output(string name, ONNXTensor onnxTensor)
{
m_ModelTensors.AddVariable(name, onnxTensor);
}
private void Output(string name, long[] onnxShape, string onnxLayout)
{
m_ModelTensors.AddVariable(name, onnxShape, onnxLayout);
}
private void Output(ONNXNodeWrapper node, int features, string productOfShape)
{
m_ModelTensors.AddVariable(node.Name, features, productOfShape);
}
// Logging helpers
private static void Warn(ModelBuilder builder, ONNXNodeWrapper node, string message)
{
Warn(builder.model, node.Name, message);
}
private static void Warn(Model model, string layerName, string message)
{
model.Warnings.Add(new Model.ImporterWarning(layerName,message));
Debug.LogWarning(message);
}
private void Err(Model model, string layerName, string message, string extendedMessage = "", string debugMessage = "")
{
if (m_TreatErrorsAsWarnings)
{
model.Warnings.Add(new Model.ImporterWarning(layerName,$"{message} {extendedMessage}"));
Debug.LogWarning($"{message} {extendedMessage}\n{debugMessage}");
}
else
throw new OnnxImportException($"{message}\n{debugMessage}");
}
}
/// <summary>
/// ONNX import exception
/// </summary>
public class OnnxImportException : Exception
{
/// <summary>
/// Create `OnnxImportException`
/// </summary>
/// <param name="message">message</param>
public OnnxImportException(string message) : base(message) { }
}
/// <summary>
/// ONNX layer import exception
/// </summary>
public class OnnxLayerImportException : Exception
{
/// <summary>
/// Create `OnnxLayerImportException`
/// </summary>
/// <param name="message">message</param>
public OnnxLayerImportException(string message) : base(message) { }
}
}