15 struct TensorInfoFixture
19 unsigned int sizes[] = {6,7,8,9};
20 m_TensorInfo =
TensorInfo(4, sizes, DataType::Float32);
22 ~TensorInfoFixture() {};
30 CHECK(listInitializedShape == m_TensorInfo.
GetShape());
36 CHECK(m_TensorInfo.
GetShape()[0] == 6);
37 CHECK(m_TensorInfo.
GetShape()[1] == 7);
38 CHECK(m_TensorInfo.
GetShape()[2] == 8);
39 CHECK(m_TensorInfo.
GetShape()[3] == 9);
45 CHECK(copyConstructed.GetNumDimensions() == 4);
46 CHECK(copyConstructed.GetShape()[0] == 6);
47 CHECK(copyConstructed.GetShape()[1] == 7);
48 CHECK(copyConstructed.GetShape()[2] == 8);
49 CHECK(copyConstructed.GetShape()[3] == 9);
55 CHECK(copyConstructed == m_TensorInfo);
61 unsigned int sizes[] = {2,3,4,5};
62 other =
TensorInfo(4, sizes, DataType::Float32);
64 CHECK(other != m_TensorInfo);
71 CHECK(copy == m_TensorInfo);
74 TEST_CASE(
"CopyNoQuantizationTensorInfo")
93 CHECK(infoA != infoB);
95 CHECK(infoA == infoB);
104 TEST_CASE(
"CopyDifferentQuantizationTensorInfo")
121 CHECK((infoA.
GetDataType() == DataType::QAsymmU8));
126 CHECK(infoA != infoB);
128 CHECK(infoA == infoB);
131 CHECK((infoA.
GetDataType() == DataType::QAsymmU8));
142 TEST_CASE(
"TensorVsConstTensor")
144 int mutableDatum = 2;
145 const int immutableDatum = 3;
150 uninitializedTensor2 = uninitializedTensor;
160 TEST_CASE(
"ModifyTensorInfo")
173 TEST_CASE(
"TensorShapeOperatorBrackets")
179 CHECK(shape[2] == 2);
181 CHECK(shape[2] == 20);
184 CHECK(constShape[2] == 2);
187 TEST_CASE(
"TensorInfoPerAxisQuantization")
190 TensorInfo tensorInfo0({ 1, 1 }, DataType::Float32, 2.0f, 1);
191 CHECK(!tensorInfo0.HasMultipleQuantizationScales());
192 CHECK(tensorInfo0.GetQuantizationScale() == 2.0f);
193 CHECK(tensorInfo0.GetQuantizationOffset() == 1);
194 CHECK(tensorInfo0.GetQuantizationScales()[0] == 2.0f);
195 CHECK(!tensorInfo0.GetQuantizationDim().has_value());
198 std::vector<float> perAxisScales{ 3.0f, 4.0f };
200 CHECK(tensorInfo0.HasMultipleQuantizationScales());
201 CHECK(tensorInfo0.GetQuantizationScales() == perAxisScales);
204 tensorInfo0.SetQuantizationScale(5.0f);
205 CHECK(!tensorInfo0.HasMultipleQuantizationScales());
206 CHECK(tensorInfo0.GetQuantizationScales()[0] == 5.0f);
210 CHECK(tensorInfo0.GetQuantizationDim().value() == 1);
213 perAxisScales = { 6.0f, 7.0f };
214 TensorInfo tensorInfo1({ 1, 1 }, DataType::Float32, perAxisScales, 1);
215 CHECK(tensorInfo1.HasMultipleQuantizationScales());
216 CHECK(tensorInfo1.GetQuantizationOffset() == 0);
217 CHECK(tensorInfo1.GetQuantizationScales() == perAxisScales);
218 CHECK(tensorInfo1.GetQuantizationDim().value() == 1);
221 TEST_CASE(
"TensorShape_scalar")
223 float mutableDatum = 3.1416f;
230 float scalarValue = *
reinterpret_cast<float*
>(tensor.GetMemoryArea());
231 CHECK_MESSAGE(mutableDatum == scalarValue,
"Scalar value is " << scalarValue);
236 CHECK(shape_equal == shape);
237 CHECK(shape_different != shape);
238 CHECK_MESSAGE(1 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
239 CHECK_MESSAGE(1 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
240 CHECK(
true == shape.GetDimensionSpecificity(0));
241 CHECK(shape.AreAllDimensionsSpecified());
242 CHECK(shape.IsAtLeastOneDimensionSpecified());
244 CHECK(1 == shape[0]);
245 CHECK(1 == tensor.GetShape()[0]);
246 CHECK(1 == tensor.GetInfo().GetShape()[0]);
249 float newMutableDatum = 42.f;
250 std::memcpy(tensor.GetMemoryArea(), &newMutableDatum,
sizeof(float));
251 scalarValue = *
reinterpret_cast<float*
>(tensor.GetMemoryArea());
252 CHECK_MESSAGE(newMutableDatum == scalarValue,
"Scalar value is " << scalarValue);
255 TEST_CASE(
"TensorShape_DynamicTensorType1_unknownNumberDimensions")
257 float mutableDatum = 3.1416f;
271 CHECK(shape_equal == shape);
272 CHECK(shape_different != shape);
275 TEST_CASE(
"TensorShape_DynamicTensorType1_unknownAllDimensionsSizes")
277 float mutableDatum = 3.1416f;
284 CHECK_MESSAGE(0 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
285 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
286 CHECK(
false == shape.GetDimensionSpecificity(0));
287 CHECK(
false == shape.GetDimensionSpecificity(1));
288 CHECK(
false == shape.GetDimensionSpecificity(2));
289 CHECK(!shape.AreAllDimensionsSpecified());
290 CHECK(!shape.IsAtLeastOneDimensionSpecified());
295 CHECK(shape_equal == shape);
296 CHECK(shape_different != shape);
299 TEST_CASE(
"TensorShape_DynamicTensorType1_unknownSomeDimensionsSizes")
301 std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
309 CHECK_MESSAGE(6 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
310 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
311 CHECK(
true == shape.GetDimensionSpecificity(0));
312 CHECK(
false == shape.GetDimensionSpecificity(1));
313 CHECK(
true == shape.GetDimensionSpecificity(2));
314 CHECK(!shape.AreAllDimensionsSpecified());
315 CHECK(shape.IsAtLeastOneDimensionSpecified());
321 CHECK(2 == shape[0]);
322 CHECK(2 == tensor.GetShape()[0]);
323 CHECK(2 == tensor.GetInfo().GetShape()[0]);
326 CHECK(3 == shape[2]);
327 CHECK(3 == tensor.GetShape()[2]);
328 CHECK(3 == tensor.GetInfo().GetShape()[2]);
333 CHECK(shape_equal == shape);
334 CHECK(shape_different != shape);
337 TEST_CASE(
"TensorShape_DynamicTensorType1_transitionFromUnknownToKnownDimensionsSizes")
339 std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
347 shape.SetNumDimensions(3);
349 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
350 CHECK(
false == shape.GetDimensionSpecificity(0));
351 CHECK(
false == shape.GetDimensionSpecificity(1));
352 CHECK(
false == shape.GetDimensionSpecificity(2));
353 CHECK(!shape.AreAllDimensionsSpecified());
354 CHECK(!shape.IsAtLeastOneDimensionSpecified());
357 shape.SetDimensionSize(0, 2);
358 shape.SetDimensionSize(2, 3);
359 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
360 CHECK_MESSAGE(6 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
361 CHECK(
true == shape.GetDimensionSpecificity(0));
362 CHECK(
false == shape.GetDimensionSpecificity(1));
363 CHECK(
true == shape.GetDimensionSpecificity(2));
364 CHECK(!shape.AreAllDimensionsSpecified());
365 CHECK(shape.IsAtLeastOneDimensionSpecified());
369 CHECK(2 == shape[0]);
370 CHECK(2 == tensor2.GetShape()[0]);
371 CHECK(2 == tensor2.GetInfo().GetShape()[0]);
377 CHECK(3 == shape[2]);
378 CHECK(3 == tensor2.GetShape()[2]);
379 CHECK(3 == tensor2.GetInfo().GetShape()[2]);
384 CHECK(shape_equal == shape);
385 CHECK(shape_different != shape);
388 shape.SetDimensionSize(1, 5);
389 CHECK_MESSAGE(3 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
390 CHECK_MESSAGE(30 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
391 CHECK(
true == shape.GetDimensionSpecificity(0));
392 CHECK(
true == shape.GetDimensionSpecificity(1));
393 CHECK(
true == shape.GetDimensionSpecificity(2));
394 CHECK(shape.AreAllDimensionsSpecified());
395 CHECK(shape.IsAtLeastOneDimensionSpecified());
398 TEST_CASE(
"Tensor_emptyConstructors")
401 CHECK_MESSAGE( 0 == shape.GetNumDimensions(),
"Number of dimensions is " << shape.GetNumDimensions());
402 CHECK_MESSAGE( 0 == shape.GetNumElements(),
"Number of elements is " << shape.GetNumElements());
404 CHECK( shape.AreAllDimensionsSpecified());
408 CHECK_MESSAGE( 0 == tensor.GetNumDimensions(),
"Number of dimensions is " << tensor.GetNumDimensions());
409 CHECK_MESSAGE( 0 == tensor.GetNumElements(),
"Number of elements is " << tensor.GetNumElements());
410 CHECK_MESSAGE( 0 == tensor.GetShape().GetNumDimensions(),
"Number of dimensions is " <<
411 tensor.GetShape().GetNumDimensions());
412 CHECK_MESSAGE( 0 == tensor.GetShape().GetNumElements(),
"Number of dimensions is " <<
413 tensor.GetShape().GetNumElements());
415 CHECK( tensor.GetShape().AreAllDimensionsSpecified());
const TensorShape & GetShape() const
Optional< unsigned int > GetQuantizationDim() const
void SetShape(const TensorShape &newShape)
A tensor defined by a TensorInfo (shape and data type) and a mutable backing store.
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
int32_t GetQuantizationOffset() const
float GetQuantizationScale() const
DataType GetDataType() const
bool has_value() const noexcept
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
void SetQuantizationScale(float scale)
const TensorInfo & GetInfo() const
void SetDataType(DataType type)
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
void SetQuantizationOffset(int32_t offset)
void SetQuantizationScales(const std::vector< float > &scales)
unsigned int GetNumDimensions() const