Resolve WES-90 "Integrate signpredictor in courses"
This commit is contained in:
committed by
Jerome Coudron
parent
1a75791d62
commit
746906294b
@@ -0,0 +1,530 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
||||
using UnityEngine;
|
||||
using UnityEngine.Assertions;
|
||||
using System.IO.Compression;
|
||||
|
||||
|
||||
namespace Unity.Barracuda {
|
||||
|
||||
/// <summary>
|
||||
/// Test set loading utility
|
||||
/// </summary>
|
||||
public class TestSet
|
||||
{
|
||||
private RawTestSet rawTestSet;
|
||||
private JSONTestSet jsonTestSet;
|
||||
|
||||
/// <summary>
|
||||
/// Create with raw test set
|
||||
/// </summary>
|
||||
/// <param name="rawTestSet">raw test set</param>
|
||||
public TestSet(RawTestSet rawTestSet)
|
||||
{
|
||||
this.rawTestSet = rawTestSet;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create with JSON test set
|
||||
/// </summary>
|
||||
/// <param name="jsonTestSet">JSON test set</param>
|
||||
public TestSet(JSONTestSet jsonTestSet)
|
||||
{
|
||||
this.jsonTestSet = jsonTestSet;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create `TestSet`
|
||||
/// </summary>
|
||||
public TestSet()
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check if test set supports named tensors
|
||||
/// </summary>
|
||||
/// <returns>`true` if named tensors are supported</returns>
|
||||
public bool SupportsNames()
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get output tensor count
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public int GetOutputCount()
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return 1;
|
||||
|
||||
return jsonTestSet.outputs.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get output tensor data
|
||||
/// </summary>
|
||||
/// <param name="idx">tensor index</param>
|
||||
/// <returns>tensor data</returns>
|
||||
public float[] GetOutputData(int idx = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return rawTestSet.labels;
|
||||
|
||||
return jsonTestSet.outputs[idx].data;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get output tensor name
|
||||
/// </summary>
|
||||
/// <param name="idx">tensor index</param>
|
||||
/// <returns>tensor name</returns>
|
||||
public string GetOutputName(int idx = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return null;
|
||||
|
||||
string name = jsonTestSet.outputs[idx].name;
|
||||
return name.EndsWith(":0") ? name.Remove(name.Length - 2) : name;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get input tensor count
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public int GetInputCount()
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return 1;
|
||||
|
||||
return jsonTestSet.inputs.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get input tensor name
|
||||
/// </summary>
|
||||
/// <param name="idx">input tensor index</param>
|
||||
/// <returns>tensor name</returns>
|
||||
public string GetInputName(int idx = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return "";
|
||||
|
||||
string name = jsonTestSet.inputs[idx].name;
|
||||
return name.EndsWith(":0") ? name.Remove(name.Length - 2) : name;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get input tensor data
|
||||
/// </summary>
|
||||
/// <param name="idx">input tensor index</param>
|
||||
/// <returns>tensor data</returns>
|
||||
public float[] GetInputData(int idx = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return rawTestSet.input;
|
||||
|
||||
return jsonTestSet.inputs[idx].data;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get input shape
|
||||
/// </summary>
|
||||
/// <param name="idx">input tensor index</param>
|
||||
/// <returns>input shape</returns>
|
||||
public TensorShape GetInputShape(int idx = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return new TensorShape(1,rawTestSet.input.Length);
|
||||
|
||||
return new TensorShape(jsonTestSet.inputs[idx].shape.sequenceLength,
|
||||
jsonTestSet.inputs[idx].shape.numberOfDirections,
|
||||
jsonTestSet.inputs[idx].shape.batch,
|
||||
jsonTestSet.inputs[idx].shape.extraDimension,
|
||||
jsonTestSet.inputs[idx].shape.depth,
|
||||
jsonTestSet.inputs[idx].shape.height,
|
||||
jsonTestSet.inputs[idx].shape.width,
|
||||
jsonTestSet.inputs[idx].shape.channels);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get output tensor shape
|
||||
/// </summary>
|
||||
/// <param name="idx">output tensor index</param>
|
||||
/// <returns>tensor shape</returns>
|
||||
public TensorShape GetOutputShape(int idx = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
return new TensorShape(1,rawTestSet.labels.Length);
|
||||
|
||||
return new TensorShape(jsonTestSet.outputs[idx].shape.sequenceLength,
|
||||
jsonTestSet.outputs[idx].shape.numberOfDirections,
|
||||
jsonTestSet.outputs[idx].shape.batch,
|
||||
jsonTestSet.outputs[idx].shape.extraDimension,
|
||||
jsonTestSet.outputs[idx].shape.depth,
|
||||
jsonTestSet.outputs[idx].shape.height,
|
||||
jsonTestSet.outputs[idx].shape.width,
|
||||
jsonTestSet.outputs[idx].shape.channels);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get inputs as `Tensor` dictionary
|
||||
/// </summary>
|
||||
/// <param name="inputs">dictionary to store results</param>
|
||||
/// <param name="batchCount">max batch count</param>
|
||||
/// <param name="fromBatch">start from batch</param>
|
||||
/// <returns>dictionary with input tensors</returns>
|
||||
/// <exception cref="Exception">thrown if called on raw test set (only JSON test set is supported)</exception>
|
||||
public Dictionary<string, Tensor> GetInputsAsTensorDictionary(Dictionary<string, Tensor> inputs = null, int batchCount = -1, int fromBatch = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
throw new Exception("GetInputsAsTensorDictionary is not supported for RAW test suites");
|
||||
|
||||
if (inputs == null)
|
||||
inputs = new Dictionary<string, Tensor>();
|
||||
|
||||
for (var i = 0; i < GetInputCount(); i++)
|
||||
inputs[GetInputName(i)] = GetInputAsTensor(i, batchCount, fromBatch);
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get outputs as `Tensor` dictionary
|
||||
/// </summary>
|
||||
/// <param name="outputs">dictionary to store results</param>
|
||||
/// <param name="batchCount">max batch count</param>
|
||||
/// <param name="fromBatch">start from batch</param>
|
||||
/// <returns>dictionary with input tensors</returns>
|
||||
/// <exception cref="Exception">thrown if called on raw test set (only JSON test set is supported)</exception>
|
||||
public Dictionary<string, Tensor> GetOutputsAsTensorDictionary(Dictionary<string, Tensor> outputs = null, int batchCount = -1, int fromBatch = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
throw new Exception("GetOutputsAsTensorDictionary is not supported for RAW test suites");
|
||||
|
||||
if (outputs == null)
|
||||
outputs = new Dictionary<string, Tensor>();
|
||||
|
||||
for (var i = 0; i < GetOutputCount(); i++)
|
||||
outputs[GetOutputName(i)] = GetOutputAsTensor(i, batchCount, fromBatch);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get input as `Tensor`
|
||||
/// </summary>
|
||||
/// <param name="idx">input index</param>
|
||||
/// <param name="batchCount">max batch count</param>
|
||||
/// <param name="fromBatch">start from batch</param>
|
||||
/// <returns>`Tensor`</returns>
|
||||
/// <exception cref="Exception">thrown if called on raw test set (only JSON test set is supported)</exception>
|
||||
public Tensor GetInputAsTensor(int idx = 0, int batchCount = -1, int fromBatch = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
throw new Exception("GetInputAsTensor is not supported for RAW test suites");
|
||||
|
||||
TensorShape shape = GetInputShape(idx);
|
||||
Assert.IsTrue(shape.sequenceLength==1 && shape.numberOfDirections==1);
|
||||
var array = GetInputData(idx);
|
||||
var maxBatchCount = array.Length / shape.flatWidth;
|
||||
|
||||
fromBatch = Math.Min(fromBatch, maxBatchCount - 1);
|
||||
if (batchCount < 0)
|
||||
batchCount = maxBatchCount - fromBatch;
|
||||
|
||||
// pad data with 0s, if test-set doesn't have enough batches
|
||||
var shapeArray = shape.ToArray();
|
||||
shapeArray[TensorShape.DataBatch] = batchCount;
|
||||
var tensorShape = new TensorShape(shapeArray);
|
||||
var managedBufferStartIndex = fromBatch * tensorShape.flatWidth;
|
||||
var count = Math.Min(batchCount, maxBatchCount - fromBatch) * tensorShape.flatWidth;
|
||||
float[] dataToUpload = new float[tensorShape.length];
|
||||
Array.Copy(array, managedBufferStartIndex, dataToUpload, 0, count);
|
||||
|
||||
var data = new ArrayTensorData(tensorShape.length);
|
||||
data.Upload(dataToUpload, tensorShape, 0);
|
||||
|
||||
var res = new Tensor(tensorShape, data);
|
||||
res.name = GetInputName(idx);
|
||||
res.name = res.name.EndsWith(":0") ? res.name.Remove(res.name.Length - 2) : res.name;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get output as `Tensor`
|
||||
/// </summary>
|
||||
/// <param name="idx">output index</param>
|
||||
/// <param name="batchCount">max batch count</param>
|
||||
/// <param name="fromBatch">start from batch</param>
|
||||
/// <returns>`Tensor`</returns>
|
||||
/// <exception cref="Exception">thrown if called on raw test set (only JSON test set is supported)</exception>
|
||||
public Tensor GetOutputAsTensor(int idx = 0, int batchCount = -1, int fromBatch = 0)
|
||||
{
|
||||
if (rawTestSet != null)
|
||||
throw new Exception("GetOutputAsTensor is not supported for RAW test suites");
|
||||
|
||||
TensorShape shape = GetOutputShape(idx);
|
||||
Assert.IsTrue(shape.sequenceLength==1 && shape.numberOfDirections==1);
|
||||
var barracudaArray = new BarracudaArrayFromManagedArray(GetOutputData(idx));
|
||||
|
||||
var maxBatchCount = barracudaArray.Length / shape.flatWidth;
|
||||
|
||||
fromBatch = Math.Min(fromBatch, maxBatchCount - 1);
|
||||
if (batchCount < 0)
|
||||
batchCount = maxBatchCount - fromBatch;
|
||||
batchCount = Math.Min(batchCount, maxBatchCount - fromBatch);
|
||||
|
||||
var shapeArray = shape.ToArray();
|
||||
shapeArray[TensorShape.DataBatch] = batchCount;
|
||||
var tensorShape = new TensorShape(shapeArray);
|
||||
|
||||
var offset = fromBatch * tensorShape.flatWidth;
|
||||
var res = new Tensor(tensorShape, new SharedArrayTensorData(barracudaArray, tensorShape, offset));
|
||||
res.name = GetOutputName(idx);
|
||||
res.name = res.name.EndsWith(":0") ? res.name.Remove(res.name.Length - 2) : res.name;
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Raw test structure
|
||||
/// </summary>
|
||||
public class RawTestSet
|
||||
{
|
||||
/// <summary>
|
||||
/// Input data
|
||||
/// </summary>
|
||||
public float[] input;
|
||||
|
||||
/// <summary>
|
||||
/// Output data
|
||||
/// </summary>
|
||||
public float[] labels;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JSON test structure
|
||||
/// </summary>
|
||||
[Serializable]
|
||||
public class JSONTestSet
|
||||
{
|
||||
/// <summary>
|
||||
/// Inputs
|
||||
/// </summary>
|
||||
public JSONTensor[] inputs;
|
||||
|
||||
/// <summary>
|
||||
/// Outputs
|
||||
/// </summary>
|
||||
public JSONTensor[] outputs;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JSON tensor shape
|
||||
/// </summary>
|
||||
[Serializable]
|
||||
public class JSONTensorShape
|
||||
{
|
||||
/// <summary>
|
||||
/// Sequence length
|
||||
/// </summary>
|
||||
public int sequenceLength;
|
||||
|
||||
/// <summary>
|
||||
/// Number of directions
|
||||
/// </summary>
|
||||
public int numberOfDirections;
|
||||
|
||||
/// <summary>
|
||||
/// Batch
|
||||
/// </summary>
|
||||
public int batch;
|
||||
|
||||
/// <summary>
|
||||
/// Extra dimension
|
||||
/// </summary>
|
||||
public int extraDimension;
|
||||
|
||||
/// <summary>
|
||||
/// Depth
|
||||
/// </summary>
|
||||
public int depth;
|
||||
|
||||
/// <summary>
|
||||
/// Height
|
||||
/// </summary>
|
||||
public int height;
|
||||
|
||||
/// <summary>
|
||||
/// Width
|
||||
/// </summary>
|
||||
public int width;
|
||||
|
||||
/// <summary>
|
||||
/// Channels
|
||||
/// </summary>
|
||||
public int channels;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JSON tensor
|
||||
/// </summary>
|
||||
[Serializable]
|
||||
public class JSONTensor
|
||||
{
|
||||
/// <summary>
|
||||
/// Name
|
||||
/// </summary>
|
||||
public string name;
|
||||
|
||||
/// <summary>
|
||||
/// Shape
|
||||
/// </summary>
|
||||
public JSONTensorShape shape;
|
||||
|
||||
/// <summary>
|
||||
/// Tensor type
|
||||
/// </summary>
|
||||
public string type;
|
||||
|
||||
/// <summary>
|
||||
/// Tensor data
|
||||
/// </summary>
|
||||
public float[] data;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Test set loader
|
||||
/// </summary>
|
||||
public class TestSetLoader
|
||||
{
|
||||
/// <summary>
|
||||
/// Load test set from file
|
||||
/// </summary>
|
||||
/// <param name="filename">file name</param>
|
||||
/// <returns>`TestSet`</returns>
|
||||
public static TestSet Load(string filename)
|
||||
{
|
||||
if (filename.ToLower().EndsWith(".raw"))
|
||||
return LoadRaw(filename);
|
||||
else if (filename.ToLower().EndsWith(".gz"))
|
||||
return LoadGZ(filename);
|
||||
|
||||
return LoadJSON(filename);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load GZ
|
||||
/// </summary>
|
||||
/// <param name="filename">file name</param>
|
||||
/// <returns>`TestSet`</returns>
|
||||
public static TestSet LoadGZ(string filename)
|
||||
{
|
||||
var jsonFileName = filename.Substring(0, filename.Length - 3);
|
||||
var sourceArchiveFileName = Path.Combine(Application.streamingAssetsPath, "TestSet", filename);
|
||||
var destinationDirectoryName = sourceArchiveFileName.Substring(0, sourceArchiveFileName.Length - 3);
|
||||
|
||||
FileInfo fileToDecompress = new FileInfo(sourceArchiveFileName);
|
||||
using (FileStream originalFileStream = fileToDecompress.OpenRead())
|
||||
{
|
||||
using (FileStream decompressedFileStream = File.Create(destinationDirectoryName))
|
||||
{
|
||||
using (GZipStream decompressionStream = new GZipStream(originalFileStream, CompressionMode.Decompress))
|
||||
{
|
||||
decompressionStream.CopyTo(decompressedFileStream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return LoadJSON(jsonFileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load JSON
|
||||
/// </summary>
|
||||
/// <param name="filename">file name</param>
|
||||
/// <returns>`TestSet`</returns>
|
||||
public static TestSet LoadJSON(string filename)
|
||||
{
|
||||
string json = "";
|
||||
|
||||
if (filename.EndsWith(".json"))
|
||||
json = File.ReadAllText(Path.Combine(Application.streamingAssetsPath, "TestSet", filename));
|
||||
else
|
||||
json = Resources.Load<TextAsset>($"TestSet/{filename}").text;
|
||||
|
||||
TestSet result = new TestSet(JsonUtility.FromJson<JSONTestSet>(json));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load raw test set
|
||||
/// </summary>
|
||||
/// <param name="filename">file name</param>
|
||||
/// <returns>`TestSet`</returns>
|
||||
public static TestSet LoadRaw(string filename)
|
||||
{
|
||||
string fullpath = Path.Combine(Application.streamingAssetsPath, "TestSet", filename);
|
||||
|
||||
using(BinaryReader file = Open(fullpath))
|
||||
{
|
||||
|
||||
var rawTestSet = new RawTestSet();
|
||||
rawTestSet.input = LoadFloatArray(file);
|
||||
rawTestSet.labels = LoadFloatArray(file);
|
||||
return new TestSet(rawTestSet);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load image
|
||||
/// </summary>
|
||||
/// <param name="filename">file name</param>
|
||||
/// <returns>`Texture`</returns>
|
||||
public static Texture LoadImage(string filename)
|
||||
{
|
||||
string fullpath = Path.Combine(Application.streamingAssetsPath, "TestSet", filename);
|
||||
|
||||
var bytes = File.ReadAllBytes(fullpath);
|
||||
var tex = new Texture2D(2, 2);
|
||||
ImageConversion.LoadImage(tex, bytes, false); // LoadImage will auto-resize the texture dimensions
|
||||
tex.wrapMode = TextureWrapMode.Clamp;
|
||||
return tex;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load float array
|
||||
/// </summary>
|
||||
/// <param name="file">binary file reader</param>
|
||||
/// <returns>float array</returns>
|
||||
public static float[] LoadFloatArray(BinaryReader file)
|
||||
{
|
||||
Int64 dataLength = file.ReadInt64();
|
||||
float[] array = new float[dataLength];
|
||||
byte[] bytes = file.ReadBytes(Convert.ToInt32(dataLength * sizeof(float))); // @TODO: support larger than MaxInt32 data blocks
|
||||
Buffer.BlockCopy(bytes, 0, array, 0, bytes.Length);
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Open file with binary reader
|
||||
/// </summary>
|
||||
/// <param name="filename">file name</param>
|
||||
/// <returns>`BinaryReader`</returns>
|
||||
static BinaryReader Open(string filename)
|
||||
{
|
||||
return new BinaryReader(new FileStream(filename, FileMode.Open, FileAccess.Read));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace Unity.Barracuda
|
||||
Reference in New Issue
Block a user