Skip to content

Commit 289711a

Browse files
committedJan 13, 2021
CHGE: Change ADLift structure (#22)
1 parent 7f620ee commit 289711a

File tree

3 files changed

+40
-31
lines changed

3 files changed

+40
-31
lines changed
 

‎src/numerical/root.rs

+13-12
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@
8484
//!
8585
//! * Walter Gautschi, *Numerical Analysis*, Springer (2012)
8686
87-
use crate::structure::ad::{AD, AD::AD1, ADLift};
87+
use crate::structure::ad::{AD, AD::AD1, ADFn};
8888
use std::fmt::Display;
8989
use RootState::{I, P};
90+
use crate::traits::stable::StableFn;
9091
//use std::collections::HashMap;
9192
//use std::marker::PhantomData;
9293

@@ -264,11 +265,11 @@ impl<F: Fn(AD) -> AD> RootFinder<F> {
264265
match self.method {
265266
RootFind::Bisection => match self.curr {
266267
I(a, b) => {
267-
let lift = ADLift::new(|x| self.f(x));
268+
let f_ad = ADFn::new(|x| self.f(x));
268269
let x = 0.5 * (a + b);
269-
let fa = lift.f_0(a);
270-
let fx = lift.f_0(x);
271-
let fb = lift.f_0(b);
270+
let fa = f_ad.call_stable(a);
271+
let fx = f_ad.call_stable(x);
272+
let fb = f_ad.call_stable(b);
272273
if (a - b).abs() <= self.tol {
273274
self.find = RootBool::Find;
274275
self.root = x;
@@ -289,11 +290,11 @@ impl<F: Fn(AD) -> AD> RootFinder<F> {
289290
},
290291
RootFind::FalsePosition => match self.curr {
291292
I(a, b) => {
292-
let lift = ADLift::new(|x| self.f(x));
293-
let fa = lift.f_0(a);
294-
let fb = lift.f_0(b);
293+
let f_ad = ADFn::new(|x| self.f(x));
294+
let fa = f_ad.call_stable(a);
295+
let fb = f_ad.call_stable(b);
295296
let x = (a * fb - b * fa) / (fb - fa);
296-
let fx = lift.f_0(x);
297+
let fx = f_ad.call_stable(x);
297298
if (a - b).abs() <= self.tol || fx.abs() <= self.tol {
298299
self.find = RootBool::Find;
299300
self.root = x;
@@ -327,9 +328,9 @@ impl<F: Fn(AD) -> AD> RootFinder<F> {
327328
},
328329
RootFind::Secant => match self.curr {
329330
I(xn_1, xn) => {
330-
let lift = ADLift::new(|x| self.f(x));
331-
let fxn_1 = lift.f_0(xn_1);
332-
let fxn = lift.f_0(xn);
331+
let f_ad = ADFn::new(|x| self.f(x));
332+
let fxn_1 = f_ad.call_stable(xn_1);
333+
let fxn = f_ad.call_stable(xn);
333334
let x = xn - (xn - xn_1) / (fxn - fxn_1) * fxn;
334335
if (x - xn).abs() <= self.tol {
335336
self.find = RootBool::Find;

‎src/structure/ad.rs

+25-17
Original file line numberDiff line numberDiff line change
@@ -1055,45 +1055,53 @@ impl Div<AD> for f64 {
10551055
/// let ad0 = 2f64;
10561056
/// let ad1 = AD1(2f64, 1f64);
10571057
///
1058-
/// let lift = ADLift::new(f_ad);
1058+
/// let f_ad = ADFn::new(f);
10591059
///
10601060
/// let ans_ad0 = ad0.powi(2);
1061-
/// let ans_ad1 = ad1.powi(2);
1061+
/// let ans_ad1 = ad1.powi(2).dx();
10621062
///
1063-
/// // f_0: f64 -> f64
1064-
/// // f: AD -> AD
1065-
/// assert_eq!(ans_ad0, lift.f_0(ad0));
1066-
/// assert_eq!(ans_ad1, lift.f(ad1));
1063+
/// assert_eq!(ans_ad0, f_ad.call_stable(ad0));
1064+
///
1065+
/// let f_grad = f_ad.grad();
1066+
/// assert_eq!(ans_ad1, f_grad.call_stable(ad0));
10671067
/// }
10681068
///
1069-
/// fn f_ad(x: AD) -> AD {
1069+
/// fn f(x: AD) -> AD {
10701070
/// x.powi(2)
10711071
/// }
10721072
/// ```
1073-
pub struct ADLift<F> {
1073+
pub struct ADFn<F> {
10741074
f: Box<F>,
1075+
grad_level: usize,
10751076
}
10761077

1077-
impl<F: Fn(AD) -> AD> ADLift<F> {
1078+
impl<F: Fn(AD) -> AD + Clone> ADFn<F> {
10781079
pub fn new(f: F) -> Self {
10791080
Self {
10801081
f: Box::new(f),
1082+
grad_level: 0usize,
10811083
}
10821084
}
10831085

1084-
pub fn f(&self, t: AD) -> AD {
1085-
(self.f)(t)
1086-
}
1087-
1088-
pub fn f_0(&self, t: f64) -> f64 {
1089-
(self.f)(AD::from(t)).into()
1086+
/// Gradient
1087+
pub fn grad(&self) -> Self {
1088+
assert!(self.grad_level < 2, "Higher order AD is not allowed");
1089+
ADFn {
1090+
f: (self.f).clone(),
1091+
grad_level: self.grad_level + 1,
1092+
}
10901093
}
10911094
}
10921095

1093-
impl<F: Fn(AD) -> AD> StableFn<f64> for ADLift<F> {
1096+
impl<F: Fn(AD) -> AD> StableFn<f64> for ADFn<F> {
10941097
type Output = f64;
10951098
fn call_stable(&self, target: f64) -> Self::Output {
1096-
self.f(AD::from(target)).into()
1099+
match self.grad_level {
1100+
0 => (self.f)(AD::from(target)).into(),
1101+
1 => (self.f)(AD1(target, 1f64)).dx(),
1102+
2 => (self.f)(AD2(target, 1f64, 0f64)).ddx(),
1103+
_ => panic!("Higher order AD is not allowed"),
1104+
}
10971105
}
10981106
}
10991107

‎tests/ad_lift.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ fn test_lift_ad2() {
2020
#[test]
2121
fn test_lift_f64() {
2222
let f = 2f64;
23-
let lift = ADLift::new(f_ad);
24-
let x = lift.f_0(f);
23+
let lift = ADFn::new(f_ad);
24+
let x = lift.call_stable(f);
2525
x.print();
2626
assert_eq!(x, f * 2f64);
2727
}

0 commit comments

Comments
 (0)
Please sign in to comment.