This commit is contained in:
Anatoly Antonov 2026-05-18 03:17:17 +09:00
parent 7a8d5d13fa
commit 9e663db9dc
68 changed files with 5647 additions and 2958 deletions

11
internal/numerics/doc.go Normal file
View file

@ -0,0 +1,11 @@
// Package numerics provides the numerical primitives used by the trajectory
// engine: regular-grid multilinear interpolation, monotone bisection, and
// a generic explicit Runge-Kutta-4 integrator with binary-search refinement
// of a termination point.
//
// The package has no dependencies on any domain type. State and derivative
// types are generic, and all coordinate-wrap or unit-conversion semantics
// live in the caller.
//
// All algorithms are documented in docs/numerics.tex.
package numerics

86
internal/numerics/grid.go Normal file
View file

@ -0,0 +1,86 @@
package numerics
import "fmt"
// Axis describes a regularly-spaced grid axis with N grid points,
// values left, left+step, left+2*step, ..., left+(N-1)*step.
//
// If Wrap is true, the axis is periodic with period N*step (e.g. longitude).
// A query value at left+N*step wraps to the value at left+0*step. Locate
// returns Hi = 0 in that case.
type Axis struct {
Left float64
Step float64
N int
Wrap bool
Name string
}
// AxisError is returned by Axis.Locate when value lies outside a non-wrapping axis.
type AxisError struct {
Axis string
Value float64
}
func (e *AxisError) Error() string {
return fmt.Sprintf("%s=%v out of range", e.Axis, e.Value)
}
// Bracket holds the two surrounding grid indices and the fractional position
// of a value within an axis. The weight at Lo is (1 - Frac); the weight at Hi
// is Frac. Frac lies in [0, 1).
type Bracket struct {
Lo, Hi int
Frac float64
}
// Locate returns the bracket containing value within the axis.
// For a non-wrapping axis, value must lie in [Left, Left + (N-1)*Step);
// for a wrapping axis, value must lie in [Left, Left + N*Step).
func (a Axis) Locate(value float64) (Bracket, error) {
pos := (value - a.Left) / a.Step
lo := int(pos) // truncates toward zero; pos is non-negative for valid inputs
maxLo := a.N - 2
if a.Wrap {
maxLo = a.N - 1
}
if lo < 0 || lo > maxLo {
return Bracket{}, &AxisError{Axis: a.Name, Value: value}
}
hi := lo + 1
if a.Wrap && hi == a.N {
hi = 0
}
return Bracket{Lo: lo, Hi: hi, Frac: pos - float64(lo)}, nil
}
// EvalTrilinear samples a 3D field via f at the eight corners defined by b3
// and returns the trilinearly interpolated value.
//
// The corners are visited in the order (axis0 outer, axis2 inner), matching
// the Cython reference. With f(i,j,k) = a*i + b*j + c*k + d this returns
// a*pos0 + b*pos1 + c*pos2 + d exactly, modulo floating-point rounding.
func EvalTrilinear(b3 [3]Bracket, f func(i, j, k int) float64) float64 {
wa0, wa1 := 1-b3[0].Frac, b3[0].Frac
wb0, wb1 := 1-b3[1].Frac, b3[1].Frac
wc0, wc1 := 1-b3[2].Frac, b3[2].Frac
a0, a1 := b3[0].Lo, b3[0].Hi
bb0, bb1 := b3[1].Lo, b3[1].Hi
c0, c1 := b3[2].Lo, b3[2].Hi
return wa0*wb0*wc0*f(a0, bb0, c0) +
wa0*wb0*wc1*f(a0, bb0, c1) +
wa0*wb1*wc0*f(a0, bb1, c0) +
wa0*wb1*wc1*f(a0, bb1, c1) +
wa1*wb0*wc0*f(a1, bb0, c0) +
wa1*wb0*wc1*f(a1, bb0, c1) +
wa1*wb1*wc0*f(a1, bb1, c0) +
wa1*wb1*wc1*f(a1, bb1, c1)
}
// Lerp returns (1-l)*a + l*b.
func Lerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}

View file

