predictor/cmd/predictor-cli/main.go

215 lines
5.3 KiB
Go

// Command predictor-cli is a small HTTP client for stratoflights-predictor.
//
// It is intended for operations and development; production callers should
// use the REST API directly.
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const usage = `predictor-cli — HTTP client for stratoflights-predictor
USAGE
predictor-cli [--server URL] <command> [args...]
COMMANDS
ready Check service health
predict <KEY=VAL>... Run a Tawhiri-compat prediction (key=value pairs)
datasets list List stored dataset epochs
datasets download [--latest|--epoch RFC3339]
Trigger a dataset download
datasets delete <epoch> Delete a stored dataset
jobs list List download jobs
jobs get <id> Show one job
jobs cancel <id> Cancel a running job
ENVIRONMENT
PREDICTOR_SERVER Default --server (overridden by the flag)
`
func main() {
fs := flag.NewFlagSet("predictor-cli", flag.ContinueOnError)
fs.Usage = func() { fmt.Fprint(os.Stderr, usage) }
server := fs.String("server", envDefault("PREDICTOR_SERVER", "http://localhost:8080"), "predictor server URL")
if err := fs.Parse(os.Args[1:]); err != nil {
os.Exit(2)
}
args := fs.Args()
if len(args) == 0 {
fs.Usage()
os.Exit(2)
}
c := &client{base: strings.TrimRight(*server, "/"), http: &http.Client{Timeout: 30 * time.Second}}
if err := dispatch(c, args); err != nil {
fmt.Fprintln(os.Stderr, "error:", err)
os.Exit(1)
}
}
func envDefault(name, fallback string) string {
if v := os.Getenv(name); v != "" {
return v
}
return fallback
}
func dispatch(c *client, args []string) error {
switch args[0] {
case "ready":
return c.ready()
case "predict":
return c.predict(args[1:])
case "datasets":
if len(args) < 2 {
return fmt.Errorf("usage: datasets {list|download|delete}")
}
switch args[1] {
case "list":
return c.datasetsList()
case "download":
return c.datasetsDownload(args[2:])
case "delete":
if len(args) < 3 {
return fmt.Errorf("usage: datasets delete <epoch>")
}
return c.datasetsDelete(args[2])
}
case "jobs":
if len(args) < 2 {
return fmt.Errorf("usage: jobs {list|get|cancel}")
}
switch args[1] {
case "list":
return c.jobsList()
case "get":
if len(args) < 3 {
return fmt.Errorf("usage: jobs get <id>")
}
return c.jobsGet(args[2])
case "cancel":
if len(args) < 3 {
return fmt.Errorf("usage: jobs cancel <id>")
}
return c.jobsCancel(args[2])
}
}
return fmt.Errorf("unknown command %q", args[0])
}
type client struct {
base string
http *http.Client
}
func (c *client) ready() error {
return c.getPrint("/ready")
}
func (c *client) predict(kv []string) error {
q := url.Values{}
for _, p := range kv {
idx := strings.IndexByte(p, '=')
if idx <= 0 {
return fmt.Errorf("expected key=value, got %q", p)
}
q.Set(p[:idx], p[idx+1:])
}
return c.getPrint("/api/v1/prediction?" + q.Encode())
}
func (c *client) datasetsList() error {
return c.getPrint("/api/v1/admin/datasets")
}
func (c *client) datasetsDownload(args []string) error {
fs := flag.NewFlagSet("datasets download", flag.ContinueOnError)
latest := fs.Bool("latest", false, "download the latest available run")
epoch := fs.String("epoch", "", "RFC3339 epoch to download")
if err := fs.Parse(args); err != nil {
return err
}
body := map[string]any{}
if *latest {
body["latest"] = true
}
if *epoch != "" {
body["epoch"] = *epoch
}
return c.postPrint("/api/v1/admin/datasets", body)
}
func (c *client) datasetsDelete(epoch string) error {
return c.deletePrint("/api/v1/admin/datasets/" + url.PathEscape(epoch))
}
func (c *client) jobsList() error { return c.getPrint("/api/v1/admin/jobs") }
func (c *client) jobsGet(id string) error {
return c.getPrint("/api/v1/admin/jobs/" + url.PathEscape(id))
}
func (c *client) jobsCancel(id string) error {
return c.deletePrint("/api/v1/admin/jobs/" + url.PathEscape(id))
}
func (c *client) getPrint(path string) error {
resp, err := c.http.Get(c.base + path)
if err != nil {
return err
}
return printResp(resp)
}
func (c *client) postPrint(path string, body any) error {
buf, err := json.Marshal(body)
if err != nil {
return err
}
resp, err := c.http.Post(c.base+path, "application/json", bytes.NewReader(buf))
if err != nil {
return err
}
return printResp(resp)
}
func (c *client) deletePrint(path string) error {
req, err := http.NewRequest(http.MethodDelete, c.base+path, nil)
if err != nil {
return err
}
resp, err := c.http.Do(req)
if err != nil {
return err
}
return printResp(resp)
}
func printResp(resp *http.Response) error {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
// Pretty-print JSON when possible; raw bytes otherwise.
if strings.Contains(resp.Header.Get("Content-Type"), "json") && len(body) > 0 {
var any any
if err := json.Unmarshal(body, &any); err == nil {
pretty, _ := json.MarshalIndent(any, "", " ")
fmt.Println(string(pretty))
return nil
}
}
if len(body) > 0 {
fmt.Println(strings.TrimSpace(string(body)))
}
return nil
}