Reference documentation for deal.II version 9.1.0-pre
sacado_number_types.h
1 // ---------------------------------------------------------------------
2 //
3 // Copyright (C) 2016 - 2017 by the deal.II authors
4 //
5 // This file is part of the deal.II library.
6 //
7 // The deal.II library is free software; you can use it, redistribute
8 // it, and/or modify it under the terms of the GNU Lesser General
9 // Public License as published by the Free Software Foundation; either
10 // version 2.1 of the License, or (at your option) any later version.
11 // The full text of the license can be found in the file LICENSE.md at
12 // the top level directory of deal.II.
13 //
14 // ---------------------------------------------------------------------
15 
16 #ifndef dealii_differentiation_ad_sacado_number_types_h
17 #define dealii_differentiation_ad_sacado_number_types_h
18 
19 #include <deal.II/base/config.h>
20 
21 #include <type_traits>
22 
23 
24 DEAL_II_NAMESPACE_OPEN
25 
26 
27 namespace Differentiation
28 {
29  namespace AD
30  {
38  template <typename NumberType, typename = void>
39  struct is_sacado_number : std::false_type
40  {};
41 
42 
50  template <typename NumberType, typename = void>
51  struct is_sacado_dfad_number : std::false_type
52  {};
53 
54 
62  template <typename NumberType, typename = void>
63  struct is_sacado_rad_number : std::false_type
64  {};
65 
66  } // namespace AD
67 } // namespace Differentiation
68 
69 
70 DEAL_II_NAMESPACE_CLOSE
71 
72 
73 
74 #ifdef DEAL_II_TRILINOS_WITH_SACADO
75 
76 DEAL_II_DISABLE_EXTRA_DIAGNOSTICS
77 # include <Sacado.hpp>
78 // It appears that some versions of Trilinos do not directly or indirectly
79 // include all the headers for all forward and reverse Sacado AD types.
80 // So we directly include these both here as a precaution.
81 // Standard forward AD classes (templated)
82 # include <Sacado_Fad_DFad.hpp>
83 // Reverse AD classes (templated)
84 # include <Sacado_trad.hpp>
85 DEAL_II_ENABLE_EXTRA_DIAGNOSTICS
86 
87 # include <deal.II/base/exceptions.h>
88 # include <deal.II/base/numbers.h>
89 
90 # include <deal.II/differentiation/ad/ad_number_traits.h>
91 # include <deal.II/differentiation/ad/ad_number_types.h>
92 
93 # include <complex>
94 # include <type_traits>
95 
96 DEAL_II_NAMESPACE_OPEN
97 
98 
99 namespace Differentiation
100 {
101  namespace AD
102  {
103  namespace internal
104  {
113  template <typename SacadoNumber, typename = void>
115 
116  } // namespace internal
117 
118 
119 
120  } // namespace AD
121 } // namespace Differentiation
122 
123 
124 /* ----------- inline and template functions and specializations ----------- */
125 
126 
127 # ifndef DOXYGEN
128 
129 namespace Differentiation
130 {
131  namespace AD
132  {
133  namespace internal
134  {
135  // The documentation on Sacado numbers is pretty sparse and/or hard to
136  // navigate. As a point of reference, see
137  // https://trilinos.org/docs/dev/packages/sacado/doc/html/classSacado_1_1Fad_1_1SimpleFad.html
138  // for semi-applicable documentation for the Sacado::Fad::Dfad class.
139  // and the examples in
140  // https://github.com/trilinos/Trilinos/tree/master/packages/sacado/example
141  //
142  // If one dares to venture there, the relevant files for the classes
143  // supported here are:
144  //
145  // Forward-mode auto-differentiable types:
146  // https://github.com/trilinos/Trilinos/blob/master/packages/sacado/src/sacado_dfad_DFad.hpp
147  // https://github.com/trilinos/Trilinos/blob/master/packages/sacado/src/sacado_dfad_GeneralFad.hpp
148  //
149  // Reverse-mode auto-differentiable types:
150  // https://github.com/trilinos/Trilinos/blob/master/packages/sacado/src/Sacado_trad.hpp
151 
152 
156  template <typename SacadoNumber>
157  struct SacadoNumberInfo<
158  SacadoNumber,
159  typename std::enable_if<std::is_same<
160  SacadoNumber,
161  Sacado::Fad::DFad<typename SacadoNumber::value_type>>::value>::type>
162  {
163  using ad_type = SacadoNumber;
164  using scalar_type = typename ad_type::scalar_type;
165  using value_type = typename ad_type::value_type;
166  using derivative_type = typename ad_type::value_type;
167 
168  static const unsigned int n_supported_derivative_levels =
170  };
171 
172 
176  template <typename SacadoNumber>
177  struct SacadoNumberInfo<
178  SacadoNumber,
179  typename std::enable_if<std::is_same<
180  SacadoNumber,
181  Sacado::Rad::ADvar<typename SacadoNumber::value_type>>::value>::type>
182  {
183  using ad_type = SacadoNumber;
184  using scalar_type = typename ad_type::ADVari::scalar_type;
185  using value_type = typename ad_type::ADVari::value_type;
186  using derivative_type = typename ad_type::ADVari::value_type;
187 
188  static const unsigned int n_supported_derivative_levels =
190  };
191 
192 
199  template <typename Number>
200  struct SacadoNumberInfo<
201  Number,
202  typename std::enable_if<
203  std::is_arithmetic<typename std::decay<Number>::type>::value>::type>
204  {
205  static const unsigned int n_supported_derivative_levels = 0;
206  };
207 
208 
213  template <typename ScalarType>
214  struct ADNumberInfoFromEnum<
215  ScalarType,
217  typename std::enable_if<
218  std::is_floating_point<ScalarType>::value>::type>
219  {
220  static const bool is_taped = false;
221  using real_type = Sacado::Fad::DFad<ScalarType>;
222  using derivative_type =
224  static const unsigned int n_supported_derivative_levels =
226  };
227 
228 
233  template <typename ScalarType>
234  struct ADNumberInfoFromEnum<
235  ScalarType,
237  typename std::enable_if<
238  std::is_floating_point<ScalarType>::value>::type>
239  {
240  static const bool is_taped = false;
241  using real_type = Sacado::Fad::DFad<Sacado::Fad::DFad<ScalarType>>;
242  using derivative_type =
244  static const unsigned int n_supported_derivative_levels =
246  };
247 
248 
253  template <typename ScalarType>
254  struct ADNumberInfoFromEnum<
255  ScalarType,
257  typename std::enable_if<
258  std::is_floating_point<ScalarType>::value>::type>
259  {
260  static const bool is_taped = false;
261  using real_type = Sacado::Rad::ADvar<ScalarType>;
262  using derivative_type =
264  static const unsigned int n_supported_derivative_levels =
266  };
267 
268 
273  template <typename ScalarType>
274  struct ADNumberInfoFromEnum<
275  ScalarType,
277  typename std::enable_if<
278  std::is_floating_point<ScalarType>::value>::type>
279  {
280  static const bool is_taped = false;
281  using real_type = Sacado::Rad::ADvar<Sacado::Fad::DFad<ScalarType>>;
282  using derivative_type =
284  static const unsigned int n_supported_derivative_levels =
286  };
287 
288 
293  template <typename NumberType>
294  struct Marking<Sacado::Fad::DFad<NumberType>>
295  {
296  using ad_type =
298  using derivative_type = typename SacadoNumberInfo<
299  Sacado::Fad::DFad<NumberType>>::derivative_type;
300  using scalar_type =
301  typename SacadoNumberInfo<Sacado::Fad::DFad<NumberType>>::scalar_type;
302 
303  /*
304  * Initialize the state of an independent variable.
305  */
306  static void
307  independent_variable(const scalar_type &in,
308  const unsigned int index,
309  const unsigned int n_independent_variables,
310  ad_type & out)
311  {
312  // It is required that we first initialise the outer number before
313  // any of the nested ones.
314  out = ad_type(n_independent_variables, index, in);
315 
316  // Initialize potential nested directional derivatives
318  in, index, n_independent_variables, out.val());
319  }
320 
321  /*
322  * Initialize the state of a dependent variable.
323  */
324  static void
325  dependent_variable(ad_type &out, const ad_type &func)
326  {
327  out = func;
328  }
329  };
330 
331 
336  template <typename NumberType>
337  struct Marking<Sacado::Rad::ADvar<NumberType>>
338  {
339  using ad_type =
341  using derivative_type = typename SacadoNumberInfo<
342  Sacado::Rad::ADvar<NumberType>>::derivative_type;
343  using scalar_type = typename SacadoNumberInfo<
344  Sacado::Rad::ADvar<NumberType>>::scalar_type;
345 
346  /*
347  * Initialize the state of an independent variable.
348  */
349  static void
350  independent_variable(const scalar_type &in,
351  const unsigned int index,
352  const unsigned int n_independent_variables,
353  ad_type & out)
354  {
355  // For Sacado::Rad::ADvar numbers, we have to initialize the
356  // ADNumber with an already fully-configured value. This means
357  // that if this nests another ADNumber then the nested number
358  // must already be setup and ready for use.
359 
360  // Initialize potential nested directional derivatives
361  derivative_type derivative_initializer;
363  in, index, n_independent_variables, derivative_initializer);
364 
365  // Initialize the outer ad_type
366  out = derivative_initializer;
367  }
368 
369  /*
370  * Initialize the state of a dependent variable.
371  */
372  static void
373  dependent_variable(ad_type &out, const ad_type &func)
374  {
375  out = func;
376  }
377  };
378 
379 
387  template <typename NumberType>
388  struct ExtractData<Sacado::Fad::DFad<NumberType>>
389  {
390  using derivative_type = typename SacadoNumberInfo<
391  Sacado::Fad::DFad<NumberType>>::derivative_type;
392  using scalar_type =
393  typename SacadoNumberInfo<Sacado::Fad::DFad<NumberType>>::scalar_type;
394  using value_type =
396 
400  static scalar_type
401  value(const Sacado::Fad::DFad<NumberType> &x)
402  {
403  return ExtractData<value_type>::value(x.val());
404  }
405 
406 
410  static unsigned int
411  n_directional_derivatives(const Sacado::Fad::DFad<NumberType> &x)
412  {
413  return x.size();
414  }
415 
416 
420  static derivative_type
421  directional_derivative(const Sacado::Fad::DFad<NumberType> &x,
422  const unsigned int direction)
423  {
424  if (x.hasFastAccess())
425  return x.fastAccessDx(direction);
426  else
427  return x.dx(direction);
428  }
429  };
430 
431 
439  template <typename NumberType>
440  struct ExtractData<Sacado::Rad::ADvar<NumberType>>
441  {
442  using derivative_type = typename SacadoNumberInfo<
443  Sacado::Rad::ADvar<NumberType>>::derivative_type;
444  using scalar_type = typename SacadoNumberInfo<
445  Sacado::Rad::ADvar<NumberType>>::scalar_type;
446  using value_type =
448 
452  static scalar_type
453  value(const Sacado::Rad::ADvar<NumberType> &x)
454  {
455  return ExtractData<value_type>::value(x.val());
456  }
457 
458 
462  static unsigned int
463  n_directional_derivatives(const Sacado::Rad::ADvar<NumberType> &)
464  {
465  // There are as many directional derivatives as there are
466  // independent variables, but each independent variable can
467  // only return one directional derivative.
468  return 1;
469  }
470 
471 
479  static derivative_type
480  directional_derivative(const Sacado::Rad::ADvar<NumberType> &x,
481  const unsigned int)
482  {
483  return x.adj();
484  }
485  };
486 
487  } // namespace internal
488 
489 
490  /* -------------- NumberTypes::sacado_dfad -------------- */
491 
492 
499  template <typename ADNumberType>
500  struct ADNumberTraits<
501  ADNumberType,
502  typename std::enable_if<std::is_same<
503  ADNumberType,
504  Sacado::Fad::DFad<typename ADNumberType::scalar_type>>::value>::type>
505  : NumberTraits<typename ADNumberType::scalar_type,
506  NumberTypes::sacado_dfad>
507  {};
508 
509 
516  template <typename ADNumberType>
517  struct ADNumberTraits<
518  ADNumberType,
519  typename std::enable_if<std::is_same<
520  ADNumberType,
521  std::complex<Sacado::Fad::DFad<
522  typename ADNumberType::value_type::scalar_type>>>::value>::type>
523  : NumberTraits<
524  std::complex<typename ADNumberType::value_type::scalar_type>,
525  NumberTypes::sacado_dfad>
526  {};
527 
528 
533  template <>
534  struct NumberTraits<Sacado::Fad::DFad<float>, NumberTypes::sacado_dfad>
535  : NumberTraits<
536  typename ADNumberTraits<Sacado::Fad::DFad<float>>::scalar_type,
537  NumberTypes::sacado_dfad>
538  {};
539 
540 
545  template <>
546  struct NumberTraits<std::complex<Sacado::Fad::DFad<float>>,
548  : NumberTraits<typename ADNumberTraits<
549  std::complex<Sacado::Fad::DFad<float>>>::scalar_type,
550  NumberTypes::sacado_dfad>
551  {};
552 
553 
558  template <>
559  struct NumberTraits<Sacado::Fad::DFad<double>, NumberTypes::sacado_dfad>
560  : NumberTraits<
561  typename ADNumberTraits<Sacado::Fad::DFad<double>>::scalar_type,
562  NumberTypes::sacado_dfad>
563  {};
564 
565 
570  template <>
571  struct NumberTraits<std::complex<Sacado::Fad::DFad<double>>,
573  : NumberTraits<typename ADNumberTraits<
574  std::complex<Sacado::Fad::DFad<double>>>::scalar_type,
575  NumberTypes::sacado_dfad>
576  {};
577 
578 
579  /* -------------- NumberTypes::sacado_rad -------------- */
580 
581 
588  template <typename ADNumberType>
589  struct ADNumberTraits<
590  ADNumberType,
591  typename std::enable_if<std::is_same<
592  ADNumberType,
593  Sacado::Rad::ADvar<typename ADNumberType::ADVari::scalar_type>>::
594  value>::type>
595  : NumberTraits<typename ADNumberType::ADVari::scalar_type,
596  NumberTypes::sacado_rad>
597  {};
598 
599 
604  template <>
605  struct NumberTraits<Sacado::Rad::ADvar<float>, NumberTypes::sacado_rad>
606  : NumberTraits<
607  typename ADNumberTraits<Sacado::Rad::ADvar<float>>::scalar_type,
608  NumberTypes::sacado_rad>
609  {};
610 
611 
616  template <>
617  struct NumberTraits<Sacado::Rad::ADvar<double>, NumberTypes::sacado_rad>
618  : NumberTraits<
619  typename ADNumberTraits<Sacado::Rad::ADvar<double>>::scalar_type,
620  NumberTypes::sacado_rad>
621  {};
622 
623 
624 # ifdef DEAL_II_TRILINOS_CXX_SUPPORTS_SACADO_COMPLEX_RAD
625 
626 
633  template <typename ADNumberType>
634  struct ADNumberTraits<
635  ADNumberType,
636  typename std::enable_if<std::is_same<
637  ADNumberType,
638  std::complex<Sacado::Rad::ADvar<typename ADNumberType::value_type::
639  ADVari::scalar_type>>>::value>::type>
640  : NumberTraits<
641  std::complex<typename ADNumberType::value_type::ADVari::scalar_type>,
642  NumberTypes::sacado_rad>
643  {};
644 
645 
650  template <>
651  struct NumberTraits<std::complex<Sacado::Rad::ADvar<float>>,
653  : NumberTraits<typename ADNumberTraits<
654  std::complex<Sacado::Rad::ADvar<float>>>::scalar_type,
655  NumberTypes::sacado_rad>
656  {};
657 
658 
663  template <>
664  struct NumberTraits<std::complex<Sacado::Rad::ADvar<double>>,
666  : NumberTraits<typename ADNumberTraits<
667  std::complex<Sacado::Rad::ADvar<double>>>::scalar_type,
668  NumberTypes::sacado_rad>
669  {};
670 
671 
672 # endif
673 
674 
675  /* -------------- NumberTypes::sacado_dfad_dfad -------------- */
676 
684  template <typename ADNumberType>
685  struct ADNumberTraits<
686  ADNumberType,
687  typename std::enable_if<
688  std::is_same<ADNumberType,
689  Sacado::Fad::DFad<Sacado::Fad::DFad<
690  typename ADNumberType::scalar_type>>>::value>::type>
691  : NumberTraits<typename ADNumberType::scalar_type,
692  NumberTypes::sacado_dfad_dfad>
693  {};
694 
695 
703  template <typename ADNumberType>
704  struct ADNumberTraits<
705  ADNumberType,
706  typename std::enable_if<std::is_same<
707  ADNumberType,
708  std::complex<Sacado::Fad::DFad<Sacado::Fad::DFad<
709  typename ADNumberType::value_type::scalar_type>>>>::value>::type>
710  : NumberTraits<
711  std::complex<typename ADNumberType::value_type::scalar_type>,
712  NumberTypes::sacado_dfad_dfad>
713  {};
714 
715 
720  template <>
721  struct NumberTraits<Sacado::Fad::DFad<Sacado::Fad::DFad<float>>,
723  : NumberTraits<typename ADNumberTraits<Sacado::Fad::DFad<
724  Sacado::Fad::DFad<float>>>::scalar_type,
725  NumberTypes::sacado_dfad_dfad>
726  {};
727 
728 
733  template <>
734  struct NumberTraits<
735  std::complex<Sacado::Fad::DFad<Sacado::Fad::DFad<float>>>,
737  : NumberTraits<typename ADNumberTraits<std::complex<Sacado::Fad::DFad<
738  Sacado::Fad::DFad<float>>>>::scalar_type,
739  NumberTypes::sacado_dfad_dfad>
740  {};
741 
742 
747  template <>
748  struct NumberTraits<Sacado::Fad::DFad<Sacado::Fad::DFad<double>>,
750  : NumberTraits<typename ADNumberTraits<Sacado::Fad::DFad<
751  Sacado::Fad::DFad<double>>>::scalar_type,
752  NumberTypes::sacado_dfad_dfad>
753  {};
754 
755 
760  template <>
761  struct NumberTraits<
762  std::complex<Sacado::Fad::DFad<Sacado::Fad::DFad<double>>>,
764  : NumberTraits<typename ADNumberTraits<std::complex<Sacado::Fad::DFad<
765  Sacado::Fad::DFad<double>>>>::scalar_type,
766  NumberTypes::sacado_dfad_dfad>
767  {};
768 
769 
770  /* -------------- NumberTypes::sacado_rad_dfad -------------- */
771 
779  template <typename ADNumberType>
780  struct ADNumberTraits<
781  ADNumberType,
782  typename std::enable_if<std::is_same<
783  ADNumberType,
784  Sacado::Rad::ADvar<Sacado::Fad::DFad<
785  typename ADNumberType::ADVari::scalar_type>>>::value>::type>
786  : NumberTraits<typename ADNumberType::ADVari::scalar_type,
787  NumberTypes::sacado_rad_dfad>
788  {};
789 
790 
795  template <>
796  struct NumberTraits<Sacado::Rad::ADvar<Sacado::Fad::DFad<float>>,
798  : NumberTraits<typename ADNumberTraits<Sacado::Rad::ADvar<
799  Sacado::Fad::DFad<float>>>::scalar_type,
800  NumberTypes::sacado_rad_dfad>
801  {};
802 
803 
808  template <>
809  struct NumberTraits<Sacado::Rad::ADvar<Sacado::Fad::DFad<double>>,
811  : NumberTraits<typename ADNumberTraits<Sacado::Rad::ADvar<
812  Sacado::Fad::DFad<double>>>::scalar_type,
813  NumberTypes::sacado_rad_dfad>
814  {};
815 
816 
817 # ifdef DEAL_II_TRILINOS_CXX_SUPPORTS_SACADO_COMPLEX_RAD
818 
819 
827  template <typename ADNumberType>
828  struct ADNumberTraits<
829  ADNumberType,
830  typename std::enable_if<std::is_same<
831  ADNumberType,
832  std::complex<Sacado::Rad::ADvar<Sacado::Fad::DFad<
833  typename ADNumberType::value_type::ADVari::scalar_type>>>>::value>::
834  type>
835  : NumberTraits<
836  std::complex<typename ADNumberType::value_type::ADVari::scalar_type>,
837  NumberTypes::sacado_rad_dfad>
838  {};
839 
840 
845  template <>
846  struct NumberTraits<
847  std::complex<Sacado::Rad::ADvar<Sacado::Fad::DFad<float>>>,
849  : NumberTraits<typename ADNumberTraits<std::complex<Sacado::Rad::ADvar<
850  Sacado::Fad::DFad<float>>>>::scalar_type,
851  NumberTypes::sacado_rad_dfad>
852  {};
853 
854 
859  template <>
860  struct NumberTraits<
861  std::complex<Sacado::Rad::ADvar<Sacado::Fad::DFad<double>>>,
863  : NumberTraits<typename ADNumberTraits<std::complex<Sacado::Rad::ADvar<
864  Sacado::Fad::DFad<double>>>>::scalar_type,
865  NumberTypes::sacado_rad_dfad>
866  {};
867 
868 
869 # endif
870 
871 
872  /* -------------- Additional type traits -------------- */
873 
874 
875  template <typename NumberType>
876  struct is_sacado_dfad_number<
877  NumberType,
878  typename std::enable_if<
879  ADNumberTraits<typename std::decay<NumberType>::type>::type_code ==
880  NumberTypes::sacado_dfad ||
881  ADNumberTraits<typename std::decay<NumberType>::type>::type_code ==
882  NumberTypes::sacado_dfad_dfad>::type> : std::true_type
883  {};
884 
885 
886  template <typename NumberType>
887  struct is_sacado_dfad_number<
888  NumberType,
889  typename std::enable_if<std::is_same<
890  NumberType,
891  Sacado::Fad::Expr<typename NumberType::value_type>>::value>::type>
892  : std::true_type
893  {};
894 
895 
896  template <typename NumberType>
897  struct is_sacado_rad_number<
898  NumberType,
899  typename std::enable_if<
900  ADNumberTraits<typename std::decay<NumberType>::type>::type_code ==
901  NumberTypes::sacado_rad ||
902  ADNumberTraits<typename std::decay<NumberType>::type>::type_code ==
903  NumberTypes::sacado_rad_dfad>::type> : std::true_type
904  {};
905 
906 
907  template <typename NumberType>
908  struct is_sacado_rad_number<
909  NumberType,
910  typename std::enable_if<std::is_same<
911  NumberType,
912  Sacado::Rad::ADvari<Sacado::Fad::DFad<
913  typename NumberType::ADVari::scalar_type>>>::value>::type>
914  : std::true_type
915  {};
916 
917 
918  template <typename NumberType>
919  struct is_sacado_number<
920  NumberType,
921  typename std::enable_if<is_sacado_dfad_number<NumberType>::value ||
922  is_sacado_rad_number<NumberType>::value>::type>
923  : std::true_type
924  {};
925 
926  } // namespace AD
927 } // namespace Differentiation
928 
929 # endif // DOXYGEN
930 
931 
932 
933 DEAL_II_NAMESPACE_CLOSE
934 
935 
936 #endif // DEAL_II_TRILINOS_WITH_SACADO
937 
938 #endif
STL namespace.