@ -0,0 +1,94 @@
package numerics
import (
"math"
"testing"
)
func TestAxisLocate(t *testing.T) {
a := Axis{Left: -90, Step: 0.5, N: 361, Name: "lat"}
b, err := a.Locate(-90)
if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 {
t.Errorf("Locate(-90) = %+v, %v; want {0 1 0}, nil", b, err)
}
b, err = a.Locate(0)
if err != nil || b.Lo != 180 || b.Hi != 181 || b.Frac != 0 {
t.Errorf("Locate(0) = %+v, %v; want {180 181 0}, nil", b, err)
}
b, err = a.Locate(-89.75)
if err != nil || b.Lo != 0 || b.Hi != 1 || math.Abs(b.Frac-0.5) > 1e-12 {
t.Errorf("Locate(-89.75) = %+v, %v; want frac=0.5", b, err)
}
// 90 is exactly on the upper boundary — there's no Hi above it
if _, err := a.Locate(90); err == nil {
t.Errorf("Locate(90) should error, got nil")
}
if _, err := a.Locate(-91); err == nil {
t.Errorf("Locate(-91) should error, got nil")
}
}
func TestAxisLocateWrap(t *testing.T) {
a := Axis{Left: 0, Step: 0.5, N: 720, Wrap: true, Name: "lng"}
b, err := a.Locate(0)
if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 {
t.Errorf("Locate(0) = %+v, %v", b, err)
}
// Right up against the wrap boundary
b, err = a.Locate(359.75)
if err != nil || b.Lo != 719 || b.Hi != 0 || math.Abs(b.Frac-0.5) > 1e-12 {
t.Errorf("Locate(359.75) = %+v, %v; want {719 0 0.5}", b, err)
}
// 360 is outside the half-open interval
if _, err := a.Locate(360); err == nil {
t.Errorf("Locate(360) should error, got nil")
}
}
func TestEvalTrilinear(t *testing.T) {
// Field f(i,j,k) = 100*i + 10*j + k.
f := func(i, j, k int) float64 { return 100*float64(i) + 10*float64(j) + float64(k) }
// At all fractions = 0.5, expected value is the mean of the 8 corners.
bs := [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}}
got := EvalTrilinear(bs, f)
want := (0 + 1 + 10 + 11 + 100 + 101 + 110 + 111) / 8.0
if math.Abs(got-want) > 1e-12 {
t.Errorf("EvalTrilinear at center = %v, want %v", got, want)
}
// At all fractions = 0, expected value is f(lo, lo, lo) = 0.
bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}}
got = EvalTrilinear(bs, f)
if got != 0 {
t.Errorf("EvalTrilinear at (lo,lo,lo) = %v, want 0", got)
}
// Asymmetric: linear field f(i,j,k) = i should give frac of axis 0 exactly.
f2 := func(i, _, _ int) float64 { return float64(i) }
bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.3}, {Lo: 0, Hi: 1, Frac: 0.7}, {Lo: 0, Hi: 1, Frac: 0.9}}
got = EvalTrilinear(bs, f2)
if math.Abs(got-0.3) > 1e-12 {
t.Errorf("EvalTrilinear of i-field = %v, want 0.3", got)
}
}
func TestLerp(t *testing.T) {
if Lerp(10, 20, 0) != 10 {
t.Errorf("Lerp(10, 20, 0) != 10")
}
if Lerp(10, 20, 1) != 20 {
t.Errorf("Lerp(10, 20, 1) != 20")
}
if math.Abs(Lerp(10, 20, 0.25)-12.5) > 1e-12 {
t.Errorf("Lerp(10, 20, 0.25) != 12.5")
}
}

61
internal/numerics/ode.go Normal file
View file

@ -0,0 +1,61 @@
package numerics
// VecAdd computes y + k*dy on the domain state type S.
// Any coordinate-wrap or other domain-specific operation lives here.
type VecAdd[S any] func(y S, k float64, dy S) S
// VecLerp computes (1-l)*a + l*b on the domain state type S.
type VecLerp[S any] func(a, b S, l float64) S
// Deriv computes the time derivative of state.
type Deriv[S any] func(t float64, y S) S
// Trigger reports whether a termination condition holds at (t, y).
type Trigger[S any] func(t float64, y S) bool
// RK4Step performs one classical Runge-Kutta-4 step from (t, y) with step dt.
// dt may be negative to integrate backwards in time.
func RK4Step[S any](t float64, y S, dt float64, deriv Deriv[S], add VecAdd[S]) S {
k1 := deriv(t, y)
k2 := deriv(t+dt/2, add(y, dt/2, k1))
k3 := deriv(t+dt/2, add(y, dt/2, k2))
k4 := deriv(t+dt, add(y, dt, k3))
y2 := y
y2 = add(y2, dt/6, k1)
y2 = add(y2, dt/3, k2)
y2 = add(y2, dt/3, k3)
y2 = add(y2, dt/6, k4)
return y2
}
// RefineTrigger locates the trigger point between (t1, y1) (trigger not fired)
// and (t2, y2) (trigger fired) via binary search in the linear-interpolation
// parameter space, stopping when the parameter interval is narrower than tol.
//
// Returns the final midpoint sampled, matching the behavior of Tawhiri's
// solver.pyx (the returned point is *not* guaranteed to satisfy the trigger;
// for tol << 1 the difference is at most one tolerance-width either side).
func RefineTrigger[S any](
t1 float64, y1 S,
t2 float64, y2 S,
trigger Trigger[S],
lerp VecLerp[S],
tol float64,
) (float64, S) {
left, right := 0.0, 1.0
t3 := t2
y3 := y2
for right-left > tol {
mid := (left + right) / 2
t3 = Lerp(t1, t2, mid)
y3 = lerp(y1, y2, mid)
if trigger(t3, y3) {
right = mid
} else {
left = mid
}
}
return t3, y3
}

View file

@ -0,0 +1,61 @@
package numerics
import (
"math"
"testing"
)
// scalarAdd / scalarLerp let us drive RK4 on a plain float64.
func scalarAdd(y float64, k float64, dy float64) float64 { return y + k*dy }
func scalarLerpF(a, b float64, l float64) float64 { return Lerp(a, b, l) }
func TestRK4ExponentialDecay(t *testing.T) {
// dy/dt = -y → exact: y(t) = y0 * exp(-t).
deriv := func(_ float64, y float64) float64 { return -y }
y := 1.0
tnow := 0.0
dt := 0.01
for range 100 {
y = RK4Step(tnow, y, dt, deriv, scalarAdd)
tnow += dt
}
want := math.Exp(-1.0)
if math.Abs(y-want) > 1e-8 {
t.Errorf("RK4 exp decay at t=1: got %v, want %v (diff %v)", y, want, y-want)
}
}
func TestRK4ReverseTime(t *testing.T) {
// dy/dt = y → exact: y(t) = y0 * exp(t).
// Integrating from t=1 backwards with dt=-0.01 over 100 steps should give y0.
deriv := func(_ float64, y float64) float64 { return y }
y := math.E
tnow := 1.0
dt := -0.01
for range 100 {
y = RK4Step(tnow, y, dt, deriv, scalarAdd)
tnow += dt
}
if math.Abs(y-1.0) > 1e-8 {
t.Errorf("RK4 reverse: got %v, want 1.0 (diff %v)", y, y-1.0)
}
}
func TestRefineTrigger(t *testing.T) {
// y crosses 0 at l=0.4 between y1=1 and y2=-1.5.
y1, y2 := 1.0, -1.5
t1, t2 := 0.0, 1.0
trig := func(_ float64, y float64) bool { return y <= 0 }
tr, yr := RefineTrigger(t1, y1, t2, y2, trig, scalarLerpF, 0.001)
// The exact crossing is at l = 1/(1+1.5) = 0.4 → t = 0.4, y = 0.
if math.Abs(tr-0.4) > 0.01 {
t.Errorf("Refined t = %v, want ~0.4", tr)
}
if math.Abs(yr) > 0.01 {
t.Errorf("Refined y = %v, want ~0", yr)
}
}

View file

@ -0,0 +1,19 @@
package numerics
// Bisect returns the largest index i in [imin, imax] such that f(i) < target,
// assuming f is monotonically nondecreasing on that range.
//
// If target <= f(imin), returns imin. If target > f(imax), returns imax.
// Performs O(log(imax-imin)) evaluations of f.
func Bisect(imin, imax int, target float64, f func(i int) float64) int {
lo, hi := imin, imax
for lo < hi {
mid := (lo + hi + 1) / 2
if target <= f(mid) {
hi = mid - 1
} else {
lo = mid
}
}
return lo
}

View file

@ -0,0 +1,28 @@
package numerics
import "testing"
func TestBisect(t *testing.T) {
// f(i) = 10*i, monotone increasing.
f := func(i int) float64 { return 10 * float64(i) }
// target = 25 → largest i with 10i < 25 is i=2
if got := Bisect(0, 10, 25, f); got != 2 {
t.Errorf("Bisect target=25 = %d, want 2", got)
}
// target on boundary: target = 30, condition is target <= f(mid) so f(3)=30 → not less; want 2
if got := Bisect(0, 10, 30, f); got != 2 {
t.Errorf("Bisect target=30 = %d, want 2", got)
}
// target below all values
if got := Bisect(0, 10, -5, f); got != 0 {
t.Errorf("Bisect target=-5 = %d, want 0", got)
}
// target above all values
if got := Bisect(0, 10, 1000, f); got != 10 {
t.Errorf("Bisect target=1000 = %d, want 10", got)
}
}