Reference documentation for deal.II version 9.1.0-pre
ad_number_traits.h
1 // ---------------------------------------------------------------------
2 //
3 // Copyright (C) 2016 - 2017 by the deal.II authors
4 //
5 // This file is part of the deal.II library.
6 //
7 // The deal.II library is free software; you can use it, redistribute
8 // it, and/or modify it under the terms of the GNU Lesser General
9 // Public License as published by the Free Software Foundation; either
10 // version 2.1 of the License, or (at your option) any later version.
11 // The full text of the license can be found in the file LICENSE.md at
12 // the top level directory of deal.II.
13 //
14 // ---------------------------------------------------------------------
15 
16 #ifndef dealii_differentiation_ad_ad_number_traits_h
17 #define dealii_differentiation_ad_ad_number_traits_h
18 
19 #include <deal.II/base/exceptions.h>
20 
21 #include <deal.II/differentiation/ad/ad_number_types.h>
22 
23 #include <boost/type_traits.hpp>
24 
25 #include <complex>
26 #include <type_traits>
27 
28 DEAL_II_NAMESPACE_OPEN
29 
30 namespace Differentiation
31 {
32  namespace AD
33  {
47  template <typename ScalarType,
48  enum NumberTypes ADNumberTypeCode,
49  typename T = void>
50  struct NumberTraits;
51 
52 
53 
68  template <typename ADNumberType, typename T = void>
70 
71 
76  namespace internal
77  {
78  // The following three classes, namely ADNumberInfoFromEnum, Marking, and
79  // ExtractData, are those that need to be implemented for each new
80  // auto-differentiable number type. This information is then used by
81  // NumberTraits and ADNumberTraits to provide a uniform interface, as used
82  // by our drivers, to the underlying number types.
83 
84 
110  template <typename ScalarType,
111  enum NumberTypes ADNumberTypeCode,
112  typename = void>
114 
115 
147  template <typename ADNumberType, typename T = void>
148  struct Marking;
149 
150 
182  template <typename ADNumberType, typename T = void>
183  struct ExtractData;
184 
185 
200  template <typename ADNumberTrait, typename T = void>
202 
203 
220  template <typename T>
221  struct NumberType;
222 
223 
230  template <typename Number>
232 
233  } // namespace internal
234 
235 
244  template <typename NumberType>
245  struct is_ad_number;
246 
247 
256  template <typename NumberType, typename = void>
258 
259 
268  template <typename NumberType, typename = void>
270 
271 
280  template <typename NumberType, typename = void>
282 
283 
292  template <typename NumberType, typename = void>
294 
295  } // namespace AD
296 } // namespace Differentiation
297 
298 
299 /* ----------- inline and template functions and specializations ----------- */
300 
301 
302 #ifndef DOXYGEN
303 
304 
305 namespace Differentiation
306 {
307  namespace AD
308  {
309  namespace internal
310  {
311  template <typename ADNumberTrait, typename>
312  struct HasRequiredADInfo : std::false_type
313  {};
314 
315 
326  template <typename ADNumberTrait>
327  struct HasRequiredADInfo<
328  ADNumberTrait,
329  decltype((void)ADNumberTrait::type_code,
330  (void)ADNumberTrait::is_taped,
331  (void)std::declval<typename ADNumberTrait::real_type>(),
332  (void)std::declval<typename ADNumberTrait::derivative_type>(),
333  void())> : std::true_type
334  {};
335 
336 
342  template <typename ScalarType>
343  struct Marking<
344  ScalarType,
345  typename std::enable_if<std::is_arithmetic<ScalarType>::value>::type>
346  {
354  template <typename ADNumberType>
355  static void
356  independent_variable(const ScalarType &in,
357  const unsigned int,
358  const unsigned int,
359  ADNumberType &out)
360  {
361  out = in;
362  }
363 
364  /*
365  * Initialize the state of a dependent variable.
366  */
367  template <typename ADNumberType>
368  static void
369  dependent_variable(ADNumberType &, const ScalarType &)
370  {
371  AssertThrow(
372  false,
373  ExcMessage(
374  "Floating point numbers cannot be marked as dependent variables."));
375  }
376  };
377 
378 
382  template <typename ADNumberType>
383  struct Marking<
384  ADNumberType,
385  typename std::enable_if<boost::is_complex<ADNumberType>::value>::type>
386  {
387  /*
388  * Initialize the state of an independent variable.
389  */
390  template <typename ScalarType>
391  static void
392  independent_variable(const ScalarType &in,
393  const unsigned int,
394  const unsigned int,
395  ADNumberType &out)
396  {
397  AssertThrow(
398  false,
399  ExcMessage(
400  "Marking for complex numbers has not yet been implemented."));
401  out = in;
402  }
403 
404  /*
405  * Initialize the state of a dependent variable.
406  */
407  template <typename ScalarType>
408  static void
409  dependent_variable(ADNumberType &, const ScalarType &)
410  {
411  AssertThrow(
412  false,
413  ExcMessage(
414  "Marking for complex numbers has not yet been implemented."));
415  }
416  };
417 
418  } // namespace internal
419 
420 
421  template <typename NumberType, typename>
422  struct is_taped_ad_number : std::false_type
423  {};
424 
425 
426  template <typename NumberType, typename>
427  struct is_tapeless_ad_number : std::false_type
428  {};
429 
430 
431  template <typename NumberType, typename>
432  struct is_real_valued_ad_number : std::false_type
433  {};
434 
435 
436  template <typename NumberType, typename>
437  struct is_complex_valued_ad_number : std::false_type
438  {};
439 
440 
446  template <typename NumberType>
447  struct is_ad_number
449  ADNumberTraits<typename std::decay<NumberType>::type>>
450  {};
451 
452 
457  template <typename NumberType>
458  struct is_taped_ad_number<
459  NumberType,
460  typename std::enable_if<
461  ADNumberTraits<typename std::decay<NumberType>::type>::is_taped>::type>
462  : std::true_type
463  {};
464 
465 
470  template <typename NumberType>
471  struct is_tapeless_ad_number<
472  NumberType,
473  typename std::enable_if<ADNumberTraits<
474  typename std::decay<NumberType>::type>::is_tapeless>::type>
475  : std::true_type
476  {};
477 
478 
484  template <typename NumberType>
486  NumberType,
487  typename std::enable_if<ADNumberTraits<
488  typename std::decay<NumberType>::type>::is_real_valued>::type>
489  : std::true_type
490  {};
491 
492 
498  template <typename NumberType>
500  NumberType,
501  typename std::enable_if<ADNumberTraits<
502  typename std::decay<NumberType>::type>::is_complex_valued>::type>
503  : std::true_type
504  {};
505 
506 
507  namespace internal
508  {
513  template <typename Number>
514  struct RemoveComplexWrapper
515  {
516  using type = Number;
517  };
518 
519 
525  template <typename Number>
526  struct RemoveComplexWrapper<std::complex<Number>>
527  {
528  using type = typename RemoveComplexWrapper<Number>::type;
529  };
530 
531 
537  template <typename NumberType>
538  struct ExtractData<
539  NumberType,
540  typename std::enable_if<std::is_arithmetic<NumberType>::value>::type>
541  {
545  static const NumberType &
546  value(const NumberType &x)
547  {
548  return x;
549  }
550 
551 
555  static unsigned int
556  n_directional_derivatives(const NumberType &)
557  {
558  return 0;
559  }
560 
561 
565  static NumberType
566  directional_derivative(const NumberType &, const unsigned int)
567  {
568  return 0.0;
569  }
570  };
571 
572 
573 
578  template <typename ADNumberType>
579  struct ExtractData<std::complex<ADNumberType>>
580  {
582  "Expected an auto-differentiable number.");
583 
584 
588  static std::complex<typename ADNumberTraits<ADNumberType>::scalar_type>
589  value(const std::complex<ADNumberType> &x)
590  {
591  return std::complex<
593  ExtractData<ADNumberType>::value(x.real()),
594  ExtractData<ADNumberType>::value(x.imag()));
595  }
596 
597 
601  static unsigned int
602  n_directional_derivatives(const std::complex<ADNumberType> &x)
603  {
604  return ExtractData<ADNumberType>::n_directional_derivatives(x.real());
605  }
606 
607 
611  static std::complex<
613  directional_derivative(const std::complex<ADNumberType> &x,
614  const unsigned int direction)
615  {
616  return std::complex<
618  ExtractData<ADNumberType>::directional_derivative(x.real(),
619  direction),
620  ExtractData<ADNumberType>::directional_derivative(x.imag(),
621  direction));
622  }
623  };
624 
625 
626  template <typename T>
627  struct NumberType
628  {
632  template <typename F>
633  static auto
634  value(const F &f,
635  typename std::enable_if<!is_ad_number<F>::value>::type * =
636  nullptr) -> decltype(::internal::NumberType<T>::value(f))
637  {
638  // We call the other function defined in the numbers
639  // header to take care of all of the usual cases.
640  return ::internal::NumberType<T>::value(f);
641  }
642 
649  template <typename F>
650  static T
651  value(const F &f,
652  typename std::enable_if<is_ad_number<F>::value &&
653  std::is_arithmetic<T>::value>::type * =
654  nullptr)
655  {
656  // We recursively call this function in case the AD number is a
657  // nested one. The recursion ends when the extracted value is
658  // a floating point number.
659  return NumberType<T>::value(ExtractData<F>::value(f));
660  }
661 
668  template <typename F>
669  static T
670  value(const F &f,
671  typename std::enable_if<is_ad_number<F>::value &&
672  is_ad_number<T>::value>::type * = nullptr)
673  {
674  return T(f);
675  }
676  };
677 
678  template <typename T>
679  struct NumberType<std::complex<T>>
680  {
684  template <typename F>
685  static auto
686  value(
687  const F &f,
688  typename std::enable_if<!is_ad_number<F>::value>::type * = nullptr)
689  -> decltype(::internal::NumberType<std::complex<T>>::value(f))
690  {
691  // We call the other function defined in the numbers
692  // header to take care of all of the usual cases.
693  return ::internal::NumberType<std::complex<T>>::value(f);
694  }
695 
696 
701  template <typename F>
702  static std::complex<T>
703  value(const F &f,
704  typename std::enable_if<is_ad_number<F>::value &&
705  std::is_arithmetic<T>::value>::type * =
706  nullptr)
707  {
708  // We recursively call this function in case the AD number is a
709  // nested one. The recursion ends when the extracted value is
710  // a floating point number.
711  return std::complex<T>(
712  NumberType<T>::value(ExtractData<F>::value(f)));
713  }
714 
715  template <typename F>
716  static std::complex<T>
717  value(const std::complex<F> &f)
718  {
719  // Deal with the two parts of the input complex
720  // number individually.
721  return std::complex<T>(NumberType<T>::value(f.real()),
722  NumberType<T>::value(f.imag()));
723  }
724  };
725 
726  } // namespace internal
727 
728 
729 
750  template <typename ScalarType, enum NumberTypes ADNumberTypeCode>
751  struct NumberTraits<
752  ScalarType,
753  ADNumberTypeCode,
754  typename std::enable_if<
755  std::is_floating_point<ScalarType>::value ||
756  (boost::is_complex<ScalarType>::value &&
757  std::is_floating_point<typename internal::RemoveComplexWrapper<
758  ScalarType>::type>::value)>::type>
759  {
763  static constexpr enum NumberTypes type_code = ADNumberTypeCode;
764 
765  // The clang compiler does not seem to like these
766  // variables being defined as constant expressions
767  // (the tests <adolc|sacado>/ad_number_traits_02 will
768  // fail with linking errors). However, GCC complains
769  // about the use of non-constant expressions in
770  // std::conditional.
771 # ifdef __clang__
772 
777  static const bool is_taped;
778 
779 
784  static const bool is_tapeless;
785 
786 
791  static const bool is_real_valued;
792 
793 
798  static const bool is_complex_valued;
799 
800 
805  static const unsigned int n_supported_derivative_levels;
806 
807 # else
808 
813  static constexpr bool is_taped = internal::ADNumberInfoFromEnum<
815  ADNumberTypeCode>::is_taped;
816 
817 
822  static constexpr bool is_tapeless =
824 
825 
830  static constexpr bool is_real_valued =
831  (!boost::is_complex<ScalarType>::value);
832 
833 
838  static constexpr bool is_complex_valued =
840 
841 
846  static constexpr unsigned int n_supported_derivative_levels =
848  typename internal::RemoveComplexWrapper<ScalarType>::type,
849  ADNumberTypeCode>::n_supported_derivative_levels;
850 
851 # endif
852 
853 
858  using scalar_type = ScalarType;
859 
860 
864  using real_type = typename internal::ADNumberInfoFromEnum<
865  typename internal::RemoveComplexWrapper<ScalarType>::type,
866  ADNumberTypeCode>::real_type;
867 
868 
872  using complex_type = std::complex<real_type>;
873 
874 
878  using ad_type = typename std::
879  conditional<is_real_valued, real_type, complex_type>::type;
880 
884  using derivative_type = typename std::conditional<
885  is_real_valued,
887  typename internal::RemoveComplexWrapper<ScalarType>::type,
888  ADNumberTypeCode>::derivative_type,
889  std::complex<typename internal::ADNumberInfoFromEnum<
890  typename internal::RemoveComplexWrapper<ScalarType>::type,
891  ADNumberTypeCode>::derivative_type>>::type;
892 
893 
897  static scalar_type get_scalar_value(const ad_type &x)
898  {
899  // Some tricky conversion cases to consider here:
900  // - Nested AD numbers
901  // - std::complex<double> --> std::complex<float>
902  // e.g. when ScalarType = float and ADNumberTypeCode = adolc_taped
903  // Therefore, we use the internal casting mechanism
904  // provided by the internal::NumberType struct.
907  }
908 
909 
913  static derivative_type get_directional_derivative(
914  const ad_type &x, const unsigned int direction)
915  {
917  x, direction);
918  }
919 
920 
925  static unsigned int n_directional_derivatives(const ad_type &x)
926  {
928  }
929 
930 
931  static_assert((is_real_valued == true ?
932  std::is_same<ad_type, real_type>::value :
933  std::is_same<ad_type, complex_type>::value),
934  "Incorrect template type selected for ad_type");
935 
936  static_assert((is_complex_valued == true ?
937  boost::is_complex<scalar_type>::value :
938  true),
939  "Expected a complex float_type");
940 
941  static_assert((is_complex_valued == true ?
942  boost::is_complex<ad_type>::value :
943  true),
944  "Expected a complex ad_type");
945  };
946 
947 # ifdef __clang__
948 
949  template <typename ScalarType, enum NumberTypes ADNumberTypeCode>
950  const bool NumberTraits<
951  ScalarType,
952  ADNumberTypeCode,
953  typename std::enable_if<
954  std::is_floating_point<ScalarType>::value ||
955  (boost::is_complex<ScalarType>::value &&
956  std::is_floating_point<typename internal::RemoveComplexWrapper<
957  ScalarType>::type>::value)>::type>::is_taped =
959  typename internal::RemoveComplexWrapper<ScalarType>::type,
960  ADNumberTypeCode>::is_taped;
961 
962 
963  template <typename ScalarType, enum NumberTypes ADNumberTypeCode>
964  const bool NumberTraits<
965  ScalarType,
966  ADNumberTypeCode,
967  typename std::enable_if<
968  std::is_floating_point<ScalarType>::value ||
969  (boost::is_complex<ScalarType>::value &&
970  std::is_floating_point<typename internal::RemoveComplexWrapper<
971  ScalarType>::type>::value)>::type>::is_tapeless =
973 
974 
975  template <typename ScalarType, enum NumberTypes ADNumberTypeCode>
976  const bool NumberTraits<
977  ScalarType,
978  ADNumberTypeCode,
979  typename std::enable_if<
980  std::is_floating_point<ScalarType>::value ||
981  (boost::is_complex<ScalarType>::value &&
982  std::is_floating_point<typename internal::RemoveComplexWrapper<
983  ScalarType>::type>::value)>::type>::is_real_valued =
984  (!boost::is_complex<ScalarType>::value);
985 
986 
987  template <typename ScalarType, enum NumberTypes ADNumberTypeCode>
988  const bool NumberTraits<
989  ScalarType,
990  ADNumberTypeCode,
991  typename std::enable_if<
992  std::is_floating_point<ScalarType>::value ||
993  (boost::is_complex<ScalarType>::value &&
994  std::is_floating_point<typename internal::RemoveComplexWrapper<
995  ScalarType>::type>::value)>::type>::is_complex_valued =
997 
998 
999  template <typename ScalarType, enum NumberTypes ADNumberTypeCode>
1000  const unsigned int NumberTraits<
1001  ScalarType,
1002  ADNumberTypeCode,
1003  typename std::enable_if<
1004  std::is_floating_point<ScalarType>::value ||
1005  (boost::is_complex<ScalarType>::value &&
1006  std::is_floating_point<typename internal::RemoveComplexWrapper<
1007  ScalarType>::type>::value)>::type>::n_supported_derivative_levels =
1009  typename internal::RemoveComplexWrapper<ScalarType>::type,
1010  ADNumberTypeCode>::n_supported_derivative_levels;
1011 
1012 # endif
1013 
1014 
1031  template <typename ScalarType>
1032  struct ADNumberTraits<
1033  ScalarType,
1034  typename std::enable_if<std::is_arithmetic<ScalarType>::value>::type>
1035  {
1040  using scalar_type = ScalarType;
1041 
1042  static ScalarType
1043  get_directional_derivative(const ScalarType & /*x*/,
1044  const unsigned int /*direction*/)
1045  {
1046  // If the AD drivers are correctly implemented then we should not get
1047  // here. This is essentially a dummy for when the ADNumberTypeCode for
1048  // the original AD number (from which one is getting a derivative >= 2)
1049  // is one that specified Adol-C taped and tapeless numbers, or a
1050  // non-nested Sacado number.
1051  AssertThrow(
1052  false,
1053  ExcMessage(
1054  "Floating point numbers have no directional derivatives."));
1055  return 0.0;
1056  }
1057  };
1058 
1059  } // namespace AD
1060 } // namespace Differentiation
1061 
1062 #endif // DOXYGEN
1063 
1064 
1065 namespace numbers
1066 {
1067  template <typename ADNumberType>
1068  bool
1069  is_nan(const typename std::enable_if<
1071  ADNumberType>::type &x)
1072  {
1073  return is_nan(
1075  }
1076 
1077 } // namespace numbers
1078 
1079 
1080 DEAL_II_NAMESPACE_CLOSE
1081 
1082 #endif
STL namespace.
#define AssertThrow(cond, exc)
Definition: exceptions.h:1329
static::ExceptionBase & ExcMessage(std::string arg1)