predictor/internal/service/predictor_test.go

492 lines
13 KiB
Go

package service
import (
"context"
"testing"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// MockGrib is a mock implementation of the Grib interface
type MockGrib struct {
mock.Mock
}
func (m *MockGrib) Update(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGrib) Extract(ctx context.Context, lat, lon, alt float64, t time.Time) ([2]float64, error) {
args := m.Called(ctx, lat, lon, alt, t)
return args.Get(0).([2]float64), args.Error(1)
}
func (m *MockGrib) Close() error {
args := m.Called()
return args.Error(0)
}
// Helper function to create a test service with mocked GRIB
func createTestService() (*Service, *MockGrib) {
mockGrib := new(MockGrib)
// Default mock behavior: return constant wind (5 m/s east, 3 m/s north)
mockGrib.On("Extract", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return([2]float64{5.0, 3.0}, nil)
service := &Service{
grib: mockGrib,
}
return service, mockGrib
}
// Helper function to create basic prediction parameters
func createBasicParams() ds.PredictionParameters {
lat := 40.0
lon := -105.0
alt := 1000.0
launchTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
profile := "standard_profile"
ascentRate := 5.0
burstAltitude := 10000.0
descentRate := 5.0
return ds.PredictionParameters{
LaunchLatitude: &lat,
LaunchLongitude: &lon,
LaunchAltitude: &alt,
LaunchDatetime: &launchTime,
Profile: &profile,
AscentRate: &ascentRate,
BurstAltitude: &burstAltitude,
DescentRate: &descentRate,
}
}
func TestRestrictedPrediction_OnlyAscent(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
// Restrict to ascent only
params.SimulateStages = []string{"ascent"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Verify all results are during ascent phase (altitude increasing)
for i := 1; i < len(results); i++ {
assert.GreaterOrEqual(t, *results[i].Altitude, *results[i-1].Altitude,
"Altitude should be increasing or equal during ascent")
}
// Last altitude should be near burst altitude
lastAlt := *results[len(results)-1].Altitude
burstAlt := *params.BurstAltitude
assert.InDelta(t, burstAlt, lastAlt, 500.0, "Last altitude should be near burst altitude")
}
func TestRestrictedPrediction_OnlyDescent(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
// Restrict to descent only
params.SimulateStages = []string{"descent"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// First result should be at burst altitude (since ascent was skipped)
firstAlt := *results[0].Altitude
burstAlt := *params.BurstAltitude
assert.Equal(t, burstAlt, firstAlt, "Should start at burst altitude when ascent is skipped")
// Verify all results are during descent phase (altitude decreasing)
for i := 1; i < len(results); i++ {
assert.LessOrEqual(t, *results[i].Altitude, *results[i-1].Altitude,
"Altitude should be decreasing or equal during descent")
}
// Last altitude should be near ground
lastAlt := *results[len(results)-1].Altitude
assert.Less(t, lastAlt, 1000.0, "Last altitude should be near ground")
}
func TestRestrictedPrediction_AscentAndDescent(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
// Include both ascent and descent
params.SimulateStages = []string{"ascent", "descent"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Find the peak altitude (transition point)
maxAlt := 0.0
maxIdx := 0
for i, result := range results {
if *result.Altitude > maxAlt {
maxAlt = *result.Altitude
maxIdx = i
}
}
// Verify ascent phase
for i := 1; i <= maxIdx; i++ {
assert.GreaterOrEqual(t, *results[i].Altitude, *results[i-1].Altitude,
"Altitude should increase during ascent phase")
}
// Verify descent phase
for i := maxIdx + 1; i < len(results); i++ {
assert.LessOrEqual(t, *results[i].Altitude, *results[i-1].Altitude,
"Altitude should decrease during descent phase")
}
}
func TestRestrictedPrediction_FloatProfile_OnlyFloat(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
profile := "float_profile"
floatAlt := 15000.0
params.Profile = &profile
params.FloatAltitude = &floatAlt
// Restrict to float only
params.SimulateStages = []string{"float"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// All results should be at the float altitude
for _, result := range results {
assert.Equal(t, floatAlt, *result.Altitude,
"Altitude should remain constant at float altitude")
}
// Verify horizontal movement (lat/lon changes due to wind)
firstLat := *results[0].Latitude
lastLat := *results[len(results)-1].Latitude
assert.NotEqual(t, firstLat, lastLat, "Latitude should change during float due to wind")
}
func TestRestrictedPrediction_FloatProfile_AllStages(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
profile := "float_profile"
floatAlt := 15000.0
params.Profile = &profile
params.FloatAltitude = &floatAlt
// Include all stages
params.SimulateStages = []string{"ascent", "float", "descent"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Verify we have ascending, constant, and descending altitude patterns
hasAscent := false
hasFloat := false
hasDescent := false
const altTolerance = 50.0 // Tolerance for altitude comparison
for i := 1; i < len(results); i++ {
altDiff := *results[i].Altitude - *results[i-1].Altitude
if altDiff > altTolerance {
hasAscent = true
} else if altDiff < -altTolerance {
hasDescent = true
} else if *results[i].Altitude > 10000 { // Float happens at high altitude
hasFloat = true
}
}
assert.True(t, hasAscent, "Should have ascent phase")
assert.True(t, hasFloat, "Should have float phase")
assert.True(t, hasDescent, "Should have descent phase")
// Verify maximum altitude is near float altitude
maxAlt := 0.0
for _, result := range results {
if *result.Altitude > maxAlt {
maxAlt = *result.Altitude
}
}
assert.InDelta(t, floatAlt, maxAlt, 1000.0, "Max altitude should be near float altitude")
}
func TestRestrictedPrediction_ReverseProfile_OnlyFloat(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
profile := "reverse_profile"
floatAlt := 5000.0
params.Profile = &profile
params.FloatAltitude = &floatAlt
// Restrict to float only
params.SimulateStages = []string{"float"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// All results should be at the float altitude
for _, result := range results {
assert.InDelta(t, floatAlt, *result.Altitude, 10.0,
"Altitude should remain near float altitude")
}
}
func TestRestrictedPrediction_EmptyStages_SimulatesAll(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
// Empty SimulateStages should simulate all stages
params.SimulateStages = []string{}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Should have both ascent and descent
// Find the peak
maxAlt := 0.0
hasAscent := false
hasDescent := false
for i := 1; i < len(results); i++ {
if *results[i].Altitude > *results[i-1].Altitude {
hasAscent = true
}
if *results[i].Altitude < *results[i-1].Altitude {
hasDescent = true
}
if *results[i].Altitude > maxAlt {
maxAlt = *results[i].Altitude
}
}
assert.True(t, hasAscent, "Should have ascent phase")
assert.True(t, hasDescent, "Should have descent phase")
}
func TestRestrictedPrediction_NilStages_SimulatesAll(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
// Nil SimulateStages should simulate all stages
params.SimulateStages = nil
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Should have both ascent and descent
maxAlt := 0.0
minAltAfterMax := 1000000.0
for _, result := range results {
if *result.Altitude > maxAlt {
maxAlt = *result.Altitude
}
}
foundMax := false
for _, result := range results {
if *result.Altitude == maxAlt {
foundMax = true
}
if foundMax && *result.Altitude < minAltAfterMax {
minAltAfterMax = *result.Altitude
}
}
// Should reach high altitude and come back down
assert.Greater(t, maxAlt, 5000.0, "Should reach high altitude")
assert.Less(t, minAltAfterMax, maxAlt, "Should descend after reaching max altitude")
}
func TestRestrictedPrediction_InvalidStage_IgnoresInvalid(t *testing.T) {
service, _ := createTestService()
params := createBasicParams()
// Include invalid stage name (should be ignored)
params.SimulateStages = []string{"ascent", "invalid_stage", "descent"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Should still simulate ascent and descent, ignoring the invalid stage
}
func TestRestrictedPrediction_WindImpact(t *testing.T) {
service, mockGrib := createTestService()
// Override mock to return strong eastward wind
mockGrib.ExpectedCalls = nil
mockGrib.On("Extract", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return([2]float64{20.0, 0.0}, nil) // Strong eastward wind
params := createBasicParams()
params.SimulateStages = []string{"ascent"}
results, err := service.PerformPrediction(context.Background(), params)
assert.NoError(t, err)
assert.NotEmpty(t, results)
// Longitude should increase significantly due to eastward wind
firstLon := *results[0].Longitude
lastLon := *results[len(results)-1].Longitude
assert.Greater(t, lastLon, firstLon, "Longitude should increase with eastward wind")
// Verify wind values are captured in results
for _, result := range results {
if result.WindU != nil {
// Wind values should be present in results
assert.NotNil(t, result.WindV, "WindV should be present if WindU is present")
}
}
}
func TestRestrictedPrediction_MissingRequiredParams(t *testing.T) {
service, _ := createTestService()
testCases := []struct {
name string
params ds.PredictionParameters
}{
{
name: "Missing latitude",
params: ds.PredictionParameters{
LaunchLongitude: floatPtr(-105.0),
LaunchAltitude: floatPtr(1000.0),
LaunchDatetime: timePtr(time.Now()),
},
},
{
name: "Missing longitude",
params: ds.PredictionParameters{
LaunchLatitude: floatPtr(40.0),
LaunchAltitude: floatPtr(1000.0),
LaunchDatetime: timePtr(time.Now()),
},
},
{
name: "Missing altitude",
params: ds.PredictionParameters{
LaunchLatitude: floatPtr(40.0),
LaunchLongitude: floatPtr(-105.0),
LaunchDatetime: timePtr(time.Now()),
},
},
{
name: "Missing datetime",
params: ds.PredictionParameters{
LaunchLatitude: floatPtr(40.0),
LaunchLongitude: floatPtr(-105.0),
LaunchAltitude: floatPtr(1000.0),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.params.SimulateStages = []string{"ascent"}
results, err := service.PerformPrediction(context.Background(), tc.params)
assert.Error(t, err)
assert.Equal(t, ErrInvalidParameters, err)
assert.Nil(t, results)
})
}
}
func TestShouldSimulateStage(t *testing.T) {
testCases := []struct {
name string
stages []string
queryStage string
shouldSimulate bool
}{
{
name: "Empty filter simulates all",
stages: []string{},
queryStage: "ascent",
shouldSimulate: true,
},
{
name: "Nil filter simulates all",
stages: nil,
queryStage: "descent",
shouldSimulate: true,
},
{
name: "Stage in filter",
stages: []string{"ascent", "descent"},
queryStage: "ascent",
shouldSimulate: true,
},
{
name: "Stage not in filter",
stages: []string{"ascent"},
queryStage: "descent",
shouldSimulate: false,
},
{
name: "Float stage in filter",
stages: []string{"float"},
queryStage: "float",
shouldSimulate: true,
},
{
name: "Multiple stages excluding one",
stages: []string{"ascent", "float"},
queryStage: "descent",
shouldSimulate: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
params := ds.PredictionParameters{
SimulateStages: tc.stages,
}
result := shouldSimulateStage(params, tc.queryStage)
assert.Equal(t, tc.shouldSimulate, result)
})
}
}
// Helper functions
func floatPtr(f float64) *float64 {
return &f
}
func timePtr(t time.Time) *time.Time {
return &t
}