Skip to content

Commit d87f323

Browse files
committed
Define and export a sum function
1 parent 49a893e commit d87f323

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

crates/web/src/lib.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,19 @@ impl Block {
15861586
self.instr(f, id::ty(t), expr)
15871587
}
15881588

1589+
/// Return the variable ID for a new instruction accumulating `addend` into `accum`.
1590+
///
1591+
/// Assumes `accum` and `addend` are defined and in scope.
1592+
#[wasm_bindgen(js_name = "addTo")]
1593+
pub fn add_to(&mut self, f: &mut FuncBuilder, accum: usize, addend: usize) -> usize {
1594+
let t = id::ty(f.ty_unit());
1595+
let expr = rose::Expr::Add {
1596+
accum: id::var(accum),
1597+
addend: id::var(addend),
1598+
};
1599+
self.instr(f, t, expr)
1600+
}
1601+
15891602
/// Return the variable ID for a new instruction resolving the given accumulator `var`.
15901603
///
15911604
/// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type

packages/core/src/impl.ts

+13
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,19 @@ export const vec = <const I, const T>(
11601160
return idVal(ctx, t, id) as Vec<Symbolic<T>>;
11611161
};
11621162

1163+
/** Return the sum after computing each number via `f`. */
1164+
export const sum = <const I>(index: I, f: (i: Symbolic<I>) => Real): Real => {
1165+
const ctx = getCtx();
1166+
const reals = ctx.func.tyF64();
1167+
const acc = ctx.block.accum(ctx.func, ctx.func.tyRef(reals), realId(ctx, 0));
1168+
vec(index, Null, (i) => {
1169+
const x = realId(ctx, f(i));
1170+
const t = ctx.func.tyUnit();
1171+
return idVal(ctx, t, ctx.block.addTo(ctx.func, acc, x)) as Null;
1172+
});
1173+
return idVal(ctx, reals, ctx.block.resolve(ctx.func, reals, acc)) as Real;
1174+
};
1175+
11631176
/** Return the variable ID for the abstract number or tangent `x`. */
11641177
const numId = (ctx: Context, x: Real | Tan): number => {
11651178
if (typeof x === "object") return (x as any)[variable];

packages/core/src/index.test.ts

+3-16
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import {
3535
sqrt,
3636
struct,
3737
sub,
38+
sum,
3839
trunc,
3940
vec,
4041
vjp,
@@ -233,12 +234,7 @@ describe("valid", () => {
233234

234235
test("dot product", () => {
235236
const R3 = Vec(3, Real);
236-
const dot = fn([R3, R3], Real, (u, v) => {
237-
const x = mul(u[0], v[0]);
238-
const y = mul(u[1], v[1]);
239-
const z = mul(u[2], v[2]);
240-
return add(add(x, y), z);
241-
});
237+
const dot = fn([R3, R3], Real, (u, v) => sum(3, (i) => mul(u[i], v[i])));
242238
const f = interp(dot);
243239
expect(f([1, 3, -5], [4, -2, -1])).toBe(3);
244240
});
@@ -280,16 +276,7 @@ describe("valid", () => {
280276

281277
const Rn = Vec(n, Real);
282278

283-
const dot = fn([Rn, Rn], Real, (u, v) => {
284-
const w = vec(n, Real, (i) => mul(u[i], v[i]));
285-
let s = w[0];
286-
s = add(s, w[1]);
287-
s = add(s, w[2]);
288-
s = add(s, w[3]);
289-
s = add(s, w[4]);
290-
s = add(s, w[5]);
291-
return s;
292-
});
279+
const dot = fn([Rn, Rn], Real, (u, v) => sum(n, (i) => mul(u[i], v[i])));
293280

294281
const m = 5;
295282
const p = 7;

packages/core/src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export {
5050
sqrt,
5151
struct,
5252
sub,
53+
sum,
5354
trunc,
5455
vec,
5556
vjp,

0 commit comments

Comments
 (0)