33 CheckValidNumDimensions(numDimensions);
35 std::fill(m_Dimensions.begin(), m_Dimensions.begin() + m_NumDimensions, 0);
36 std::fill(m_DimensionsSpecificity.begin(), m_DimensionsSpecificity.begin() + m_NumDimensions,
37 initDimensionsSpecificity);
43 CheckValidNumDimensions(numDimensions);
45 if (dimensionSizes ==
nullptr)
50 std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
51 std::fill(m_DimensionsSpecificity.begin(), m_DimensionsSpecificity.begin() + m_NumDimensions,
true);
60 const unsigned int*
const dimensionSizes,
61 const bool*
const dimensionsSpecificity)
64 CheckValidNumDimensions(numDimensions);
66 if (dimensionSizes ==
nullptr)
71 if (dimensionsSpecificity ==
nullptr)
76 std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
77 std::copy(dimensionsSpecificity, dimensionsSpecificity + numDimensions, m_DimensionsSpecificity.begin());
81 std::initializer_list<bool> dimensionsSpecificityList)
83 auto numDimensions =
static_cast<unsigned int>(dimensionSizeList.size());
84 if (dimensionsSpecificityList.size() != numDimensions)
89 *
this =
TensorShape(numDimensions, dimensionSizeList.begin(), dimensionsSpecificityList.begin());
93 : m_Dimensionality(dimensionality)
95 switch (dimensionality)
99 "for tensors that have an unknown number of dimensions or that are scalar");
104 m_DimensionsSpecificity = {
false};
109 m_DimensionsSpecificity = {
true};
117 : m_NumDimensions(other.m_NumDimensions), m_Dimensionality(other.m_Dimensionality)
119 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
120 std::copy(other.m_DimensionsSpecificity.cbegin(), other.m_DimensionsSpecificity.cbegin() + other.m_NumDimensions,
121 m_DimensionsSpecificity.begin());
126 m_NumDimensions = other.m_NumDimensions;
127 m_Dimensionality = other.m_Dimensionality;
128 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
129 std::copy(other.m_DimensionsSpecificity.cbegin(), other.m_DimensionsSpecificity.cbegin() + other.m_NumDimensions,
130 m_DimensionsSpecificity.begin());
137 CheckUnspecifiedNumDimensions();
138 CheckDimensionIndex(i);
139 CheckDimensionSpecified(i);
141 return m_Dimensions.at(i);
149 std::stringstream errorMessage;
150 errorMessage <<
"TensorShape with Dimensionality::Scalar must be const to use operator[]";
153 CheckUnspecifiedNumDimensions();
154 CheckDimensionIndex(i);
155 CheckDimensionSpecified(i);
157 return m_Dimensions.at(i);
162 return ((m_NumDimensions == other.m_NumDimensions) &&
163 (m_Dimensionality == other.m_Dimensionality) &&
164 std::equal(m_Dimensions.cbegin(), m_Dimensions.cbegin() + m_NumDimensions, other.m_Dimensions.cbegin()) &&
165 std::equal(m_DimensionsSpecificity.cbegin(), m_DimensionsSpecificity.cbegin() + m_NumDimensions,
166 other.m_DimensionsSpecificity.cbegin()));
171 return !(*
this == other);
176 CheckUnspecifiedNumDimensions();
178 return m_NumDimensions;
183 CheckUnspecifiedNumDimensions();
185 if (m_NumDimensions == 0)
190 unsigned int count = 1;
191 bool atLeastOneDimensionSpecified =
false;
192 for (
unsigned int i = 0; i < m_NumDimensions; ++i)
194 if (m_DimensionsSpecificity[i])
196 atLeastOneDimensionSpecified =
true;
197 count *= m_Dimensions[i];
201 if (atLeastOneDimensionSpecified)
213 CheckUnspecifiedNumDimensions();
214 CheckDimensionIndex(i);
216 return m_DimensionsSpecificity[i];
222 CheckSpecifiedNumDimensions();
223 CheckValidNumDimensions(numDimensions);
225 m_NumDimensions = numDimensions;
227 std::fill(m_Dimensions.begin(), m_Dimensions.begin() + m_NumDimensions, 0);
228 std::fill(m_DimensionsSpecificity.begin(), m_DimensionsSpecificity.begin() + m_NumDimensions,
229 initDimensionsSpecificity);
235 CheckDimensionIndex(i);
237 m_Dimensions[i] = dimensionSize;
238 m_DimensionsSpecificity[i] =
true;
243 CheckUnspecifiedNumDimensions();
245 bool areAllDimensionsSpecified =
true;
246 for (
unsigned int i = 0; i < m_NumDimensions; ++i)
248 if (!m_DimensionsSpecificity[i])
250 areAllDimensionsSpecified =
false;
254 return areAllDimensionsSpecified;
259 CheckUnspecifiedNumDimensions();
261 bool isAtLeastOneDimensionSpecified =
false;
262 for (
unsigned int i = 0; i < m_NumDimensions; ++i)
264 if (m_DimensionsSpecificity[i])
266 isAtLeastOneDimensionSpecified =
true;
270 return isAtLeastOneDimensionSpecified;
273 void TensorShape::CheckDimensionIndex(
unsigned int i)
const
275 if (i >= m_NumDimensions)
277 std::stringstream errorMessage;
278 errorMessage <<
"Invalid dimension index: " << i <<
" (number of dimensions is " << m_NumDimensions <<
")";
283 void TensorShape::CheckValidNumDimensions(
unsigned int numDimensions)
285 if (numDimensions < 1)
287 throw InvalidArgumentException(
"Tensor numDimensions must be greater than 0",
CHECK_LOCATION());
292 throw InvalidArgumentException(
"Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions"
297 void TensorShape::CheckDimensionSpecified(
unsigned int i)
const
299 if (!m_DimensionsSpecificity[i])
301 std::stringstream errorMessage;
302 errorMessage <<
"Dimension index: " << i <<
" not specified. Tensor shape not inferred yet.";
303 throw InvalidArgumentException(errorMessage.str(),
CHECK_LOCATION());
307 void TensorShape::CheckScalar()
const
311 std::stringstream errorMessage;
312 errorMessage <<
"Invalid action on a tensor shape that holds a scalar value.";
313 throw InvalidArgumentException(errorMessage.str(),
CHECK_LOCATION());
317 void TensorShape::CheckUnspecifiedNumDimensions()
const
321 std::stringstream errorMessage;
322 errorMessage <<
"Invalid action on a tensor shape that has unknown number of dimensions.";
323 throw InvalidArgumentException(errorMessage.str(),
CHECK_LOCATION());
327 void TensorShape::CheckSpecifiedNumDimensions()
const
331 std::stringstream errorMessage;
332 errorMessage <<
"Invalid action on a tensor shape that has known number of dimensions.";
333 throw InvalidArgumentException(errorMessage.str(),
CHECK_LOCATION());
348 float quantizationScale,
349 int32_t quantizationOffset,
352 , m_DataType(dataType)
353 , m_IsConstant(isConstant)
360 const unsigned int* dimensionSizes,
362 float quantizationScale,
363 int32_t quantizationOffset,
365 : m_Shape(numDimensions, dimensionSizes), m_DataType(dataType), m_IsConstant(isConstant)
373 const std::vector<float>& quantizationScales,
374 unsigned int quantizationDim,
377 , m_DataType(dataType)
378 , m_IsConstant(isConstant)
385 const unsigned int* dimensionSizes,
387 const std::vector<float>& quantizationScales,
388 unsigned int quantizationDim,
390 : m_Shape(numDimensions, dimensionSizes)
391 , m_DataType(dataType)
392 , m_IsConstant(isConstant)
399 : m_Shape(other.m_Shape)
400 , m_DataType(other.m_DataType)
401 , m_IsConstant(other.m_IsConstant)
402 , m_Quantization(other.m_Quantization)
407 m_Shape = other.m_Shape;
408 m_DataType = other.m_DataType;
409 m_Quantization = other.m_Quantization;
410 m_IsConstant = other.m_IsConstant;
416 return ((m_Shape == other.m_Shape) &&
417 (m_DataType == other.m_DataType) &&
418 (m_Quantization == other.m_Quantization) &&
419 (m_IsConstant == other.m_IsConstant));
424 return !(*
this == other);
436 match &= m_DataType == other.m_DataType;
453 return m_Quantization.m_Scales;
458 m_Quantization.m_Scales = scales;
463 if (m_Quantization.m_Scales.empty())
470 return m_Quantization.m_Scales[0];
475 m_Quantization.m_Scales = { scale };
480 if (!m_Quantization.m_Offset.has_value())
486 return m_Quantization.m_Offset.value();
491 m_Quantization.m_Offset = MakeOptional<int32_t>(offset);
496 return m_Quantization.m_QuantizationDim;
501 m_Quantization.m_QuantizationDim = quantizationDim;
523 template<
typename MemoryType>
525 : m_MemoryArea(nullptr)
529 template<
typename MemoryType>
531 : m_MemoryArea(memoryArea)
536 template<
typename MemoryType>
538 : m_MemoryArea(other.m_MemoryArea)
539 , m_Info(other.GetInfo())
543 template<
typename MemoryType>
546 m_Info = other.m_Info;