16 #ifndef dealii_differentiation_ad_ad_number_traits_h 17 #define dealii_differentiation_ad_ad_number_traits_h 19 #include <deal.II/base/exceptions.h> 21 #include <deal.II/differentiation/ad/ad_number_types.h> 23 #include <boost/type_traits.hpp> 26 #include <type_traits> 28 DEAL_II_NAMESPACE_OPEN
47 template <
typename ScalarType,
68 template <
typename ADNumberType,
typename T =
void>
110 template <
typename ScalarType,
147 template <
typename ADNumberType,
typename T =
void>
182 template <
typename ADNumberType,
typename T =
void>
200 template <
typename ADNumberTrait,
typename T =
void>
220 template <
typename T>
230 template <
typename Number>
244 template <
typename NumberType>
256 template <
typename NumberType,
typename =
void>
268 template <
typename NumberType,
typename =
void>
280 template <
typename NumberType,
typename =
void>
292 template <
typename NumberType,
typename =
void>
311 template <
typename ADNumberTrait,
typename>
312 struct HasRequiredADInfo : std::false_type
326 template <
typename ADNumberTrait>
327 struct HasRequiredADInfo<
329 decltype((void)ADNumberTrait::type_code,
330 (void)ADNumberTrait::is_taped,
331 (void)std::declval<typename ADNumberTrait::real_type>(),
332 (void)std::declval<typename ADNumberTrait::derivative_type>(),
333 void())> : std::true_type
342 template <
typename ScalarType>
345 typename
std::enable_if<std::is_arithmetic<ScalarType>::value>::type>
354 template <
typename ADNumberType>
356 independent_variable(
const ScalarType &in,
367 template <
typename ADNumberType>
369 dependent_variable(ADNumberType &,
const ScalarType &)
374 "Floating point numbers cannot be marked as dependent variables."));
382 template <
typename ADNumberType>
385 typename
std::enable_if<boost::is_complex<ADNumberType>::value>::type>
390 template <
typename ScalarType>
392 independent_variable(
const ScalarType &in,
400 "Marking for complex numbers has not yet been implemented."));
407 template <
typename ScalarType>
409 dependent_variable(ADNumberType &,
const ScalarType &)
414 "Marking for complex numbers has not yet been implemented."));
421 template <
typename NumberType,
typename>
426 template <
typename NumberType,
typename>
431 template <
typename NumberType,
typename>
436 template <
typename NumberType,
typename>
446 template <
typename NumberType>
449 ADNumberTraits<typename std::decay<NumberType>::type>>
457 template <
typename NumberType>
460 typename
std::enable_if<
461 ADNumberTraits<typename std::decay<NumberType>::type>::is_taped>::type>
470 template <
typename NumberType>
473 typename std::enable_if<ADNumberTraits<
474 typename std::decay<NumberType>::type>::is_tapeless>::type>
484 template <
typename NumberType>
487 typename std::enable_if<ADNumberTraits<
488 typename std::decay<NumberType>::type>::is_real_valued>::type>
498 template <
typename NumberType>
501 typename std::enable_if<ADNumberTraits<
502 typename std::decay<NumberType>::type>::is_complex_valued>::type>
513 template <
typename Number>
514 struct RemoveComplexWrapper
525 template <
typename Number>
526 struct RemoveComplexWrapper<std::complex<Number>>
528 using type =
typename RemoveComplexWrapper<Number>::type;
537 template <
typename NumberType>
540 typename std::enable_if<std::is_arithmetic<NumberType>::value>::type>
545 static const NumberType &
546 value(
const NumberType &x)
556 n_directional_derivatives(
const NumberType &)
566 directional_derivative(
const NumberType &,
const unsigned int)
578 template <
typename ADNumberType>
579 struct ExtractData<std::complex<ADNumberType>>
582 "Expected an auto-differentiable number.");
588 static std::complex<typename ADNumberTraits<ADNumberType>::scalar_type>
589 value(
const std::complex<ADNumberType> &x)
593 ExtractData<ADNumberType>::value(x.real()),
594 ExtractData<ADNumberType>::value(x.imag()));
602 n_directional_derivatives(
const std::complex<ADNumberType> &x)
604 return ExtractData<ADNumberType>::n_directional_derivatives(x.real());
613 directional_derivative(
const std::complex<ADNumberType> &x,
614 const unsigned int direction)
618 ExtractData<ADNumberType>::directional_derivative(x.real(),
620 ExtractData<ADNumberType>::directional_derivative(x.imag(),
626 template <
typename T>
632 template <
typename F>
640 return ::internal::NumberType<T>::value(f);
649 template <
typename F>
653 std::is_arithmetic<T>::value>::type * =
659 return NumberType<T>::value(ExtractData<F>::value(f));
668 template <
typename F>
678 template <
typename T>
679 struct NumberType<std::complex<T>>
684 template <
typename F>
693 return ::internal::NumberType<std::complex<T>>::value(f);
701 template <
typename F>
702 static std::complex<T>
705 std::is_arithmetic<T>::value>::type * =
711 return std::complex<T>(
712 NumberType<T>::value(ExtractData<F>::value(f)));
715 template <
typename F>
716 static std::complex<T>
717 value(
const std::complex<F> &f)
721 return std::complex<T>(NumberType<T>::value(f.real()),
722 NumberType<T>::value(f.imag()));
750 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
754 typename std::enable_if<
755 std::is_floating_point<ScalarType>::value ||
756 (boost::is_complex<ScalarType>::value &&
757 std::is_floating_point<typename internal::RemoveComplexWrapper<
758 ScalarType>::type>::value)>::type>
763 static constexpr
enum NumberTypes type_code = ADNumberTypeCode;
777 static const bool is_taped;
784 static const bool is_tapeless;
791 static const bool is_real_valued;
798 static const bool is_complex_valued;
805 static const unsigned int n_supported_derivative_levels;
815 ADNumberTypeCode>::is_taped;
822 static constexpr
bool is_tapeless =
830 static constexpr
bool is_real_valued =
831 (!boost::is_complex<ScalarType>::value);
838 static constexpr
bool is_complex_valued =
846 static constexpr
unsigned int n_supported_derivative_levels =
848 typename internal::RemoveComplexWrapper<ScalarType>::type,
849 ADNumberTypeCode>::n_supported_derivative_levels;
858 using scalar_type = ScalarType;
865 typename internal::RemoveComplexWrapper<ScalarType>::type,
866 ADNumberTypeCode>::real_type;
872 using complex_type = std::complex<real_type>;
878 using ad_type =
typename std::
879 conditional<is_real_valued, real_type, complex_type>::type;
884 using derivative_type =
typename std::conditional<
887 typename internal::RemoveComplexWrapper<ScalarType>::type,
888 ADNumberTypeCode>::derivative_type,
890 typename internal::RemoveComplexWrapper<ScalarType>::type,
891 ADNumberTypeCode>::derivative_type>>::type;
897 static scalar_type get_scalar_value(
const ad_type &x)
913 static derivative_type get_directional_derivative(
914 const ad_type &x,
const unsigned int direction)
925 static unsigned int n_directional_derivatives(
const ad_type &x)
931 static_assert((is_real_valued ==
true ?
932 std::is_same<ad_type, real_type>::value :
933 std::is_same<ad_type, complex_type>::value),
934 "Incorrect template type selected for ad_type");
936 static_assert((is_complex_valued ==
true ?
937 boost::is_complex<scalar_type>::value :
939 "Expected a complex float_type");
941 static_assert((is_complex_valued ==
true ?
942 boost::is_complex<ad_type>::value :
944 "Expected a complex ad_type");
949 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
953 typename std::enable_if<
954 std::is_floating_point<ScalarType>::value ||
955 (boost::is_complex<ScalarType>::value &&
957 ScalarType>::type>::value)>::type>::is_taped =
959 typename internal::RemoveComplexWrapper<ScalarType>::type,
960 ADNumberTypeCode>::is_taped;
963 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
967 typename std::enable_if<
968 std::is_floating_point<ScalarType>::value ||
969 (boost::is_complex<ScalarType>::value &&
971 ScalarType>::type>::value)>::type>::is_tapeless =
975 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
979 typename std::enable_if<
980 std::is_floating_point<ScalarType>::value ||
981 (boost::is_complex<ScalarType>::value &&
983 ScalarType>::type>::value)>::type>::is_real_valued =
984 (!boost::is_complex<ScalarType>::value);
987 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
991 typename std::enable_if<
992 std::is_floating_point<ScalarType>::value ||
993 (boost::is_complex<ScalarType>::value &&
995 ScalarType>::type>::value)>::type>::is_complex_valued =
999 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1003 typename std::enable_if<
1004 std::is_floating_point<ScalarType>::value ||
1005 (boost::is_complex<ScalarType>::value &&
1007 ScalarType>::type>::value)>::type>::n_supported_derivative_levels =
1009 typename internal::RemoveComplexWrapper<ScalarType>::type,
1010 ADNumberTypeCode>::n_supported_derivative_levels;
1031 template <
typename ScalarType>
1034 typename std::enable_if<std::is_arithmetic<ScalarType>::value>::type>
1040 using scalar_type = ScalarType;
1043 get_directional_derivative(
const ScalarType & ,
1044 const unsigned int )
1054 "Floating point numbers have no directional derivatives."));
1067 template <
typename ADNumberType>
1069 is_nan(
const typename std::enable_if<
1071 ADNumberType>::type &x)
1080 DEAL_II_NAMESPACE_CLOSE
#define AssertThrow(cond, exc)
static::ExceptionBase & ExcMessage(std::string arg1)