Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
TensorContractionThreadPool.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
12
13// evaluator for thread pool device
14#ifdef EIGEN_USE_THREADS
15
16namespace Eigen {
17
18template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
19struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
20 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
21
22 typedef ThreadPoolDevice Device;
23
24 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
25 typedef TensorContractionEvaluatorBase<Self> Base;
26
27 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
28 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
29 typedef typename XprType::Index Index;
30 typedef typename XprType::CoeffReturnType CoeffReturnType;
31 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
32
33 enum {
34 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
35 };
36
37 // Most of the code is assuming that both input tensors are ColMajor. If the
38 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
39 // If we want to compute A * B = C, where A is LHS and B is RHS, the code
40 // will pretend B is LHS and A is RHS.
41 typedef typename internal::conditional<
42 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
43 typedef typename internal::conditional<
44 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
45
46 static const int LDims =
47 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
48 static const int RDims =
49 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
50 static const int ContractDims = internal::array_size<Indices>::value;
51
52 typedef array<Index, LDims> left_dim_mapper_t;
53 typedef array<Index, RDims> right_dim_mapper_t;
54
55 typedef array<Index, ContractDims> contract_t;
56 typedef array<Index, LDims - ContractDims> left_nocontract_t;
57 typedef array<Index, RDims - ContractDims> right_nocontract_t;
58
59 static const int NumDims = LDims + RDims - 2 * ContractDims;
60
61 typedef DSizes<Index, NumDims> Dimensions;
62
63 // typedefs needed in evalTo
64 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
65 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
66 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
67
68 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
69 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
70
71 TensorEvaluator(const XprType& op, const Device& device) :
72 Base(op, device) {}
73
74 template <int Alignment>
75 void evalProduct(Scalar* buffer) const {
76 evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
77 }
78
79 template <typename EvalToCallback, int Alignment>
80 void evalProductAsync(Scalar* buffer, EvalToCallback done) const {
81 evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
82 }
83
84 template <typename DoneCallback, int Alignment>
85 void evalProductImpl(Scalar* buffer, DoneCallback done) const {
86 // This function computes a lot of heuristics in multiple steps, and it
87 // also has multiple exit points. To keep it sane, readable and all in one
88 // place, sync/async execution decision is made at runtime at the very end.
89 //
90 // (1) In sync mode we allocate Context on the stack, submit computations
91 // to the device thread pool, and block on a barrier until it is
92 // completed.
93 //
94 // (2) In async mode we allocate Context on the heap, and after all tasks
95 // are finished, we call provided the done callback, and delete a
96 // context from the heap.
97 //
98 // (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
99 // and temporary buffers, requried for executing the tensor contraction.
100 // They are responsible for cleaning it up after contraction is done.
101 static const bool IsEvalInSyncMode =
102 std::is_same<DoneCallback, NoCallback>::value;
103
104 const Index m = this->m_i_size;
105 const Index n = this->m_j_size;
106 const Index k = this->m_k_size;
107 if (m == 0 || n == 0 || k == 0) return;
108
109 // Compute a set of algorithm parameters:
110 // - kernel block sizes (bm, bn, bk)
111 // - task grain sizes (number of kernels executed per task: gm, gn)
112 // - number of threads
113 // - sharding by row/column
114 // - parallel packing or first lhs then rhs
115 // and some derived parameters:
116 // - number of tasks (nm, nn, nk)
117 // - number of kernels (nm0, nn0)
118 // Unfortunately, all these parameters are tightly interdependent.
119 // So in some cases we first compute approximate values, then compute other
120 // values based on these approximations and then refine the approximations.
121
122 // There are lots of heuristics here. There is some reasoning behind them,
123 // but ultimately they are just tuned on contraction benchmarks for
124 // different input configurations, thread counts and instruction sets.
125 // So feel free to question any of them.
126
127 // Compute whether we want to shard by row or by column.
128 // This is a first approximation, it will be refined later. Since we don't
129 // know number of threads yet we use 2, because what's we are most
130 // interested in at this point is whether it makes sense to use
131 // parallelization at all or not.
132 bool shard_by_col = shardByCol(m, n, 2);
133
134 // First approximation of kernel blocking sizes.
135 // Again, we don't know number of threads yet, so we use 2.
136 Index bm, bn, bk;
137 if (shard_by_col) {
138 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
139 internal::ShardByCol>
140 blocking(k, m, n, 2);
141 bm = blocking.mc();
142 bn = blocking.nc();
143 bk = blocking.kc();
144 } else {
145 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
146 internal::ShardByRow>
147 blocking(k, m, n, 2);
148 bm = blocking.mc();
149 bn = blocking.nc();
150 bk = blocking.kc();
151 }
152
153 // Compute optimal number of threads.
154 // Note: we use bk instead of k here because we are interested in amount of
155 // _parallelizable_ computations, and computations are not parallelizable
156 // across k dimension.
157 const TensorOpCost cost =
158 contractionCost(m, n, bm, bn, bk, shard_by_col, false);
159 int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
160 static_cast<double>(n) * m, cost, this->m_device.numThreads());
161 int num_threads_by_k = numThreadsInnerDim(m, n, k);
162 if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
163 // We are in the scenario where it is more effective to shard by the
164 // inner dimension.
165 if (IsEvalInSyncMode) {
166 EvalShardedByInnerDimContext<DoneCallback> ctx(
167 this, num_threads_by_k, buffer, m, n, k, std::move(done));
168 ctx.template run<Alignment>();
169 } else {
170 auto* ctx = new EvalShardedByInnerDimContext<DoneCallback>(
171 this, num_threads_by_k, buffer, m, n, k, std::move(done));
172 ctx->template runAsync<Alignment>();
173 }
174
175 return;
176 }
177
178 // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
179 // model is not tuned. Remove this when the cost model is tuned.
180 if (n == 1) num_threads = 1;
181
182 if (num_threads == 1) {
183 TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
184 Unaligned, (buffer));
185 if (!IsEvalInSyncMode) done();
186 return;
187 }
188
189 // Now that we know number of threads, recalculate sharding and blocking.
190 shard_by_col = shardByCol(m, n, num_threads);
191 if (shard_by_col) {
192 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
193 internal::ShardByCol>
194 blocking(k, m, n, num_threads);
195 bm = blocking.mc();
196 bn = blocking.nc();
197 bk = blocking.kc();
198 } else {
199 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
200 internal::ShardByRow>
201 blocking(k, m, n, num_threads);
202 bm = blocking.mc();
203 bn = blocking.nc();
204 bk = blocking.kc();
205 }
206
207 // Number of kernels for each dimension.
208 Index nm0 = divup(m, bm);
209 Index nn0 = divup(n, bn);
210 Index nk = divup(k, bk);
211
212 // Calculate task grain size (number of kernels executed per task).
213 // This task size coarsening serves two purposes:
214 // 1. It reduces per-task overheads including synchronization overheads.
215 // 2. It allows to use caches better (reuse the same packed rhs in several
216 // consecutive kernels).
217 Index gm = 1;
218 Index gn = 1;
219 // If we are sharding by column, then we prefer to reduce rows first.
220 if (shard_by_col) {
221 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
222 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
223 } else {
224 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
225 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
226 }
227 // Number of tasks in each dimension.
228 Index nm = divup(nm0, gm);
229 Index nn = divup(nn0, gn);
230
231 // If there is enough concurrency in the sharding dimension, we choose not
232 // to paralellize by the other dimension, and execute all kernels in sync
233 // mode. This reduces parallelism from the nm x nn down to nn
234 // (shard_by_col==true) or nm (shard_by_col==false).
235 const Index sharding_dim_tasks = shard_by_col ? nn : nm;
236 const int num_worker_threads = this->m_device.numThreadsInPool();
237
238 // With small number of threads we want to make sure that we do not reduce
239 // parallelism too much. With large number of threads we trade maximum
240 // parallelism for better memory locality.
241 const float oversharding_factor =
242 num_worker_threads <= 4 ? 8.0 :
243 num_worker_threads <= 8 ? 4.0 :
244 num_worker_threads <= 16 ? 2.0 :
245 num_worker_threads <= 32 ? 1.0 :
246 num_worker_threads <= 64 ? 0.8 : /* num_worker_threads > 64 */ 0.6;
247
248 const bool parallelize_by_sharding_dim_only =
249 sharding_dim_tasks >= oversharding_factor * num_worker_threads;
250
251 // Last by not least, decide whether we want to issue both lhs and rhs
252 // packing in parallel; or issue lhs packing first, and then issue rhs
253 // packing when lhs packing completes (for !shard_by_col lhs and rhs are
254 // swapped). Parallel packing allows more parallelism (for both packing and
255 // kernels), while sequential packing provides better locality (once
256 // a thread finishes rhs packing it proceed to kernels with that rhs).
257 // First, we are interested in parallel packing if there are few tasks.
258 bool parallel_pack = num_threads >= nm * nn;
259 // Also do parallel packing if all data fits into L2$.
260 if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <=
261 l2CacheSize() * num_threads)
262 parallel_pack = true;
263 // But don't do it if we will use each rhs only once. Locality seems to be
264 // more important in this case.
265 if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
266 // Also don't get in the way of parallelize_by_sharding_dim_only
267 // optimization.
268 if (parallelize_by_sharding_dim_only) parallel_pack = false;
269
270 // TODO(ezhulnev): With if contexpr we don't need SyncEvalParallelContext.
271 if (IsEvalInSyncMode) {
272#define CONTEXT_ARGS \
273 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
274 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
275 NoCallback()) \
276 .run()
277 TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment,
278 CONTEXT_ARGS);
279#undef CONTEXT_ARGS
280
281 } else {
282#define CONTEXT_ARGS \
283 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
284 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
285 std::move(done))
286 TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback,
287 Alignment, CONTEXT_ARGS, run());
288#undef CONTEXT_ARGS
289 }
290 }
291
292 // ------------------------------------------------------------------------ //
293
294 // Dummy struct to represent an empty DoneCallback.
295
296 struct NoCallback {
297 void operator()() {
298 eigen_assert(false && "NoCallback should never be called");
299 }
300 };
301
302 // ------------------------------------------------------------------------ //
303
304 template <typename DoneCallback, typename Context>
305 class EvalParallelNotification;
306
307 // Synchronous evaluation notification that blocks caller thread in Wait().
308 template <typename Context>
309 class EvalParallelNotification<NoCallback, Context> {
310 public:
311 EvalParallelNotification(Context*, NoCallback) {}
312 void Notify() { done_.Notify(); }
313 void Wait() { done_.Wait(); }
314 private:
315 Eigen::Notification done_;
316 };
317
318 // Asynchronous evaluation notification that does not block in Wait().
319 template <typename DoneCallback, typename Context>
320 class EvalParallelNotification {
321 public:
322 EvalParallelNotification(Context* ctx, DoneCallback done)
323 : ctx_(ctx), done_(std::move(done)) {}
324
325 void Notify() {
326 // Make a copy of done callback, because it will be destructed when we
327 // will delete context in the next line (EvalParallelNotification is a
328 // data member of EvalParallelContext class).
329 DoneCallback done_copy = std::move(done_);
330
331 // Delete parallel evaluation context.
332 delete ctx_;
333
334 // Now safely call the done callback.
335 done_copy();
336 }
337
338 void Wait() {}
339
340 private:
341 Context* ctx_;
342 DoneCallback done_;
343 };
344
345 // Context orchestrates sync/async parallel contraction evaluation. When it is
346 // executed in asynchronous mode, it owns all the shared state that might be
347 // accessible by block packing and kernel tasks.
348
349 template <typename DoneCallback, bool lhs_inner_dim_contiguous,
350 bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
351 int Alignment>
352 class EvalParallelContext {
353 public:
354 typedef internal::TensorContractionInputMapper<
355 LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
356 contract_t, internal::packet_traits<LhsScalar>::size,
357 lhs_inner_dim_contiguous, false, Unaligned>
358 LhsMapper;
359 typedef internal::TensorContractionInputMapper<
360 RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
361 contract_t, internal::packet_traits<RhsScalar>::size,
362 rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
363 RhsMapper;
364
365 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
366
367 typedef internal::TensorContractionKernel<
368 Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
369 TensorContractionKernel;
370
371 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
372 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
373 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
374
375 EvalParallelContext(const Self* self, int num_threads, Scalar* buffer,
376 Index tm, Index tn, Index tk, Index bm, Index bn,
377 Index bk, Index nm, Index nn, Index nk, Index gm,
378 Index gn, Index nm0, Index nn0, bool shard_by_col,
379 bool parallel_pack,
380 bool parallelize_by_sharding_dim_only,
381 DoneCallback done)
382 : created_by_thread_id_(std::this_thread::get_id()),
383 done_(this, std::move(done)),
384 device_(self->m_device),
385 lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
386 self->m_i_strides, self->m_left_contracting_strides,
387 self->m_k_strides),
388 rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
389 self->m_j_strides, self->m_right_contracting_strides,
390 self->m_k_strides),
391 buffer_(buffer),
392 output_(buffer, tm),
393 output_kernel_(self->m_output_kernel),
394 tensor_contraction_params_(self->m_tensor_contraction_params),
395 num_threads_(num_threads),
396 shard_by_col_(shard_by_col),
397 parallel_pack_(parallel_pack),
398 parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
399 m_(tm),
400 n_(tn),
401 k_(tk),
402 bm_(bm),
403 bn_(bn),
404 bk_(bk),
405 nm_(nm),
406 nn_(nn),
407 nk_(nk),
408 gm_(gm),
409 gn_(gn),
410 nm0_(nm0),
411 nn0_(nn0),
412 kernel_(m_, k_, n_, bm_, bk_, bn_),
413 num_thread_local_allocations_(0),
414 // We reserve 2X more capacity for a thread local values, than the
415 // number of threads in the pool to efficiently handle task stealing
416 // by threads that are not managed by the pool.
417 thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
418 ? device_.numThreadsInPool()
419 : 0)),
420 // We will use only one of the Lhs/Rhs thread local storage depending
421 // on the shard_by_col value and we parallelize by sharding dim ONLY.
422 lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
423 {*this}, {*this}),
424 rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
425 {*this}, {*this}) {
426 // These two options are mutually exclusive.
427 eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
428
429 for (Index x = 0; x < P; x++) {
430 // Normal number of notifications for k slice switch is
431 // nm_ + nn_ + nm_ * nn_. However, first P - 1 slices will receive only
432 // nm_ + nn_ notifications, because they will not receive notifications
433 // from preceding kernels.
434 state_switch_[x] =
435 x == 0
436 ? 1
437 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
438 (x == P - 1 ? nm_ * nn_ : 0);
439 state_packing_ready_[x] =
440 parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
441 state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
442 for (Index m = 0; m < nm_; m++) {
443 state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
444 // Kernels generally receive 3 notifications (previous kernel + 2
445 // packing), but the first slice won't get notifications from previous
446 // kernels.
447 for (Index n = 0; n < nn_; n++)
448 state_kernel_[x][m][n].store(
449 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
450 std::memory_order_relaxed);
451 }
452 }
453
454 // Allocate memory for packed rhs/lhs matrices.
455 packed_mem_ = kernel_.allocateSlices( //
456 device_, //
457 /*num_lhs=*/nm0_, //
458 /*num_rhs=*/nn0_, //
459 /*num_slices=*/std::min<Index>(nk_, P - 1), //
460 packed_lhs_, packed_rhs_);
461
462 if (parallelize_by_sharding_dim_only_) {
463 const int num_worker_threads = device_.numThreadsInPool();
464
465 if (shard_by_col) {
466 can_use_thread_local_packed_ = new std::atomic<bool>[nn_];
467 for (int i = 0; i < nn_; ++i)
468 can_use_thread_local_packed_[i].store(true,
469 std::memory_order_relaxed);
470
471 Index num_blocks = num_worker_threads * gn_;
472 thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
473 device_, //
474 /*num_lhs=*/0, //
475 /*num_rhs=*/num_blocks, //
476 /*num_slices=*/1, //
477 /*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
478
479 } else {
480 can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
481 for (int i = 0; i < nm_; ++i)
482 can_use_thread_local_packed_[i].store(true,
483 std::memory_order_relaxed);
484
485 Index num_blocks = num_worker_threads * gm_;
486 thread_local_pre_alocated_mem_ = kernel_.allocateSlices( //
487 device_, //
488 /*num_lhs=*/num_blocks, //
489 /*num_rhs=*/0, //
490 /*num_slices=*/1, &lhs_thread_local_pre_allocated_, //
491 /*rhs_blocks=*/nullptr);
492 }
493 }
494 }
495
496 ~EvalParallelContext() {
497 for (Index x = 0; x < P; x++) {
498 for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
499 delete[] state_kernel_[x];
500 }
501 kernel_.deallocate(device_, packed_mem_);
502 if (parallelize_by_sharding_dim_only_) {
503 kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
504 delete[] can_use_thread_local_packed_;
505 }
506 }
507
508 void run() {
509 // Kick off packing of the first slice.
510 signal_switch(0, 1);
511
512 // Wait for overall completion.
513 //
514 // If parallel evaluation is executed in async mode, this is a no-op, and
515 // Wait() will return immediately. In synchronous mode it will block the
516 // caller thread until it will receive notification from last task.
517 //
518 // In async mode, last task when completed will call done callback from
519 // the same thread, and will delete this context.
520 //
521 // TODO(dvyukov): This wait can lead to deadlock if contraction is
522 // evaluated in synchronous mode. If nthreads contractions are
523 // concurrently submitted from worker threads, this wait will block all
524 // worker threads and the system will deadlock.
525 done_.Wait();
526 }
527
528 private:
529 std::thread::id created_by_thread_id_;
530
531 // This notification is specialized on the type of DoneCallback and can be
532 // blocking or non-blocking.
533 EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
534
535 const Device& device_;
536 LhsMapper lhs_;
537 RhsMapper rhs_;
538 Scalar* const buffer_;
539 OutputMapper output_;
540 OutputKernelType output_kernel_;
541 TensorContractionParams tensor_contraction_params_;
542 const int num_threads_;
543 const bool shard_by_col_;
544 const bool parallel_pack_;
545 const bool parallelize_by_sharding_dim_only_;
546 // Matrix sizes.
547 const Index m_;
548 const Index n_;
549 const Index k_;
550 // Block sizes.
551 const Index bm_;
552 const Index bn_;
553 const Index bk_;
554 // Number of tasks.
555 const Index nm_;
556 const Index nn_;
557 const Index nk_;
558 // Task grain sizes (number of kernels executed per task).
559 const Index gm_;
560 const Index gn_;
561 // Number of blocks (this is different from ni_/nn_ because of task size
562 // coarsening).
563 const Index nm0_;
564 const Index nn0_;
565 // Tensor contraction kernel.
566 TensorContractionKernel kernel_;
567
568 // Parallelization strategy.
569 //
570 // Blocks related to the same k block can run in parallel because they write
571 // to different output blocks. So we parallelize within k slices, this
572 // gives us parallelism level of m x n. Before we can start any kernels
573 // related to k-th slice, we need to issue m lhs packing tasks and n rhs
574 // packing tasks.
575 //
576 // However, there is a bottleneck when we are finishing kernels for k-th
577 // slice (at the very end there is only 1 runnable kernel). To mitigate this
578 // bottleneck we allow kernels from k-th and k+1-th slices to run in
579 // parallel. Note that (m, n, k) and (m, n, k+1) kernels write to the same
580 // output block, so they must not run in parallel.
581 //
582 // This gives us the following dependency graph.
583 // On each k slice we have m x n kernel tasks, m lhs paking tasks and n rhs
584 // packing tasks.
585 // Kernel (m, n, k) can start when:
586 // - kernel (m, n, k-1) has finished
587 // - lhs packing (m, k) has finished
588 // - rhs packing (n, k) has finished
589 // Lhs/rhs packing can start when:
590 // - all k-1 packing has finished (artificially imposed to limit amount of
591 // parallel packing)
592 //
593 // On top of that we limit runnable tasks to two consecutive k slices.
594 // This is done to limit amount of memory we need for packed lhs/rhs
595 // (for each k slice we need m*bk + n*bk memory in packed_lhs_/packed_rhs_).
596 //
597 // state_switch_ tracks when we are ready to switch to the next k slice.
598 // state_kernel_[m][n] tracks when we are ready to kick off kernel (m, n).
599 // These variable are rolling over 3 consecutive k slices: first two we are
600 // actively executing + one to track completion of kernels in the second
601 // slice.
602 static const Index P = 3;
603
604 // Handle to the allocated temporary storage for Lhs/Rhs blocks.
605 BlockMemHandle packed_mem_;
606 std::vector<LhsBlock> packed_lhs_[P - 1];
607 std::vector<RhsBlock> packed_rhs_[P - 1];
608
609 // If we choose to parallelize only by the sharding dimension, each thread
610 // will have it's own "thead local" (not a c++ thread local storage) memory
611 // for packed_lhs or packed_rhs (shard_by_col = false of true). This memory
612 // can't be passed to a kernel that might execute on a different thread.
613 //
614 // In practice when we are ready to pack memory for the sharding dimension
615 // (rhs if shard_by_col==true) of the K-th slice, all kernels for K-1 slice
616 // already computed (99% of the time), and we can pack data into the thread
617 // local storage, and guarantee that all the kernels will be executed
618 // immediately in the same thread. This significantly increases L1 cache hit
619 // ratio and reduces pressure on the memory bus.
620 //
621 // It's still possible that kernel for the K-th slice will be ready before
622 // completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
623 // and packed_rhs_ to allow kernels to be executed later on a thread
624 // different from the thread that was used for packing.
625
626 // Handle for pre-allocated thread local memory buffers.
627 BlockMemHandle thread_local_pre_alocated_mem_;
628
629 // Only one of these will be initialized depending on shard_by_col value
630 // (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
631 std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
632 std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
633
634 // How many thread local blocks were already allocated.
635 std::atomic<int> num_thread_local_allocations_;
636 const int thread_local_capacity;
637
638 // We will use pre-allocated Lhs/Rhs blocks defined above, if the number of
639 // unique threads in a system is below or equal to the number of threads in
640 // a thread pool. We will fallback on dynamic memory allocation after that.
641
642 // ThreadLocalBlocks is a container for Lhs or Rhs thread local buffers. Its
643 // size is equal to the grain size in Lhs/Rhs sharding dimension.
644 template <typename BlockType>
645 class ThreadLocalBlocks {
646 public:
647 ThreadLocalBlocks() = default;
648
649 ThreadLocalBlocks(BlockType* base, size_t grain_size)
650 : is_pre_allocated_(true),
651 thread_local_pre_allocated_base_(base),
652 grain_size_(grain_size) {}
653
654 ThreadLocalBlocks(BlockMemHandle mem_handle,
655 std::vector<BlockType> blocks)
656 : is_pre_allocated_(false),
657 mem_handle_(std::move(mem_handle)),
658 blocks_(std::move(blocks)) {}
659
660 BlockType& block(int grain_index) {
661 eigen_assert(grain_index >= 0);
662 eigen_assert(static_cast<size_t>(grain_index) < size());
663 return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
664 : blocks_[grain_index];
665 }
666
667 void Release(EvalParallelContext& ctx) const {
668 if (!is_pre_allocated_) {
669 ctx.kernel_.deallocate(ctx.device_, mem_handle_);
670 }
671 }
672
673 size_t size() const {
674 return is_pre_allocated_ ? grain_size_ : blocks_.size();
675 }
676
677 private:
678 bool is_pre_allocated_;
679
680 // Reuse pre-allocated thread local buffers.
681 BlockType* thread_local_pre_allocated_base_ = nullptr;
682 size_t grain_size_ = 0;
683
684 // These will be initialized only if `is_pre_allocated == false`.
685 BlockMemHandle mem_handle_{};
686 std::vector<BlockType> blocks_;
687 };
688
689 // ThreadLocalBlocksInitialize callable does custom thread local blocks
690 // initialization, and will reuse pre-allocated buffers if possible, or will
691 // dynamically allocate new memory.
692 //
693 // Lhs/Rhs blocks might be of the same type, so we have to pass explicitly
694 // for what side do we plan to do block allocation.
695 template <typename BlockType, bool is_rhs>
696 class ThreadLocalBlocksInitialize {
697 static constexpr bool kIsLhs =
698 !is_rhs && std::is_same<BlockType, LhsBlock>::value;
699 static const bool kIsRhs =
700 is_rhs && std::is_same<BlockType, RhsBlock>::value;
701 static_assert(kIsLhs || kIsRhs, "Unkown block type");
702
703 using Blocks = ThreadLocalBlocks<BlockType>;
704
705 public:
706 ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
707 : ctx_(ctx),
708 num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
709
710 void operator()(Blocks& blocks) {
711 const int n = ctx_.num_thread_local_allocations_.fetch_add(
712 1, std::memory_order_relaxed);
713
714 if (n >= num_worker_threads_) {
715 ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
716 } else {
717 ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
718 }
719 }
720
721 private:
722 // NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
723 // TensorContractionKernel::allocateSlices into template specializations.
724 // Also explicit specializations are not allowed at class scope in C++03,
725 // EvalCtx type parameter is just a workaround for that limitation.
726 template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
727 struct ThreadLocalBlocksAllocator;
728
729 template <typename EvalCtx>
730 struct ThreadLocalBlocksAllocator</*pack_rhs=*/true, EvalCtx> {
731 static void allocate(EvalCtx& ctx, Blocks& blocks) {
732 std::vector<RhsBlock> rhs_blocks;
733 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
734 ctx.device_,
735 /*num_lhs=*/0,
736 /*num_rhs=*/ctx.gn_,
737 /*num_slices=*/1,
738 /*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
739
740 blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
741 std::move(rhs_blocks));
742 }
743
744 static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
745 RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
746 blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
747 }
748 };
749
750 template <typename EvalCtx>
751 struct ThreadLocalBlocksAllocator</*pack_rhs=*/false, EvalCtx> {
752 static void allocate(EvalCtx& ctx, Blocks& blocks) {
753 std::vector<LhsBlock> lhs_blocks;
754 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
755 ctx.device_,
756 /*num_lhs=*/ctx.gm_,
757 /*num_rhs=*/0,
758 /*num_slices=*/1,
759 /*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
760
761 blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
762 std::move(lhs_blocks));
763 }
764
765 static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
766 LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
767 blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
768 }
769 };
770
771 EvalParallelContext& ctx_;
772 const int num_worker_threads_;
773 };
774
775 template <typename BlockType>
776 class ThreadLocalBlocksRelease {
777 public:
778 using Blocks = ThreadLocalBlocks<BlockType>;
779 ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
780 void operator()(Blocks& blocks) { blocks.Release(ctx_); }
781
782 private:
783 EvalParallelContext& ctx_;
784 };
785
786 // ThreadLocalBlocks initialization callables.
787 using ThreadLocalLhsInit =
788 ThreadLocalBlocksInitialize<LhsBlock, /*is_rhs=*/false>;
789 using ThreadLocalRhsInit =
790 ThreadLocalBlocksInitialize<RhsBlock, /*is_rhs=*/true>;
791
792 // ThreadLocalBlocks release callables.
793 using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
794 using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
795
796 // Thread local containers for Lhs/Rhs block packs. In practice only one of
797 // them will be used, depending on the shard_by_col value.
798 Eigen::ThreadLocal<ThreadLocalBlocks<LhsBlock>, ThreadLocalLhsInit,
799 ThreadLocalLhsRelease>
800 lhs_thread_local_blocks_;
801 Eigen::ThreadLocal<ThreadLocalBlocks<RhsBlock>, ThreadLocalRhsInit,
802 ThreadLocalRhsRelease>
803 rhs_thread_local_blocks_;
804
805 // After a particular shard for Kth slice missed thread local execution
806 // opportunity (K-1 slice didn't complete kernels execution), we can no
807 // longer schedule K+1 and following slices in thread local mode, because
808 // there is no more guarantee that previous kernels were executed
809 // sequentially in the same thread (size is nn_ or nm_).
810 std::atomic<bool>* can_use_thread_local_packed_;
811
812 std::atomic<uint8_t>** state_kernel_[P];
813 // state_switch_ is frequently modified by worker threads, while other
814 // fields are read-only after constructor. Let's move it to a separate cache
815 // line to reduce cache-coherency traffic.
816 char pad_[128];
817 std::atomic<Index> state_packing_ready_[P];
818 std::atomic<Index> state_switch_[P];
819
820 LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
821 if (use_thread_local) {
822 eigen_assert(!shard_by_col_);
823 ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.local();
824
825 Index grain_index = m1 - m * gm_;
826 return blocks.block(internal::convert_index<int>(grain_index)); // FIXME better make ThreadLocalBlocks use Eigen::Index?
827 } else {
828 return packed_lhs_[k % (P - 1)][m1];
829 }
830 }
831
832 RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
833 if (use_thread_local) {
834 eigen_assert(shard_by_col_);
835 ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.local();
836
837 Index grain_index = n1 - n * gn_;
838 return blocks.block(internal::convert_index<int>(grain_index)); // FIXME better make ThreadLocalBlocks use Eigen::Index?
839 } else {
840 return packed_rhs_[k % (P - 1)][n1];
841 }
842 }
843
844 // In following two methods (pack_lhs and pack_rhs), if we know for sure
845 // that we'll be able to immediately call a kernel with packed data, and do
846 // not submit it to the thread pool, we can use thread local memory for
847 // packed data.
848 //
849 // We can only reliably check it if we are running all kernels in sync mode
850 // (parallelize only by sharding dim). If kernel for m==0 (n==0) is ready to
851 // run, it's guaranteed that all kernels with larger values of m (n) are
852 // also ready, because we execute them in the same order for all K slices.
853
854 void pack_lhs(Index m, Index k) {
855 bool use_thread_local = false;
856
857 if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
858 can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
859 if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
860 use_thread_local = true;
861 } else {
862 // If we can't guarantee that all kernels in `k` slice will be
863 // executed sequentially in current thread, it's no longer safe to use
864 // thread local memory in following slices along the k dimensions.
865 eigen_assert(k > 0);
866 can_use_thread_local_packed_[m].store(false,
867 std::memory_order_relaxed);
868 }
869 }
870
871 const Index mend = m * gm_ + gm(m);
872 for (Index m1 = m * gm_; m1 < mend; m1++)
873 kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
874 lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
875
876 if (!parallel_pack_ && shard_by_col_) {
877 assert(!use_thread_local);
878 signal_packing(k);
879 } else {
880 signal_switch(k + 1);
881 for (Index n = nn_ - 1; n >= 0; n--) {
882 bool sync = parallelize_by_sharding_dim_only_ || n == 0;
883 signal_kernel(m, n, k, sync, use_thread_local);
884 }
885 }
886 }
887
888 void pack_rhs(Index n, Index k) {
889 bool use_thread_local = false;
890
891 if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
892 can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
893 if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
894 use_thread_local = true;
895 } else {
896 // If we can't guarantee that all kernels in `k` slice will be
897 // executed sequentially in current thread, it's no longer safe to use
898 // thread local memory in followig slices along the k dimensions.
899 eigen_assert(k > 0);
900 can_use_thread_local_packed_[n].store(false,
901 std::memory_order_relaxed);
902 }
903 }
904
905 const Index nend = n * gn_ + gn(n);
906 for (Index n1 = n * gn_; n1 < nend; n1++) {
907 if (!TensorContractionKernel::HasBeta && k == 0) {
908 // Zero the output memory in parallel, only if contraction kernel does
909 // not support `beta`. Otherwise we will pass beta 0.0 to the first
910 // call to the `TensorContractionKernel::invoke()`.
911 //
912 // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn
913 // x m) row. Safe to do here because all kernels that will write to
914 // this memory depend on completion of this task. Note: don't call
915 // device_.memset() here. device_.memset() blocks on thread pool
916 // worker thread, which can lead to underutilization and deadlocks.
917 memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
918 }
919 kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
920 rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
921 }
922
923 if (parallel_pack_ || shard_by_col_) {
924 signal_switch(k + 1);
925 for (Index m = nm_ - 1; m >= 0; m--) {
926 bool sync = parallelize_by_sharding_dim_only_ || m == 0;
927 signal_kernel(m, n, k, sync, use_thread_local);
928 }
929 } else {
930 assert(!use_thread_local);
931 signal_packing(k);
932 }
933 }
934
935 void kernel(Index m, Index n, Index k, bool use_thread_local) {
936 // Note: order of iteration matters here. Iteration over m is innermost
937 // because we want to reuse the same packed rhs in consecutive tasks
938 // (rhs fits into L2$ while lhs only into L3$).
939 const Index nend = n * gn_ + gn(n);
940 const Index mend = m * gm_ + gm(m);
941
942 // NOTE: output = alpha * LHS * RHS + beta * output.
943 const Scalar alpha = Scalar(1);
944 const Scalar beta =
945 (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
946
947 if (shard_by_col_) {
948 for (Index n1 = n * gn_; n1 < nend; n1++) {
949 for (Index m1 = m * gm_; m1 < mend; m1++) {
950 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
951 kernel_.invoke(
952 output_mapper,
953 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
954 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
955 bk(k), bn(n1), alpha, beta);
956
957 // We are done with the last task for the [m1, n1] block.
958 if (k + 1 == nk_) {
959 output_kernel_(output_mapper, tensor_contraction_params_,
960 m1 * bm_, n1 * bn_, bm(m1), bn(n1));
961 }
962 }
963 }
964 } else {
965 for (Index m1 = m * gm_; m1 < mend; m1++)
966 for (Index n1 = n * gn_; n1 < nend; n1++) {
967 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
968 kernel_.invoke(
969 output_mapper,
970 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
971 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
972 bk(k), bn(n1), alpha, beta);
973
974 // We are done with the last task for the [m1, n1] block.
975 if (k + 1 == nk_) {
976 output_kernel_(output_mapper, tensor_contraction_params_,
977 m1 * bm_, n1 * bn_, bm(m1), bn(n1));
978 }
979 }
980 }
981 signal_kernel(m, n, k + 1, /*sync=*/false, /*use_thread_local=*/false);
982 signal_switch(k + 2);
983 }
984
985 void signal_packing(Index k) {
986 eigen_assert(!parallel_pack_);
987 Index s = state_packing_ready_[k % P].fetch_sub(1);
988 eigen_assert(s > 0);
989 if (s != 1) return;
990 state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
991 enqueue_packing(k, shard_by_col_);
992 }
993
994 void signal_kernel(Index m, Index n, Index k, bool sync,
995 bool use_thread_local) {
996 std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
997 Index s = state->load();
998 eigen_assert(s > 0);
999 if (s != 1 && state->fetch_sub(1) != 1) {
1000 eigen_assert(!use_thread_local);
1001 return;
1002 }
1003 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
1004 if (sync) {
1005 kernel(m, n, k, use_thread_local);
1006 } else {
1007 eigen_assert(!use_thread_local);
1008 device_.enqueueNoNotification(
1009 [=]() { kernel(m, n, k, use_thread_local); });
1010 }
1011 }
1012
1013 void signal_switch(Index k, Index v = 1) {
1014 Index s = state_switch_[k % P].fetch_sub(v);
1015 eigen_assert(s >= v);
1016 if (s != v) return;
1017
1018 // Ready to switch to the next k slice.
1019 // Reset counter for the next iteration.
1020 state_switch_[k % P] =
1021 (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
1022 nm_ * nn_;
1023 if (k < nk_) {
1024 // Issue lhs/rhs packing. Their completion will in turn kick off
1025 // kernels.
1026 if (parallel_pack_) {
1027 enqueue_packing(k, !shard_by_col_);
1028 enqueue_packing(k, shard_by_col_);
1029 } else if (shard_by_col_) {
1030 enqueue_packing(k, false);
1031 } else {
1032 enqueue_packing(k, true);
1033 }
1034
1035 // Termination handling.
1036 // Because kernel completion signals k + 2 switch, we need to finish nk
1037 // + 2 slices without issuing any tasks on nk + 1 slice. So here we
1038 // pretend that all nk + 1 packing tasks just finish instantly; so that
1039 // nk + 2 switch only waits for completion of nk kernels.
1040 } else if (k == nk_) {
1041 signal_switch(k + 1,
1042 parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
1043 } else {
1044 done_.Notify();
1045 }
1046 }
1047
1048 // Enqueue all rhs/lhs packing for k-th slice.
1049 void enqueue_packing(Index k, bool rhs) {
1050 enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
1051 }
1052
1053 void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
1054 if (end - start == 1) {
1055 if (rhs)
1056 pack_rhs(start, k);
1057 else
1058 pack_lhs(start, k);
1059 } else {
1060 while (end - start > 1) {
1061 Index mid = (start + end) / 2;
1062 device_.enqueueNoNotification(
1063 [=]() { enqueue_packing_helper(mid, end, k, rhs); });
1064 end = mid;
1065 }
1066
1067 // Decide if we want to run first packing task (start == 0) in
1068 // async mode if we parallelize only by sharding dim:
1069 // (1) pack_lhs and pack_rhs call signal_switch before completing
1070 // all calls to signal_kernel, which in sync mode might lead
1071 // to the execution of the first kernel of the k+1 slice, before
1072 // completing a call to the last kernel of the k slice.
1073 // (2) all pack tasks for sharded dim must be executed in a thread
1074 // pool to get pre-allocated thead local buffers.
1075 bool pack_async =
1076 (start == 0) &&
1077 (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
1078 (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1079
1080 if (pack_async) {
1081 device_.enqueueNoNotification(
1082 [=]() { enqueue_packing_helper(start, end, k, rhs); });
1083 } else {
1084 enqueue_packing_helper(start, end, k, rhs);
1085 }
1086 }
1087 }
1088
1089 // Block sizes with accounting for potentially incomplete last block.
1090 Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1091 Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1092 Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1093 // Task grain sizes accounting for potentially incomplete last task.
1094 Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1095 Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1096
1097 EvalParallelContext(const EvalParallelContext&) = delete;
1098 void operator=(const EvalParallelContext&) = delete;
1099 };
1100
1101 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
1102 bool rhs_inner_dim_reordered, int Alignment>
1103 using SyncEvalParallelContext =
1104 EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
1105 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1106 Alignment>;
1107
1108 // ------------------------------------------------------------------------ //
1109
1110 // EvalShardedByInnerDimContext orchestrates sync/async contraction
1111 // evaluation, when we shard by inner dimension. When it is executed in
1112 // asynchronous mode, it owns all the shared state that might be accessible by
1113 // block processing tasks.
1114
1115 template <typename DoneCallback>
1116 struct EvalShardedByInnerDimContext {
1117 EvalShardedByInnerDimContext(const Self* self, int num_threads,
1118 Scalar* result_buffer,
1119 Index m_size, Index n_size, Index k_size,
1120 DoneCallback done_callback)
1121 : evaluator(self),
1122 m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1123 m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1124 m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1125 result(result_buffer),
1126 m(m_size),
1127 n(n_size),
1128 k(k_size),
1129 done(std::move(done_callback)),
1130 buffer_size_bytes(m * n * sizeof(Scalar)),
1131 block_size(blockSize(k, num_threads)),
1132 num_blocks(divup<Index>(k, block_size)),
1133 num_pending_blocks(internal::convert_index<int>(num_blocks)),
1134 l0_ranges(divup<Index>(num_blocks, l0_size)),
1135 l0_state(l0_ranges),
1136 block_buffers(num_blocks) {
1137 // Keep count of pending gemm tasks for each l0 range.
1138 for (int i = 0; i < l0_ranges; ++i) {
1139 const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
1140 l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1141 }
1142
1143 // Allocate temporary buffers for each block.
1144 for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1145 Scalar* buf = block_idx == 0
1146 ? result
1147 : static_cast<Scalar*>(evaluator->m_device.allocate(
1148 buffer_size_bytes));
1149 block_buffers.emplace_back(buf);
1150 }
1151 }
1152
1153 ~EvalShardedByInnerDimContext() {
1154 for (Index i = 1; i < num_blocks; ++i) {
1155 evaluator->m_device.deallocate(block_buffers[i]);
1156 }
1157 }
1158
1159 template <int Alignment>
1160 void run() {
1161 Barrier barrier(internal::convert_index<int>(num_blocks));
1162 eval<Alignment>(barrier, 0, num_blocks);
1163 barrier.Wait();
1164
1165 // Aggregate partial sums from l0 ranges.
1166 aggregateL0Blocks<Alignment>();
1167
1168 // Apply output kernel.
1169 applyOutputKernel();
1170 }
1171
1172 template <int Alignment>
1173 void runAsync() {
1174 evalAsync<Alignment>(0, num_blocks);
1175 }
1176
1177 private:
1178 // The underlying GEMM kernel assumes that k is a multiple of
1179 // the packet size and subtle breakage occurs if this is violated.
1180 static const Index packet_size = internal::packet_traits<RhsScalar>::size;
1181
1182 const Self* evaluator; // TensorContraction evaluator
1183
1184 // These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
1185 bool m_lhs_inner_dim_contiguous;
1186 bool m_rhs_inner_dim_contiguous;
1187 bool m_rhs_inner_dim_reordered;
1188
1189 Scalar* result;
1190
1191 Index m;
1192 Index n;
1193 Index k;
1194
1195 DoneCallback done;
1196
1197 // ----------------------------------------------------------------------//
1198 // Algorithm parameters.
1199
1200 // We will compute partial results into the buffers of this size.
1201 Index buffer_size_bytes;
1202
1203 Index block_size;
1204 Index num_blocks;
1205
1206 // Keep track of pending tasks when evaluate in async mode.
1207 std::atomic<int> num_pending_blocks;
1208
1209 // We compute partial gemm results in parallel, and to get the final result
1210 // we need to add them all together. For the large number of threads (>= 48)
1211 // this adds a very expensive sequential step at the end.
1212 //
1213 // We split the [0, num_blocks) into small ranges, and when a task for the
1214 // block finishes its partial gemm computation, it checks if it was the last
1215 // gemm in the range, and if so, it will add all blocks of the range.
1216 //
1217 // After all tasks done, we need to add only these pre-aggregated blocks.
1218
1219 // For now we use just a single level of ranges to compute pre-aggregated
1220 // partial sums, but in general we can use more layers to compute tree
1221 // aggregation in parallel and reduce the size of the sequential step.
1222 //
1223 // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
1224 // sense only if number of threads >= ~128?
1225 static const Index l0_size = 4;
1226 Index l0_ranges;
1227
1228 // Keep count of pending gemm tasks for each l0 range.
1229 MaxSizeVector<std::atomic<int>> l0_state; // [0, l0_ranges)
1230
1231 // Buffers allocated for each temporary block computation.
1232 MaxSizeVector<Scalar*> block_buffers; // [0, num_blocks)
1233
1234 template <int Alignment>
1235 void processBlock(Index block_idx, Index begin, Index end) {
1236 Scalar* buf = block_buffers[block_idx];
1237
1238 TENSOR_CONTRACTION_DISPATCH(
1239 evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1240 (buf, begin, end,
1241 /*num_threads=*/internal::convert_index<int>(num_blocks)));
1242
1243 // Check if it was the last task in l0 range.
1244 const Index l0_index = block_idx / l0_size;
1245 const int v = l0_state[l0_index].fetch_sub(1);
1246 eigen_assert(v >= 1);
1247
1248 // If we processed the last block of the range, we can aggregate all
1249 // partial results into the first block of the range.
1250 if (v == 1) {
1251 const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1252 const Index dst_block_idx = l0_index * l0_size;
1253
1254 if (rng_size == l0_size) {
1255 addAllToBuffer<Alignment>(
1256 m * n,
1257 /*src_buf0=*/block_buffers[dst_block_idx + 1],
1258 /*src_buf1=*/block_buffers[dst_block_idx + 2],
1259 /*src_buf2=*/block_buffers[dst_block_idx + 3],
1260 /*dst_buf= */ block_buffers[dst_block_idx]);
1261 } else {
1262 // Aggregate blocks of potentially incomplete last range.
1263 for (int i = 1; i < rng_size; ++i) {
1264 addToBuffer<Alignment>(m * n,
1265 /*src_buf=*/block_buffers[dst_block_idx + i],
1266 /*dst_buf=*/block_buffers[dst_block_idx]);
1267 }
1268 }
1269 }
1270 }
1271
1272 // Aggregate partial sums from l0 ranges.
1273 template <int Alignment>
1274 void aggregateL0Blocks() const {
1275 Index l0_index = 1;
1276
1277 for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1278 addAllToBuffer<Alignment>(
1279 m * n,
1280 /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
1281 /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
1282 /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
1283 /*dst_buf= */ block_buffers[0]);
1284 }
1285
1286 for (; l0_index < l0_ranges; ++l0_index) {
1287 addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
1288 block_buffers[0]);
1289 }
1290 }
1291
1292 void applyOutputKernel() const {
1293 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1294 evaluator->m_output_kernel(
1295 OutputMapper(result, m), evaluator->m_tensor_contraction_params,
1296 static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
1297 }
1298
1299 // Compute block size with accounting for potentially incomplete last block.
1300 Index actualBlockSize(Index block_idx) const {
1301 return block_idx + 1 < num_blocks
1302 ? block_size
1303 : k + block_size - block_size * num_blocks;
1304 };
1305
1306 // Compute range size with accounting for potentially incomplete last range.
1307 Index actualRangeSize(Index num_ranges, Index range_size,
1308 Index range_idx) const {
1309 eigen_assert(range_idx < num_ranges);
1310 return range_idx + 1 < num_ranges
1311 ? range_size
1312 : num_blocks + range_size - range_size * num_ranges;
1313 };
1314
1315 template <int Alignment>
1316 EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf,
1317 Scalar* tgt_buf) {
1318 const int output_packet_size =
1319 internal::unpacket_traits<PacketReturnType>::size;
1320 size_t i = 0;
1321 const size_t num_packets = n / output_packet_size;
1322 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1323 const PacketReturnType src_val =
1324 internal::pload<PacketReturnType>(src_buf + i);
1325 const PacketReturnType tgt_val =
1326 internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
1327 const PacketReturnType sum = internal::padd(src_val, tgt_val);
1328 internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i,
1329 sum);
1330 }
1331 for (; i < n; ++i) {
1332 tgt_buf[i] += src_buf[i];
1333 }
1334 }
1335
1336 template <int Alignment>
1337 EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n,
1338 const Scalar* src_buf0,
1339 const Scalar* src_buf1,
1340 const Scalar* src_buf2,
1341 Scalar* dst_buf) {
1342 using ::Eigen::internal::padd;
1343 using ::Eigen::internal::pload;
1344 using ::Eigen::internal::ploadt;
1345 using ::Eigen::internal::pstoret;
1346
1347 const int output_packet_size =
1348 internal::unpacket_traits<PacketReturnType>::size;
1349
1350 size_t i = 0;
1351 const size_t num_packets = n / output_packet_size;
1352 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1353 const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
1354 const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
1355 const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
1356
1357 const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
1358 const auto sum =
1359 padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
1360
1361 pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
1362 }
1363 for (; i < n; ++i) {
1364 dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
1365 }
1366 }
1367
1368 template <int Alignment>
1369 void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
1370 while (end_block_idx - start_block_idx > 1) {
1371 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1372 evaluator->m_device.enqueueNoNotification(
1373 [this, &barrier, mid_block_idx, end_block_idx]() {
1374 eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1375 });
1376 end_block_idx = mid_block_idx;
1377 }
1378
1379 Index block_idx = start_block_idx;
1380 Index block_start = block_idx * block_size;
1381 Index block_end = block_start + actualBlockSize(block_idx);
1382
1383 processBlock<Alignment>(block_idx, block_start, block_end);
1384 barrier.Notify();
1385 }
1386
1387 template <int Alignment>
1388 void evalAsync(Index start_block_idx, Index end_block_idx) {
1389 while (end_block_idx - start_block_idx > 1) {
1390 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1391 evaluator->m_device.enqueueNoNotification(
1392 [this, mid_block_idx, end_block_idx]() {
1393 evalAsync<Alignment>(mid_block_idx, end_block_idx);
1394 });
1395 end_block_idx = mid_block_idx;
1396 }
1397
1398 Index block_idx = start_block_idx;
1399
1400 Index block_start = block_idx * block_size;
1401 Index block_end = block_start + actualBlockSize(block_idx);
1402
1403 processBlock<Alignment>(block_idx, block_start, block_end);
1404
1405 int v = num_pending_blocks.fetch_sub(1);
1406 eigen_assert(v >= 1);
1407
1408 if (v == 1) {
1409 // Aggregate partial sums from l0 ranges.
1410 aggregateL0Blocks<Alignment>();
1411
1412 // Apply output kernel.
1413 applyOutputKernel();
1414
1415 // NOTE: If we call `done` callback before deleting this (context),
1416 // it might deallocate Self* pointer captured by context, and we'll
1417 // fail in destructor trying to deallocate temporary buffers.
1418
1419 // Move done call back from context before it will be destructed.
1420 DoneCallback done_copy = std::move(done);
1421
1422 // We are confident that we are the last one who touches context.
1423 delete this;
1424
1425 // Now safely call the done callback.
1426 done_copy();
1427 }
1428 }
1429
1430 // Cost model doesn't capture well the cost associated with constructing
1431 // tensor contraction mappers and computing loop bounds in gemm_pack_lhs
1432 // and gemm_pack_rhs, so we specify minimum desired block size.
1433 static Index blockSize(Index k, int num_threads) {
1434 const auto round_up = [=](Index index) -> Index {
1435 const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1436 return divup<Index>(index, kmultiple) * kmultiple;
1437 };
1438
1439 const Index target_block_size = round_up(divup<Index>(k, num_threads));
1440 const Index desired_min_block_size = 12 * packet_size;
1441
1442 return numext::mini<Index>(
1443 k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1444 }
1445
1446 EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete;
1447 void operator=(const EvalShardedByInnerDimContext&) = delete;
1448 };
1449
1450 // ------------------------------------------------------------------------ //
1451
1452 // Below are the function used by evalProductImpl heuristics, trying to select
1453 // optimcal parameters for parallelization algorithm.
1454
1455 // Decide whether we want to shard m x n contraction by columns or by rows.
1456 static bool shardByCol(Index m, Index n, Index num_threads) {
1457 // Note: we are comparing both n and m against Traits::nr, it is not
1458 // a mistake. We are trying to figure out how both n and m will fit into
1459 // the main sharding dimension.
1460
1461 // Sharding by column is the default
1462 // ... unless there is enough data for vectorization over rows
1463 if (m / num_threads >= Traits::nr &&
1464 // and not enough data for vectorization over columns
1465 (n / num_threads < Traits::nr ||
1466 // ... or barely enough data for vectorization over columns,
1467 // but it is not evenly dividable across threads
1468 (n / num_threads < 4 * Traits::nr &&
1469 (n % (num_threads * Traits::nr)) != 0 &&
1470 // ... and it is evenly dividable across threads for rows
1471 ((m % (num_threads * Traits::nr)) == 0 ||
1472 // .. or it is not evenly dividable for both dimensions but
1473 // there is much more data over rows so that corner effects are
1474 // mitigated.
1475 (m / n >= 6)))))
1476 return false;
1477 // Wait, or if matrices are just substantially prolonged over the other
1478 // dimension.
1479 if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
1480 return true;
1481 }
1482
1483 Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
1484 int num_threads, bool shard_by_col) const {
1485 Index gm = 1;
1486 Index gm1 = 1;
1487 Index nm0 = divup(m, bm);
1488 Index nm1 = nm0;
1489 for (;;) {
1490 // Find the next candidate for m grain size. It needs to result in
1491 // different number of blocks. E.g. if we have 10 kernels, we want to try
1492 // 5 and 10, but not 6, 7, 8 and 9.
1493 while (gm1 <= nm0 && nm1 == divup(nm0, gm1)) gm1++;
1494 if (gm1 > nm0) break;
1495 // Check the candidate.
1496 int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
1497 shard_by_col);
1498 if (res < 0) break;
1499 nm1 = divup(nm0, gm1);
1500 if (res == 0) continue;
1501 // Commit new grain size.
1502 gm = gm1;
1503 }
1504 return gm;
1505 }
1506
1507 Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1508 int num_threads, bool shard_by_col) const {
1509 Index gn = 1;
1510 Index gn1 = 1;
1511 Index nn0 = divup(n, bn);
1512 Index nn1 = nn0;
1513 for (;;) {
1514 while (gn1 <= nn0 && nn1 == divup(nn0, gn1)) gn1++;
1515 if (gn1 > nn0) break;
1516 int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
1517 shard_by_col);
1518 if (res < 0) break;
1519 nn1 = divup(nn0, gn1);
1520 if (res == 0) continue;
1521 gn = gn1;
1522 }
1523 return gn;
1524 }
1525
1526 // checkGrain checks whether grain (gm, gn) is suitable and is better than
1527 // (oldgm, oldgn).
1528 int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1529 Index gn, Index oldgm, Index oldgn, int num_threads,
1530 bool shard_by_col) const {
1531 const TensorOpCost cost =
1532 contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
1533 double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
1534 static_cast<double>(bm) * gm * bn * gn, cost);
1535 // If the task is too small, then we agree on it regardless of anything
1536 // else. Otherwise synchronization overheads will dominate.
1537 if (taskSize < 1) return 1;
1538 // If it is too large, then we reject it and all larger tasks.
1539 if (taskSize > 2) return -1;
1540 // Now we are in presumably good task size range.
1541 // The main deciding factor here is parallelism. Consider that we have 12
1542 // kernels and 4 threads. Grains of 2, 3 and 4 all yield good task sizes.
1543 // But 2/4 yield 6/3 tasks, which gives us parallelism of 0.75 (at most 3/4
1544 // of cores will be busy). While grain size 3 gives us 4 tasks, which gives
1545 // us parallelism of 1 (we can load all cores).
1546 Index nm0 = divup(m, bm);
1547 Index nn0 = divup(n, bn);
1548 Index new_tasks = divup(nm0, gm) * divup(nn0, gn);
1549 double new_parallelism = static_cast<double>(new_tasks) /
1550 (divup<int>(new_tasks, num_threads) * num_threads);
1551 Index old_tasks = divup(nm0, oldgm) * divup(nn0, oldgn);
1552 double old_parallelism = static_cast<double>(old_tasks) /
1553 (divup<int>(old_tasks, num_threads) * num_threads);
1554 if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
1555 return 0;
1556 }
1557
1558 TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1559 bool shard_by_col, bool prepacked) const {
1560 const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size,
1561 PacketType<RhsScalar, Device>::size);
1562 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1563 const double kd = static_cast<double>(bk);
1564 double compute_bandwidth = computeBandwidth(false, bm, bn, bk);
1565 // Computations.
1566 TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size);
1567 // Output stores.
1568 cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1569 if (prepacked) {
1570 // Packing and kernels are executed in different tasks. When we calculate
1571 // task grain size we look only at kernel cost assuming that kernel
1572 // is more expensive than packing.
1573 return cost;
1574 }
1575 // Lhs/rhs loads + computations.
1576 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
1577 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
1578 // Lhs packing memory cost does not contribute considerably to overall
1579 // execution time because lhs is prefetched early and accessed sequentially.
1580 if (shard_by_col)
1581 lhsCost.dropMemoryCost();
1582 else
1583 rhsCost.dropMemoryCost();
1584 return cost + lhsCost + rhsCost;
1585 }
1586
1587 // Decide whether we want to shard m x k x n contraction over the inner
1588 // (contraction) dimension (k).
1589 static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
1590 int num_threads_by_k) {
1591 std::ptrdiff_t bufsize = m * n * sizeof(Scalar);
1592 bool shard_by_k = false;
1593 if (n == 1 || // If mat*vec or...
1594 num_threads_by_k < 2 || // running single threaded or...
1595 num_threads_by_k <
1596 num_threads || // sharding by k gives less parallelism or...
1597 bufsize > l3CacheSize() / num_threads_by_k || // need more buffer space
1598 // than L3 cache or...
1599 k / num_threads_by_k < 2 * Traits::nr) { // k per thread is tiny.
1600 shard_by_k = false;
1601 } else if (numext::maxi(m, n) / num_threads <
1602 Traits::nr || // both other dimensions are tiny or...
1603 // k per thread is not small and...
1604 (k / num_threads_by_k > 8 * Traits::nr &&
1605 // one of the outer dimensions is tiny or sharding by k offers
1606 // more parallelism.
1607 (numext::mini(m, n) < 2 * Traits::nr ||
1608 num_threads_by_k > num_threads))) {
1609 shard_by_k = true;
1610 }
1611 return shard_by_k;
1612 }
1613
1614 TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
1615 // Compute cost.
1616 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1617 TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n, true, output_packet_size);
1618 // Output stores.
1619 cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1620 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
1621 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n;
1622 // Since the inner gemm kernel is always sharded by column, the lhs
1623 // load cost is negligible.
1624 lhsCost.dropMemoryCost();
1625 return cost + lhsCost + rhsCost;
1626 }
1627
1628 int numThreadsInnerDim(Index m, Index n, Index k) const {
1629 const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1630 TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
1631 double total_parallel_cost =
1632 TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
1633 // Cost of reduction step accumulating the m*n per-thread buffers into the
1634 // result.
1635 double reduction_cost = TensorCostModel<ThreadPoolDevice>::totalCost(
1636 m * n, TensorOpCost(2, 1, 1, true, output_packet_size));
1637 int num_threads = 1;
1638 double min_cost = total_parallel_cost;
1639 double kPerThreadOverHead = 3000;
1640 double kFixedOverHead = 100000;
1641 for (int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
1642 double sequential_cost =
1643 kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1644 double parallel_cost = total_parallel_cost / nt + sequential_cost;
1645 if (parallel_cost < min_cost) {
1646 num_threads = nt;
1647 min_cost = parallel_cost;
1648 }
1649 }
1650 return num_threads;
1651 }
1652
1653 double computeBandwidth(bool shard_by_col, Index bm, Index bn,
1654 Index bk) const {
1655 // Peak VFMA bandwidth is 0.5. However if we have not enough data for
1656 // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
1657 // experimentally.
1658 double computeBandwidth =
1659 bk == 1 ? 4.0
1660 : (shard_by_col ? bn : bm) < Traits::nr ||
1661 (shard_by_col ? bm : bn) < Traits::mr
1662 ? 2.0
1663 : 0.5;
1664#ifndef EIGEN_VECTORIZE_FMA
1665 // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
1666 // However for MULPS/ADDPS we have dependent sequence of 2 such
1667 // instructions,
1668 // so overall bandwidth is 1.0.
1669 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1670#endif
1671 return computeBandwidth;
1672 }
1673
1674};
1675
1676} // end namespace Eigen
1677
1678#endif // EIGEN_USE_THREADS
1679#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index