using Onnx; using UnityEngine; using UnityEditor; using System; using System.Linq; using System.Collections.Generic; using System.Runtime.CompilerServices; using UnityEngine.Assertions; [assembly: InternalsVisibleToAttribute("Barracuda.EditorTests")] namespace Unity.Barracuda.ONNX { internal class ONNXNodeWrapper { // Layer identification (name and op) public static string GetName(NodeProto node) { // prefer node.output over the node.name return node.Output.Count > 0 ? node.Output[0] : node.Name; } public string Name { get { return GetName(m_ONNXNode); } } public string OperatorType { get { return m_ONNXNode.OpType; } } public bool IsConstant { get { return OperatorType == "Constant"; } } public bool IsTerminatorForProductOfShape { get { return OperatorType == "Reshape"; } } // Outputs public string[] Outputs { get { return m_ONNXNode.Output.ToArray(); }} // Inputs public int InputCount { get { return m_ONNXNode.Input.Count; } } public string[] Inputs { get { return m_ONNXNode.Input.ToArray(); } } public string Input0 { get { return GetRequiredInput(0); } } public string Input1 { get { return GetRequiredInput(1); } } public string Input2 { get { return GetRequiredInput(2); } } public string Input3 { get { return GetRequiredInput(3); } } public string Input4 { get { return GetRequiredInput(4); } } public string Input5 { get { return GetRequiredInput(5); } } public string Input6 { get { return GetRequiredInput(6); } } public string Input0Optional { get { return InputCount > 0 ? GetRequiredInput(0) : ""; } } public string Input1Optional { get { return InputCount > 1 ? GetRequiredInput(1) : ""; } } public string Input2Optional { get { return InputCount > 2 ? GetRequiredInput(2) : ""; } } public string Input3Optional { get { return InputCount > 3 ? GetRequiredInput(3) : ""; } } public string Input4Optional { get { return InputCount > 4 ? GetRequiredInput(4) : ""; } } public string Input5Optional { get { return InputCount > 5 ? GetRequiredInput(5) : ""; } } public string Input6Optional { get { return InputCount > 6 ? GetRequiredInput(6) : ""; } } public bool IsInput0Const { get { return IsInputConst(0); } } public bool IsInput1Const { get { return IsInputConst(1); } } public bool IsInput2Const { get { return IsInputConst(2); } } public bool IsInput3Const { get { return IsInputConst(3); } } public bool IsInput4Const { get { return IsInputConst(4); } } public bool IsInput5Const { get { return IsInputConst(5); } } public bool IsInput6Const { get { return IsInputConst(6); } } public bool AreAllInputsConst { get { for (var i = 0; i < InputCount; ++i) if (!IsInputConst(i)) return false; return true; } } public int Input0Features { get { return m_ONNXModelTensors.variables[Input0].features; } } public int Input1Features { get { return m_ONNXModelTensors.variables[Input1].features; } } public int Input2Features { get { return m_ONNXModelTensors.variables[Input2].features; } } public int Input3Features { get { return m_ONNXModelTensors.variables[Input3].features; } } public int Input4Features { get { return m_ONNXModelTensors.variables[Input4].features; } } public int Input5Features { get { return m_ONNXModelTensors.variables[Input5].features; } } public int Input6Features { get { return m_ONNXModelTensors.variables[Input6].features; } } public int Input0Rank { get { return m_ONNXModelTensors.variables[Input0].rank; } } public int Input1Rank { get { return m_ONNXModelTensors.variables[Input1].rank; } } public VariableTensor.Layout Input0Layout { get { return m_ONNXModelTensors.variables[Input0].layout; } } public Tensor Input0Constant(string onnxLayout, string name = "X") { return GetRequiredInputAsConstant(Input0, onnxLayout, name); } public int[] Input0ConstantONNXShape(string name) { return GetRequiredInputConstantONNXShape(Input0, name); } public Tensor Input1Constant(string onnxLayout, string name) { return GetRequiredInputAsConstant(Input1, onnxLayout, name); } public Tensor Input2Constant(string onnxLayout, string name) { return GetRequiredInputAsConstant(Input2, onnxLayout, name); } public Tensor Input3Constant(string onnxLayout, string name) { return GetRequiredInputAsConstant(Input3, onnxLayout, name); } public Tensor Input4Constant(string onnxLayout, string name) { return GetRequiredInputAsConstant(Input4, onnxLayout, name); } public Tensor Input5Constant(string onnxLayout, string name) { return GetRequiredInputAsConstant(Input5, onnxLayout, name); } public Tensor Input6Constant(string onnxLayout, string name) { return GetRequiredInputAsConstant(Input6, onnxLayout, name); } public Tensor Input1ConstantOptional(Tensor defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input1, onnxLayout, name); } catch (Exception) { return defaultValue; } } public Tensor Input2ConstantOptional(Tensor defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input2, onnxLayout, name); } catch (Exception) { return defaultValue; } } public Tensor Input3ConstantOptional(Tensor defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input3, onnxLayout, name); } catch (Exception) { return defaultValue; } } public Tensor Input4ConstantOptional(Tensor defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input4, onnxLayout, name); } catch (Exception) { return defaultValue; } } public Tensor Input1ConstantOptional(TensorShape shape, float defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input1, onnxLayout, name); } catch (Exception) { return DefaultTensor(shape, defaultValue); } } public Tensor Input2ConstantOptional(TensorShape shape, float defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input2, onnxLayout, name); } catch (Exception) { return DefaultTensor(shape, defaultValue); } } public Tensor Input3ConstantOptional(TensorShape shape, float defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input3, onnxLayout, name); } catch (Exception) { return DefaultTensor(shape, defaultValue); } } public Tensor Input4ConstantOptional(TensorShape shape, float defaultValue, string onnxLayout, string name) { try { return GetRequiredInputAsConstant(Input4, onnxLayout, name); } catch (Exception) { return DefaultTensor(shape, defaultValue); } } public Tensor Input1ConstantOptional(float defaultValue, string onnxLayout, string name) { return Input1ConstantOptional(new TensorShape(1, 1), defaultValue, onnxLayout, name); } public Tensor Input2ConstantOptional(float defaultValue, string onnxLayout, string name) { return Input2ConstantOptional(new TensorShape(1, 1), defaultValue, onnxLayout, name); } public Tensor Input3ConstantOptional(float defaultValue, string onnxLayout, string name) { return Input3ConstantOptional(new TensorShape(1, 1), defaultValue, onnxLayout, name); } public Tensor Input4ConstantOptional(float defaultValue, string onnxLayout, string name) { return Input4ConstantOptional(new TensorShape(1, 1), defaultValue, onnxLayout, name); } // Attributes public float Alpha { get { return GetRequiredFloat("alpha"); } } public float Beta { get { return GetRequiredFloat("beta"); } } public float Gamma { get { return GetRequiredFloat("gamma"); } } public float Epsilon { get { return GetRequiredFloat("epsilon"); } } public float Mean { get { return GetRequiredFloat("mean"); } } public float Scale { get { return GetRequiredFloat("scale"); } } public float Seed { get { return GetOptionalFloat("seed", 1337f); } } // seed is always optional and defaults to 'auto generated' public ONNXTensor ValueAsTensor { get { return GetRequiredTensor("value"); } } public int Axis { get { return GetRequiredInt("axis"); } } public int BlockSize { get { return GetRequiredInt("blocksize"); } } public int Group { get { return GetRequiredInt("group"); } } public int[] Shape { get { return GetRequiredIntArray("shape"); } } public int[] Starts { get { return GetRequiredIntArray("starts"); } } public int[] Ends { get { return GetRequiredIntArray("ends"); } } public int[] Axes { get { return GetRequiredIntArray("axes"); } } public float[] Bias { get { return GetRequiredFloatArray("bias"); } } public int[] KernelShape { get { return GetRequiredIntArray("kernel_shape"); } } public int[] Strides { get { return GetOptionalIntArray("strides", new[] {1,1}); } } public int[] Strides3D { get { return GetOptionalIntArray("strides", new[] {1,1,1}); } } public int[] OutputPadding { get { return GetOptionalIntArray("output_padding", new[] {0,0}); } } internal bool SupportsAutoPad { get { return OperatorType != "Pad"; } } internal bool SupportsSpatialOnlyPads { get { return OperatorType != "Pad"; } } public int[] Pads { get { return ConvertPadsToBarracuda(); } } public int[] Pads3D { get { return ConvertPadsToBarracuda(new int[] {0,0,0,0,0,0}); } } public float[] Scales { get { return ConvertScalesToBarracuda(); } } public int[] Sizes { get { return ConvertSizesToBarracuda(); } } public float AlphaOptional(float defaultValue) { return GetOptionalFloat("alpha", defaultValue); } public float BetaOptional(float defaultValue) { return GetOptionalFloat("beta", defaultValue); } public float GammaOptional(float defaultValue) { return GetOptionalFloat("gamma", defaultValue); } public float EpsilonOptional(float defaultValue=1e-5f) { return GetOptionalFloat("epsilon", defaultValue); } public float MeanOptional(float defaultValue=0f) { return GetOptionalFloat("mean", defaultValue); } public float ScaleOptional(float defaultValue=1f) { return GetOptionalFloat("scale", defaultValue); } public bool TransAOptional(bool defaultValue=false) { return GetOptionalInt("transA", defaultValue?1:0) != 0;} public bool TransBOptional(bool defaultValue=false) { return GetOptionalInt("transB", defaultValue?1:0) != 0;} public int AxisOptional(int defaultValue) { return GetOptionalInt("axis", defaultValue); } public int GroupOptional(int defaultValue=1) { return GetOptionalInt("group", defaultValue); } public int[] KernelShapeOptional(int[] defaultValue) { return GetOptionalIntArray("kernel_shape", defaultValue); } public int[] AxesOptional(int[] defaultValue) { return GetOptionalIntArray("axes", defaultValue); } public float MinOptional(float defaultValue) { return GetOptionalFloat("min", defaultValue); } public float MaxOptional(float defaultValue) { return GetOptionalFloat("max", defaultValue); } public string ModeOptional(string defaultValue) { return GetOptionalString("mode", defaultValue); } public int[] DilatationsOptional(int[] defaultValue) { return GetOptionalIntArray("dilations", defaultValue); } // --------------------------------------------------------------------------------- // Implementation private NodeProto m_ONNXNode; private ONNXModelTensors m_ONNXModelTensors; private List m_ImporterWarnings; public ONNXNodeWrapper(NodeProto ONNXNode, ONNXModelTensors ONNXModelTensors, List importerWarnings) { m_ONNXNode = ONNXNode; m_ONNXModelTensors = ONNXModelTensors; m_ImporterWarnings = importerWarnings; } // Logging helpers public void Warn(string message) { m_ImporterWarnings.Add(new Model.ImporterWarning(Name, message)); Debug.LogWarning(message); } public bool HasAttribute(string name) { AttributeProto attr; return TryFindAttribute(name, out attr); } public void UnsupportedAttribute(string name) { AttributeProto attr; if (TryFindAttribute(name, out attr)) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored."); } public void UnsupportedAttribute(string name, int defaultValue) { if (GetOptionalInt(name, defaultValue) != defaultValue) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored and defaulted to {defaultValue}."); } public void UnsupportedAttribute(string name, float defaultValue) { if (GetOptionalFloat(name, defaultValue) != defaultValue) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored and defaulted to {defaultValue}."); } public void UnsupportedAttribute(string name, string defaultValue) { if (GetOptionalString(name, defaultValue) != defaultValue) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored and defaulted to {defaultValue}."); } public void UnsupportedAttribute(string name, int[] defaultValue) { var valueArray = GetOptionalIntArray(name, defaultValue); if (!Enumerable.SequenceEqual(valueArray, defaultValue)) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored and defaulted to [{string.Join(", ", defaultValue)}]."); } public void UnsupportedAttribute(string name, string[] defaultValue) { var stringArray = GetOptionalStringArray(name, defaultValue); if (!Enumerable.SequenceEqual(stringArray, defaultValue)) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored and defaulted to [{string.Join(", ", defaultValue)}]."); } public void UnsupportedAttribute(string name, Func predicate, int[] defaultValue) { var valueArray = GetOptionalIntArray(name, defaultValue); if (!Enumerable.All(valueArray, predicate)) Warn($"Unsupported attribute {name}, node {Name} of type {OperatorType}. Value will be ignored and defaulted to [{string.Join(", ", defaultValue)}]."); } public void IgnoredAttribute(string name, string reasonToIgnore) { } // Input helpers internal string GetRequiredInput(int inputIndex) { if ((inputIndex >= m_ONNXNode.Input.Count) || (m_ONNXNode.Input[inputIndex] == "")) throw new OnnxLayerImportException($"required Input {inputIndex} was not found."); return m_ONNXNode.Input[inputIndex]; } internal bool IsInput1Array(string name) { if (Input1 == "") throw new OnnxLayerImportException("Input value is marked as required, but it is missing in the model."); ONNXTensor onnxTensor; if (!m_ONNXModelTensors.constants.TryGetValue(Input1, out onnxTensor)) throw new OnnxLayerImportException( $"Currently only constant tensors are supported for `{name}` input in node of type {OperatorType}. Instead {Name}.{name} is pointing to non constant node {Input1}."); return onnxTensor.rank != 0; } internal Tensor GetRequiredInputAsConstant(string input, string onnxLayout, string onnxName) { if (input == "") throw new OnnxLayerImportException("Input value is marked as required, but it is missing in the model."); ONNXTensor onnxTensor; if (!m_ONNXModelTensors.constants.TryGetValue(input, out onnxTensor)) throw new OnnxLayerImportException( $"Currently only constant tensors are supported for `{onnxName}` input in node of type {OperatorType}. Instead {Name}.{onnxName} is pointing to non constant node {input}."); return onnxTensor.ToBarracuda(onnxLayout); } internal int[] GetRequiredInputConstantONNXShape(string input, string onnxName) { if (input == "") throw new OnnxLayerImportException("Input value is marked as required, but it is missing in the model."); ONNXTensor onnxTensor; if (!m_ONNXModelTensors.constants.TryGetValue(input, out onnxTensor)) throw new OnnxLayerImportException( $"Currently only constant tensors are supported for `{onnxName}` input in node of type {OperatorType}. Instead {Name}.{onnxName} is pointing to non constant node {input}."); return onnxTensor.shape; } internal bool IsInputConst(int inputIndex) { var input = GetRequiredInput(inputIndex); return m_ONNXModelTensors.constants.ContainsKey(input); } // Attribute helpers internal bool TryFindAttribute(string name, out AttributeProto attr) { return TryFindAttribute(name, AttributeProto.Types.AttributeType.Undefined, out attr); } internal bool TryFindAttribute(string name, AttributeProto.Types.AttributeType type, out AttributeProto attr) { const AttributeProto.Types.AttributeType undefined = AttributeProto.Types.AttributeType.Undefined; var attributes = m_ONNXNode.Attribute; for (var i = 0; i < attributes.Count; ++i) { attr = attributes[i]; if (attr.Name == name && (attr.Type == type || attr.Type == undefined || type == undefined)) return true; } attr = null; return false; } internal AttributeProto FindAttribute(string name, AttributeProto.Types.AttributeType type = AttributeProto.Types.AttributeType.Undefined) { AttributeProto attr = null; if (TryFindAttribute(name, type, out attr)) return attr; throw new OnnxLayerImportException($"Couldn't find attribute {name} of type {type}"); } public float GetOptionalFloat(string name, float defaultValue) { try { return GetRequiredFloat(name); } catch (OnnxLayerImportException) { return defaultValue; } } public float GetRequiredFloat(string name) { return FindAttribute(name, AttributeProto.Types.AttributeType.Float).F; } public float[] GetOptionalFloatArray(string name, float[] defaultValue) { try { return GetRequiredFloatArray(name); } catch (OnnxLayerImportException) { return defaultValue; } } public float[] GetRequiredFloatArray(string name) { var attribute = FindAttribute(name,AttributeProto.Types.AttributeType.Floats); return attribute.Floats.ToArray(); } public ONNXTensor GetOptionalTensor(string name, ONNXTensor defaultValue) { try { return GetRequiredTensor(name); } catch (OnnxLayerImportException) { return defaultValue; } } public ONNXTensor GetRequiredTensor(string name) { var tensorProto = FindAttribute(name, AttributeProto.Types.AttributeType.Tensor).T; return new ONNXTensor(tensorProto); } public int GetOptionalInt(string name, int defaultValue) { try { return GetRequiredInt(name); } catch (OnnxLayerImportException) { return defaultValue; } } public int GetRequiredInt(string name) { long v = FindAttribute(name, AttributeProto.Types.AttributeType.Int).I; return v < int.MinValue ? int.MinValue : v > int.MaxValue ? int.MaxValue : (int)v; } public int[] GetOptionalIntArray(string name, int[] defaultValue) { try { return GetRequiredIntArray(name); } catch (OnnxLayerImportException) { return defaultValue; } } public int[] GetRequiredIntArray(string name) { var attribute = FindAttribute(name,AttributeProto.Types.AttributeType.Ints); return attribute.Ints.Select(v => v < int.MinValue ? int.MinValue : v > int.MaxValue ? int.MaxValue : (int)v).ToArray(); } public string GetOptionalString(string name, string defaultValue) { try { return GetRequiredString(name); } catch (OnnxLayerImportException) { return defaultValue; } } public string GetRequiredString(string name) { var raw = FindAttribute(name, AttributeProto.Types.AttributeType.String).S; return raw.ToStringUtf8(); } public string[] GetOptionalStringArray(string name, string[] defaultValue) { try { return GetRequiredStringArray(name); } catch (OnnxLayerImportException) { return defaultValue; } } public string[] GetRequiredStringArray(string name) { var attribute = FindAttribute(name,AttributeProto.Types.AttributeType.Strings); return attribute.Strings.Select(s => s.ToStringUtf8()).ToArray(); } public Layer.AutoPad AutoPadMode() { var autoPad = GetOptionalString("auto_pad", "NOTSET"); Layer.AutoPad autoPadType = Layer.AutoPad.NotSet; if (autoPad == "VALID") autoPadType = Layer.AutoPad.Valid; else if (autoPad == "SAME_UPPER") autoPadType = Layer.AutoPad.SameUpper; else if (autoPad == "SAME_LOWER") autoPadType = Layer.AutoPad.SameLower; return autoPadType; } public Layer.PadMode PadMode() { var mode = ModeOptional("constant"); var modeType = Layer.PadMode.Constant; switch (mode) { case "constant": modeType = Layer.PadMode.Constant; break; case "reflect": modeType = Layer.PadMode.Reflect; break; case "edge": modeType = Layer.PadMode.Edge; break; } return modeType; } // Complex attribute helpers private int[] ConvertPadsToBarracuda(int[] defaultValues = null) { var noPadding = defaultValues??new[] {0,0,0,0}; if (SupportsAutoPad) { // known_paddings = { // 'VALID' : [0,0,0,0], // 'SAME_UPPER' : [-1], // 'SAME_LOWER' : [-2], // } var autoPad = GetOptionalString("auto_pad", "NOTSET"); if (autoPad == "VALID") return noPadding; else if (autoPad == "SAME_UPPER") return new[] { -1 }; else if (autoPad == "SAME_LOWER") return new[] { -2 }; else {} // TODO: Assert NOTSET } var pads = GetOptionalIntArray("pads", noPadding); if (pads.Length % 2 != 0) throw new OnnxLayerImportException( $"Attribute pads of unsupported length {pads.Length} in {Name} ot fype {OperatorType}."); var starts = pads.Take(pads.Length / 2).ToArray(); var ends = pads.Skip(pads.Length / 2).ToArray(); if (SupportsSpatialOnlyPads) { // See: https://github.com/onnx/onnx/blob/master/docs/Operators.md#AveragePool // Padding for the beginning and ending along each spatial axis, it can take any value greater than or equal to 0. // The value represent the number of pixels added to the beginning and end part of the corresponding axis. } else { // Padding containts non-spatial dimensions including N and C // See: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pad // `pads` should be a 1D tensor of shape [2 * input_rank]. Assert.IsTrue(starts.Length == ends.Length); bool[] dimHavePadding = new bool[starts.Length]; for (int i = 0; i < starts.Length; ++i) { dimHavePadding[i] = starts[i] != 0 && ends[i] != 0; } if (dimHavePadding.SequenceEqual(new bool []{ false, true, true, false })) { // Look like this padding operator is defined over NHWC layout // We skip first and last dimension thus starts = starts.Skip(1).Take(2).ToArray(); ends = ends.Skip(1).Take(2).ToArray(); } else { if ((starts.Length < 2) || (starts[0] != 0) || (starts[1] != 0) || // N (ends[0] != 0) || (ends[1] != 0)) // C Warn("Only spatial (H and W) padding is currently supported." + " Non spatial padding (N and C) will be ignored and default to 0."); // Skip non-spatial dimensions N, C (NCHW layout) starts = starts.Skip(2).ToArray(); ends = ends.Skip(2).ToArray(); } } // See: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pad // ONNX `pads` format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...], // where xi_begin the number of pixels added at the beginning of axis `i` and xi_end, // the number of pixels added at the end of axis `i`. // Convert ONNX pad layout of [z, y, x ..., z', y', x'] to Barracuda layout [x, y, z ..., x', y', z'] // where x is x1_begin, y is x2_begin ... // x' is x1_end, y' is x2_end ... Assert.IsTrue(starts.Length == ends.Length); switch (starts.Length) { case 0: return new [] { 0, 0, 0, 0 }; case 1: return new [] { starts[0], 0, ends[0], 0 }; // 1D W => W_ case 2: return new [] { starts[1], starts[0], ends[1], ends[0] }; // 2D HW => WH case 3: return new [] { starts[2], starts[1], starts[0], ends[2], ends[1], ends[0] };// 3D DHW => WHD default: throw new OnnxLayerImportException( $"Attribute pads of unsupported length {pads.Length} in {Name} ot type {OperatorType}."); } } internal float[] ConvertScales() { float[] scales; if (InputCount > 2) // Resize-11 { Assert.IsTrue(OperatorType == "Resize"); scales = Input2Constant(onnxLayout: "C", name: "scales").AsFloats(); } else if (InputCount > 1) // Resize-10, Upsample-9 { scales = Input1Constant(onnxLayout: "C", name: "scales").AsFloats(); } else { Assert.IsTrue(OperatorType == "Upsample"); scales = GetOptionalFloatArray("scales", new float[0]); // Upsample-7 if (scales?.Length == 0) // Upsample-1 { scales = new[] { 1, // N 1, // C GetRequiredFloat("height_scale"), GetRequiredFloat("width_scale") }; } } Assert.IsTrue(scales != null); return scales; } internal int[] ConvertSizes() { int[] sizes = null; Assert.IsTrue(OperatorType == "Resize"); Assert.IsTrue(InputCount == 4); if (IsInput3Const) { sizes = Input3Constant(onnxLayout: "C", name: "sizes").AsInts(); Assert.IsTrue(sizes != null); Assert.IsTrue(sizes.Length == 4); if ((sizes[0] != 1) || (sizes[1] != 1)) Warn("Only spatial (H and W) resizing is currently supported." + " Non spatial sizes (N and C) will be ignored and default to identity."); } else throw new OnnxLayerImportException( $"Only constant size values are currently supported in {Name} ot type {OperatorType}."); return sizes; } private float[] ConvertScalesToBarracuda() { float[] scales; if (InputCount > 2) // Resize-11 { Assert.IsTrue(OperatorType == "Resize"); scales = Input2Constant(onnxLayout:"C", name:"scales").AsFloats(); } else if (InputCount > 1) // Resize-10, Upsample-9 { scales = Input1Constant(onnxLayout:"C", name:"scales").AsFloats(); } else { Assert.IsTrue(OperatorType == "Upsample"); scales = GetOptionalFloatArray("scales", new float[0]); // Upsample-7 if (scales?.Length == 0) // Upsample-1 { scales = new[] { 1, // N 1, // C GetRequiredFloat("height_scale"), GetRequiredFloat("width_scale") }; } } Assert.IsTrue(scales != null); if ((scales.Length < 2) || (scales[0] != 1) || (scales[1] != 1)) Warn("Only spatial (H and W) padding is currently supported." + " Non spatial scales (N and C) will be ignored and default to 1."); // Skip non-spatial dimensions N, C (NCHW layout) scales = scales.Skip(2).ToArray(); switch (scales.Length) { case 0: return new [] { 1f, 1f }; case 1: return new [] { scales[0], 1 }; // 1D W => W_ case 2: return new [] { scales[1], scales[0] }; // 2D HW => WH case 3: return new [] { scales[2], scales[1], scales[0] }; // 3D DHW => WHD default: throw new OnnxLayerImportException( $"Attribute pads of unsupported length {scales.Length} in {Name} ot type {OperatorType}."); } } private int[] ConvertSizesToBarracuda() { int[] sizes = null; Assert.IsTrue(OperatorType == "Resize"); Assert.IsTrue(InputCount == 4); if (IsInput3Const) { sizes = Input3Constant(onnxLayout: "C", name: "sizes").AsInts(); Assert.IsTrue(sizes != null); Assert.IsTrue(sizes.Length == 4); if ((sizes[0] != 1) || (sizes[1] != 1)) Warn("Only spatial (H and W) resizing is currently supported." + " Non spatial sizes (N and C) will be ignored and default to identity."); // Skip non-spatial dimensions N, C, return WH (NCHW layout) sizes = sizes.Skip(2).Reverse().ToArray(); } else throw new OnnxLayerImportException( $"Only constant size values are currently supported in {Name} ot type {OperatorType}."); return sizes; } public Tensor DefaultTensor(TensorShape tensorShape, float defaultValue) { var shape = tensorShape; var data = Enumerable.Repeat(defaultValue, tensorShape.length).ToArray(); return new Tensor(shape, data); } } }