42 #ifndef THYRA_TPETRA_MULTIVECTOR_HPP
43 #define THYRA_TPETRA_MULTIVECTOR_HPP
45 #include "Thyra_TpetraMultiVector_decl.hpp"
46 #include "Thyra_TpetraVectorSpace.hpp"
47 #include "Thyra_TpetraVector.hpp"
48 #include "Teuchos_Assert.hpp"
57 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
62 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
66 const RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
69 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
73 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
77 const RCP<
const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
80 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
84 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
88 return tpetraMultiVector_.getNonconstObj();
92 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
96 return tpetraMultiVector_;
103 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
114 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
118 tpetraMultiVector_.getNonconstObj()->putScalar(alpha);
122 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
126 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
131 tpetraMultiVector_.getNonconstObj()->assign(*tmv);
134 tpetraMultiVector_.getNonconstObj()->sync_host ();
135 tpetraMultiVector_.getNonconstObj()->modify_host ();
141 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
145 tpetraMultiVector_.getNonconstObj()->scale(alpha);
149 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
155 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
161 tpetraMultiVector_.getNonconstObj()->update(alpha, *tmv, ST::one());
164 tpetraMultiVector_.getNonconstObj()->sync_host ();
165 tpetraMultiVector_.getNonconstObj()->modify_host ();
171 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
183 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
186 bool allCastsSuccessful =
true;
188 auto mvIter = mv.begin();
189 auto tmvIter = tmvs.
begin();
190 for (; mvIter != mv.end(); ++mvIter, ++tmvIter) {
191 tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromPtr(*mvIter));
195 allCastsSuccessful =
false;
203 auto len = tmvs.
size();
205 tpetraMultiVector_.getNonconstObj()->scale(beta);
206 }
else if (len == 1 && allCastsSuccessful) {
207 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], beta);
208 }
else if (len == 2 && allCastsSuccessful) {
209 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], alpha[1], *tmvs[1], beta);
210 }
else if (allCastsSuccessful) {
212 auto tmvIter = tmvs.
begin();
213 auto alphaIter = alpha.
begin();
218 for (; tmvIter != tmvs.
end(); ++tmvIter) {
219 if (tmvIter->getRawPtr() == tpetraMultiVector_.getConstObj().getRawPtr()) {
221 tmv = Teuchos::rcp(
new TMV(*tpetraMultiVector_.getConstObj(), Teuchos::Copy));
226 tmvIter = tmvs.
begin();
230 if ((tmvs.
size() % 2) == 0) {
231 tpetraMultiVector_.getNonconstObj()->scale(beta);
233 tpetraMultiVector_.getNonconstObj()->update(*alphaIter, *(*tmvIter), beta);
237 for (; tmvIter != tmvs.
end(); tmvIter+=2, alphaIter+=2) {
238 tpetraMultiVector_.getNonconstObj()->update(
239 *alphaIter, *(*tmvIter), *(alphaIter+1), *(*(tmvIter+1)), ST::one());
243 tpetraMultiVector_.getNonconstObj()->sync_host ();
244 tpetraMultiVector_.getNonconstObj()->modify_host ();
250 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
256 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
261 tpetraMultiVector_.getConstObj()->dot(*tmv, prods);
264 tpetraMultiVector_.getNonconstObj()->sync_host ();
265 tpetraMultiVector_.getNonconstObj()->modify_host ();
271 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
276 tpetraMultiVector_.getConstObj()->norm1(norms);
280 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
285 tpetraMultiVector_.getConstObj()->norm2(norms);
289 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
294 tpetraMultiVector_.getConstObj()->normInf(norms);
298 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
305 return constTpetraVector<Scalar>(
307 tpetraMultiVector_->getVector(j)
312 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
319 return tpetraVector<Scalar>(
321 tpetraMultiVector_.getNonconstObj()->getVectorNonConst(j)
326 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
332 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
333 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) const called!\n";
335 const Range1D colRng = this->validateColRange(col_rng_in);
338 this->getConstTpetraMultiVector()->subView(colRng);
341 tpetraVectorSpace<Scalar>(
342 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
343 tpetraView->getNumVectors(),
344 tpetraView->getMap()->getComm()
348 return constTpetraMultiVector(
356 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
362 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
363 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) called!\n";
365 const Range1D colRng = this->validateColRange(col_rng_in);
368 this->getTpetraMultiVector()->subViewNonConst(colRng);
371 tpetraVectorSpace<Scalar>(
372 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
373 tpetraView->getNumVectors(),
374 tpetraView->getMap()->getComm()
378 return tpetraMultiVector(
386 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
392 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
393 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) const called!\n";
398 cols[i] =
static_cast<std::size_t
>(cols_in[i]);
401 this->getConstTpetraMultiVector()->subView(cols());
404 tpetraVectorSpace<Scalar>(
405 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
406 tpetraView->getNumVectors(),
407 tpetraView->getMap()->getComm()
411 return constTpetraMultiVector(
419 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
425 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
426 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) called!\n";
431 cols[i] =
static_cast<std::size_t
>(cols_in[i]);
434 this->getTpetraMultiVector()->subViewNonConst(cols());
437 tpetraVectorSpace<Scalar>(
438 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
439 tpetraView->getNumVectors(),
440 tpetraView->getMap()->getComm()
444 return tpetraMultiVector(
452 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
459 const Ordinal primary_global_offset
465 for (
auto itr = multi_vecs.begin(); itr != multi_vecs.end(); ++itr) {
468 Teuchos::rcp_const_cast<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >(
469 tmv->getConstTpetraMultiVector())-> sync_host ();
474 for (
auto itr = targ_multi_vecs.begin(); itr != targ_multi_vecs.end(); ++itr) {
475 Ptr<TMV> tmv = Teuchos::ptr_dynamic_cast<TMV>(*itr);
477 tmv->getTpetraMultiVector()->sync_host ();
478 tmv->getTpetraMultiVector()->modify_host ();
483 primary_op, multi_vecs, targ_multi_vecs, reduct_objs, primary_global_offset);
487 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
496 typedef typename Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
497 Teuchos::rcp_const_cast<TMV>(
498 tpetraMultiVector_.getConstObj())->sync_host ();
505 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
514 tpetraMultiVector_.getNonconstObj()->sync_host ();
515 tpetraMultiVector_.getNonconstObj()->modify_host ();
522 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
532 typedef typename Tpetra::MultiVector<
533 Scalar,LocalOrdinal,GlobalOrdinal,Node>::execution_space execution_space;
534 tpetraMultiVector_.getNonconstObj()->template sync<execution_space>();
585 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
589 return tpetraVectorSpace_;
593 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
598 *localValues = tpetraMultiVector_.getNonconstObj()->get1dViewNonConst();
599 *leadingDim = tpetraMultiVector_->getStride();
603 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
608 *localValues = tpetraMultiVector_->get1dView();
609 *leadingDim = tpetraMultiVector_->getStride();
613 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
623 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
631 typedef typename TMV::execution_space execution_space;
632 Teuchos::rcp_const_cast<TMV>(X_tpetra)->template sync<execution_space>();
633 Y_tpetra->template sync<execution_space>();
634 Teuchos::rcp_const_cast<TMV>(
635 tpetraMultiVector_.getConstObj())->template sync<execution_space>();
640 "Error, conjugation without transposition is not allowed for complex scalar types!");
645 trans = Teuchos::NO_TRANS;
648 trans = Teuchos::NO_TRANS;
654 trans = Teuchos::CONJ_TRANS;
658 Y_tpetra->template modify<execution_space>();
659 Y_tpetra->multiply(trans, Teuchos::NO_TRANS, alpha, *tpetraMultiVector_.getConstObj(), *X_tpetra, beta);
661 Teuchos::rcp_const_cast<TMV>(
662 tpetraMultiVector_.getConstObj())->sync_host ();
671 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
672 template<
class TpetraMultiVector_t>
686 tpetraVectorSpace_ = tpetraVectorSpace;
687 domainSpace_ = domainSpace;
688 tpetraMultiVector_.initialize(tpetraMultiVector);
689 this->updateSpmdSpace();
693 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
694 RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
698 using Teuchos::rcp_dynamic_cast;
702 RCP<TMV> tmv = rcp_dynamic_cast<TMV>(mv);
704 return tmv->getTpetraMultiVector();
707 RCP<TV> tv = rcp_dynamic_cast<TV>(mv);
709 return tv->getTpetraVector();
712 return Teuchos::null;
715 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
716 RCP<const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
720 using Teuchos::rcp_dynamic_cast;
724 RCP<const TMV> tmv = rcp_dynamic_cast<const TMV>(mv);
726 return tmv->getConstTpetraMultiVector();
729 RCP<const TV> tv = rcp_dynamic_cast<const TV>(mv);
731 return tv->getConstTpetraVector();
734 return Teuchos::null;
741 #endif // THYRA_TPETRA_MULTIVECTOR_HPP