4189{
4190 const std::string descriptorName{"BatchMatMulDescriptor"};
4191
4192 ValidateNumInputs(workloadInfo, descriptorName, 2);
4193 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4194
4195
4196
4197
4198
4202
4203
4204 std::vector<DataType> supportedTypes =
4205 {
4212 };
4213
4214 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4215 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4216 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4217
4218 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4219 (inputYInfoBeforeParams.GetNumDimensions() < 2))
4220 {
4221 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4222 }
4223
4224 TensorInfo inputXInfoAfterParams;
4225 TensorInfo inputYInfoAfterParams;
4226
4229 {
4230 throw InvalidArgumentException(descriptorName +
4231 ": Invalid descriptor parameters - Transpose and Adjoint "
4232 "cannot both be true for a given input tensor.");
4233 }
4235 {
4239 inputXInfoBeforeParams.GetShape()));
4240 }
4242 {
4244 inputXInfoBeforeParams.GetShape());
4245 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4246 inputXInfoBeforeParams.GetShape()[axesToMul.second])
4247 {
4248 throw InvalidArgumentException(descriptorName +
4249 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4250 }
4251
4252 inputXInfoAfterParams = inputXInfoBeforeParams;
4253 }
4254 else
4255 {
4256 inputXInfoAfterParams = inputXInfoBeforeParams;
4257 }
4258
4260 {
4264 inputYInfoBeforeParams.GetShape()));
4265 }
4267 {
4269 inputYInfoBeforeParams.GetShape());
4270 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4271 inputYInfoBeforeParams.GetShape()[axesToMul.second])
4272 {
4273 throw InvalidArgumentException(descriptorName +
4274 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4275 }
4276
4277 inputYInfoAfterParams = inputYInfoBeforeParams;
4278 }
4279 else
4280 {
4281 inputYInfoAfterParams = inputYInfoBeforeParams;
4282 }
4283
4285 {
4289 {
4290 throw InvalidArgumentException(descriptorName +
4291 ": Input tensor X does not have the correct "
4292 "number of dimensions for the Data Layout that it has been assigned.");
4293 }
4294 break;
4297 default:
4298 break;
4299 }
4300
4302 {
4306 {
4307 throw InvalidArgumentException(descriptorName +
4308 ": Input tensor Y does not have the correct "
4309 "number of dimensions for the Data Layout that it has been assigned.");
4310 }
4311 break;
4314 default:
4315 break;
4316 }
4317
4321 inputYInfoBeforeParams.GetShape());
4322
4323 if(inputXInfoAfterParams.
GetShape()[axesXToMul.second]
4324 != inputYInfoAfterParams.
GetShape()[axesYToMul.first])
4325 {
4326 throw InvalidArgumentException(descriptorName +
4327 ": The final axis of input tensor X must be the same size as "
4328 "the second last axis of input tensor Y.");
4329 }
4330
4331 {
4332
4335
4337 {
4339 {
4340 throw InvalidArgumentException(descriptorName +
4341 ": Invalid input tensor data layout combination.");
4342 }
4343 }
4345 {
4347 {
4348 throw InvalidArgumentException(descriptorName +
4349 ": Invalid input tensor data layout combination.");
4350 }
4351 }
4352 }
4353
4354
4355 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.
GetNumDimensions(),
4357 if(outputTensorDimSize-2 > 0)
4358 {
4359 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4361 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4363 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4365
4366 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4367 {
4368 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4369
4370 for(unsigned int i = 0; i < sizeDiff; i++)
4371 {
4372 axisIndices.insert(axisIndices.begin(), 1);
4373 }
4374
4375 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4376 {
4377 ti.GetShape()[i] = inputXInfoAfterParams.
GetShape()[i];
4378 }
4379 };
4380
4385
4386 doAxisExtension(axesXNotMul, tiXNotMul);
4387 doAxisExtension(axesYNotMul, tiYNotMul);
4388
4390 {
4393 }
4394
4395 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4396 tiYNotMul,
4397 tiOutNotMul,
4398 descriptorName,
4399 "input_X",
4400 "input_Y");
4401 }
4402}
const TensorShape & GetShape() const
unsigned int GetNumDimensions() const
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
static std::vector< unsigned int > GetAxesNotMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes (for each input) that will not be multiplied together.
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
BatchMatMulDescriptor m_Parameters
std::vector< TensorInfo > m_OutputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos