using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.CompilerServices; using Google.Protobuf; using Google.Protobuf.Collections; using Onnx; using Unity.Barracuda.Compiler.Passes; using UnityEngine; using UnityEngine.Assertions; using UnityEngine.Profiling; [assembly: InternalsVisibleTo("Unity.Barracuda.Tests")] namespace Unity.Barracuda.ONNX { /// /// ONNX model converter to Barracuda format. /// public class ONNXModelConverter { [Flags] internal enum ImportMode { Legacy = 0, // No flags == legacy Standard = 1 << 0, // Additional options KeepAsNCHW = 1 << 16, SkipMetadataImport = 1 << 17, } [Flags] internal enum DataTypeMode { Default = 0, ForceHalf = 1, ForceFloat = 2 } // Configuration bool m_TreatErrorsAsWarnings; bool m_OptimizeModel = true; bool m_ForceArbitraryBatchSize; ImportMode m_ImportMode; // TF2ONNX known issue: (as of 1.5.4) // - Conv are framed with Transposes as long as the NCHW flag is not set // (note this seems that it's going to be fixed https://github.com/onnx/tensorflow-onnx/pull/796) // - Tensorflow appends :0 to all node names bool m_FixTf2OnnxExportIssues; /// /// Model imported event /// public static event Action ModelImported; private readonly Dictionary m_OverrideGlobalInputs = new Dictionary() { { "sequence_length:0", new ONNXTensor(new Tensor(1, 1, new[] { 1f }), new [] { 1 }) }, { "sequence_length", new ONNXTensor(new Tensor(1, 1, new[] { 1f }), new [] { 1 }) } }; private readonly HashSet m_ShouldNotBeBaked = new HashSet() { // the following nodes handle constant inputs in a custom manner and should not be baked: "Constant", "Reshape", "Shape", "Slice", "Gather", "Transpose", "Squeeze", "Unsqueeze", "NonZero", "ConstantOfShape", // the following nodes are dynamic in nature and can not be baked even when all inputs are constant: "RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike" }; private readonly HashSet m_AllInputsChannelFirst = new HashSet() { // the following onnx nodes have all of there inputs as channel first layout "Concat", "Add", "Sum", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Mean", "Greater", "Less", "Equal", "Or", "And", "Xor", "Where" }; // Shortcuts private Dictionary constantTensors { get { return m_ModelTensors.constants; } } private Dictionary variableTensors { get { return m_ModelTensors.variables; } } private Dictionary lstmInputs = new Dictionary(); private Dictionary lstmOutputs = new Dictionary(); private List layerRequiringUpstreamPatch = new List(); private void Add(string opType, Action opImportAction) { m_NodeImporters.Add(opType, opImportAction); } /// /// Convert ONNX model and return Barracuda Model object. /// /// Location of the input ONNX model. /// Barracuda Model object. public Model Convert(string filePath) { using (var readStream = new FileStream(filePath, FileMode.Open, FileAccess.Read)) using (var inputStream = new CodedInputStream(readStream)) return Convert(inputStream); } /// /// Convert ONNX model and return Barracuda Model object. /// /// Memory buffer containing ONNX model. /// Barracuda Model object. public Model Convert(byte[] buffer) { using (var inputStream = new CodedInputStream(buffer)) return Convert(inputStream); } // Legacy LSTM importer automagically split input nodes and added output nodes when they didn't exist in the // network, which is no longer supported bool IsLegacyMLAgentsLSTMNetwork(ModelProto onnxModel) { GraphProto graph = onnxModel.Graph; // Hallway-lstm.onnx - legacy importer splits recurrent_in to recurrent_in_c and recurrent_in_h // adds output node recurrent_out_c and recurrent_out_h if (onnxModel.ProducerName == "tf2onnx" && graph.Input.Any(i => i.Name.Contains("recurrent_in")) && graph.Output.Any(o => o.Name.Contains("recurrent_out"))) return true; // Hallway.onnx / Hallway-no-workaround.onnx - legacy importer splits memories to memories_c and memories_h; // adds output node recurrent_out__c and recurrent_out__h NodeProto lstmNode = graph.Node.FirstOrDefault(n => n.OpType == "LSTM"); if (onnxModel.ProducerName == "pytorch" && graph.Input.Any(i => i.Name.Contains("memories")) && lstmNode != null && lstmNode.Output.Count == 3 && !graph.Node.Any(n => n.Name == lstmNode.Output[1]) // missing output cell and hidden nodes && !graph.Node.Any(n => n.Name == lstmNode.Output[2])) return true; // Hallway_1_9.onnx - This was supposed to be the candidate for ML-Agents 2.0, but did not have transposes // in the network, so we will have to import using legacy importer and support during the 1.x ML-Agents // lifecycle since this already shipped. lstmNode = graph.Node.FirstOrDefault(n => n.OpType == "LSTM"); if (onnxModel.ProducerName == "pytorch" && graph.Input.Any(i => i.Name.Contains("recurrent_in")) && graph.Output.Any(i => i.Name.Contains("recurrent_out")) // Input to LSTM node is incorrectly coming directly from a Slice w/o a Transpose && lstmNode != null && lstmNode.Input.Any(i => { var inputNode = graph.Node.FirstOrDefault(n => n.Output.FirstOrDefault() == i); return inputNode != null && inputNode.Input.Contains("recurrent_in") && inputNode.OpType == "Slice"; })) return true; return false; } internal Model Convert(CodedInputStream inputStream) { var onnxModel = new ModelProto(); onnxModel.MergeFrom(inputStream); m_FixTf2OnnxExportIssues = onnxModel.ProducerName == "tf2onnx"; bool legacyMLAgentsLSTMNetwork = IsLegacyMLAgentsLSTMNetwork(onnxModel); if (legacyMLAgentsLSTMNetwork) m_ImportMode = ImportMode.Legacy; if (m_ImportMode.HasFlag(ImportMode.Standard)) UseStandardImporter(); else UseLegacyImporter(); var model = ConvertOnnxModel(onnxModel); if (m_ImportMode.HasFlag(ImportMode.Standard)) { var preserveLayersPass = new PreserveLayersPass(); preserveLayersPass.Run(ref model); if (m_ImportMode.HasFlag(ImportMode.KeepAsNCHW)) { // Since our model is non-runnable due to NHWC-native ops this pass is always required var runnableNCHWPass = new IntermediateToRunnableNCHWPass(); runnableNCHWPass.Run(ref model); } else { var runnableNHWCPass = new IntermediateToRunnableNHWCPass() { Optimize = m_OptimizeModel }; runnableNHWCPass.Run(ref model); } } if (legacyMLAgentsLSTMNetwork) model.Warnings.Add(new Model.ImporterWarning("model", "Using legacy importer since legacy LSTM network was detected; Support will be removed in Barracuda v2.0")); ModelImported?.Invoke(onnxModel, model); return model; } /// /// Constructs ONNX model converter /// /// Enable/disable various model optimizations while importing model from ONNX format. /// Treat import errors as warnings. /// Repair model input batch size. Sometimes needed for ONNX models coming from PyTorch. public ONNXModelConverter(bool optimizeModel, bool treatErrorsAsWarnings = false, bool forceArbitraryBatchSize = true) : this(optimizeModel, treatErrorsAsWarnings, forceArbitraryBatchSize, ImportMode.Standard) { } // Internal constructor to allow setting import mode internal ONNXModelConverter(bool optimizeModel, bool treatErrorsAsWarnings, bool forceArbitraryBatchSize, ImportMode importMode) { m_OptimizeModel = optimizeModel; m_TreatErrorsAsWarnings = treatErrorsAsWarnings; m_ForceArbitraryBatchSize = forceArbitraryBatchSize; m_ImportMode = importMode; } void UseStandardImporter() { m_NodeImporters.Clear(); var defaultZeroTensor = new ONNXTensor(new Tensor(1, 1, new[] { 0f }), new[] { 1 }); Add("Constant", (net, node) => { node.UnsupportedAttribute("sparse_value"); Const(node, node.ValueAsTensor); }); Add("ConstantOfShape", (net, node) => { UnityEngine.Debug.Assert(node.InputCount > 0); ONNXTensor valueTensor = node.GetOptionalTensor("value", defaultZeroTensor); var value = valueTensor.ToBarracuda("ONNX").AsFloats()[0]; if (node.IsInput0Const) { var onnxShape = node.Input0Constant("ONNX").AsInts(); int onnxRank = onnxShape.Length; onnxShape = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, "ONNX"); var tensor = new Tensor(onnxShape); tensor.Fill(value); net.Const(node.Name, tensor, -1, onnxRank); } else { net.ConstantOfShape(node.Name, node.Input0, value); } }); Add("Reshape", (net, node) => { int[] onnxShape; if (node.InputCount == 1) { onnxShape = node.Shape; if (node.IsInput0Const) { // reshape constant source tensor and store it as the new constant var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape); Const(node, reshapedTensor); } else { net.Reshape(node.Name, node.Input0, onnxShape); Output(node, rank:onnxShape.Length); } } else { if (node.IsInput1Const) { onnxShape = node.Input1Constant(onnxLayout: "ONNX", name: "shape").AsInts(); if (node.IsInput0Const) { // reshape constant source tensor and store it as the new constant var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape); Const(node, reshapedTensor); } else { net.Reshape(node.Name, node.Input0, onnxShape); Output(node, rank:onnxShape.Length); } } else { net.Reshape(node.Name, node.Input0, node.Input1); } } }); Add("Expand", (net, node) => { if (node.IsInput1Const) { var onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsInts(); net.Expand(node.Name, node.Input0, onnxShape); Output(node, rank: onnxShape.Length); } else { net.Expand(node.Name, node.Input0, node.Input1); } }); Add("Shape", (net, node) => { float[] shapeValuesAsFloats; if (node.IsInput0Const) { shapeValuesAsFloats = constantTensors[node.Input0].shape.Select(x => (float)x).ToArray(); } else { net.Shape(node.Name, node.Input0); } }); Add("Unsqueeze", (net, node) => { int[] constAxes = null; if (node.InputCount >= 2 && node.IsInput1Const) constAxes = node.Input1Constant(onnxLayout: "ONNX", name: "axes").AsInts(); else constAxes = node.Axes; if (node.IsInput0Const && constAxes != null) { var unsqueezed = constantTensors[node.Input0].Unsqueeze(constAxes); Const(node, unsqueezed); } else if (node.InputCount == 1) { net.Unsqueeze(node.Name, node.Input0, node.Axes); } else { net.Unsqueeze(node.Name, node.Input0, node.Input1); } }); Add("Squeeze", (net, node) => { int[] constAxes = null; if (node.InputCount >= 2 && node.IsInput1Const) constAxes = node.Input1Constant(onnxLayout: "ONNX", name: "axes").AsInts(); else constAxes = node.Axes; if (node.IsInput0Const && constAxes != null) { var squeezed = constantTensors[node.Input0].Squeeze(constAxes); Const(node, squeezed); } else if (node.InputCount == 1) { net.Squeeze(node.Name, node.Input0, node.Axes); } else { net.Squeeze(node.Name, node.Input0, node.Input1); } }); Add("Tile", (net, node) => { // only 4D Tile support for now net.Tile(node.Name, node.Input0, node.Input1); }); Add("Flatten", (net, node) => { node.UnsupportedAttribute("axis", 1); // TODO we can support it, insert transposes or if dimensions are ok, == reshape net.Flatten(node.Name, node.Input0); Output(node, rank:2); }); Add("Concat", (net, node) => { int axis = node.AxisOptional(0); if (node.Inputs.Length == 1) net.Identity(node.Name, node.Input0); else { net.Concat(node.Name, node.Inputs, axis, true); } }); Add("Split", (net, node) => { int axis = node.AxisOptional(0); int[] splits; try { splits = node.GetRequiredIntArray("split"); } catch (OnnxLayerImportException) { throw new OnnxLayerImportException($"Unsupported default attribute `split` for node {node.Name} of type Split. Value is required."); } Assert.IsTrue(splits.Length == node.Outputs.Length); int currentSliceStartIndex = 0; // Convert `Split` into multiple `StridedSlice` operations. for (int i = 0; i < splits.Length; ++i) { var starts = currentSliceStartIndex; var ends = starts + splits[i]; var strides = 1; net.StridedSlice(node.Outputs[i], node.Input0, new[] { starts }, new[] { ends }, new[] { strides }, new[] { axis }); currentSliceStartIndex += splits[i]; } }); Add("Slice", (net, node) => { int[] starts, ends, axes, steps; if (node.InputCount > 1) // Slice-10 { if (!node.IsInput1Const || !node.IsInput2Const) { if(node.InputCount == 5) net.StridedSlice(node.Name, node.Input0, starts: node.Input1, ends: node.Input2, strides: node.Input4, axes: node.Input3); else if (node.InputCount == 3) net.StridedSlice(node.Name, node.Input0, starts: node.Input1, ends: node.Input2, strides: null, axes: null); } else { var constStarts = node.Input1Constant(onnxLayout: "ONNX", name: "starts"); var constEnds = node.Input2Constant(onnxLayout: "ONNX", name: "ends"); var defaultAxes = new Tensor(constStarts.shape, Enumerable.Range(0, constStarts.length).Select(v => (float)v).ToArray()); var constAxes = node.Input3ConstantOptional(defaultAxes, onnxLayout: "ONNX", name: "axes"); var constSteps = node.Input4ConstantOptional(constStarts.shape, 1.0f, onnxLayout: "ONNX", name: "steps"); starts = constStarts.AsInts(); ends = constEnds.AsInts(); axes = constAxes.AsInts(); steps = constSteps.AsInts(); net.StridedSlice(node.Name, node.Input0, starts: starts, ends: ends, strides: steps, axes: axes); } } else // Slice-1 { starts = node.Starts; ends = node.Ends; axes = node.AxesOptional(Enumerable.Range(0, starts.Length).ToArray()); steps = Enumerable.Repeat(1, starts.Length).ToArray(); net.StridedSlice(node.Name, node.Input0, starts: starts, ends: ends, strides: steps, axes: axes); } }); Add("Gather", (net, node) => { int axis = node.AxisOptional(0); if (node.IsInput0Const && node.IsInput1Const) { var indices = node.Input1Constant(onnxLayout:"ONNX", name:"indices").AsInts(); ONNXTensor gatheredTensor = constantTensors[node.Input0].Gather(axis, indices); Const(node, gatheredTensor); } else { int input1Rank = node.Input1Rank; if (node.IsInput1Const) { bool isIndicesIntValue = !node.IsInput1Array("indices"); // The original rank was cached above since our constant tensor requires a shape of rank 1 and original may have been a scalar var indices = node.Input1Constant(onnxLayout: "ONNX", name: "indices").AsFloats(); var shape = isIndicesIntValue ? new int[] { } : new[] { indices.Length }; var constTensor = new ONNXTensor(new Tensor(new [] { indices.Length, 1, 1, 1, 1, 1, 1, 1 }, indices), shape); Const(node.Input1, constTensor); } // for import conveintcy, gather with single int values and not int[] implemented with int[] followed by squeeze if (node.Input1Rank == 0) { var gatherLayer = net.Gather(node.Name + "_Squeezed", node.Input0, node.Input1, axis, true); net.Squeeze(node.Name, gatherLayer, new[] { axis }); } else { net.Gather(node.Name, node.Input0, node.Input1, axis, true); } Output(node.Name, rank: input1Rank + node.Input0Rank - 1); } }); Add("ScatterND", (net, node) => { string reduction = node.GetOptionalString("reduction", "none"); Layer.ScatterNDReductionMode reductionType = Layer.ScatterNDReductionMode.None; if (reduction == "add") reductionType = Layer.ScatterNDReductionMode.Add; else if (reduction == "mul") reductionType = Layer.ScatterNDReductionMode.Mul; net.ScatterND(node.Name, node.Input0, node.Input1, node.Input2, reductionType); }); Add("NonMaxSuppression", (net, node) => { int centerPointBox = node.GetOptionalInt("center_point_box", 0); var boxes = node.GetRequiredInput(0); var scores = node.GetRequiredInput(1); object maxOutputBoxesPerClass = 0f; object iouThreshold = 0f; object scoreThreshold = 0f; if (node.InputCount > 4 && node.IsInput2Const && node.IsInput3Const && node.IsInput4Const || node.InputCount > 3 && node.IsInput2Const && node.IsInput3Const || node.InputCount > 2 && node.IsInput2Const) { // Use constant version (possibly with defaults) maxOutputBoxesPerClass = node.Input2ConstantOptional((float)maxOutputBoxesPerClass, "ONNX", nameof(maxOutputBoxesPerClass))[0]; iouThreshold = node.Input3ConstantOptional((float)iouThreshold, "ONNX", nameof(iouThreshold))[0]; scoreThreshold = node.Input4ConstantOptional((float)scoreThreshold, "ONNX", nameof(scoreThreshold))[0]; } else { // Use dynamic tensor version maxOutputBoxesPerClass = node.Input2Optional; iouThreshold = node.Input3Optional; scoreThreshold = node.Input4Optional; } // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is net.NonMaxSuppression(node.Name, boxes, scores, maxOutputBoxesPerClass, iouThreshold, scoreThreshold, centerPointBox); Output(node, rank: 2); }); Add("OneHot", (net, node) => { node.UnsupportedAttribute("axis", -1); var defaultOffOn = new Tensor(2, 0, new float[] {0, 1}); var depth = (int)node.Input1Constant(onnxLayout:"C", name:"depth")[0]; var offon = node.Input2ConstantOptional(defaultOffOn, onnxLayout:"C", name:"values"); net.OneHot(node.Name, node.Input0, depth, (int)offon[1], (int)offon[0]); Output(node, features:depth, rank: node.Input0Rank + 1); }); Add("RoiAlign", (net, node) => { node.UnsupportedAttribute("mode"); // TODO support int output_height = node.GetOptionalInt("output_height", 1); int output_width = node.GetOptionalInt("output_width", 1); int sampling_ratio = node.GetOptionalInt("sampling_ratio", 0); float spatial_scale = node.GetOptionalFloat("spatial_scale", 1.0f); net.RoiAlign(node.Name, node.Input0, node.Input1, node.Input2, output_height, output_width, sampling_ratio, spatial_scale); }); Add("TopK", (net, node) => { int axis = node.AxisOptional(-1); // TopK-11 introduced these options bool largest = node.GetOptionalInt("largest", 1) == 1; // If sorted = false, then the output is undefined bool sorted = node.GetOptionalInt("sorted", 1) == 1; string kName; if (node.InputCount > 1) // TopK-10 introduced K as an input tensor { kName = node.Input1; } else { // TopK-1 int k = node.GetRequiredInt("k"); kName = "Const_TopK"; var kTensor = new ONNXTensor( data:new Tensor(new[] { 1, 1, 1, 1 }, new[] { (float)k }, kName), onnxShape:new [] { 1 }); Const(node, kTensor); } Layer indices = net.TopKIndices(node.Outputs[1], node.Input0, kName, axis, largest, sorted); Output(node.Outputs[1], rank: node.Input0Rank); net.TopKValues(node.Outputs[0], node.Input0, indices, axis); Output(node.Outputs[0], rank: node.Input0Rank); }); Add("NonZero", (net, node) => { if (node.IsInput0Const) { var nonZeroTensor = constantTensors[node.Input0].NonZero(); Const(node, nonZeroTensor); } else { net.NonZero(node.Name, node.Input0); Output(node.Outputs[0], rank: 2); } }); Add("LSTM", (net, node) => { node.UnsupportedAttribute("activation_alpha"); node.UnsupportedAttribute("activation_beta"); node.UnsupportedAttribute("activations", new[] { "Sigmoid", "Tanh", "Tanh" }); // Only Sigmoid is supported for now node.UnsupportedAttribute("clip"); node.UnsupportedAttribute("direction", "forward"); // Only forward direction supported node.UnsupportedAttribute("input_forget"); node.UnsupportedAttribute("layout"); // alternate layout not supported int hiddenSize = node.GetRequiredInt("hidden_size"); string[] nodeInputs = node.Inputs; int inputCount = nodeInputs.Length; object W = node.Input1; if (node.IsInput1Const) W = node.Input1Constant(onnxLayout: "RKC", name: "W"); object R = node.Input2; if (node.IsInput2Const) R = node.Input2Constant(onnxLayout: "RKC", name: "R"); object B = node.Input3Optional; if (inputCount > 3 && node.IsInput3Const) { B = node.Input3Constant(onnxLayout: "RC", name: "B"); } else if (string.IsNullOrEmpty((string)B)) { var tensor = new Tensor(new TensorShape(1, 8 * hiddenSize)); tensor.Fill(0); B = net.Const($"Const_{node.Name}_B", tensor, rank: 2); } int outputCount = node.Outputs.Length; string[] outputs = { node.Outputs[0], outputCount > 1 ? node.Outputs[1] : null, outputCount > 2 ? node.Outputs[2] : null }; string initialHidden = inputCount > 5 && !string.IsNullOrEmpty(nodeInputs[5]) ? node.Input5Optional : null; string initialCell = inputCount > 6 && !string.IsNullOrEmpty(nodeInputs[6]) ? node.Input6Optional : null; net.LSTM(node.Name, node.Input0, outputs, W, R, B, hiddenSize, initialHidden, initialCell); Output(node.Outputs[0], rank:2); // Actually rank 4, but needs to be 2 for how we handle this layer (re-evaluate?) if (outputCount > 1) Output(node.Outputs[1], rank:2); // Actually rank 3, but needs to be 2 for how we handle this layer (re-evaluate?) if (outputCount > 2) Output(node.Outputs[2], rank:2); // Actually rank 3, but needs to be 2 for how we handle this layer (re-evaluate?) }); // Activation ops Add("Relu", (net, node) => { net.Relu(node.Name, node.Input0); }); Add("Softmax", (net, node) => { const int defaultAxis = 1; int axis = node.AxisOptional(defaultAxis); net.Softmax(node.Name, node.Input0, axis, axisIs8D: true); // keep axis as is }); Add("Tanh", (net, node) => { net.Tanh(node.Name, node.Input0); }); Add("Sqrt", (net, node) => { net.Sqrt(node.Name, node.Input0); }); Add("Sigmoid", (net, node) => { net.Sigmoid(node.Name, node.Input0); }); Add("Elu", (net, node) => { net.Elu(node.Name, node.Input0, node.AlphaOptional(1f)); }); Add("LeakyRelu",(net, node) => { net.LeakyRelu(node.Name, node.Input0, node.AlphaOptional(0.01f)); }); Add("Selu", (net, node) => { net.Selu(node.Name, node.Input0, node.AlphaOptional(1.67326f), node.GammaOptional(1.0507f)); }); Add("Swish", (net, node) => { net.Swish(node.Name, node.Input0); }); Add("PRelu", (net, node) => { net.PRelu(node.Name, node.Input0, node.Input1); }); Add("LogSoftmax", (net, node) => { const int defaultAxis = 1; int axis = node.AxisOptional(defaultAxis); net.LogSoftmax(node.Name, node.Input0, axis, axisIs8D: true); // keep axis as is }); // TODO: Add("Hardmax", (net, node) => { net.Hardmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); }); Add("Softplus", (net, node) => { net.Softplus(node.Name, node.Input0); }); // TODO: Add("Softsign", (net, node) => { net.Softsign(node.Name, node.Input0); }); Add("HardSigmoid", (net, node) => { net.HardSigmoid(node.Name, node.Input0, node.AlphaOptional(0.2f), node.BetaOptional(0.5f)); }); Add("Exp", (net, node) => { net.Exp(node.Name, node.Input0); }); Add("Log", (net, node) => { net.Log(node.Name, node.Input0); }); Add("Reciprocal", (net, node) => { net.Reciprocal(node.Name, node.Input0); }); Add("Abs", (net, node) => { net.Abs(node.Name, node.Input0); }); Add("Neg", (net, node) => { net.Neg(node.Name, node.Input0); }); Add("Ceil", (net, node) => { net.Ceil(node.Name, node.Input0); }); Add("Floor", (net, node) => { net.Floor(node.Name, node.Input0); }); Add("Round", (net, node) => { net.Round(node.Name, node.Input0); }); Add("Clip", (net, node) => { float minValue = float.MinValue; float maxValue = float.MaxValue; if (node.InputCount > 1) // Clip-11 { minValue = node.Input1ConstantOptional(minValue, onnxLayout:"C", name:"min")[0]; maxValue = node.Input2ConstantOptional(maxValue, onnxLayout:"C", name:"max")[0]; } else { minValue = node.MinOptional(minValue); maxValue = node.MaxOptional(maxValue); } net.Clip(node.Name, node.Input0, minValue, maxValue); }); Add("Acos", (net, node) => { net.Acos(node.Name, node.Input0); }); Add("Acosh", (net, node) => { net.Acosh(node.Name, node.Input0); }); Add("Asin", (net, node) => { net.Asin(node.Name, node.Input0); }); Add("Asinh", (net, node) => { net.Asinh(node.Name, node.Input0); }); Add("Atan", (net, node) => { net.Atan(node.Name, node.Input0); }); Add("Atanh", (net, node) => { net.Atanh(node.Name, node.Input0); }); Add("Cos", (net, node) => { net.Cos(node.Name, node.Input0); }); Add("Cosh", (net, node) => { net.Cosh(node.Name, node.Input0); }); Add("Sin", (net, node) => { net.Sin(node.Name, node.Input0); }); Add("Sinh", (net, node) => { net.Sinh(node.Name, node.Input0); }); Add("Tan", (net, node) => { net.Tan(node.Name, node.Input0); }); Add("Erf", (net, node) => { net.Erf(node.Name, node.Input0); }); string[] GetArithmeticOpInputs(ONNXNodeWrapper node, ModelBuilder net) { string[] inputs = new string[node.Inputs.Length]; Array.Copy(node.Inputs, inputs, inputs.Length); if (node.IsInput1Const) { string onnxLayout = "ONNX"; string constName = $"Const_{node.Input1}"; if (!constantTensors.ContainsKey(constName)) { Tensor tensorData = node.Input1Constant(onnxLayout, node.Input1); Layer layer = net.Const(constName, tensorData, rank: node.Input1Rank); inputs[1] = layer.name; Const(constName, new ONNXTensor(tensorData, tensorData.shape.ToArray())); } } return inputs; } // Broadcast ops Add("Add", (net, node) => { net.Add(node.Name, GetArithmeticOpInputs(node, net)); }); Add("Sum", (net, node) => { net.Add(node.Name, GetArithmeticOpInputs(node, net)); }); // Sum is implemented via Add Add("Sub", (net, node) => { net.Sub(node.Name, GetArithmeticOpInputs(node, net)); }); Add("Mul", (net, node) => { net.Mul(node.Name, GetArithmeticOpInputs(node, net)); }); Add("Div", (net, node) => { net.Div(node.Name, GetArithmeticOpInputs(node, net)); }); Add("Pow", (net, node) => { net.Pow(node.Name, node.Inputs); }); Add("Min", (net, node) => { net.Min(node.Name, node.Inputs); }); Add("Max", (net, node) => { net.Max(node.Name, node.Inputs); }); Add("Mean", (net, node) => { net.Mean(node.Name, node.Inputs); }); // Logical ops Add("Greater", (net, node) => { net.Greater(node.Name, node.Input0, node.Input1); }); Add("Less", (net, node) => { net.Less(node.Name, node.Input0, node.Input1); }); Add("LessOrEqual", (net, node) => { net.LessEqual(node.Name, node.Input0, node.Input1); }); Add("Equal", (net, node) => { net.Equal(node.Name, node.Input0, node.Input1); }); Add("Or", (net, node) => { net.LogicalOr(node.Name, node.Input0, node.Input1); }); Add("And", (net, node) => { net.LogicalAnd(node.Name, node.Input0, node.Input1); }); Add("Not", (net, node) => { net.LogicalNot(node.Name, node.Input0); }); Add("Sign", (net, node) => { net.Sign(node.Name, node.Input0); }); Add("Xor", (net, node) => { net.LogicalXor(node.Name, node.Input0, node.Input1); }); Add("Where", (net, node) => { net.Where(node.Name, node.Input0, node.Input1, node.Input2); }); // Padding ops Add("MirrorPad", (net, node) => { //Note: MirrorPad is not in onnx spec, it is a custom op from tensorflow implementing there own padding (aka symmetric). node.UnsupportedAttribute("mode", "symmetric"); var value = node.GetOptionalFloat("value", 0.0f); var autoPad = node.AutoPadMode(); // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is if (node.InputCount == 1) { var pads = node.GetRequiredIntArray("pads"); net.Pad(node.Name, node.Input0, pads, value, Layer.PadMode.Symetric, Layer.AutoPad.NotSet); } else net.Pad(node.Name, node.Input0, node.Input1, node.Input2Optional, Layer.PadMode.Symetric, Layer.AutoPad.NotSet); }); Add("Pad", (net, node) => { var value = node.GetOptionalFloat("value", 0.0f); var modeType = node.PadMode(); var autoPadType = node.AutoPadMode(); // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is if (node.InputCount == 1) { var pads = node.GetRequiredIntArray("pads"); net.Pad(node.Name, node.Input0, pads, value, modeType, autoPadType); } else net.Pad(node.Name, node.Input0, node.Input1, node.Input2Optional, modeType, autoPadType); }); // Pooling ops Add("AveragePool", (net, node) => { node.UnsupportedAttribute("ceil_mode", 0); node.UnsupportedAttribute("count_include_pad", 0); net.AvgPool2D(node.Name, node.Input0, node.KernelShape, node.Strides, node.Pads); }); Add("MaxPool", (net, node) => { node.UnsupportedAttribute("ceil_mode", 0); node.UnsupportedAttribute("dilations", new[] {1, 1}); node.UnsupportedAttribute("storage_order", 0); int[] strides = node.Strides; int[] pads = node.Pads; if (strides.Length == 1) strides = new[] { 1, strides[0] }; UnityEngine.Debug.Assert(strides.Length == 2); int[] kernelShape = node.KernelShape; if (kernelShape.Length == 1) kernelShape = new[] { kernelShape[0], 1 }; net.MaxPool2D(node.Name, node.Input0, kernelShape, strides, pads); }); Add("GlobalAveragePool", (net, node) => { // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is net.GlobalAvgPool2D(node.Name, node.Input0); }); Add("GlobalMaxPool", (net, node) => { // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is net.GlobalMaxPool2D(node.Name, node.Input0); }); Add("Upsample", (net, node) => { UpsampleNCHW(net, node, 1); }); Add("Resize", (net, node) => { var mode = node.ModeOptional("nearest"); var bilinear = IsModeBilinear(net, node, mode); if (node.InputCount > 2) // Resize-11/13 { node.UnsupportedAttribute("coordinate_transformation_mode", "half_pixel"); node.UnsupportedAttribute("cubic_coeff_a", -0.75f); node.UnsupportedAttribute("exclude_outside", 0); node.UnsupportedAttribute("extrapolation_value", 0f); node.UnsupportedAttribute("nearest_mode", "round_prefer_floor"); // Inputs (3 - 4) // X : T1 // roi : T2, It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize" // scales : tensor(float) // sizes (optional) : tensor(int64) // TODO: cropping via roi input } // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default and size as constants, so this is non-runnable as-is if (node.InputCount == 4) { //Resize-11/13 using target size net.Resample2D(node.Name, node.Input0, node.Input3, bilinear); } else { //Resize using scales UpsampleNCHW(net, node, node.InputCount-1); } }); Add("Transpose", (net, node) => { // From https://github.com/onnx/onnx/blob/master/docs/Operators.md#transpose // By default, reverse the dimensions, otherwise permute the axes according to the values given. if (node.IsInput0Const) { int inputTensorRank = constantTensors[node.Input0].rank; var defaultPermutations = new int[inputTensorRank]; for (int i = 0; i < inputTensorRank; ++i) defaultPermutations[i] = inputTensorRank - 1 - i; var permutations = node.GetOptionalIntArray("perm", defaultPermutations); var transposedTensor = constantTensors[node.Input0].Permute(permutations); Const(node, transposedTensor); } else { var defaultPermutations = new[] { 0, 1, 2, 3, 4, 5 }; var permutations = node.GetOptionalIntArray("perm", defaultPermutations); if (permutations.Length > 6) throw new OnnxLayerImportException($"Transpose support up to 6 dimensions but got a permutations of rank {permutations}."); net.Transpose(node.Name, node.Input0, permutations); } }); Add("DepthToSpace", (net, node) => { net.DepthToSpace(node.Name, node.Input0, node.BlockSize, node.ModeOptional("DCR")); }); Add("SpaceToDepth", (net, node) => { net.SpaceToDepth(node.Name, node.Input0, node.BlockSize); }); // Tensor ops Add("Gemm", (net, node) => { node.UnsupportedAttribute("alpha", 1.0f); node.UnsupportedAttribute("beta", 1.0f); if (node.IsInput1Const && node.IsInput2Const) { var weights = node.Input1Constant(node.TransBOptional() ? "KC" : "CK", name: "B"); var biases = node.Input2ConstantOptional(Bias(weights.shape), 0.0f, "C", name: "C"); var input0 = node.Input0; int transposeA = node.GetOptionalInt("transA", 0); if (transposeA == 1) { input0 = input0 + "_transpose"; net.Transpose(input0, node.Input0, new[] { 1, 0 }); } net.Dense(node.Name, input0, weights, biases); Output(node, features: weights.channels, rank: 2); // Gemm forces flatten of the input to rank 2 } else { int transposeA = node.GetOptionalInt("transA", 0); int transposeB = node.GetOptionalInt("transB", 0); var input0 = node.Input0; var input1 = node.Input1; if (transposeA == 1) { input0 = input0 + "_transpose"; net.Transpose(input0, node.Input0, new[] { 1, 0 }); } if (transposeB == 1) { input1 = input1 + "_transpose"; net.Transpose(input1, node.Input1, new[] { 1, 0 }); } net.MatMul(node.Name, input0, input1); if (node.InputCount == 3) { net.Add(node.Name + "_bias", new[] { node.Name, node.Input2 }); } } }); Add("MatMul", (net, node) => { net.MatMul(node.Name, node.Input0, node.Input1); Output(node, features: node.Input0Features, rank: Math.Max(node.Input0Rank, node.Input1Rank)); }); Add("Conv", (net, node) => { int[] dilationsDHW = new[] { 1, 1, 1 }; // @TODO trap on wrong values int[] strides = node.Strides; int[] pads = node.Pads; node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead"); // Ideally, we'd import kernels/biases in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW var kernels = node.Input1Constant(onnxLayout: "KCHW", name: "W"); var kernelRank = node.Input1Rank; if (kernelRank == 3) // Conv1D { dilationsDHW = node.DilatationsOptional(new[] { 1 }); // @TODO trap on wrong values UnityEngine.Debug.Assert(dilationsDHW.Length == 1); dilationsDHW = new[] { 1, 1, dilationsDHW[0] }; if (strides.Length == 1) strides = new[] { strides[0], 1 }; if (pads.Length == 2) pads = new[] { pads[0], 0, pads[1], 0 }; } else if (kernelRank == 4) // Conv2D { dilationsDHW = node.DilatationsOptional(new[] { 1, 1 }); UnityEngine.Debug.Assert(dilationsDHW.Length == 2); dilationsDHW = new[] { 1, dilationsDHW[0], dilationsDHW[1] }; } else if (kernelRank == 5) // Conv3D { //TODO specific error message for DepthwiseConv3D (or support it). node.UnsupportedAttribute("group", 1); dilationsDHW = node.DilatationsOptional(new[] { 1, 1, 1 }); UnityEngine.Debug.Assert(dilationsDHW.Length == 3); pads = node.Pads3D; strides = node.Strides3D; } else { Warn(net, node, $"Unsuported Conv kernel rank. Conv1D/2D/3 assumes rank 3/4/5 respectively, but got {kernelRank}."); } UnityEngine.Debug.Assert(dilationsDHW.Length == 3); if (dilationsDHW[0] != 1 || dilationsDHW[1] != 1 || dilationsDHW[2] != 1) kernels = DilateKernel(kernels, dilationsDHW); // @TODO inefficient method. Support dilatation in kernel code properly var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout: "C", name: "B"); // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is // TODO assert correctly: with group == 2 or group != in_channel we don't support it if (node.GroupOptional() > 1) net.DepthwiseConv2D(node.Name, node.Input0, strides, pads, kernels, biases); else { if (kernelRank < 5) net.Conv2D(node.Name, node.Input0, strides, pads, kernels, biases); else net.Conv3D(node.Name, node.Input0, strides, pads, kernels, biases); } Output(node, features: kernels.channels); }); Add("ConvTranspose", (net, node) => { node.UnsupportedAttribute("group", 1); node.UnsupportedAttribute("output_shape", new int[0]); node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead"); int[] strides = node.Strides; int[] pads = node.Pads; int[] outputPadding = node.OutputPadding; var kernelRank = node.Input1Rank; if (kernelRank == 3) // ConvTranspose1D { node.UnsupportedAttribute("dilations", new[] {1}); if (strides.Length == 1) strides = new[] { strides[0], 1 }; if (pads.Length == 2) pads = new[] { pads[0], 0, pads[1], 0 }; if (outputPadding.Length == 1) outputPadding = new[] { outputPadding[0], 0 }; } else if (kernelRank == 4)// ConvTranspose2D { node.UnsupportedAttribute("dilations", new[] {1, 1}); } else { Warn(net, node, $"Unsuported ConvTranspose kernel rank. ConvTranspose1D/2D assumes rank 3/4 respectively, but got {kernelRank}."); } // Ideally, we'd import kernels/biases in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW var kernels = node.Input1Constant(onnxLayout:"CKHW", name:"W"); var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout:"C", name:"B"); // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is net.Conv2DTrans(node.Name, node.Input0, strides, pads, outputPadding, kernels, biases); Output(node, features:kernels.channels); }); Add("BatchNormalization", (net, node) => { // Ideally, we'd import variances/scales/biases/means in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW var variance = node.Input4Constant(onnxLayout:"C", name:"var"); var scale = node.Input1ConstantOptional(variance.shape, 1.0f, onnxLayout:"C", name:"scale"); var bias = node.Input2ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"B"); var mean = node.Input3ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"mean"); if (variance.length != scale.length || scale.length != bias.length || bias.length != mean.length) Warn(net, node, $"Number of elements in all parameters for BatchNorm must be the same." + $"Parameter shapes are: {scale.shape}, {bias.shape}, {mean.shape}, {variance.shape}"); // TODO: Jeremy has one non valid onnx model with #channels > than input_channels, see if we want to support his model? var fusedData = FuseBatchNormWeights(scale, bias, mean, variance, node.EpsilonOptional(), variance.shape.channels); // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is net.ScaleBias(node.Name, node.Input0, fusedData.Item1, fusedData.Item2); }); Add("ImageScaler", (net, node) => { var attrBias = node.Bias; var attrScale = node.ScaleOptional(); int maxElements = attrBias.Length; Tensor scale = new Tensor(1, maxElements); Tensor bias = new Tensor(1, maxElements); for (int i = 0; i < maxElements; ++i) { scale[i] = attrScale; bias[i] = attrBias[i]; } net.ScaleBias(node.Name, node.Input0, scale, bias); }); Add("InstanceNormalization", (net, node) => { // Ideally, we'd import scales/biases in native ONNX layout, but we already have to transpose input since the op doesn't work natively in NCHW var scale = node.Input1Constant(onnxLayout:"C", name:"scale"); var bias = node.Input2ConstantOptional(scale.shape, 0.0f, onnxLayout:"C", name:"B"); if (scale.length != bias.length) Warn(net, node, $"Number of elements in all parameters for InstanceNorm must be the same." + $"Parameter shapes are: {scale.shape}, {bias.shape}"); if (scale.channels != node.Input0Features && node.Input0Features > 0) { Warn(net, node, $"Number of elements in InstanceNorm must match features from the previous layer. Was expecting {node.Input0Features}, but got {scale.channels}."); var scaleArray = scale.ToReadOnlyArray(); Array.Resize(ref scaleArray, node.Input0Features); var biasArray = bias.ToReadOnlyArray(); Array.Resize(ref biasArray, node.Input0Features); scale = new Tensor(1, node.Input0Features, scaleArray); bias = new Tensor(1, node.Input0Features, biasArray); } // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is net.Normalization(node.Name, node.Input0, scale, bias, node.EpsilonOptional()); }); Add("LRN", (net, node) => { float bias = node.GetOptionalFloat("bias", 1.0f); int size = node.GetRequiredInt("size"); net.LRN(node.Name, node.Input0, node.AlphaOptional(0.0001f), node.BetaOptional(0.75f), bias, size); }); // random ops Add("RandomNormal", (net, node) => { var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"ONNX"); net.RandomNormal(node.Name, shape, node.MeanOptional(), node.ScaleOptional(), node.Seed); Output(node, rank:node.Shape.Length); }); Add("RandomNormalLike", (net, node) => { net.RandomNormal(node.Name, node.Input0, node.MeanOptional(), node.ScaleOptional(), node.Seed); }); Add("RandomUniform", (net, node) => { float high = node.GetOptionalFloat("high", 1.0f); float low = node.GetOptionalFloat("low", 0.0f); var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"ONNX"); net.RandomUniform(node.Name, shape, low, high, node.Seed); Output(node, rank:node.Shape.Length); }); Add("RandomUniformLike", (net, node) => { float high = node.GetOptionalFloat("high", 1.0f); float low = node.GetOptionalFloat("low", 0.0f); net.RandomUniform(node.Name, node.Input0, low, high, node.Seed); }); Add("Multinomial", (net, node) => { int samples = node.GetOptionalInt("sample_size", 1); net.Multinomial(node.Name, node.Input0, samples, node.Seed); }); Add("Range", (net, node) => { if (node.IsInput0Const && node.IsInput1Const && node.IsInput2Const) { var startTensor = node.GetRequiredInputAsConstant(node.Input0, "N", "start"); var limitTensor = node.GetRequiredInputAsConstant(node.Input1, "N", "start"); var deltaTensor = node.GetRequiredInputAsConstant(node.Input2, "N", "start"); Assert.AreEqual(startTensor.length, 1); Assert.AreEqual(limitTensor.length, 1); Assert.AreEqual(deltaTensor.length, 1); float start = startTensor[0]; float limit = limitTensor[0]; float delta = deltaTensor[0]; var range = ONNXTensor.Range(start, limit, delta); Const(node, range); } else { net.Range(node.Name, node.Input0, node.Input1, node.Input2); } }); // Reduce ops Add("ReduceMax", (net, node) => { ReduceNCHW(net, node, Layer.Type.ReduceMax); }); Add("ReduceMean", (net, node) => { ReduceNCHW(net, node, Layer.Type.ReduceMean); }); Add("ReduceMin", (net, node) => { ReduceNCHW(net, node, Layer.Type.ReduceMin); }); Add("ReduceProd", (net, node) => { ReduceNCHW(net, node, Layer.Type.ReduceProd); }); Add("ReduceSum", (net, node) => { ReduceNCHW(net, node, Layer.Type.ReduceSum); }); Add("ArgMax", (net, node) => { node.UnsupportedAttribute("select_last_index"); ReduceNCHW(net, node, Layer.Type.ArgMax); }); Add("ArgMin", (net, node) => { node.UnsupportedAttribute("select_last_index"); ReduceNCHW(net, node, Layer.Type.ArgMin); }); // Ignore, noop during inference Add("Identity", (net, node) => { net.Identity(node.Name, node.Input0); }); Add("Cast", (net, node) => { net.Identity(node.Name, node.Input0); }); Add("Dropout", (net, node) => { net.Identity(node.Name, node.Input0); }); } void UseLegacyImporter() { m_NodeImporters.Clear(); var defaultZeroTensor = new ONNXTensor(new Tensor(1, 1, new[] { 0f }), new[] { 1 }); var defaultOneTensor = new ONNXTensor(new Tensor(1, 1, new[] { 1f }), new[] { 1 }); var toNCHW = new [] { 0, 3, 1, 2 }; var toNHWC = new [] { 0, 2, 3, 1 }; var fromN1WCtoNCH = new [] { 0, 3, 2, 1 }; var fromNCHtoN1WC = new [] { 0, 3, 2, 1 }; // TODO: setup m_NodeImporters via initializer list // TODO: simplify code to avoid passing node.Name over and over again Add("Constant", (net, node) => { node.UnsupportedAttribute("sparse_value"); Const(node, node.ValueAsTensor); }); Add("ConstantOfShape", (net, node) => { Assert.IsTrue(node.InputCount > 0); var valueTensor = node.GetOptionalTensor("value", defaultZeroTensor); var onnxShape = node.Input0ConstantONNXShape(name: "input"); var dataShape = ONNXLayout.ConvertShapeToBarracuda(onnxShape, onnxLayout:"?"); var tensorData = new Tensor(dataShape); tensorData.Fill(valueTensor[0]); var constantOfShape = new ONNXTensor(tensorData, onnxShape); Const(node, constantOfShape); }); Add("Reshape", (net, node) => { int[] onnxShape; if (node.InputCount > 1) // Reshape-5 { if (node.IsInput1Const) { onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsInts(); } else { int input0Rank = node.Input0Rank; if (input0Rank <= 4 && variableTensors.TryGetValue(node.Input0, out VariableTensor previousOutput) && previousOutput.layout != VariableTensor.Layout.ChannelsLast) { int outputRank = 4; Model.Input input1 = net.model.inputs.Where(i => i.name == node.Input1).FirstOrDefault(); if (!input1.Equals(default)) { if (input1.rank == 1) // shape is in the tensor outputRank = input1.shape[TensorShape.DataBatch]; } // For handling all reshapes correctly with dynamic shapes (e.g. rank 3) perform in NCHW layout Layer nchwTranspose = net.Transpose($"Transpose_{node.Input0}_For_{node.Name}", node.Input0, input0Rank == 3 ? fromN1WCtoNCH : toNCHW); Layer reshape = net.Reshape($"{node.Name}_NCHW", nchwTranspose, node.Input1); net.Transpose(node.Name, reshape, outputRank == 3 ? fromNCHtoN1WC : toNHWC); Output(node, rank:4); } else { net.Reshape(node.Name, node.Input0, node.Input1); } return; } } else // Reshape-1 onnxShape = node.Shape; if (node.IsInput0Const) { // reshape constant source tensor and store it as the new constant var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape); Const(node, reshapedTensor); } else { Layer reshapeLayer = null; int numDimensionContainingChannelsInformationAfterReshape = 1; var symbolicShape = ONNXLayout.ConvertReshapeToBarracuda(onnxShape, node.Input0Rank, out numDimensionContainingChannelsInformationAfterReshape); int variableDimension = Array.IndexOf(symbolicShape, -1); bool containsNoVariableDimensions = variableDimension == -1; // special case handling with inferable reshapes // TODO: remove this once we have full shape inference // onnx: NCW -> N1CW // N: is unknown and H,W are inferable if (node.Input0Rank == 3 && onnxShape.Length == 4 && onnxShape[0] == 0 && onnxShape[1] == 1 && onnxShape[2] == 0 && onnxShape[3] == 0) { // onnx: NCW -> N1CW // barracuda: N_WC -> NCW1 net.Transpose(node.Name, node.Input0, new[] { 0, 3, 2, 1 }); Output(node, features: 1, rank: onnxShape.Length); return; } if (containsNoVariableDimensions) { if (m_ForceArbitraryBatchSize) symbolicShape[0] = -1; // force arbitrary batch size // Creating any of the spatial dimensions requires to run reshape in NCHW and transpose to NHWC after it to match NCHW behavior. if (onnxShape.Length > 2 && node.Input0Rank <= 2) { int[] onnxPaddedShape = onnxShape; if (onnxShape.Length == 3) // correct NCH to NCW onnxPaddedShape = new[] {onnxShape[0], onnxShape[1], 1, onnxShape[2]}; reshapeLayer = net.Reshape($"{node.Name}_NCHW", node.Input0, onnxPaddedShape); reshapeLayer = net.Transpose(node.Name, reshapeLayer, toNHWC); } } else if (onnxShape.Length <= 4 && node.Input0Rank <= 4 && (onnxShape.Length == 2 || variableDimension != TensorShape.C) && variableTensors.TryGetValue(node.Input0, out VariableTensor previousOutput) && previousOutput.layout != VariableTensor.Layout.ChannelsLast) { // Collapsing any of the spatial dimensions requires a reshape in NCHW layout int[] onnxPaddedShape; if (onnxShape.Length == 3) // correct NCH to NCW onnxPaddedShape = new[] { onnxShape[0], onnxShape[1], 1, onnxShape[2] }; else onnxPaddedShape = onnxShape.Concat(Enumerable.Repeat(1, 4 - onnxShape.Length)).ToArray(); Layer nchwTranspose = net.Transpose($"Transpose_{node.Input0}_For_{node.Name}", node.Input0, toNCHW); reshapeLayer = net.Reshape($"{node.Name}_NCHW", nchwTranspose, onnxPaddedShape); reshapeLayer = net.Transpose(node.Name, reshapeLayer, toNHWC); } if (reshapeLayer == null) reshapeLayer = net.Reshape(node.Name, node.Input0, symbolicShape); reshapeLayer.axis = numDimensionContainingChannelsInformationAfterReshape; var features = onnxShape.Length > 1 ? onnxShape[1] : -1; Output(node, features: features, rank:onnxShape.Length); } }); Add("Expand", (net, node) => { var onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsInts(); var symbolicShape = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, "NCHW"); bool containsNoVariableDimensions = Array.IndexOf(symbolicShape, -1) == -1; if (containsNoVariableDimensions && m_ForceArbitraryBatchSize) symbolicShape[0] = -1; // force arbitrary batch size net.Expand(node.Name, node.Input0, symbolicShape); Output(node, rank:symbolicShape.Length); }); Add("Shape", (net, node) => { float[] shapeValuesAsFloats; if (node.IsInput0Const) { shapeValuesAsFloats = constantTensors[node.Input0].shape.Select(x => (float)x).ToArray(); } else { switch (node.Input0Rank) { default: case 4: // NCHW case 3: // NCW case 2: // NC // @TODO: dynamic implementation that would return real shape during execution of the model // // meanwhile at import time we assume 0 (taken from input tensor) for the spatial dimensions // NOTE: this assumption works for common Upsample opset=9 case: // Upsample.scales = (shape.hw * constant) / shape.hw // however this would not work for potential (opset=10) cases like: // Resize.size = shape.hw + constant // stored in ONNX layout var shapeWithChannelsFirst = new[] { 0f, node.Input0Features }; // NC var fillSpatialDimensionsWithUnknown = 0f; var numberOfSpatialDimensions = node.Input0Rank - 2; var shapeFollowedWithSpatialDimensions = Enumerable.Repeat(fillSpatialDimensionsWithUnknown, numberOfSpatialDimensions); shapeValuesAsFloats = shapeWithChannelsFirst.Concat(shapeFollowedWithSpatialDimensions).ToArray(); break; case 1: // C shapeValuesAsFloats = new[] {(float)node.Input0Features}; break; case 0: // scalar shapeValuesAsFloats = new[] {0f}; break; } } var shapeLength = shapeValuesAsFloats.Length; Assert.IsTrue(shapeLength == node.Input0Rank); var shape = new int[8]; shape[0] = shapeLength; var shapeTensor = new ONNXTensor( // NOTE: stored in single rank ONNX layout // with data in the 1st dimension // thus `shapeLength` specifies the length of the 1st dimension data:new Tensor(shape, shapeValuesAsFloats), onnxShape:new [] { shapeLength }); Const(node, shapeTensor); Output(node, features:shapeLength, productOfShape:node.Input0); }); Add("Unsqueeze", (net, node) => { if (node.IsInput0Const) { var unsqueezed = constantTensors[node.Input0].Unsqueeze(node.Axes); Const(node, unsqueezed); } else { // NOTE: axis=0 or 1 will require Transpose between channels and other spatial dimensions when converted to Barracuda layout. // As we have different layouts between ONNX and Barracuda, Unsqueeze might require actual Transpose not just Reshape! var features = node.Input0Features; var inputRank = node.Input0Rank; var outputRank = inputRank + 1; Output(node.Name, features: features, rank: outputRank); // ONNX pseudocode here: // a = Tensor [2, 10] # NC -> barracuda N11C // b = Unsqueeze(a, axis=0) // # b is now Tensor [1, 2, 10] # NCHW -> barrada NHWC // Because ONNX is NCHW, but generally hell knows what goes where and Barracuda is strict NHWC. We end up with: // `a` would be [2, 1, 1, 10], but `b` would have to be [1, 10, 1, 2]. Note the actual data swap in channels! int axis = node.Axes[0]; if (axis < 0) axis = node.Input0Rank+1 - axis; var transpose = ONNXLayout.UnSqueezeAxisPermutationForMappingONNXLayoutToBarracuda(inputRank, axis, "NCHW"); net.Transpose(node.Name, node.Input0, transpose); } }); Add("Squeeze", (net, node) => { if (node.IsInput0Const) { var squeezed = constantTensors[node.Input0].Squeeze(node.Axes); Const(node, squeezed); } else { var features = node.Input0Features; var inputRank = node.Input0Rank; var outputRank = inputRank - 1; Output(node.Name, features: features, rank: outputRank); // See Unsqueeze above for explanation int axis = node.Axes[0]; if (axis < 0) axis = node.Input0Rank + 1 - axis; var transpose = ONNXLayout.SqueezeAxisPermutationForMappingONNXLayoutToBarracuda(inputRank, axis, "NCHW"); net.Transpose(node.Name, node.Input0, transpose); } }); Add("Flatten", (net, node) => { node.UnsupportedAttribute("axis", 1); if (variableTensors.TryGetValue(node.Input0, out var inputTensor) && inputTensor.layout == VariableTensor.Layout.ChannelsLast) net.Flatten(node.Name, node.Input0); else { Layer nchwTranspose = net.Transpose($"Transpose_{node.Input0}_For_{node.Name}", node.Input0, node.Input0Rank == 3 ? fromN1WCtoNCH : toNCHW); net.Flatten(node.Name, nchwTranspose); // No need to transpose back b/c final shape is always NC (rank 2) } Output(node, rank:2); }); Add("Concat", (net, node) => { int axis = node.AxisOptional(0); if (node.Inputs.Length == 1) net.Identity(node.Name, node.Input0); else { // TODO: write dedicated ONNXTensor.Concat() so that output shape is exact to ONNX // if (node.AreAllInputsConst) // Const(node, ONNXTensor.Concat(node.Inputs.Select(i => constantTensors[i]).ToArray(), axis)); axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank: node.Input0Rank, onnxLayout: "NCHW"); net.Concat(node.Name, node.Inputs, axis, true); bool lastAxis = (axis == -1 || axis == TensorShape.C || axis == node.Input0Rank - 1); // last axis in Barracuda is feature axis if (lastAxis) { var featuresConcatenated = node.Inputs.Sum(i => variableTensors[i].features); Output(node, features: featuresConcatenated); } } }); Add("Split", (net, node) => { int axis = node.AxisOptional(0); int[] splits; try { splits = node.GetRequiredIntArray("split"); } catch (OnnxLayerImportException) { throw new OnnxLayerImportException($"Unsupported default attribute `split` for node {node.Name} of type Split. Value is required."); } Assert.IsTrue(splits.Length == node.Outputs.Length); axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank:node.Input0Rank, onnxLayout:"NCHW"); int currentSliceStartIndex = 0; //Convert `Split` into multiple `StridedSlice` operations. for (int i = 0; i < splits.Length; ++i) { var starts = new int[] {0, 0, 0, 0, 0, 0, 0, 0}; var ends = new int[] {0, 0, 0, 0, 0, 0, 0, 0}; var strides = new int[] {1, 1, 1, 1, 1, 1, 1, 1}; starts[axis] = currentSliceStartIndex; ends[axis] = starts[axis] + splits[i]; net.StridedSlice(node.Outputs[i], node.Input0,starts,ends,strides); currentSliceStartIndex += splits[i]; } }); Add("Slice", (net, node) => { int[] starts, ends, axes, steps; if (node.InputCount > 1) // Slice-10 { var constStarts = node.Input1Constant(onnxLayout:"C", name:"starts"); var constEnds = node.Input2Constant(onnxLayout:"C", name:"ends"); var defaultAxes = new Tensor(constStarts.shape, Enumerable.Range(0, constStarts.length).Select(v => (float)v).ToArray()); var constAxes = node.Input3ConstantOptional(defaultAxes, onnxLayout:"C", name:"axes"); var constSteps = node.Input4ConstantOptional(constStarts.shape, 1.0f, onnxLayout:"C", name:"steps"); starts = constStarts.AsInts(); ends = constEnds.AsInts(); axes = constAxes.AsInts(); steps = constSteps.AsInts(); } else // Slice-1 { starts = node.Starts; ends = node.Ends; axes = node.AxesOptional(Enumerable.Range(0, starts.Length).ToArray()); steps = Enumerable.Repeat(1, starts.Length).ToArray(); } Assert.IsTrue(starts.Length == ends.Length); var onnxRank = node.Input0Rank; var onnxStarts = Enumerable.Repeat(0, onnxRank).ToArray(); var onnxEnds = Enumerable.Repeat(int.MaxValue, onnxRank).ToArray(); // by default copy the whole axis till the end var onnxSteps = Enumerable.Repeat(1, onnxRank).ToArray(); // NOTE: begin=0, end=0, stride=1 <= full range from existing axis // begin=0, end=inf,stride=1 <= full range from existing axis // begin=0, end=X, stride=1 <= full range from existing axis, if X==last element on this axis // begin=0, end=0, stride=0 <= new axis OR shrink axis to single 1st element // begin=N, end=N, stride=0 <= shrink axis to single Nth element // These notes are copied from TensorExtensions.ApplyStridedSlice(...) for (int i = 0; i < axes.Length; ++i) { var axis = axes[i]; if (axis < 0) axis += onnxRank; axis = Math.Min(Math.Max(axis, 0), onnxRank); onnxStarts[axis] = starts[i]; onnxEnds[axis] = ends[i]; onnxSteps[axis] = steps[i]; } if (node.IsInput0Const) { var slicedTensor = constantTensors[node.Input0].Slice(starts:onnxStarts, ends:onnxEnds, steps:onnxSteps); Const(node, slicedTensor); } else { net.StridedSlice(node.Name, node.Input0, starts:ONNXLayout.PermuteToBarracuda(onnxStarts, onnxLayout:"NCHW",0), ends:ONNXLayout.PermuteToBarracuda(onnxEnds, onnxLayout:"NCHW",int.MaxValue), strides:ONNXLayout.PermuteToBarracuda(onnxSteps, onnxLayout:"NCHW",1)); } }); Add("Tile", (net, node) => { // TODO: Implement Tile in ONNXTensor for const var onnxRepeats = node.Input1Constant(onnxLayout: "C", name: "repeats").AsInts(); var repeats = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxRepeats, onnxLayout: "NCHW"); var features = node.Input0Features; features *= repeats[1]; Output(node.Name, rank: node.Input0Rank, features: features); // only 4D Tile support for now net.Tile(node.Name, node.Input0, new[] { repeats[2], repeats[5], repeats[6], repeats[7] }); }); Add("Gather", (net, node) => { int axis = node.AxisOptional(0); if (node.IsInput0Const && node.IsInput1Const) { var indices = node.Input1Constant(onnxLayout:"C", name:"indices").AsInts(); // If the previous node was a shape and we're gathering an inferred value, then don't treat the shape as a constant if (node.Input0.IndexOf("shape", StringComparison.OrdinalIgnoreCase) >= 0 && indices.Length == 1 && indices[0] > 0 && constantTensors[node.Input0].ToBarracuda("C")[indices[0]] == 0 // Must resolve at runtime && variableTensors.TryGetValue(node.Input0, out VariableTensor input0Tensor) && variableTensors.TryGetValue(input0Tensor.productOfShape, out VariableTensor shapeInputTensor)) { axis = ONNXLayout.ConvertAxisToBarracuda(indices[0], shapeInputTensor.rank, "NCHW"); net.Shape(node.Name, input0Tensor.productOfShape, axis); D.Log($"Re-writing {node.Name} to a Shape+Axis layer (results in a scalar)"); } else { ONNXTensor gatheredTensor = constantTensors[node.Input0].Gather(axis, indices); Const(node, gatheredTensor); } } else { int input1Rank = node.Input1Rank; if (node.IsInput1Const) { // The original rank was cached above since our constant tensor requires a shape of rank 1 and original may have been a scalar var indices = node.Input1Constant(onnxLayout: "C", name: "indices").AsFloats(); var constTensor = new ONNXTensor(new Tensor(new [] { indices.Length, 1, 1, 1, 1, 1, 1, 1 }, indices), new [] { indices.Length }); Const(node.Input1, constTensor); } axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank:node.Input0Rank, onnxLayout:"NCHW"); net.Gather(node.Name, node.Input0, node.Input1, axis, true); Output(node.Name, rank: input1Rank + node.Input0Rank - 1); } }); Add("NonMaxSuppression", (net, node) => { int centerPointBox = node.GetOptionalInt("center_point_box", 0); var boxes = node.GetRequiredInput(0); var scores = node.GetRequiredInput(1); object maxOutputBoxesPerClass = 0f; object iouThreshold = 0f; object scoreThreshold = 0f; if (node.InputCount > 4 && node.IsInput2Const && node.IsInput3Const && node.IsInput4Const || node.InputCount > 3 && node.IsInput2Const && node.IsInput3Const || node.InputCount > 2 && node.IsInput2Const) { // Use constant version (possibly with defaults) maxOutputBoxesPerClass = node.Input2ConstantOptional((float)maxOutputBoxesPerClass, "C", nameof(maxOutputBoxesPerClass))[0]; iouThreshold = node.Input3ConstantOptional((float)iouThreshold, "C", nameof(iouThreshold))[0]; scoreThreshold = node.Input4ConstantOptional((float)scoreThreshold, "C", nameof(scoreThreshold))[0]; } else { // Use dynamic tensor version maxOutputBoxesPerClass = node.Input2Optional; iouThreshold = node.Input3Optional; scoreThreshold = node.Input4Optional; } net.NonMaxSuppression(node.Name, boxes, scores, maxOutputBoxesPerClass, iouThreshold, scoreThreshold, centerPointBox); Output(node, rank: 2); }); Add("OneHot", (net, node) => { node.UnsupportedAttribute("axis", -1); var defaultOffOn = new Tensor(2, 0, new float[] {0, 1}); var depth = (int)node.Input1Constant(onnxLayout:"C", name:"depth")[0]; var offon = node.Input2ConstantOptional(defaultOffOn, onnxLayout:"C", name:"values"); net.OneHot(node.Name, node.Input0, depth, (int)offon[1], (int)offon[0]); Output(node, features: depth, rank: node.Input0Rank + 1); }); Add("TopK", (net, node) => { int axis = node.AxisOptional(-1); axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank:node.Input0Rank, onnxLayout:"NCHW"); // TopK-11 introduced these options bool largest = node.GetOptionalInt("largest", 1) == 1; // If sorted = false, then the output is undefined bool sorted = node.GetOptionalInt("sorted", 1) == 1; string kName; if (node.InputCount > 1) // TopK-10 introduced K as an input tensor { kName = node.Input1; } else { // TopK-1 int k = node.GetRequiredInt("k"); kName = "Const_TopK"; var kTensor = new ONNXTensor( data:new Tensor(new[] { 1, 1, 1, 1 }, new[] { (float)k }, kName), onnxShape:new [] { 1 }); Const(node, kTensor); } Layer indices = net.TopKIndices(node.Outputs[1], node.Input0, kName, axis, largest, sorted); Output(node.Outputs[1], rank: node.Input0Rank); net.TopKValues(node.Outputs[0], node.Input0, indices, axis); Output(node.Outputs[0], rank: node.Input0Rank); }); Add("NonZero", (net, node) => { if (node.IsInput0Const) { var nonZeroTensor = constantTensors[node.Input0].NonZero(); Const(node, nonZeroTensor); } else { net.NonZero(node.Name, node.Input0); Output(node.Outputs[0], rank: 2); } }); // LSTM // - it = f(Xt*Wi + Ht_1*Ri + Wbi + Rbi) // - ft = f(Xt*Wf + Ht_1*Rf + Wbf + Rbf) // - ct = g(Xt*Wc + Ht_1*Rc + Wbc + Rbc), c means j in our formula // - Ct = ft . Ct_ + it . ct // - ot = f(Xt*Wo + Ht_1*Ro + Wbo + Rbo) // - Ht = ot . h(Ct) Add("LSTM", (net, node) => { var W = node.Input1Constant(onnxLayout: "RKC", name: "W"); var R = node.Input2Constant(onnxLayout: "RKC", name: "R"); var B = node.Input3Constant(onnxLayout: "RC", name: "B"); // gate order [iofj] var ops = new ReferenceCPUOps(); var w_i = ops.StridedSlice(W, new[] {0,0,0,0}, new[] {W.batch,1,1,W.channels/4 }, new[] {1, 1, 1, 1}); var w_o = ops.StridedSlice(W, new[] {0,0,0,W.channels/4}, new[] {W.batch,1,1,2*W.channels/4 }, new[] {1, 1, 1, 1}); var w_f = ops.StridedSlice(W, new[] {0,0,0,2*W.channels/4}, new[] {W.batch,1,1,3*W.channels/4 }, new[] {1, 1, 1, 1}); var w_j = ops.StridedSlice(W, new[] {0,0,0,3*W.channels/4}, new[] {W.batch,1,1,4*W.channels/4 }, new[] {1, 1, 1, 1}); var r_i = ops.StridedSlice(R, new[] {0,0,0,0}, new[] {R.batch,1,1,R.channels/4 }, new[] {1, 1, 1, 1}); var r_o = ops.StridedSlice(R, new[] {0,0,0,R.channels/4}, new[] {R.batch,1,1,2*R.channels/4 }, new[] {1, 1, 1, 1}); var r_f = ops.StridedSlice(R, new[] {0,0,0,2*R.channels/4}, new[] {R.batch,1,1,3*R.channels/4 }, new[] {1, 1, 1, 1}); var r_j = ops.StridedSlice(R, new[] {0,0,0,3*R.channels/4}, new[] {R.batch,1,1,4*R.channels/4 }, new[] {1, 1, 1, 1}); var wb_i = ops.StridedSlice(B, new[] {0,0,0,0}, new[] {1,1,1,B.channels/8 }, new[] {1, 1, 1, 1}); var wb_o = ops.StridedSlice(B, new[] {0,0,0,B.channels/8}, new[] {1,1,1,2*B.channels/8 }, new[] {1, 1, 1, 1}); var wb_f = ops.StridedSlice(B, new[] {0,0,0,2*B.channels/8}, new[] {1,1,1,3*B.channels/8 }, new[] {1, 1, 1, 1}); var wb_j = ops.StridedSlice(B, new[] {0,0,0,3*B.channels/8}, new[] {1,1,1,4*B.channels/8 }, new[] {1, 1, 1, 1}); var rb_i = ops.StridedSlice(B, new[] {0,0,0,4*B.channels/8}, new[] {1,1,1,5*B.channels/8 }, new[] {1, 1, 1, 1}); var rb_o = ops.StridedSlice(B, new[] {0,0,0,5*B.channels/8}, new[] {1,1,1,6*B.channels/8 }, new[] {1, 1, 1, 1}); var rb_f = ops.StridedSlice(B, new[] {0,0,0,6*B.channels/8}, new[] {1,1,1,7*B.channels/8 }, new[] {1, 1, 1, 1}); var rb_j = ops.StridedSlice(B, new[] {0,0,0,7*B.channels/8}, new[] {1,1,1,8*B.channels/8 }, new[] {1, 1, 1, 1}); var memSize = r_i.flatHeight; var baseLSTMName = ResolveLstmInputName(node); var initial_h = $"{baseLSTMName}_h"; var initial_c = $"{baseLSTMName}_c"; var baseLSTMOutputName = ResolveLstmOutputName(node); var output_h = $"{baseLSTMOutputName}_h"; var output_c = $"{baseLSTMOutputName}_c"; var i_mad_w = net.Dense($"{node.Name}_bc_i_mad_w", node.Input0, w_i, wb_i); var i_mad_r = net.Dense($"{node.Name}_bc_i_mad_r", initial_h, r_i, rb_i); var i_mad = net.Add($"{node.Name}_bc_i_mad", new [] {i_mad_w, i_mad_r}); var j_mad_w = net.Dense($"{node.Name}_bc_j_mad_w", node.Input0, w_j, wb_j); var j_mad_r = net.Dense($"{node.Name}_bc_j_mad_r", initial_h, r_j, rb_j); var j_mad = net.Add($"{node.Name}_bc_j_mad", new [] {j_mad_w, j_mad_r}); var f_mad_w = net.Dense($"{node.Name}_bc_f_mad_w", node.Input0, w_f, wb_f); var f_mad_r = net.Dense($"{node.Name}_bc_f_mad_r", initial_h, r_f, rb_f); var f_mad = net.Add($"{node.Name}_bc_f_mad", new [] {f_mad_w, f_mad_r}); var o_mad_w = net.Dense($"{node.Name}_bc_o_mad_w", node.Input0, w_o, wb_o); var o_mad_r = net.Dense($"{node.Name}_bc_o_mad_r", initial_h, r_o, rb_o); var o_mad = net.Add($"{node.Name}_bc_o_mad", new [] {o_mad_w, o_mad_r}); var i = net.Sigmoid($"{node.Name}_bc_i_sigmoid", i_mad); var j = net.Tanh($"{node.Name}_bc_j_tanh", j_mad); var f = net.Sigmoid($"{node.Name}_bc_f_sigmoid", f_mad); var o = net.Sigmoid($"{node.Name}_bc_o_sigmoid", o_mad); var state_c_mul = net.Mul($"{node.Name}_bc_state_c_mul", new[] {initial_c, f.name}); var i_j_mul = net.Mul($"{node.Name}_bc_i_j_mul", new[] {i, j}); var state_c = net.Add(output_c, new[] {state_c_mul, i_j_mul}); var state_c_tanh = net.Tanh($"{node.Name}_bc_state_c_tanh", state_c); var state_h = net.Mul(output_h, new[] {o, state_c_tanh}); net.Identity(node.Outputs[0], state_h); net.Identity(node.Outputs[1], state_h); net.Identity(node.Outputs[2], state_c); net.Memory(initial_c, state_c, new TensorShape(-1,1,1,memSize)); net.Memory(initial_h, state_h, new TensorShape(-1,1,1,memSize)); Output(node.Outputs[0], features:wb_o.channels, rank:2); Output(node.Outputs[1], features:wb_o.channels, rank:2); Output(node.Outputs[2], features:wb_o.channels, rank:2); }); // Activation ops Add("Relu", (net, node) => { net.Relu(node.Name, node.Input0); }); Add("Softmax", (net, node) => { const int defaultAxis = 1; int axis = node.AxisOptional(defaultAxis); // Leave in NCHW form and transpose instead if (axis < 0) axis = node.Input0Rank + axis; string input = node.Input0; string output = node.Name; int rank = node.Input0Rank; if(rank == 2) { axis = axis == 0 ? 0 : 3; // NC => N__C } else if (rank == 3) { axis = axis == 0 ? 0 : (axis == 1 ? 3 : axis); // NCW => N_WC } else { axis = axis == 0 ? 0 : (axis == 1 ? 3 : axis-1); // NCHW => NHWC } Layer layer = net.Softmax(output, input, axis); }); Add("Tanh", (net, node) => { net.Tanh(node.Name, node.Input0); }); Add("Sqrt", (net, node) => { net.Sqrt(node.Name, node.Input0); }); Add("Sigmoid", (net, node) => { net.Sigmoid(node.Name, node.Input0); }); Add("Elu", (net, node) => { net.Elu(node.Name, node.Input0, node.AlphaOptional(1f)); }); Add("LeakyRelu",(net, node) => { net.LeakyRelu(node.Name, node.Input0, node.AlphaOptional(0.01f)); }); Add("Selu", (net, node) => { net.Selu(node.Name, node.Input0, node.AlphaOptional(1.67326f), node.GammaOptional(1.0507f)); }); Add("Swish", (net, node) => { net.Swish(node.Name, node.Input0); }); Add("PRelu", (net, node) => { net.PRelu(node.Name, node.Input0, node.Input1); }); Add("LogSoftmax", (net, node) => { net.LogSoftmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); }); // TODO: Add("Hardmax", (net, node) => { net.Hardmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); }); Add("Softplus", (net, node) => { net.Softplus(node.Name, node.Input0); }); // TODO: Add("Softsign", (net, node) => { net.Softsign(node.Name, node.Input0); }); // TODO: Add("HardSigmoid", (net, node) => { net.HardSigmoid(node.Name, node.Input0, node.AlphaOptional(0.2f), node.BetaOptional(0.5f)); }); Add("Exp", (net, node) => { net.Exp(node.Name, node.Input0); }); Add("Log", (net, node) => { net.Log(node.Name, node.Input0); }); Add("Reciprocal", (net, node) => { net.Reciprocal(node.Name, node.Input0); }); Add("Abs", (net, node) => { net.Abs(node.Name, node.Input0); }); Add("Neg", (net, node) => { net.Neg(node.Name, node.Input0); }); Add("Ceil", (net, node) => { net.Ceil(node.Name, node.Input0); }); Add("Floor", (net, node) => { net.Floor(node.Name, node.Input0); }); Add("Round", (net, node) => { net.Round(node.Name, node.Input0); }); Add("Clip", (net, node) => { float minValue = float.MinValue; float maxValue = float.MaxValue; if (node.InputCount > 1) // Clip-11 { minValue = node.Input1ConstantOptional(minValue, onnxLayout:"C", name:"min")[0]; maxValue = node.Input2ConstantOptional(maxValue, onnxLayout:"C", name:"max")[0]; } else { minValue = node.MinOptional(minValue); maxValue = node.MaxOptional(maxValue); } net.Clip(node.Name, node.Input0, minValue, maxValue); }); Add("Acos", (net, node) => { net.Acos(node.Name, node.Input0); }); Add("Acosh", (net, node) => { net.Acosh(node.Name, node.Input0); }); Add("Asin", (net, node) => { net.Asin(node.Name, node.Input0); }); Add("Asinh", (net, node) => { net.Asinh(node.Name, node.Input0); }); Add("Atan", (net, node) => { net.Atan(node.Name, node.Input0); }); Add("Atanh", (net, node) => { net.Atanh(node.Name, node.Input0); }); Add("Cos", (net, node) => { net.Cos(node.Name, node.Input0); }); Add("Cosh", (net, node) => { net.Cosh(node.Name, node.Input0); }); Add("Sin", (net, node) => { net.Sin(node.Name, node.Input0); }); Add("Sinh", (net, node) => { net.Sinh(node.Name, node.Input0); }); Add("Tan", (net, node) => { net.Tan(node.Name, node.Input0); }); string[] GetCorrectedConstants(ONNXNodeWrapper node, ModelBuilder net) { string[] inputs = new string[node.Inputs.Length]; Array.Copy(node.Inputs, inputs, inputs.Length); if (node.IsInput1Const) { string onnxLayout; switch (node.Input1Rank) { case 1: onnxLayout = "C"; break; default: onnxLayout = "NCHW"; break; } string constName = $"Const_{node.Input1}"; if (!constantTensors.ContainsKey(constName)) { Tensor tensorData = node.Input1Constant(onnxLayout, node.Input1); if(node.Input0Rank == 3 && node.Input1Rank == 1) { // 1,1,1,C -> 1,1,C,1 tensorData = tensorData.Reshape(new int[] { 1, 1, tensorData.channels, 1 }); } Layer layer = net.Const(constName, tensorData); inputs[1] = layer.name; Const(constName, new ONNXTensor(tensorData, tensorData.shape.ToArray())); } } return inputs; } // Broadcast ops Add("Add", (net, node) => { net.Add(node.Name, GetCorrectedConstants(node, net)); }); Add("Sum", (net, node) => { net.Add(node.Name, GetCorrectedConstants(node, net)); }); // Sum is implemented via Add Add("Sub", (net, node) => { net.Sub(node.Name, GetCorrectedConstants(node, net)); }); Add("Mul", (net, node) => { net.Mul(node.Name, GetCorrectedConstants(node, net)); }); Add("Div", (net, node) => { net.Div(node.Name, GetCorrectedConstants(node, net)); }); Add("Pow", (net, node) => { net.Pow(node.Name, node.Inputs); }); Add("Min", (net, node) => { net.Min(node.Name, node.Inputs); }); Add("Max", (net, node) => { net.Max(node.Name, node.Inputs); }); Add("Mean", (net, node) => { net.Mean(node.Name, node.Inputs); }); // Logical ops Add("Greater", (net, node) => { net.Greater(node.Name, node.Input0, node.Input1); }); Add("Less", (net, node) => { net.Less(node.Name, node.Input0, node.Input1); }); Add("LessOrEqual", (net, node) => { net.LessEqual(node.Name, node.Input0, node.Input1); }); Add("Equal", (net, node) => { net.Equal(node.Name, node.Input0, node.Input1); }); Add("Or", (net, node) => { net.LogicalOr(node.Name, node.Input0, node.Input1); }); Add("And", (net, node) => { net.LogicalAnd(node.Name, node.Input0, node.Input1); }); Add("Not", (net, node) => { net.LogicalNot(node.Name, node.Input0); }); Add("Xor", (net, node) => { net.LogicalXor(node.Name, node.Input0, node.Input1); }); Add("Where", (net, node) => { net.Where(node.Name, node.Input0, node.Input1, node.Input2); }); // Padding ops Add("Pad", (net, node) => { // TODO refactor pad handling to truncate only in NCHWToNHWCPass var mode = node.ModeOptional("constant"); var pads = node.Pads; switch (mode) { case "constant": var value = node.GetOptionalFloat("value", 0.0f); if (pads.Length > 4) net.Border3D(node.Name, node.Input0, pads, value); else net.Border2D(node.Name, node.Input0, pads, value); break; case "reflect": net.Pad2DReflect(node.Name, node.Input0, pads); break; case "edge": net.Pad2DEdge(node.Name, node.Input0, pads); break; } }); // Pooling ops Add("AveragePool", (net, node) => { node.UnsupportedAttribute("ceil_mode", 0); node.UnsupportedAttribute("count_include_pad", 0); net.AvgPool2D(node.Name, node.Input0, node.KernelShape, node.Strides, node.Pads); }); Add("MaxPool", (net, node) => { node.UnsupportedAttribute("ceil_mode", 0); node.UnsupportedAttribute("dilations", new[] {1, 1}); node.UnsupportedAttribute("storage_order", 0); int[] strides = node.Strides; int[] pads = node.Pads; if (strides.Length == 1) strides = new[] { 1, strides[0] }; Assert.IsTrue(strides.Length == 2); int[] kernenShape = node.KernelShape; if (kernenShape.Length == 1) kernenShape = new[] { kernenShape[0], 1 }; net.MaxPool2D(node.Name, node.Input0, kernenShape, strides, pads); }); Add("GlobalAveragePool", (net, node) => { net.GlobalAvgPool2D(node.Name, node.Input0); }); Add("GlobalMaxPool", (net, node) => { net.GlobalMaxPool2D(node.Name, node.Input0); }); Add("Upsample", (net, node) => { // @TODO: the same for Resize node string mode = node.ModeOptional("nearest"); if (node.InputCount == 2 && !node.IsInput1Const) if (node.Input0Rank <= 4) net.Upsample2D(node.Name, node.Input0, node.Input1, IsModeBilinear(net, node, mode)); else net.Upsample3D(node.Name, node.Input0, node.Input1, IsModeBilinear(net, node, mode)); else Resample(net, node, node.Name, node.Input0, node.Scales, mode); }); Add("Resize", (net, node) => { if (node.InputCount > 2) // Resize-11 { node.UnsupportedAttribute("coordinate_transformation_mode", "half_pixel"); node.UnsupportedAttribute("cubic_coeff_a", -0.75f); node.UnsupportedAttribute("exclude_outside", 0); node.UnsupportedAttribute("extrapolation_value", 0f); node.UnsupportedAttribute("nearest_mode", "round_prefer_floor"); // Inputs (3 - 4) // X : T1 // roi : T2, It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize" // scales : tensor(float) // sizes (optional) : tensor(int64) // TODO: cropping via roi input // TODO: support sizes } if (node.InputCount > 3) { var mode = node.ModeOptional("nearest"); var bilinear = IsModeBilinear(net, node, mode); net.Resample2D(node.Name, node.Input0, node.Sizes, bilinear); } else { Resample(net, node, node.Name, node.Input0, node.Scales, node.ModeOptional("nearest")); } }); Add("Transpose", (net, node) => { // From https://github.com/onnx/onnx/blob/master/docs/Operators.md#transpose // By default, reverse the dimensions, otherwise permute the axes according to the values given. if (node.IsInput0Const) { int inputTensorRank = constantTensors[node.Input0].rank; var defaultPermutations = new int[inputTensorRank]; for (int i = 0; i < inputTensorRank; ++i) defaultPermutations[i] = inputTensorRank - 1 - i; var permutations = node.GetOptionalIntArray("perm", defaultPermutations); var transposedTensor = constantTensors[node.Input0].Permute(permutations); Const(node, transposedTensor); } else { var defaultPermutations = new[] {5, 4, 3, 2, 1, 0}; var permutations = node.GetOptionalIntArray("perm", defaultPermutations); if (permutations.Length > 6) throw new OnnxLayerImportException($"Transpose support up to 6 dimensions but got a permutations of rank {permutations}."); if (Enumerable.SequenceEqual(permutations, new[] { 0, 3, 1, 2 }) || // NHWC -> NCHW Enumerable.SequenceEqual(permutations, new[] { 0, 2, 3, 1 })) // NCHW -> NHWC { // @TODO: reorder uptream nodes and global input dimensions accordingly from NHWC -> NCHW net.Identity(node.Name, node.Input0); if (permutations[1] == 3) // NHWC -> NCHW Output(node, layout: VariableTensor.Layout.ChannelsFirst); else if (permutations[1] == 2) // NCHW -> NHWC { Output(node, layout: VariableTensor.Layout.ChannelsLast); layerRequiringUpstreamPatch.Add(node.Name); } else Assert.IsTrue("Reached unexpected branch" == ""); } else if (Enumerable.SequenceEqual(permutations, new[] { 0, 2, 1 })) // NWC <-> NCW { // @TODO: reorder uptream nodes and global input dimensions accordingly from NHWC -> NCHW if (m_FixTf2OnnxExportIssues) { Warn(net, node, $"Use '--inputs-as-nchw' flag when exporting model from Tensorflow with tf2onnx"); net.Identity(node.Name, node.Input0); // flip layout if (node.Input0Layout == VariableTensor.Layout.ChannelsLast) Output(node, layout: VariableTensor.Layout.ChannelsFirst); else { Output(node, layout: VariableTensor.Layout.ChannelsLast); layerRequiringUpstreamPatch.Add(node.Name); } } else { int[] barracudaPermutation = { 0, 1, 3, 2 }; net.Transpose(node.Name, node.Input0, barracudaPermutation); } } else if (Enumerable.SequenceEqual(permutations, new[] { 1, 0, 2 })) // batch <-> seq_length { // LSTM layout is problematic as it's usually flanked by a few Transposed if exported from TF // @TODO investigate if better solution net.Identity(node.Name, node.Input0); } else { //Here we assume `Channels` are represented by only one dimensions and it that it is the 2nd one. //however in some case (example: shufflenet, sub-pixel-cnn) reshape-transpose-reshape pattern lead //to channels being represented by two dimenssion this is handled in //FixReshapeTransposePatternWhenChannelsAreSplitIntoMultipleDimensions() //Expand received permutation to 6D adding padding between Channels and other feature dimensions. int numDimensionDimensionsThatWerePaddedAtCenterOfTranspose = 0; var permutationsNCTDHW = ONNXLayout.ExpandONNXPermutationToNCTDHW(permutations, out numDimensionDimensionsThatWerePaddedAtCenterOfTranspose); //From channel first to channel last. var permutationsNTDHWC = ONNXLayout.ConvertPermutationToLayout(permutationsNCTDHW, "NCTDHW", "NTDHWC"); //6d to 8d int[] permuteSRNTDHWC = new int[TensorShape.MaxRank]; permuteSRNTDHWC[0] = 0; permuteSRNTDHWC[1] = 1; for (int i = 0; i < 6; ++i) permuteSRNTDHWC[i+2] = 2+permutationsNTDHWC[i]; var layer = net.Transpose(node.Name, node.Input0, permuteSRNTDHWC); layer.axis = numDimensionDimensionsThatWerePaddedAtCenterOfTranspose; } } }); Add("DepthToSpace", (net, node) => { net.DepthToSpace(node.Name, node.Input0, node.BlockSize, node.ModeOptional("DCR")); }); Add("SpaceToDepth", (net, node) => { net.SpaceToDepth(node.Name, node.Input0, node.BlockSize); }); // Tensor ops Add("Gemm", (net, node) => { node.UnsupportedAttribute("alpha", 1.0f); node.UnsupportedAttribute("beta", 1.0f); node.UnsupportedAttribute("transA", 0); var onnxLayout = node.TransBOptional() ? "KC" : "CK"; var weights = node.Input1Constant(onnxLayout, name:"B"); var biases = node.Input2ConstantOptional(Bias(weights.shape), 0.0f, onnxLayout:"C", name:"C"); // Change data layout from "channels first" to "channels last" weights = SwapSpatialDimensionsAndFeaturesInMatMulWeights(weights, weights.flatHeight, node.Input0Layout); net.Dense(node.Name, node.Input0, weights, biases); Output(node, features:weights.channels, rank:2); // Gemm forces flatten of the input to rank 2 }); Add("MatMul", (net, node) => { if (node.InputCount == 2 && !node.IsInput1Const || node.Input0Rank != 2 || node.Input1Rank != 2) { // if inputs are const, need to transpose them if(node.IsInput1Const) { var Y = constantTensors[node.Input1].ToBarracuda("NCTDHW"); net.Const(node.Input1, Y); } net.MatMul(node.Name, node.Input0, node.Input1); Output(node, features: node.Input0Features, rank: Math.Max(node.Input0Rank, node.Input1Rank)); } else { var weights = node.Input1Constant(onnxLayout: "CK", name: "B"); var biases = node.DefaultTensor(Bias(weights.shape), 0.0f); // Change data layout from "channels first" to "channels last" weights = SwapSpatialDimensionsAndFeaturesInMatMulWeights(weights, node.Input0Features, node.Input0Layout); net.Dense(node.Name, node.Input0, weights, biases); Output(node, features: weights.channels, rank: 2); // MatMul forces flatten of the input to rank 2 } }); Add("Conv", (net, node) => { int[] dilationsDHW = new[] { 1, 1, 1 }; // @TODO trap on wrong values int[] strides = node.Strides; int[] pads = node.Pads; node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead"); var kernels = node.Input1Constant(onnxLayout: "KCHW", name: "W"); var kernelRank = node.Input1Rank; if (kernelRank == 3) // Conv1D { dilationsDHW = node.DilatationsOptional(new[] { 1 }); // @TODO trap on wrong values Assert.IsTrue(dilationsDHW.Length == 1); dilationsDHW = new[] { 1, 1, dilationsDHW[0] }; if (strides.Length == 1) strides = new[] { strides[0], 1 }; if (pads.Length == 2) pads = new[] { pads[0], 0, pads[1], 0 }; } else if (kernelRank == 4) // Conv2D { dilationsDHW = node.DilatationsOptional(new[] { 1, 1 }); Assert.IsTrue(dilationsDHW.Length == 2); dilationsDHW = new[] { 1, dilationsDHW[0], dilationsDHW[1] }; } else if (kernelRank == 5) // Conv3D { //TODO specific error message for DepthwiseConv3D (or support it). node.UnsupportedAttribute("group", 1); dilationsDHW = node.DilatationsOptional(new[] { 1, 1, 1 }); Assert.IsTrue(dilationsDHW.Length == 3); pads = node.Pads3D; strides = node.Strides3D; } else { Warn(net, node, $"Unsuported Conv kernel rank. Conv1D/2D/3 assumes rank 3/4/5 respectively, but got {kernelRank}."); } Assert.IsTrue(dilationsDHW.Length == 3); if (dilationsDHW[0] != 1 || dilationsDHW[1] != 1 || dilationsDHW[2] != 1) kernels = DilateKernel(kernels, dilationsDHW); // @TODO inefficient method. Support dilatation in kernel code properly var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout: "C", name: "B"); if (node.GroupOptional() > 1) net.DepthwiseConv2D(node.Name, node.Input0, strides, pads, kernels, biases); else { if (kernelRank < 5) net.Conv2D(node.Name, node.Input0, strides, pads, kernels, biases); else net.Conv3D(node.Name, node.Input0, strides, pads, kernels, biases); } Output(node, features: kernels.channels); }); Add("ConvTranspose", (net, node) => { node.UnsupportedAttribute("dilations", new[] {1, 1}); node.UnsupportedAttribute("group", 1); node.UnsupportedAttribute("output_shape", new int[0]); node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead"); var kernels = node.Input1Constant(onnxLayout:"CKHW", name:"W"); var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout:"C", name:"B"); net.Conv2DTrans(node.Name, node.Input0, node.Strides, node.Pads, node.OutputPadding, kernels, biases); Output(node, features:kernels.channels); }); Add("BatchNormalization", (net, node) => { var variance = node.Input4Constant(onnxLayout:"C", name:"var"); var scale = node.Input1ConstantOptional(variance.shape, 1.0f, onnxLayout:"C", name:"scale"); var bias = node.Input2ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"B"); var mean = node.Input3ConstantOptional(variance.shape, 0.0f, onnxLayout:"C", name:"mean"); if (variance.length != scale.length || scale.length != bias.length || bias.length != mean.length) Warn(net, node, $"Number of elements in all parameters for BatchNorm must be the same." + $"Parameter shapes are: {scale.shape}, {bias.shape}, {mean.shape}, {variance.shape}"); if (variance.channels != node.Input0Features && node.Input0Features > 0) Warn(net, node, $"Number of elements in BatchNorm must match features from the previous layer. Was expecting {node.Input0Features}, but got {variance.channels}."); var fusedData = FuseBatchNormWeights(scale, bias, mean, variance, node.EpsilonOptional(), node.Input0Features); net.ScaleBias(node.Name, node.Input0, fusedData.Item1, fusedData.Item2); }); Add("ImageScaler", (net, node) => { var attrBias = node.Bias; var attrScale = node.ScaleOptional(); int maxElements = attrBias.Length; Tensor scale = new Tensor(1, maxElements); Tensor bias = new Tensor(1, maxElements); for (int i = 0; i < maxElements; ++i) { scale[i] = attrScale; bias[i] = attrBias[i]; } net.ScaleBias(node.Name, node.Input0, scale, bias); }); Add("InstanceNormalization", (net, node) => { var scale = node.Input1Constant(onnxLayout:"C", name:"scale"); var bias = node.Input2ConstantOptional(scale.shape, 0.0f, onnxLayout:"C", name:"B"); if (scale.length != bias.length) Warn(net, node, $"Number of elements in all parameters for InstanceNorm must be the same." + $"Parameter shapes are: {scale.shape}, {bias.shape}"); if (scale.channels != node.Input0Features && node.Input0Features > 0) { Warn(net, node, $"Number of elements in InstanceNorm must match features from the previous layer. Was expecting {node.Input0Features}, but got {scale.channels}."); var scaleArray = scale.ToReadOnlyArray(); Array.Resize(ref scaleArray, node.Input0Features); var biasArray = bias.ToReadOnlyArray(); Array.Resize(ref biasArray, node.Input0Features); scale = new Tensor(1, node.Input0Features, scaleArray); bias = new Tensor(1, node.Input0Features, biasArray); } net.Normalization(node.Name, node.Input0, scale, bias, node.EpsilonOptional()); }); Add("LRN", (net, node) => { float bias = node.GetOptionalFloat("bias", 1.0f); int size = node.GetRequiredInt("size"); net.LRN(node.Name, node.Input0, node.AlphaOptional(0.0001f), node.BetaOptional(0.75f), bias, size); }); // random ops Add("RandomNormal", (net, node) => { var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"NCHW"); net.RandomNormal(node.Name, shape, node.MeanOptional(), node.ScaleOptional(), node.Seed); Output(node, rank:node.Shape.Length); }); Add("RandomNormalLike", (net, node) => { net.RandomNormal(node.Name, node.Input0, node.MeanOptional(), node.ScaleOptional(), node.Seed); }); Add("RandomUniform", (net, node) => { float high = node.GetOptionalFloat("high", 1.0f); float low = node.GetOptionalFloat("low", 0.0f); var shape = ONNXLayout.ConvertShapeToBarracuda(onnxShape:node.Shape, onnxLayout:"NCHW"); net.RandomUniform(node.Name, shape, low, high, node.Seed); Output(node, rank:node.Shape.Length); }); Add("RandomUniformLike", (net, node) => { float high = node.GetOptionalFloat("high", 1.0f); float low = node.GetOptionalFloat("low", 0.0f); net.RandomUniform(node.Name, node.Input0, low, high, node.Seed); }); Add("Multinomial", (net, node) => { int samples = node.GetOptionalInt("sample_size", 1); net.Multinomial(node.Name, node.Input0, samples, node.Seed); }); // Reduce ops Add("ReduceMax", (net, node) => { Reduce(net, node, Layer.Type.ReduceMax); }); Add("ReduceMean", (net, node) => { Reduce(net, node, Layer.Type.ReduceMean); }); Add("ReduceMin", (net, node) => { Reduce(net, node, Layer.Type.ReduceMin); }); Add("ReduceProd", (net, node) => { Reduce(net, node, Layer.Type.ReduceProd); }); Add("ReduceSum", (net, node) => { Reduce(net, node, Layer.Type.ReduceSum); }); Add("ArgMax", (net, node) => { node.UnsupportedAttribute("select_last_index"); Reduce(net, node, Layer.Type.ArgMax); }); Add("ArgMin", (net, node) => { node.UnsupportedAttribute("select_last_index"); Reduce(net, node, Layer.Type.ArgMin); }); // Ignore, noop during inference Add("Identity", (net, node) => { net.Identity(node.Name, node.Input0); }); Add("Cast", (net, node) => { net.Identity(node.Name, node.Input0); }); Add("Dropout", (net, node) => { net.Identity(node.Name, node.Input0); }); } private string ResolveLstmOutputName(ONNXNodeWrapper node) { var baseLSTMOutputName = $"recurrent_out_{node.Name}"; if (lstmOutputs.ContainsKey(node.Name)) { var actualName = lstmOutputs[node.Name]; if (actualName.EndsWith(":0")) actualName = actualName.Substring(0, actualName.Length - 2); if (actualName.EndsWith("_h") || actualName.EndsWith("_c")) baseLSTMOutputName = actualName.Substring(0, actualName.Length - 2); else baseLSTMOutputName = actualName; } return baseLSTMOutputName; } private string ResolveLstmInputName(ONNXNodeWrapper node) { var baseLSTMName = $"recurrent_in_{node.Name}"; if (lstmInputs.ContainsKey(node.Name)) { var actualName = lstmInputs[node.Name]; if (actualName.EndsWith(":0")) actualName = actualName.Substring(0, actualName.Length - 2); if (actualName.EndsWith("_h") || actualName.EndsWith("_c")) baseLSTMName = actualName.Substring(0, actualName.Length - 2); else baseLSTMName = actualName; } return baseLSTMName; } // Fuse training time BatchNorm tensors into Scale & Bias internal static Tuple FuseBatchNormWeights(Tensor gamma, Tensor beta, Tensor mean, Tensor variance, float epsilon, int maxElements = -1) { // https://github.com/Tencent/ncnn/blob/master/src/layer/batchnorm.cpp // float sqrt_var = sqrt(var_data[i]); // a_data[i] = bias_data[i] - slope_data[i] * mean_data[i] / sqrt_var; // b_data[i] = slope_data[i] / sqrt_var; // ... // ptr[i] = b * ptr[i] + a; Assert.IsTrue(gamma.channels == gamma.length); // assert 1d tensor Assert.IsTrue(gamma.shape == beta.shape); Assert.IsTrue(gamma.shape == mean.shape); Assert.IsTrue(gamma.shape == variance.shape); if (maxElements <= 0 || gamma.length < maxElements) // clip to the smallest valid number of channels maxElements = gamma.length; Tensor scale = new Tensor(1, maxElements); Tensor bias = new Tensor(1, maxElements); for (int i = 0; i < maxElements; ++i) { scale[i] = gamma[i] / Mathf.Sqrt(variance[i] + epsilon); bias[i] = beta[i] - gamma[i] * mean[i] / Mathf.Sqrt(variance[i] + epsilon); } return Tuple.Create(scale, bias); } // TODO move that in custom pass if need be // Transpose channels first to channels last data in MatMul/GEMM weight tensor internal static Tensor SwapSpatialDimensionsAndFeaturesInMatMulWeights(Tensor weights, int featureCount, VariableTensor.Layout layout) { if (featureCount == 0) // wild card feature: after Reduce, runtime correct weights. TODO: remove when full dims are known return weights; Assert.IsTrue(featureCount <= weights.flatHeight); var weightsAssumeChannelsFirstLayout = (layout != VariableTensor.Layout.ChannelsLast); if (featureCount != weights.flatHeight && weightsAssumeChannelsFirstLayout) { var shape = weights.shape; var implicitSpatialDimensionsInWeights = shape.flatHeight / featureCount; Assert.IsTrue(shape.flatHeight % featureCount == 0); // reshape: __C____K -> __C__HWK weights = weights.Reshape( new TensorShape(featureCount, implicitSpatialDimensionsInWeights, 1, shape.channels)); // permute: __C__HWK -> __H__WCK var permutations = TensorExtensions.Get8DPermutationsForNHWCPermutationsAndShape(weights.shape, new int[] {1, 0, 2, 3}); weights = ONNXTensor.Permute(weights, permutations); // reshape: __H__WCK -> __C____K weights = weights.Reshape(shape); } return weights; } internal static Model PatchFromIncorrectlyAssumedChannelsFirstToChannelsLastLayoutUpstream(Model model, List layerRequiringUpstreamPatch) { HashSet patchedInputIndices = new HashSet(); HashSet patchedLayerAxis = new HashSet(); var inputIndexByName = new Dictionary(); for (var i = 0; i < model.inputs.Count; ++i) inputIndexByName.Add(model.inputs[i].name, i); // NOTE: although original input had NHWC layout // (most probably exported from Tensorflow without '--inputs-as-nchw' flag) // earlier when parsing input and axis we made incorrect assumption that they were NCHW // now we need to revert that assumption! foreach (var rootNodeForPatch in layerRequiringUpstreamPatch) { int inputIndex = -1; var upstream = ModelAnalyzer.FindUpstreamLayers(model, new[] {rootNodeForPatch}); foreach (var layer in upstream) { // patch axis if (!patchedLayerAxis.Contains(layer.name) && ( layer.type == Layer.Type.Concat || layer.type == Layer.Type.Gather || layer.type == Layer.Type.TopKValues))//TODO handle ReduceXX and StridedSlice { patchedLayerAxis.Add(layer.name); if (layer.axis == 6) layer.axis = TensorShape.C; else if (layer.axis == TensorShape.C) layer.axis = 6; } //patch inputs foreach (var inputName in layer.inputs) { if (inputIndexByName.TryGetValue(inputName, out inputIndex) && !patchedInputIndices.Contains(inputIndex)) { // example (NCHW): -1,2,2,16 -> (incorrect) -1,2,16,2 -> (fix) -1,2,2,16 // example (NCW): -1,2,16 -> (incorrect) -1,1,16,2 -> (fix) -1,1,2,16 patchedInputIndices.Add(inputIndex); var inputDesc = model.inputs[inputIndex]; inputDesc.shape = ONNXLayout.Permute(inputDesc.shape, new[] {-1, -1, 2, -1, -1, 7, 5, 6}); model.inputs[inputIndex] = inputDesc; } } // @TODO: figure out, if there is any case where we would have to propagate fixed layout assumption downstream? } } return model; } // TODO: use Burst for this internal static Tensor DilateKernel(Tensor kernel, int[] dilationsDHW) { //TODO: slow path in C# consider refactoring in Burst Assert.IsTrue(dilationsDHW.Length == 3); Assert.IsTrue(dilationsDHW[0] > 0); Assert.IsTrue(dilationsDHW[1] > 0); Assert.IsTrue(dilationsDHW[2] > 0); // https://arxiv.org/pdf/1603.07285.pdf Tensor dilatedKernel = new Tensor(new TensorShape(1, kernel.shape.kernelSpatialDepth + (kernel.shape.kernelSpatialDepth - 1) * (dilationsDHW[0] - 1), kernel.shape.kernelHeight + (kernel.shape.kernelHeight - 1) * (dilationsDHW[1] - 1), 1, 1, kernel.shape.kernelWidth + (kernel.shape.kernelWidth - 1) * (dilationsDHW[2] - 1), kernel.shape.kernelDepth, kernel.shape.kernelCount)); for (int c = 0; c < dilatedKernel.kernelDepth; ++c) for (int k = 0; k < dilatedKernel.kernelCount; ++k) { for (int d = 0; d < kernel.shape.kernelSpatialDepth; ++d) for (int y = 0; y < kernel.shape.kernelHeight; ++y) for (int x = 0; x < kernel.shape.kernelWidth; ++x) { int od = d * dilationsDHW[0]; int oy = y * dilationsDHW[1]; int ox = x * dilationsDHW[2]; int strideD = d == (kernel.shape.kernelSpatialDepth - 1) ? 1 : dilationsDHW[0]; int strideY = y == (kernel.shape.kernelHeight - 1) ? 1 : dilationsDHW[1]; int strideX = x == (kernel.shape.kernelWidth - 1) ? 1 : dilationsDHW[2]; for (int dx = 0; dx < strideX; dx++) for (int dy = 0; dy < strideY; dy++) for (int dd = 0; dd < strideD; dd++) { dilatedKernel[ 0, od +dd, oy + dy, 0, 0, ox + dx, c, k] = 0.0f; } float v = kernel[ 0, d, y, 0, 0, x, c, k]; dilatedKernel[0, od, oy, 0, 0, ox, c, k] = v; } } return dilatedKernel; } internal static TensorShape Bias(TensorShape shape) { return new TensorShape(1, 1, 1, shape.channels); } internal static bool IsModeBilinear(ModelBuilder net, ONNXNodeWrapper node, string mode) { bool bilinear = false; if (mode == "linear" || mode == "bilinear") bilinear = true; else if (mode != "nearest") Warn(net, node, $"Mode `{mode}` is not supported for type {node.OperatorType}."); return bilinear; } internal static Layer UpsampleNCHW(ModelBuilder net, ONNXNodeWrapper node, int scaleInputIndex) { string mode = node.ModeOptional("nearest"); var bilinear = IsModeBilinear(net, node, mode); // NOTE: Intermediate NCHW -- op is implemented expecting NHWC by default, so this is non-runnable as-is if (scaleInputIndex != 0 && node.InputCount > scaleInputIndex && !node.IsInputConst(scaleInputIndex)) { // TODO: Input1 may be rank 1, which means that this would require a swizzle in the actual data if (node.Input0Rank <= 4) return net.Upsample2D(node.Name, node.Input0, node.GetRequiredInput(scaleInputIndex), bilinear); else return net.Upsample3D(node.Name, node.Input0, node.GetRequiredInput(scaleInputIndex), bilinear); } else return UpsampleFromConstNCHW(net, node, node.Name, node.Input0, node.ConvertScales(), mode); } internal static Layer UpsampleFromConstNCHW(ModelBuilder net, ONNXNodeWrapper node, string name, object input, float[] scales, string mode) { if (!scales.All(x => x > 0.0f)) Warn(net, node, $"Only positive scale values are supported."); if (scales.Length == 4 && scales[0] == 1.0f && scales[1] == 1.0f && scales[2] < 1.0f && scales[3] < 1.0f && IsModeBilinear(net, node, mode)) { var scales2D = scales.Skip(2); if (!scales2D.All(x => Mathf.Approximately(1f / x, Mathf.Round(1f / x)))) Warn(net, node, $"Only inverse of scale values which produce integer are currently supported. Inverse of scale value will be rounded to closest integer."); var noPad = new[] { 0, 0, 0, 0 }; var inverseScalesRoundedToInt = scales2D.Select(x => (int)Mathf.Round(1f / x)).ToArray(); return net.AvgPool2D(name, input, inverseScalesRoundedToInt, inverseScalesRoundedToInt, noPad); } else { if (!scales.All(x => Mathf.Approximately(x, Mathf.Round(x)))) Warn(net, node, $"Only integer scale values are currently supported. Scale value will be rounded to closest integer value."); var scalesRoundedToInt = scales.Select(x => (int)Mathf.Round(x)).ToArray(); if (scales.Length > 5) Warn(net, node, ">3D upsampling are not supported yet!"); if (scales.Length == 5) return net.Upsample3D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode)); else return net.Upsample2D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode)); } } internal static Layer Resample(ModelBuilder net, ONNXNodeWrapper node, string name, object input, float[] scales, string mode) { if (!scales.All(x => x > 0.0f)) Warn(net, node, $"Only positive scale values are supported."); if (scales.All(x => x < 1.0f)) { if (!scales.All(x => Mathf.Approximately(1f/x, Mathf.Round(1f/x)))) Warn(net, node, $"Only inverse of scale values which produce integer are currently supported. Inverse of scale value will be rounded to closest integer."); var noPad = new[] {0, 0, 0, 0}; var inverseScalesRoundedToInt = scales.Select(x => (int)Mathf.Round(1f/x)).ToArray(); // @TODO: nearest, actually this is bilinear downsampling if (scales.Length > 2) Warn(net, node, ">2D downsampling are not supported yet!"); return net.AvgPool2D(name, input, inverseScalesRoundedToInt, inverseScalesRoundedToInt, noPad); } else { if (!scales.All(x => Mathf.Approximately(x, Mathf.Round(x)))) Warn(net, node, $"Only integer scale values are currently supported. Scale value will be rounded to closest integer value."); var scalesRoundedToInt = scales.Select(x => (int)Mathf.Round(x)).ToArray(); if (scales.Length > 3) Warn(net, node, ">3D upsampling are not supported yet!"); if (scales.Length > 2) return net.Upsample3D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode)); else return net.Upsample2D(name, input, scalesRoundedToInt, IsModeBilinear(net, node, mode)); } } private static int[] GetPermutationToMatchReduceWithDroppedDimensionsFromONNX(int[] droppedONNXAxis, int rank) { Assert.IsTrue(droppedONNXAxis.Length>0); //Barracuda always have all dimensions, however in ONNX it is not the case one can drop dimensions, //Here we handle the case of ReduceXXX ops when they do so. //An example: //ONNX -> NCHW //Reduce on C with keepDims=False. //ONNX -> NHW //However ONNX tensor semantic are deducted by position to be mapped to Barracuda in the following way: //ONNX 1D -> N -> Barracuda N,1,1,1 //ONNX 2D -> NC -> Barracuda N,1,1,C //ONNX 3D -> NCW -> Barracuda N,1,W,C //ONNX 4D -> NCHW -> Barracuda N,H,W,C //Thus the output tensor above (NHW) will be mapped to N,1,W,C in Barracuda //while Reduce in Barracuda would rather output N,H,W,1 if keepDim would be true. //Here we find the transpose needed in Barracuda to match the ONNX behavior as seen by Barracuda. //ie the transpose from N,H,W,1 to N,1,W,C in this case aka 0,3,2,1. //ONNX input Layout from rank string onnxLayout; switch (rank) { case 1: onnxLayout = "N"; break; case 2: onnxLayout = "NC"; break; case 3: onnxLayout = "NCW"; break; case 4: onnxLayout = "NCHW"; break; default: //TODO support 8D throw new OnnxLayerImportException($"Reduce ops support up to 4D at the moment, however received an input of rank {rank}."); } //ONNX Layout once dimensions are dropped (example: NHW if C was dropped) string onnxLayoutDimensionsDropped = onnxLayout; foreach (var axis in droppedONNXAxis) { var onnxAxis = axis; if (onnxAxis < 0) onnxAxis = rank + axis; string semanticToRemove = onnxLayout[onnxAxis].ToString(); onnxLayoutDimensionsDropped = onnxLayoutDimensionsDropped.Replace(semanticToRemove, string.Empty); } Assert.IsTrue(onnxLayoutDimensionsDropped.Length>0); //Find all missing dimensions that will be unitary in Barracuda var missingDimensions = new List(); foreach (var dim in "NHWC") { if (!onnxLayoutDimensionsDropped.Contains(dim)) missingDimensions.Add(dim); } //Find semantic of onnx layout with dropped dimension in Barracuda var barracudaSemanticLayoutFromONNXReduce = new char[4]; switch (onnxLayoutDimensionsDropped.Length) { case 1: //ONNX 1D -> N -> Barracuda N,1,1,1 barracudaSemanticLayoutFromONNXReduce[0] = onnxLayoutDimensionsDropped[0]; barracudaSemanticLayoutFromONNXReduce[1] = missingDimensions[0]; barracudaSemanticLayoutFromONNXReduce[2] = missingDimensions[1]; barracudaSemanticLayoutFromONNXReduce[3] = missingDimensions[2]; break; case 2: //ONNX 2D -> NC -> Barracuda N,1,1,C barracudaSemanticLayoutFromONNXReduce[0] = onnxLayoutDimensionsDropped[0]; barracudaSemanticLayoutFromONNXReduce[1] = missingDimensions[0]; barracudaSemanticLayoutFromONNXReduce[2] = missingDimensions[1]; barracudaSemanticLayoutFromONNXReduce[3] = onnxLayoutDimensionsDropped[1]; break; case 3: //3D -> NCW -> Barracuda N,1,W,C barracudaSemanticLayoutFromONNXReduce[0] = onnxLayoutDimensionsDropped[0]; barracudaSemanticLayoutFromONNXReduce[1] = missingDimensions[0]; barracudaSemanticLayoutFromONNXReduce[2] = onnxLayoutDimensionsDropped[2]; barracudaSemanticLayoutFromONNXReduce[3] = onnxLayoutDimensionsDropped[1]; break; } //Find permutation from NHWC Barracuda layout when mapped from ONNX with dropped dimensions. var permutation = new int[4]; for(int idTarget = 0; idTarget= 2) axes = node.Input1Constant(onnxLayout: "ONNX", name: "axes").AsInts(); // Sort high to low since we are reducing rank in each iteration // var axes = node.AxesOptional(new[] { 0 }).OrderByDescending(a => a).ToArray(); int reducedDim = 0; foreach (var onnxAxis in axes) { //TODO support 8D inputs //var axis = ONNXLayout.ConvertAxisToBarracuda(onnxAxis, onnxRank: rank, onnxLayout: "ONNX"); var axis = onnxAxis; if (reducedDim != 0) axis--; var nameR = $"{node.Name}__axis{onnxAxis}"; input = net.Reduce(reduceType, nameR, input, axis, true, keepdims); //if (axis == TensorShape.C) // This is actually W // features = 1; // this operation collapse all features to 1 Output(nameR, features: features, rank: rank); // Without keepdims, we will be reducing rank every axis iteration if((keepdims == 0)) { rank--; reducedDim++; } } net.Identity(node.Name, input); } internal void Reduce(ModelBuilder net, ONNXNodeWrapper node, Layer.Type reduceType) { var keepdims = node.GetOptionalInt("keepdims", 1); var features = node.Input0Features; var rank = node.Input0Rank; object input = node.Input0; var axes = node.HasAttribute("axes") ? node.AxesOptional(new[] { 0 }) : new[] {node.AxisOptional(0)}; foreach (var onnxAxis in axes) { //TODO support 8D inputs var axis = ONNXLayout.ConvertAxisToBarracuda(onnxAxis, onnxRank: rank, onnxLayout: "NCHW"); if (node.Input0Layout == VariableTensor.Layout.ChannelsLast && node.Input0Rank == 4) axis = TensorExtensions.Convert4DTo8DAxis(onnxAxis); var nameR = $"{node.Name}__axis{axis}"; input = net.Reduce(reduceType, nameR, input, axis, true, keepdims); if (axis == TensorShape.C) features = 1; // this operation collapse all features to 1 Output(nameR, features: features, rank: rank); } if (keepdims != 1 && rank > 1 && (node.Input0Layout != VariableTensor.Layout.ChannelsLast)) // keepdims removes dimensions in the context of onnx thus we need to repack/transpose to match behavior. { var nameT = $"{node.Name}__transpose"; var transpose = GetPermutationToMatchReduceWithDroppedDimensionsFromONNX(axes, rank); input = net.Transpose(nameT, input, transpose); rank = rank - axes.Length; //TODO: features count is wrong and should potentially be deduced from input + transpose Output(nameT, features: 0, rank: rank); } net.Identity(node.Name, input); //TODO: features count is wrong and should potentially be deduced from input Output(node.Name, features: 0, rank: rank); } private ONNXModelTensors m_ModelTensors = new ONNXModelTensors(); private readonly Dictionary> m_NodeImporters = new Dictionary>(); // NOTE: It's questionable whether we should be doing this since the ONNX specification requires the graph to be // topologically sorted, but at least one network encountered that was exported from keras2onnx v1.7.0 produced // an incorrectly sorted graph. related example: https://github.com/onnx/keras-onnx/issues/184 void SortTopologically(ModelProto onnxModel, List sortedGraph) { var nodesToSort = new Queue(); GraphProto onnxGraph = onnxModel.Graph; foreach (NodeProto node in onnxGraph.Node) { nodesToSort.Enqueue(node); } var requeueNodes = new Queue(); while (nodesToSort.Count > 0) { NodeProto node = nodesToSort.Dequeue(); var allInputsExist = true; foreach (string input in node.Input) { if (string.IsNullOrEmpty(input)) continue; if (!sortedGraph.Exists(n => n.Output.Any(o => o == input)) && !onnxGraph.Input.Any(i => i.Name == input) && !onnxGraph.Initializer.Any(i => i.Name == input)) { allInputsExist = false; break; } } if (!allInputsExist) { if (nodesToSort.Count != 0) { // Mark for re-processing again when (potentially) all inputs have been processed // We use a separate list, so we don't continually spin on nodes that are missing inputs if (!requeueNodes.Contains(node)) requeueNodes.Enqueue(node); continue; } // Something must've gone wrong throw new OnnxImportException($"Missing inputs to node {node.Name}, but there are no nodes to process."); } if (!sortedGraph.Contains(node)) sortedGraph.Add(node); // Now that we have at least processed a single new node, let's requeue while (requeueNodes.Count > 0) nodesToSort.Enqueue(requeueNodes.Dequeue()); } } private Model ConvertOnnxModel(ModelProto onnxModel) { var model = new Model(); bool standardImport = m_ImportMode.HasFlag(ImportMode.Standard); model.layout = standardImport ? "iNCHW" : "NHWC"; var modelBuilder = new ModelBuilder(model); // Builds list of nodes that should not be included into the final Barracuda Model, mostly for LSTMs var nodesToSkip = standardImport ? new HashSet() : BuildNodeSkipList(onnxModel.Graph); // Import any (optional) metadata properties if (!m_ImportMode.HasFlag(ImportMode.SkipMetadataImport)) { RepeatedField metadataProps = onnxModel.MetadataProps; Dictionary metadata = model.Metadata; for (int p = 0; p < metadataProps.Count; p++) { StringStringEntryProto prop = metadataProps[p]; metadata.Add(prop.Key, prop.Value); } } // Convert graph inputs & outputs var initializersByName = onnxModel.Graph.Initializer.ToDictionary(i => i.Name, i => true); foreach (ValueInfoProto i in onnxModel.Graph.Input) { // skip input tensors that have initializer data, they are constant tensors not global inputs // also skip nodes that should be trimmed if (initializersByName.ContainsKey(i.Name) || (!standardImport && nodesToSkip.Contains(i.Name))) continue; if (!standardImport && m_OverrideGlobalInputs.ContainsKey(i.Name)) { Const(i.Name, m_OverrideGlobalInputs[i.Name]); continue; } int[] onnxShape = i.Type.TensorType.Shape.AsInts(); modelBuilder.Input(i.Name, ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, onnxLayout:standardImport ? "ONNX" : "NCHW"), onnxShape.Length); var shapeValues = i.Type.TensorType.Shape.Dim.Select(d => d.DimValue).ToArray(); Output(i.Name, onnxShape: shapeValues, onnxLayout:"NCHW"); } foreach (ValueInfoProto o in onnxModel.Graph.Output) modelBuilder.Output(o.Name); // Read constants from initializer list foreach (TensorProto initializer in onnxModel.Graph.Initializer) Const(initializer.Name, new ONNXTensor(initializer)); // Nodes are supposed to be sorted, but this isn't always the case var sortedGraph = new List(); if (standardImport) { SortTopologically(onnxModel, sortedGraph); } else { // for the legacy import pipeline, let's keep it as it was sortedGraph.AddRange(onnxModel.Graph.Node); } // Convert graph nodes foreach (NodeProto onnxNode in sortedGraph) { if (!standardImport && nodesToSkip.Contains(ONNXNodeWrapper.GetName(onnxNode))) continue; var node = new ONNXNodeWrapper(onnxNode, m_ModelTensors, model.Warnings); var nodeId = node.Name; var opType = node.OperatorType; Output(node); bool injectDummy = false; if (m_NodeImporters.ContainsKey(opType)) { try { if (!standardImport && node.AreAllInputsConst && !m_ShouldNotBeBaked.Contains(opType)) { Profiler.BeginSample($"Bake {opType} {node.Name}"); var bakedTensor = BakeNodeIntoConstant(opType, node); Const(node.Name, bakedTensor); var printTensor = bakedTensor.ToBarracuda("NCHW"); D.Log($"Baked node {nodeId} into constant of shape {printTensor.shape} and values: {printTensor.DataToString()}"); Profiler.EndSample(); } else { Profiler.BeginSample($"Import {opType} {node.Name}"); m_NodeImporters[opType](modelBuilder, node); Profiler.EndSample(); } } catch (Exception e) { // We support the layer but something went wrong while importing it // We log the problem and insert an identity layer string message = $"Unexpected error while parsing layer {nodeId} of type {opType}."; Err(model, nodeId, message, extendedMessage:"Will replace it by an Identity layer.", debugMessage:$"{e.Message}\n\nJson: {onnxNode}\n{e.StackTrace}\n"); injectDummy = true; } } else { // We don't support this type of layer // We log the problem and insert an identity layer string message = $"Unknown type {opType} encountered while parsing layer {nodeId}."; Err(model, nodeId, message, extendedMessage:"Will replace it by an Identity layer."); injectDummy = true; } if (injectDummy) { var originalLayerHadInputs = (node.InputCount > 0); if (originalLayerHadInputs) { var originalLayerHadConstantInput = node.IsInput0Const; if (originalLayerHadConstantInput) Const(nodeId, constantTensors[node.Input0]); // copy constant else modelBuilder.Identity(nodeId, node.Input0); } else // if errorneous layer had no inputs, inject dummy constant which does not require any inputs modelBuilder.Const(nodeId, new Tensor()); } m_ModelTensors.CompleteUninitializedFields(node); } // Convert constant tensors var requiredConstants = new HashSet(ModelAnalyzer.FindBrokenLinks(model)); // ML-Agents metadata is stored in otherwise unreferenced constants var unreferencedConstantsContainMLAgentsMetadata = UnreferencedNodes(onnxModel.Graph); requiredConstants.UnionWith(unreferencedConstantsContainMLAgentsMetadata); // keep ML-Agents metadata int insertionIndex = 0; // insert constants at the beginning of the model foreach(var entry in constantTensors) { if (requiredConstants.Contains(entry.Key)) // skip if constant is unused { modelBuilder.Const(entry.Key, entry.Value.ToBarracuda(standardImport ? "ONNX" : GetONNXLayoutForConstant(model, entry.Key)), insertionIndex++, rank: entry.Value.rank); } } if (m_ImportMode == ImportMode.Legacy) { foreach (Layer l in model.layers) { if (requiredConstants.Contains(l.name)) l.flags |= Layer.Flags.Preserve; } model = ModelOptimizer.Optimize(model, allowFusing: m_OptimizeModel, keepLayers:requiredConstants); // keep ML-Agents metadata model = FixReshapeTransposePatternWhenChannelsAreSplitIntoMultipleDimensions(model); if (!m_FixTf2OnnxExportIssues) model = PatchFromIncorrectlyAssumedChannelsFirstToChannelsLastLayoutUpstream(model, layerRequiringUpstreamPatch); } // strip :0 at the end of string name for TF import if (m_FixTf2OnnxExportIssues) model = TrimTensorflowNames(model); if (m_ImportMode == ImportMode.Legacy) Validate(model); // Parse meta data var irVersion = onnxModel.IrVersion; // legacy if (onnxModel.OpsetImport?.Count > 0) irVersion = onnxModel.OpsetImport[0].Version; model.ProducerName = $"{onnxModel.ProducerName} v{onnxModel.ProducerVersion}"; model.IrSource = "ONNX"; model.IrVersion = $"{irVersion}"; return model; } private bool IsLayerInputChannelDependant(Layer.Type opType, int index) { return index == 0 || //First input is usually channel order dependants opType == Layer.Type.Add || //however some operator have all input channel dependants opType == Layer.Type.Sub || opType == Layer.Type.Mul || opType == Layer.Type.Div || opType == Layer.Type.Pow || opType == Layer.Type.Min || opType == Layer.Type.Max || opType == Layer.Type.Mean || opType == Layer.Type.Greater || opType == Layer.Type.GreaterEqual || opType == Layer.Type.Less || opType == Layer.Type.LessEqual || opType == Layer.Type.Equal || opType == Layer.Type.LogicalOr || opType == Layer.Type.LogicalAnd || opType == Layer.Type.LogicalXor || opType == Layer.Type.Where || opType == Layer.Type.Concat; } private string GetONNXLayoutForConstant(Model model, string nodeName) { int constLayoutRequestCount = 0; int nctdhwRequestCount = 0; //find all layer using that constant as an input. foreach (var l in model.layers) { for (int i = 0; i < l.inputs.Length; ++i) { if (l.inputs[i] == nodeName) { if (IsLayerInputChannelDependant(l.type, i)) ++nctdhwRequestCount; else ++constLayoutRequestCount; } } } if (nctdhwRequestCount != 0 && constLayoutRequestCount != 0) { Err(model, nodeName, $"{nodeName} is both used as channel order dependant constant and a plain constant, this is not supported at the moment."); } return nctdhwRequestCount>constLayoutRequestCount?"NCTDHW":"CONST"; } private ONNXTensor BakeNodeIntoConstant(string opType, ONNXNodeWrapper node) { var model = new Model(); var net = new ModelBuilder(model); // add all inputs as constants Assert.IsTrue(node.AreAllInputsConst); for (var i = 0; i < node.InputCount; ++i) { var assumeOnnxLayout = (m_AllInputsChannelFirst.Contains(opType) || i == 0) ? "NCTDHW" : "CONST"; var input = node.Inputs[i]; net.Const(input, constantTensors[input].ToBarracuda(assumeOnnxLayout)); } // add node that we are going to bake into the constant m_NodeImporters[opType](net, node); // bake var useCPUforBaking = WorkerFactory.Device.CPU; using (var worker = WorkerFactory.CreateWorker(model, useCPUforBaking)) { var bakedConstant = worker.Execute().PeekOutput(); // convert from Barracuda back into ONNX layout Tensor onnxData = bakedConstant; onnxData = ONNXTensor.Permute(bakedConstant, new int[] {0,1,2,7,3,4,5,6}); // S,R,N,T,D,H,W,C (channelLast)-> S,R,N,C,H,W (channelFirst) var onnxShape = onnxData.shape.ToArray(); return new ONNXTensor(onnxData, onnxShape).SqueezeAll(); } } static private void Validate(Model model) { // Model should not contain any broken links in the end var unconnectedInputs = ModelAnalyzer.FindBrokenLinks(model); Assert.IsTrue(unconnectedInputs.Length == 0); if (unconnectedInputs.Length > 0) { var message = $"Broken links: {string.Join(", ", unconnectedInputs)}"; Warn(model, "", message); } } private HashSet UnreferencedNodes(GraphProto graph) { var allNodes = new HashSet(); var allInputs = new HashSet(); foreach (var node in graph.Node) { allNodes.Add(ONNXNodeWrapper.GetName(node)); foreach (var input in node.Input) allInputs.Add(input); } // Remove all global output nodes foreach (ValueInfoProto o in graph.Output) allNodes.Remove(o.Name); // Remove all nodes that are referenced by Inputs to get the set of unreferenced ones var unreferencedNodes = allNodes; unreferencedNodes.ExceptWith(allInputs); return unreferencedNodes; } private void BacktraceNodeInputs(Dictionary nameToNode, NodeProto[] startingNodes, Action regularNodeCallback, Action inputNodeCallback) { HashSet nodesToCheck = new HashSet(startingNodes); while (nodesToCheck.Count > 0) { var el = nodesToCheck.First(); regularNodeCallback(el); nodesToCheck.Remove(el); if (el.Input.Count > 0) { if (nameToNode.ContainsKey(el.Input[0])) nodesToCheck.Add(nameToNode[el.Input[0]]); // regular node else inputNodeCallback(el); } } } // TODO: Remove along with legacy importer in Barracuda 2.0 private HashSet BuildNodeSkipList(GraphProto graph) { var res = new HashSet(); var nameToNode = graph.Node.ToDictionary(i => ONNXNodeWrapper.GetName(i), i => i); var outputToLSTMNode = new Dictionary(); // Skip all LSTM _h & _c inputs as they will be accessible directly via Model.memories foreach (NodeProto onnxNode in graph.Node) { if (onnxNode.OpType == "LSTM") { var lstmNodeName = ONNXNodeWrapper.GetName(onnxNode); var initial_h = onnxNode.Input[5]; var initial_c = onnxNode.Input[6]; List startingNodes = new List(); if (nameToNode.ContainsKey(initial_h)) startingNodes.Add(nameToNode[initial_h]); if (nameToNode.ContainsKey(initial_c)) startingNodes.Add(nameToNode[initial_c]); BacktraceNodeInputs( nameToNode, startingNodes.ToArray(), el => { res.Add(ONNXNodeWrapper.GetName(el)); }, el => { lstmInputs[lstmNodeName] = el.Input[0]; res.Add(el.Input[0]);} ); outputToLSTMNode[onnxNode.Output[1]] = lstmNodeName; // _h outputToLSTMNode[onnxNode.Output[2]] = lstmNodeName; // _c } } // Also trace from outputs to LSTM nodes to figure out names of the output _h and _c nodes foreach (var output in graph.Output) { if (!nameToNode.ContainsKey(output.Name)) continue; // As LSTM has 3 outputs and backtracing is done only via output[0] // then output[1] and output[2] will be treated as leaf input nodes BacktraceNodeInputs( nameToNode, new[] {nameToNode[output.Name]}, el => { }, el => { var inputName = el.Input[0]; if (outputToLSTMNode.ContainsKey(inputName)) { lstmOutputs[outputToLSTMNode[inputName]] = output.Name; } } ); } return res; } static private string ApplyPermutationToLayout(string layout, int[] permutation) { Assert.IsTrue(layout.Length == permutation.Length); char[] permutedLayout = new char[layout.Length]; for (int i = 0; i < layout.Length; ++i) { permutedLayout[i] = layout[permutation[i]]; } return new string(permutedLayout); } static private int[] FindPermutationFromLayouts(string layout, string permutedLayout) { Assert.IsTrue(layout.Length == permutedLayout.Length); int[] permutation = new int[layout.Length]; for (int i = 0; i < layout.Length; ++i) { permutation[i] = layout.IndexOf(permutedLayout[i]); } return permutation; } static private Model FixReshapeTransposePatternWhenChannelsAreSplitIntoMultipleDimensions(Model model) { var transposes = model.layers.Where(l => l.type == Layer.Type.Transpose).ToList(); foreach (var transposeLayer in transposes) { var previousLayer = model.layers.Find(l => l.name == transposeLayer.inputs[0]); if (previousLayer == null) continue; if (previousLayer.type != Layer.Type.Reshape) continue; var numChannelDimensionBeforeTranspose = previousLayer.axis; if (numChannelDimensionBeforeTranspose <= 1) continue; int centerPaddingThatWasAddedInPermutation = transposeLayer.axis; Assert.IsTrue(centerPaddingThatWasAddedInPermutation <= 1); Assert.IsTrue(centerPaddingThatWasAddedInPermutation >= 0); //NOTE: See also ConvertReshapeToBarracuda() for mode detail on the problem. //In some network like shufflenet, superresolution_cnn and yolov3 a reshape is used //before a transpose to split the channels resulting in a tensor with //multiple dimension used for channels, this is a problem when importing to //barracuda as the semantic of the dimensions are changed and this change the //way channel first to channel last conversion should happen. The code below //is a limited to support for that. Assert.IsTrue(numChannelDimensionBeforeTranspose == 2 || numChannelDimensionBeforeTranspose == 3); var permutationSRNTDHWC = transposeLayer.pool; if (permutationSRNTDHWC.Length != 8) { Warn(model, transposeLayer.name, $"Expecting a permutation of rank 8 after Reshape '{previousLayer.name}' itself outputting more than one channel dimension. Permutation can't be patched to account for the extra channel dimensions."); continue; } //Find layouts before transpose in both channel order string layoutBeforeTranspose_ChannelFirst = (numChannelDimensionBeforeTranspose == 3) ? "SRN123HW" : "SRN1T2HW"; string layoutBeforeTranspose_ChannelLast = (numChannelDimensionBeforeTranspose == 3) ? "SRNHW123" : "SRNTHW12"; //Find layout after transpose in channel first int[] permutation_ChannelFirst = ONNXLayout.ConvertPermutationToLayout(permutationSRNTDHWC, "SRNTDHWC","SRNCTDHW"); string layoutAfterTranspose_ChannelFirst = ApplyPermutationToLayout(layoutBeforeTranspose_ChannelFirst, permutation_ChannelFirst); //Find layout after transpose in channel last //TODO/HEURISTIC: We differentiate the various case by knowing if channels and features are interleaved during permutations. //This is a work around to create the right permutation for the shufflenet/super-resolution and yolov3, it does not generalise well however. //In next version of the importer we might need to introduce transposes in channel last mode to generalise fully. int[] channelFirstToLastPermutation = null; if (numChannelDimensionBeforeTranspose == 3) { //super resolution -> final reshape will pick only 1 dimension as channel -> regular channel first to last transposition. channelFirstToLastPermutation = FindPermutationFromLayouts("SRN1TDHW", "SRNTDHW1"); } else if (IsPermutationMixingChannelsAndOtherFeatures(layoutBeforeTranspose_ChannelFirst, permutation_ChannelFirst)) { //yolov3 -> final reshape does not pick any dimension as channel -> no transposition. channelFirstToLastPermutation = FindPermutationFromLayouts("SRNTUDHW", "SRNTUDHW"); } else { //shufflenet -> final reshape take 2 dimension and merge them so both need to be affected by channel first to last transposition channelFirstToLastPermutation = FindPermutationFromLayouts("SRN1T2HW", "SRNTHW12"); } string layoutAfterTranspose_ChannelLast = ApplyPermutationToLayout(layoutAfterTranspose_ChannelFirst, channelFirstToLastPermutation); //Finally compute and return permutation in channel last int[] permutation_ChannelLast = FindPermutationFromLayouts(layoutBeforeTranspose_ChannelLast, layoutAfterTranspose_ChannelLast); transposeLayer.pool = permutation_ChannelLast; } return model; } static private bool IsPermutationMixingChannelsAndOtherFeatures(string layout, int[] permutation) { //Convention here is that channels are described as numbers, while other features by letters. Assert.IsTrue(layout.Length == permutation.Length); for (int i = 0; i < permutation.Length; ++i) { bool sourceIsAChannel = Char.IsNumber(layout[i]); bool targetIsAChannel = Char.IsNumber(layout[permutation[i]]); if (sourceIsAChannel != targetIsAChannel) return true; } return false; } static private Model TrimTensorflowNames(Model model) { model.inputs = model.inputs.Select(i => { i.name = TrimTensorflowName(i.name); return i; }).ToList(); model.outputs = model.outputs.Select(o => { return TrimTensorflowName(o); }).ToList(); model.memories = model.memories.Select(m => { m.input = TrimTensorflowName(m.input); m.output = TrimTensorflowName(m.output); return m; }).ToList(); model.layers = model.layers.Select(l => { l.name = TrimTensorflowName(l.name); for(int i = 0; i < l.datasets.Length; i++) l.datasets[i].name = TrimTensorflowName(l.datasets[i].name); for(int i = 0; i < l.inputs.Length; i++) l.inputs[i] = TrimTensorflowName(l.inputs[i]); if (l.outputs != null) { for (int i = 0; i < l.outputs.Length; i++) l.outputs[i] = TrimTensorflowName(l.outputs[i]); } return l; }).ToList(); return model; } static private string TrimTensorflowName(string name) { if (name.EndsWith(":0")) return name.Remove(name.Length-2); return name; } // Helpers to keep track of model tensors private void Const(ONNXNodeWrapper node, ONNXTensor onnxTensor) { m_ModelTensors.AddConstant(node.Name, onnxTensor); } private void Const(string name, ONNXTensor onnxTensor) { m_ModelTensors.AddConstant(name, onnxTensor); } private void Output(ONNXNodeWrapper node, int features = -1, int rank = -1, VariableTensor.Layout layout = VariableTensor.Layout.Unknown) { Output(node.Name, features, rank, layout); } private void Output(string name, int features = -1, int rank = -1, VariableTensor.Layout layout = VariableTensor.Layout.Unknown) { m_ModelTensors.AddVariable(name, features, rank, layout); } private void Output(string name, ONNXTensor onnxTensor) { m_ModelTensors.AddVariable(name, onnxTensor); } private void Output(string name, long[] onnxShape, string onnxLayout) { m_ModelTensors.AddVariable(name, onnxShape, onnxLayout); } private void Output(ONNXNodeWrapper node, int features, string productOfShape) { m_ModelTensors.AddVariable(node.Name, features, productOfShape); } // Logging helpers private static void Warn(ModelBuilder builder, ONNXNodeWrapper node, string message) { Warn(builder.model, node.Name, message); } private static void Warn(Model model, string layerName, string message) { model.Warnings.Add(new Model.ImporterWarning(layerName,message)); Debug.LogWarning(message); } private void Err(Model model, string layerName, string message, string extendedMessage = "", string debugMessage = "") { if (m_TreatErrorsAsWarnings) { model.Warnings.Add(new Model.ImporterWarning(layerName,$"{message} {extendedMessage}")); Debug.LogWarning($"{message} {extendedMessage}\n{debugMessage}"); } else throw new OnnxImportException($"{message}\n{debugMessage}"); } } /// /// ONNX import exception /// public class OnnxImportException : Exception { /// /// Create `OnnxImportException` /// /// message public OnnxImportException(string message) : base(message) { } } /// /// ONNX layer import exception /// public class OnnxLayerImportException : Exception { /// /// Create `OnnxLayerImportException` /// /// message public OnnxLayerImportException(string message) : base(message) { } } }