33#ifndef GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
34#define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
37#include <ginkgo/core/base/abstract_factory.hpp>
38#include <ginkgo/core/base/batch_lin_op.hpp>
39#include <ginkgo/core/base/batch_multi_vector.hpp>
40#include <ginkgo/core/base/utils_helper.hpp>
41#include <ginkgo/core/log/batch_logger.hpp>
42#include <ginkgo/core/matrix/batch_identity.hpp>
43#include <ginkgo/core/stop/batch_stop_enum.hpp>
66 return this->system_matrix_;
76 return this->preconditioner_;
95 GKO_INVALID_STATE(
"Tolerance cannot be negative!");
115 if (max_iterations < 0) {
116 GKO_INVALID_STATE(
"Max iterations cannot be negative!");
118 this->max_iterations_ = max_iterations;
128 return this->tol_type_;
138 if (
tol_type == ::gko::batch::stop::tolerance_type::absolute ||
139 tol_type == ::gko::batch::stop::tolerance_type::relative) {
142 GKO_INVALID_STATE(
"Invalid tolerance type specified!");
149 BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
151 const double res_tol,
const int max_iterations,
152 const ::gko::batch::stop::tolerance_type
tol_type)
153 : system_matrix_{std::
move(system_matrix)},
156 max_iterations_{max_iterations},
161 void set_system_matrix_base(std::shared_ptr<const BatchLinOp> system_matrix)
163 this->system_matrix_ = std::move(system_matrix);
166 void set_preconditioner_base(std::shared_ptr<const BatchLinOp> precond)
168 this->preconditioner_ = std::move(precond);
171 std::shared_ptr<const BatchLinOp> system_matrix_{};
172 std::shared_ptr<const BatchLinOp> preconditioner_{};
173 double residual_tol_{};
174 int max_iterations_{};
175 ::gko::batch::stop::tolerance_type tol_type_{};
176 mutable array<unsigned char> workspace_{};
180template <
typename Parameters,
typename Factory>
210 std::shared_ptr<const BatchLinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
241 this->validate_application_parameters(b.get(), x.get());
253 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
266 this->validate_application_parameters(b.get(), x.get());
278 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
295 template <
typename FactoryParameters>
297 std::shared_ptr<const BatchLinOp> system_matrix,
304 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(system_matrix_);
306 using value_type =
typename ConcreteSolver::value_type;
310 if (
params.generated_preconditioner) {
311 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(
params.generated_preconditioner,
313 preconditioner_ = std::move(
params.generated_preconditioner);
314 }
else if (
params.preconditioner) {
315 preconditioner_ =
params.preconditioner->generate(system_matrix_);
317 auto id = Identity::create(exec, system_matrix->get_size());
318 preconditioner_ = std::move(
id);
321 (
sizeof(real_type) +
sizeof(
int));
322 workspace_.set_executor(exec);
328 auto exec = self()->get_executor();
339 void set_preconditioner(std::shared_ptr<const BatchLinOp>
new_precond)
341 auto exec = self()->get_executor();
343 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(),
new_precond);
344 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(
new_precond);
354 if (&
other !=
this) {
355 this->set_size(
other.get_size());
356 this->set_system_matrix(
other.get_system_matrix());
357 this->set_preconditioner(
other.get_preconditioner());
367 if (&
other !=
this) {
368 this->set_size(
other.get_size());
369 this->set_system_matrix(
other.get_system_matrix());
370 this->set_preconditioner(
other.get_preconditioner());
374 other.set_system_matrix(
nullptr);
375 other.set_preconditioner(
nullptr);
382 other.self()->get_executor(),
other.self()->get_size())
389 other.self()->get_executor(),
other.self()->get_size())
391 *
this = std::move(
other);
402 auto log_data_ = std::make_unique<log::detail::log_data<real_type>>(
405 this->solver_apply(b, x,
log_data_.get());
424 log::detail::log_data<real_type>*
info)
const = 0;
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:263
Definition batch_lin_op.hpp:88
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:281
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition batch_multi_vector.hpp:85
void scale(ptr_param< const MultiVector< ValueType > > alpha)
Scales the vector with a scalar (aka: BLAS scal).
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:157
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:147
void add_scaled(ptr_param< const MultiVector< ValueType > > alpha, ptr_param< const MultiVector< ValueType > > b)
Adds b scaled by alpha to the vector (aka: BLAS axpy).
The batch Identity matrix, which represents a batch of Identity matrices.
Definition batch_identity.hpp:61
The BatchSolver is a base class for all batched solvers and provides the common getters and setter fo...
Definition batch_solver_base.hpp:57
std::shared_ptr< const BatchLinOp > get_system_matrix() const
Returns the system operator (matrix) of the linear system.
Definition batch_solver_base.hpp:64
void reset_max_iterations(int max_iterations)
Set the maximum number of iterations for the solver to use, independent of the factory that created i...
Definition batch_solver_base.hpp:113
double get_tolerance() const
Get the residual tolerance used by the solver.
Definition batch_solver_base.hpp:84
int get_max_iterations() const
Get the maximum number of iterations set on the solver.
Definition batch_solver_base.hpp:105
void reset_tolerance(double res_tol)
Update the residual tolerance to be used by the solver.
Definition batch_solver_base.hpp:92
::gko::batch::stop::tolerance_type get_tolerance_type() const
Get the tolerance type.
Definition batch_solver_base.hpp:126
void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
Set the type of tolerance check to use inside the solver.
Definition batch_solver_base.hpp:136
std::shared_ptr< const BatchLinOp > get_preconditioner() const
Returns the generated preconditioner.
Definition batch_solver_base.hpp:74
This mixin provides apply and common iterative solver functionality to all the batched solvers.
Definition batch_solver_base.hpp:234
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:239
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:71
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:473
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:354
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:203
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:148
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition temporary_clone.hpp:207
Definition batch_solver_base.hpp:182
int max_iterations
Default maximum number iterations allowed.
Definition batch_solver_base.hpp:189
double tolerance
Default residual tolerance.
Definition batch_solver_base.hpp:197
std::shared_ptr< const BatchLinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition batch_solver_base.hpp:211
::gko::batch::stop::tolerance_type tolerance_type
To specify which type of tolerance check is to be considered, absolute or relative (to the rhs l2 nor...
Definition batch_solver_base.hpp:204
std::shared_ptr< const BatchLinOp > generated_preconditioner
Already generated preconditioner.
Definition batch_solver_base.hpp:218