Skip to content

Commit 580acaf

Browse files
committed
Use power Hessian example
1 parent 02cc90b commit 580acaf

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

README.md

+24-9
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,36 @@ bun add rose
3333

3434
## Usage
3535

36-
This example computes the output and gradient of a simple function that
37-
multiplies together the two components of a vector:
36+
This example defines custom gradients for the builtin JavaScript logarithm and
37+
power functions, then computes the output, gradient, and Hessian for the power
38+
function applied with base 2 and exponent 3:
3839

3940
```js
40-
import { Real, Vec, fn, interp, mul, vjp } from "rose";
41+
import { Dual, Real, Vec, add, compile, div, fn, mul, opaque, vjp } from "rose";
4142

42-
const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1]));
43+
const log = opaque([Real], Real, Math.log);
44+
log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
45+
return { re: log(x), du: div(dx, x) };
46+
});
47+
48+
const pow = opaque([Real, Real], Real, Math.pow);
49+
pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
50+
const z = pow(x, y);
51+
return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) };
52+
});
53+
54+
const Vec2 = Vec(2, Real);
55+
const Mat2 = Vec(2, Vec2);
4356

44-
const g = fn([Real, Real], Vec(3, Real), (x, y) => {
45-
const { ret, grad } = vjp(f)([x, y]);
46-
const v = grad(1);
47-
return [ret, v[0], v[1]];
57+
const f = fn([Vec2], Real, (v) => pow(v[0], v[1]));
58+
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
59+
const h = fn([Vec2], Mat2, (v) => {
60+
const { grad } = vjp(g)(v);
61+
return [grad([1, 0]), grad([0, 1])];
4862
});
4963

50-
console.log(interp(g)(2, 3)); // [6, 3, 2]
64+
const funcs = await Promise.all([compile(f), compile(g), compile(h)]);
65+
console.log(funcs.map((func) => func([2, 3])));
5166
```
5267

5368
### With Vite

packages/site/index.html

+17-19
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,31 @@
3131
</div>
3232
<div class="example">
3333
<pre><code class="language-javascript">import * as rose from "rose";
34-
import { Real, Vec, fn, pow } from "rose";
34+
import { Dual, Real, Vec, add, compile, div, fn, mul, opaque, vjp } from "rose";
3535

36-
const Vec2 = Vec(Real, 2);
37-
const Mat2 = Vec(Vec2, 2);
38-
39-
const f = fn([Vec2], Real, (v) => {
40-
return pow(v[0], v[1]);
36+
const log = opaque([Real], Real, Math.log);
37+
log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
38+
return { re: log(x), du: div(dx, x) };
4139
});
4240

43-
const g = fn([Vec2], Vec2, (v) => {
44-
return rose.vjp(f)(v).vjp(1);
41+
const pow = opaque([Real, Real], Real, Math.pow);
42+
pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
43+
const z = pow(x, y);
44+
return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) };
4545
});
4646

47+
const Vec2 = Vec(2, Real);
48+
const Mat2 = Vec(2, Vec2);
49+
50+
const f = fn([Vec2], Real, (v) => pow(v[0], v[1]));
51+
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
4752
const h = fn([Vec2], Mat2, (v) => {
48-
const x = rose.vjp(g)(v);
49-
return [x.vjp([1, 0]), x.vjp([0, 1])];
53+
const { grad } = vjp(g)(v);
54+
return [grad([1, 0]), grad([0, 1])];
5055
});
5156

52-
const l = await Promise.all([
53-
rose.compile(f),
54-
rose.compile(g),
55-
rose.compile(h),
56-
]);
57-
58-
export default (x, y) => {
59-
return l.map((func) => func([x, y]));
60-
};</code></pre>
57+
const funcs = await Promise.all([compile(f), compile(g), compile(h)]);
58+
console.log(funcs.map((func) => func([2, 3])));</code></pre>
6159
</div>
6260
<div class="bottom">
6361
<a href="https://github.com/rose-lang/rose-icons">icons</a> by

packages/site/src/main.ts

+27-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,30 @@ import "highlight.js/styles/base16/helios.css";
55
hljs.registerLanguage("javascript", javascript);
66
hljs.highlightAll();
77

8-
console.log(await import("rose"));
8+
import("rose").then(
9+
async ({ Dual, Real, Vec, add, compile, div, fn, mul, opaque, vjp }) => {
10+
const log = opaque([Real], Real, Math.log);
11+
log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
12+
return { re: log(x), du: div(dx, x) };
13+
});
14+
15+
const pow = opaque([Real, Real], Real, Math.pow);
16+
pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
17+
const z = pow(x, y);
18+
return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) };
19+
});
20+
21+
const Vec2 = Vec(2, Real);
22+
const Mat2 = Vec(2, Vec2);
23+
24+
const f = fn([Vec2], Real, (v) => pow(v[0], v[1]));
25+
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
26+
const h = fn([Vec2], Mat2, (v) => {
27+
const { grad } = vjp(g)(v);
28+
return [grad([1, 0] as any), grad([0, 1] as any)];
29+
});
30+
31+
const funcs = await Promise.all([compile(f), compile(g), compile(h)]);
32+
console.log(funcs.map((func) => func([2, 3])));
33+
},
34+
);

0 commit comments

Comments
 (0)