Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
trad_example.cpp
Go to the documentation of this file.
1// $Id$
2// $Source$
3// @HEADER
4// ***********************************************************************
5//
6// Sacado Package
7// Copyright (2006) Sandia Corporation
8//
9// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
10// the U.S. Government retains certain rights in this software.
11//
12// This library is free software; you can redistribute it and/or modify
13// it under the terms of the GNU Lesser General Public License as
14// published by the Free Software Foundation; either version 2.1 of the
15// License, or (at your option) any later version.
16//
17// This library is distributed in the hope that it will be useful, but
18// WITHOUT ANY WARRANTY; without even the implied warranty of
19// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20// Lesser General Public License for more details.
21//
22// You should have received a copy of the GNU Lesser General Public
23// License along with this library; if not, write to the Free Software
24// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
25// USA
26// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
27// (etphipp@sandia.gov).
28//
29// ***********************************************************************
30// @HEADER
31
32// trad_example
33//
34// usage:
35// trad_example
36//
37// output:
38// prints the results of differentiating a simple function with reverse
39// mode AD using the Sacado::Rad::ADvar class.
40
41#include <iostream>
42#include <iomanip>
43
44#include "Sacado_No_Kokkos.hpp"
45
46// The function to differentiate
47template <typename ScalarT>
48ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
49 ScalarT r = c*std::log(b+1.)/std::sin(a);
50
51 return r;
52}
53
54// The analytic derivative of func(a,b,c) with respect to a and b
55void func_deriv(double a, double b, double c, double& drda, double& drdb)
56{
57 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
58 drdb = c / ((b+1.)*std::sin(a));
59}
60
61int main(int argc, char **argv)
62{
63 double pi = std::atan(1.0)*4.0;
64
65 // Values of function arguments
66 double a = pi/4;
67 double b = 2.0;
68 double c = 3.0;
69
70 // Rad objects
73 Sacado::Rad::ADvar<double> crad = c; // Passive variable
74 Sacado::Rad::ADvar<double> rrad; // Result
75
76 // Compute function
77 double r = func(a, b, c);
78
79 // Compute derivative analytically
80 double drda, drdb;
81 func_deriv(a, b, c, drda, drdb);
82
83 // Compute function and derivative with AD
84 rrad = func(arad, brad, crad);
85
87
88 // Extract value and derivatives
89 double r_ad = rrad.val(); // r
90 double drda_ad = arad.adj(); // dr/da
91 double drdb_ad = brad.adj(); // dr/db
92
93 // Free Rad's memory to avoid memory leaks
95
96 // Print the results
97 int p = 4;
98 int w = p+7;
99 std::cout.setf(std::ios::scientific);
100 std::cout.precision(p);
101 std::cout << " r = " << r << " (original) == " << std::setw(w) << r_ad
102 << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl
103 << "dr/da = " << std::setw(w) << drda << " (analytic) == "
104 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
105 << drda - drda_ad << std::endl
106 << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
107 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
108 << drdb - drdb_ad << std::endl;
109
110 double tol = 1.0e-14;
111
112 if (std::fabs(r - r_ad) < tol &&
113 std::fabs(drda - drda_ad) < tol &&
114 std::fabs(drdb - drdb_ad) < tol) {
115 std::cout << "\nExample passed!" << std::endl;
116 return 0;
117 }
118 else {
119 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
120 return 1;
121 }
122}
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
int main()
static void Gradcomp()
const char * p
void func_deriv(double a, double b, double c, double &drda, double &drdb)
ScalarT func(const ScalarT &a, const ScalarT &b, const ScalarT &c)
const double tol