package service import ( "context" "testing" "time" "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) // MockGrib is a mock implementation of the Grib interface type MockGrib struct { mock.Mock } func (m *MockGrib) Update(ctx context.Context) error { args := m.Called(ctx) return args.Error(0) } func (m *MockGrib) Extract(ctx context.Context, lat, lon, alt float64, t time.Time) ([2]float64, error) { args := m.Called(ctx, lat, lon, alt, t) return args.Get(0).([2]float64), args.Error(1) } func (m *MockGrib) Close() error { args := m.Called() return args.Error(0) } // Helper function to create a test service with mocked GRIB func createTestService() (*Service, *MockGrib) { mockGrib := new(MockGrib) // Default mock behavior: return constant wind (5 m/s east, 3 m/s north) mockGrib.On("Extract", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return([2]float64{5.0, 3.0}, nil) service := &Service{ grib: mockGrib, } return service, mockGrib } // Helper function to create basic prediction parameters func createBasicParams() ds.PredictionParameters { lat := 40.0 lon := -105.0 alt := 1000.0 launchTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) profile := "standard_profile" ascentRate := 5.0 burstAltitude := 10000.0 descentRate := 5.0 return ds.PredictionParameters{ LaunchLatitude: &lat, LaunchLongitude: &lon, LaunchAltitude: &alt, LaunchDatetime: &launchTime, Profile: &profile, AscentRate: &ascentRate, BurstAltitude: &burstAltitude, DescentRate: &descentRate, } } func TestRestrictedPrediction_OnlyAscent(t *testing.T) { service, _ := createTestService() params := createBasicParams() // Restrict to ascent only params.SimulateStages = []string{"ascent"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Verify all results are during ascent phase (altitude increasing) for i := 1; i < len(results); i++ { assert.GreaterOrEqual(t, *results[i].Altitude, *results[i-1].Altitude, "Altitude should be increasing or equal during ascent") } // Last altitude should be near burst altitude lastAlt := *results[len(results)-1].Altitude burstAlt := *params.BurstAltitude assert.InDelta(t, burstAlt, lastAlt, 500.0, "Last altitude should be near burst altitude") } func TestRestrictedPrediction_OnlyDescent(t *testing.T) { service, _ := createTestService() params := createBasicParams() // Restrict to descent only params.SimulateStages = []string{"descent"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // First result should be at burst altitude (since ascent was skipped) firstAlt := *results[0].Altitude burstAlt := *params.BurstAltitude assert.Equal(t, burstAlt, firstAlt, "Should start at burst altitude when ascent is skipped") // Verify all results are during descent phase (altitude decreasing) for i := 1; i < len(results); i++ { assert.LessOrEqual(t, *results[i].Altitude, *results[i-1].Altitude, "Altitude should be decreasing or equal during descent") } // Last altitude should be near ground lastAlt := *results[len(results)-1].Altitude assert.Less(t, lastAlt, 1000.0, "Last altitude should be near ground") } func TestRestrictedPrediction_AscentAndDescent(t *testing.T) { service, _ := createTestService() params := createBasicParams() // Include both ascent and descent params.SimulateStages = []string{"ascent", "descent"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Find the peak altitude (transition point) maxAlt := 0.0 maxIdx := 0 for i, result := range results { if *result.Altitude > maxAlt { maxAlt = *result.Altitude maxIdx = i } } // Verify ascent phase for i := 1; i <= maxIdx; i++ { assert.GreaterOrEqual(t, *results[i].Altitude, *results[i-1].Altitude, "Altitude should increase during ascent phase") } // Verify descent phase for i := maxIdx + 1; i < len(results); i++ { assert.LessOrEqual(t, *results[i].Altitude, *results[i-1].Altitude, "Altitude should decrease during descent phase") } } func TestRestrictedPrediction_FloatProfile_OnlyFloat(t *testing.T) { service, _ := createTestService() params := createBasicParams() profile := "float_profile" floatAlt := 15000.0 params.Profile = &profile params.FloatAltitude = &floatAlt // Restrict to float only params.SimulateStages = []string{"float"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // All results should be at the float altitude for _, result := range results { assert.Equal(t, floatAlt, *result.Altitude, "Altitude should remain constant at float altitude") } // Verify horizontal movement (lat/lon changes due to wind) firstLat := *results[0].Latitude lastLat := *results[len(results)-1].Latitude assert.NotEqual(t, firstLat, lastLat, "Latitude should change during float due to wind") } func TestRestrictedPrediction_FloatProfile_AllStages(t *testing.T) { service, _ := createTestService() params := createBasicParams() profile := "float_profile" floatAlt := 15000.0 params.Profile = &profile params.FloatAltitude = &floatAlt // Include all stages params.SimulateStages = []string{"ascent", "float", "descent"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Verify we have ascending, constant, and descending altitude patterns hasAscent := false hasFloat := false hasDescent := false const altTolerance = 50.0 // Tolerance for altitude comparison for i := 1; i < len(results); i++ { altDiff := *results[i].Altitude - *results[i-1].Altitude if altDiff > altTolerance { hasAscent = true } else if altDiff < -altTolerance { hasDescent = true } else if *results[i].Altitude > 10000 { // Float happens at high altitude hasFloat = true } } assert.True(t, hasAscent, "Should have ascent phase") assert.True(t, hasFloat, "Should have float phase") assert.True(t, hasDescent, "Should have descent phase") // Verify maximum altitude is near float altitude maxAlt := 0.0 for _, result := range results { if *result.Altitude > maxAlt { maxAlt = *result.Altitude } } assert.InDelta(t, floatAlt, maxAlt, 1000.0, "Max altitude should be near float altitude") } func TestRestrictedPrediction_ReverseProfile_OnlyFloat(t *testing.T) { service, _ := createTestService() params := createBasicParams() profile := "reverse_profile" floatAlt := 5000.0 params.Profile = &profile params.FloatAltitude = &floatAlt // Restrict to float only params.SimulateStages = []string{"float"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // All results should be at the float altitude for _, result := range results { assert.InDelta(t, floatAlt, *result.Altitude, 10.0, "Altitude should remain near float altitude") } } func TestRestrictedPrediction_EmptyStages_SimulatesAll(t *testing.T) { service, _ := createTestService() params := createBasicParams() // Empty SimulateStages should simulate all stages params.SimulateStages = []string{} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Should have both ascent and descent // Find the peak maxAlt := 0.0 hasAscent := false hasDescent := false for i := 1; i < len(results); i++ { if *results[i].Altitude > *results[i-1].Altitude { hasAscent = true } if *results[i].Altitude < *results[i-1].Altitude { hasDescent = true } if *results[i].Altitude > maxAlt { maxAlt = *results[i].Altitude } } assert.True(t, hasAscent, "Should have ascent phase") assert.True(t, hasDescent, "Should have descent phase") } func TestRestrictedPrediction_NilStages_SimulatesAll(t *testing.T) { service, _ := createTestService() params := createBasicParams() // Nil SimulateStages should simulate all stages params.SimulateStages = nil results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Should have both ascent and descent maxAlt := 0.0 minAltAfterMax := 1000000.0 for _, result := range results { if *result.Altitude > maxAlt { maxAlt = *result.Altitude } } foundMax := false for _, result := range results { if *result.Altitude == maxAlt { foundMax = true } if foundMax && *result.Altitude < minAltAfterMax { minAltAfterMax = *result.Altitude } } // Should reach high altitude and come back down assert.Greater(t, maxAlt, 5000.0, "Should reach high altitude") assert.Less(t, minAltAfterMax, maxAlt, "Should descend after reaching max altitude") } func TestRestrictedPrediction_InvalidStage_IgnoresInvalid(t *testing.T) { service, _ := createTestService() params := createBasicParams() // Include invalid stage name (should be ignored) params.SimulateStages = []string{"ascent", "invalid_stage", "descent"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Should still simulate ascent and descent, ignoring the invalid stage } func TestRestrictedPrediction_WindImpact(t *testing.T) { service, mockGrib := createTestService() // Override mock to return strong eastward wind mockGrib.ExpectedCalls = nil mockGrib.On("Extract", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return([2]float64{20.0, 0.0}, nil) // Strong eastward wind params := createBasicParams() params.SimulateStages = []string{"ascent"} results, err := service.PerformPrediction(context.Background(), params) assert.NoError(t, err) assert.NotEmpty(t, results) // Longitude should increase significantly due to eastward wind firstLon := *results[0].Longitude lastLon := *results[len(results)-1].Longitude assert.Greater(t, lastLon, firstLon, "Longitude should increase with eastward wind") // Verify wind values are captured in results for _, result := range results { if result.WindU != nil { // Wind values should be present in results assert.NotNil(t, result.WindV, "WindV should be present if WindU is present") } } } func TestRestrictedPrediction_MissingRequiredParams(t *testing.T) { service, _ := createTestService() testCases := []struct { name string params ds.PredictionParameters }{ { name: "Missing latitude", params: ds.PredictionParameters{ LaunchLongitude: floatPtr(-105.0), LaunchAltitude: floatPtr(1000.0), LaunchDatetime: timePtr(time.Now()), }, }, { name: "Missing longitude", params: ds.PredictionParameters{ LaunchLatitude: floatPtr(40.0), LaunchAltitude: floatPtr(1000.0), LaunchDatetime: timePtr(time.Now()), }, }, { name: "Missing altitude", params: ds.PredictionParameters{ LaunchLatitude: floatPtr(40.0), LaunchLongitude: floatPtr(-105.0), LaunchDatetime: timePtr(time.Now()), }, }, { name: "Missing datetime", params: ds.PredictionParameters{ LaunchLatitude: floatPtr(40.0), LaunchLongitude: floatPtr(-105.0), LaunchAltitude: floatPtr(1000.0), }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tc.params.SimulateStages = []string{"ascent"} results, err := service.PerformPrediction(context.Background(), tc.params) assert.Error(t, err) assert.Equal(t, ErrInvalidParameters, err) assert.Nil(t, results) }) } } func TestShouldSimulateStage(t *testing.T) { testCases := []struct { name string stages []string queryStage string shouldSimulate bool }{ { name: "Empty filter simulates all", stages: []string{}, queryStage: "ascent", shouldSimulate: true, }, { name: "Nil filter simulates all", stages: nil, queryStage: "descent", shouldSimulate: true, }, { name: "Stage in filter", stages: []string{"ascent", "descent"}, queryStage: "ascent", shouldSimulate: true, }, { name: "Stage not in filter", stages: []string{"ascent"}, queryStage: "descent", shouldSimulate: false, }, { name: "Float stage in filter", stages: []string{"float"}, queryStage: "float", shouldSimulate: true, }, { name: "Multiple stages excluding one", stages: []string{"ascent", "float"}, queryStage: "descent", shouldSimulate: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { params := ds.PredictionParameters{ SimulateStages: tc.stages, } result := shouldSimulateStage(params, tc.queryStage) assert.Equal(t, tc.shouldSimulate, result) }) } } // Helper functions func floatPtr(f float64) *float64 { return &f } func timePtr(t time.Time) *time.Time { return &t }