10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
24template<
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
typename StorageIndex,
int ShardingType = ShardByCol>
25class TensorContractionBlocking {
42 #if !defined(EIGEN_HIPCC)
45 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, StorageIndex num_threads = 1) :
46 kc_(k), mc_(m), nc_(n)
48 if (ShardingType == ShardByCol) {
49 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
52 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
55 const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
56 kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ?
57 kc_ : (kc_ / rhs_packet_size) * rhs_packet_size;
60 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc()
const {
return kc_; }
61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex mc()
const {
return mc_; }
62 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex nc()
const {
return nc_; }
Namespace containing all symbols from the Eigen library.