Reference documentation for deal.II version 9.1.0-pre
cuda_vector.h
1 // ---------------------------------------------------------------------
2 //
3 // Copyright (C) 2016 - 2018 by the deal.II authors
4 //
5 // This file is part of the deal.II library.
6 //
7 // The deal.II library is free software; you can use it, redistribute
8 // it, and/or modify it under the terms of the GNU Lesser General
9 // Public License as published by the Free Software Foundation; either
10 // version 2.1 of the License, or (at your option) any later version.
11 // The full text of the license can be found in the file LICENSE.md at
12 // the top level directory of deal.II.
13 //
14 // ---------------------------------------------------------------------
15 
16 #ifndef dealii_cuda_vector_h
17 #define dealii_cuda_vector_h
18 
19 #include <deal.II/base/config.h>
20 
21 #include <deal.II/base/exceptions.h>
22 #include <deal.II/base/index_set.h>
23 
24 #include <deal.II/lac/vector_operation.h>
25 #include <deal.II/lac/vector_space_vector.h>
26 
27 #ifdef DEAL_II_WITH_CUDA
28 
29 DEAL_II_NAMESPACE_OPEN
30 
31 class CommunicationPatternBase;
32 template <typename Number>
33 class ReadWriteVector;
34 
35 namespace LinearAlgebra
36 {
40  namespace CUDAWrappers
41  {
52  template <typename Number>
53  class Vector : public VectorSpaceVector<Number>
54  {
55  public:
56  using value_type = typename VectorSpaceVector<Number>::value_type;
57  using size_type = typename VectorSpaceVector<Number>::size_type;
58  using real_type = typename VectorSpaceVector<Number>::real_type;
59 
63  Vector();
64 
68  Vector(const Vector<Number> &V);
69 
80  explicit Vector(const size_type n);
81 
85  ~Vector();
86 
92  void
93  reinit(const size_type n, const bool omit_zeroing_entries = false);
94 
99  virtual void
101  const bool omit_zeroing_entries = false) override;
102 
111  virtual void
112  import(
113  const ReadWriteVector<Number> & V,
114  VectorOperation::values operation,
115  std::shared_ptr<const CommunicationPatternBase> communication_pattern =
116  std::shared_ptr<const CommunicationPatternBase>()) override;
117 
122  virtual Vector<Number> &
123  operator=(const Number s) override;
124 
128  virtual Vector<Number> &
129  operator*=(const Number factor) override;
130 
134  virtual Vector<Number> &
135  operator/=(const Number factor) override;
136 
140  virtual Vector<Number> &
141  operator+=(const VectorSpaceVector<Number> &V) override;
142 
146  virtual Vector<Number> &
147  operator-=(const VectorSpaceVector<Number> &V) override;
148 
152  virtual Number
153  operator*(const VectorSpaceVector<Number> &V) const override;
154 
158  virtual void
159  add(const Number a) override;
160 
164  virtual void
165  add(const Number a, const VectorSpaceVector<Number> &V) override;
166 
170  virtual void
171  add(const Number a,
172  const VectorSpaceVector<Number> &V,
173  const Number b,
174  const VectorSpaceVector<Number> &W) override;
175 
180  virtual void
181  sadd(const Number s,
182  const Number a,
183  const VectorSpaceVector<Number> &V) override;
184 
190  virtual void
191  scale(const VectorSpaceVector<Number> &scaling_factors) override;
192 
196  virtual void
197  equ(const Number a, const VectorSpaceVector<Number> &V) override;
198 
202  virtual bool
203  all_zero() const override;
204 
208  virtual value_type
209  mean_value() const override;
210 
215  virtual real_type
216  l1_norm() const override;
217 
222  virtual real_type
223  l2_norm() const override;
224 
229  virtual real_type
230  linfty_norm() const override;
231 
251  virtual Number
252  add_and_dot(const Number a,
253  const VectorSpaceVector<Number> &V,
254  const VectorSpaceVector<Number> &W) override;
255 
259  Number *
260  get_values() const;
261 
265  virtual size_type
266  size() const override;
267 
272  virtual ::IndexSet
273  locally_owned_elements() const override;
274 
278  virtual void
279  print(std::ostream & out,
280  const unsigned int precision = 2,
281  const bool scientific = true,
282  const bool across = true) const override;
283 
287  virtual std::size_t
288  memory_consumption() const override;
289 
296 
297  private:
301  Number *val;
302 
306  size_type n_elements;
307  };
308 
309 
310 
311  // ---------------------------- Inline functions --------------------------
312  template <typename Number>
313  inline Number *
315  {
316  return val;
317  }
318 
319 
320 
321  template <typename Number>
322  inline typename Vector<Number>::size_type
324  {
325  return n_elements;
326  }
327 
328 
329  template <typename Number>
330  inline IndexSet
332  {
333  return complete_index_set(n_elements);
334  }
335  } // namespace CUDAWrappers
336 } // namespace LinearAlgebra
337 
338 DEAL_II_NAMESPACE_CLOSE
339 
340 #endif
341 
342 #endif
virtual Number add_and_dot(const Number a, const VectorSpaceVector< Number > &V, const VectorSpaceVector< Number > &W) override
virtual void equ(const Number a, const VectorSpaceVector< Number > &V) override
void reinit(const size_type n, const bool omit_zeroing_entries=false)
virtual Vector< Number > & operator-=(const VectorSpaceVector< Number > &V) override
virtual real_type l1_norm() const override
virtual value_type mean_value() const override
virtual ::IndexSet locally_owned_elements() const override
Definition: cuda_vector.h:331
virtual void add(const Number a) override
#define DeclException0(Exception0)
Definition: exceptions.h:385
static::ExceptionBase & ExcVectorTypeNotCompatible()
virtual Vector< Number > & operator*=(const Number factor) override
virtual bool all_zero() const override
virtual Number operator*(const VectorSpaceVector< Number > &V) const override
virtual Vector< Number > & operator/=(const Number factor) override
virtual std::size_t memory_consumption() const override
virtual real_type l2_norm() const override
virtual Vector< Number > & operator+=(const VectorSpaceVector< Number > &V) override
virtual void scale(const VectorSpaceVector< Number > &scaling_factors) override
virtual Vector< Number > & operator=(const Number s) override
virtual void sadd(const Number s, const Number a, const VectorSpaceVector< Number > &V) override
virtual void print(std::ostream &out, const unsigned int precision=2, const bool scientific=true, const bool across=true) const override
virtual size_type size() const override
Definition: cuda_vector.h:323
virtual real_type linfty_norm() const override