215 lines
5.3 KiB
Go
215 lines
5.3 KiB
Go
// Command predictor-cli is a small HTTP client for stratoflights-predictor.
|
|
//
|
|
// It is intended for operations and development; production callers should
|
|
// use the REST API directly.
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const usage = `predictor-cli — HTTP client for stratoflights-predictor
|
|
|
|
USAGE
|
|
predictor-cli [--server URL] <command> [args...]
|
|
|
|
COMMANDS
|
|
ready Check service health
|
|
predict <KEY=VAL>... Run a Tawhiri-compat prediction (key=value pairs)
|
|
datasets list List stored dataset epochs
|
|
datasets download [--latest|--epoch RFC3339]
|
|
Trigger a dataset download
|
|
datasets delete <epoch> Delete a stored dataset
|
|
jobs list List download jobs
|
|
jobs get <id> Show one job
|
|
jobs cancel <id> Cancel a running job
|
|
|
|
ENVIRONMENT
|
|
PREDICTOR_SERVER Default --server (overridden by the flag)
|
|
`
|
|
|
|
func main() {
|
|
fs := flag.NewFlagSet("predictor-cli", flag.ContinueOnError)
|
|
fs.Usage = func() { fmt.Fprint(os.Stderr, usage) }
|
|
server := fs.String("server", envDefault("PREDICTOR_SERVER", "http://localhost:8080"), "predictor server URL")
|
|
if err := fs.Parse(os.Args[1:]); err != nil {
|
|
os.Exit(2)
|
|
}
|
|
args := fs.Args()
|
|
if len(args) == 0 {
|
|
fs.Usage()
|
|
os.Exit(2)
|
|
}
|
|
|
|
c := &client{base: strings.TrimRight(*server, "/"), http: &http.Client{Timeout: 30 * time.Second}}
|
|
if err := dispatch(c, args); err != nil {
|
|
fmt.Fprintln(os.Stderr, "error:", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func envDefault(name, fallback string) string {
|
|
if v := os.Getenv(name); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func dispatch(c *client, args []string) error {
|
|
switch args[0] {
|
|
case "ready":
|
|
return c.ready()
|
|
case "predict":
|
|
return c.predict(args[1:])
|
|
case "datasets":
|
|
if len(args) < 2 {
|
|
return fmt.Errorf("usage: datasets {list|download|delete}")
|
|
}
|
|
switch args[1] {
|
|
case "list":
|
|
return c.datasetsList()
|
|
case "download":
|
|
return c.datasetsDownload(args[2:])
|
|
case "delete":
|
|
if len(args) < 3 {
|
|
return fmt.Errorf("usage: datasets delete <epoch>")
|
|
}
|
|
return c.datasetsDelete(args[2])
|
|
}
|
|
case "jobs":
|
|
if len(args) < 2 {
|
|
return fmt.Errorf("usage: jobs {list|get|cancel}")
|
|
}
|
|
switch args[1] {
|
|
case "list":
|
|
return c.jobsList()
|
|
case "get":
|
|
if len(args) < 3 {
|
|
return fmt.Errorf("usage: jobs get <id>")
|
|
}
|
|
return c.jobsGet(args[2])
|
|
case "cancel":
|
|
if len(args) < 3 {
|
|
return fmt.Errorf("usage: jobs cancel <id>")
|
|
}
|
|
return c.jobsCancel(args[2])
|
|
}
|
|
}
|
|
return fmt.Errorf("unknown command %q", args[0])
|
|
}
|
|
|
|
type client struct {
|
|
base string
|
|
http *http.Client
|
|
}
|
|
|
|
func (c *client) ready() error {
|
|
return c.getPrint("/ready")
|
|
}
|
|
|
|
func (c *client) predict(kv []string) error {
|
|
q := url.Values{}
|
|
for _, p := range kv {
|
|
idx := strings.IndexByte(p, '=')
|
|
if idx <= 0 {
|
|
return fmt.Errorf("expected key=value, got %q", p)
|
|
}
|
|
q.Set(p[:idx], p[idx+1:])
|
|
}
|
|
return c.getPrint("/api/v1/prediction?" + q.Encode())
|
|
}
|
|
|
|
func (c *client) datasetsList() error {
|
|
return c.getPrint("/api/v1/admin/datasets")
|
|
}
|
|
|
|
func (c *client) datasetsDownload(args []string) error {
|
|
fs := flag.NewFlagSet("datasets download", flag.ContinueOnError)
|
|
latest := fs.Bool("latest", false, "download the latest available run")
|
|
epoch := fs.String("epoch", "", "RFC3339 epoch to download")
|
|
if err := fs.Parse(args); err != nil {
|
|
return err
|
|
}
|
|
body := map[string]any{}
|
|
if *latest {
|
|
body["latest"] = true
|
|
}
|
|
if *epoch != "" {
|
|
body["epoch"] = *epoch
|
|
}
|
|
return c.postPrint("/api/v1/admin/datasets", body)
|
|
}
|
|
|
|
func (c *client) datasetsDelete(epoch string) error {
|
|
return c.deletePrint("/api/v1/admin/datasets/" + url.PathEscape(epoch))
|
|
}
|
|
|
|
func (c *client) jobsList() error { return c.getPrint("/api/v1/admin/jobs") }
|
|
func (c *client) jobsGet(id string) error {
|
|
return c.getPrint("/api/v1/admin/jobs/" + url.PathEscape(id))
|
|
}
|
|
func (c *client) jobsCancel(id string) error {
|
|
return c.deletePrint("/api/v1/admin/jobs/" + url.PathEscape(id))
|
|
}
|
|
|
|
func (c *client) getPrint(path string) error {
|
|
resp, err := c.http.Get(c.base + path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return printResp(resp)
|
|
}
|
|
|
|
func (c *client) postPrint(path string, body any) error {
|
|
buf, err := json.Marshal(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := c.http.Post(c.base+path, "application/json", bytes.NewReader(buf))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return printResp(resp)
|
|
}
|
|
|
|
func (c *client) deletePrint(path string) error {
|
|
req, err := http.NewRequest(http.MethodDelete, c.base+path, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return printResp(resp)
|
|
}
|
|
|
|
func printResp(resp *http.Response) error {
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode >= 400 {
|
|
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
// Pretty-print JSON when possible; raw bytes otherwise.
|
|
if strings.Contains(resp.Header.Get("Content-Type"), "json") && len(body) > 0 {
|
|
var any any
|
|
if err := json.Unmarshal(body, &any); err == nil {
|
|
pretty, _ := json.MarshalIndent(any, "", " ")
|
|
fmt.Println(string(pretty))
|
|
return nil
|
|
}
|
|
}
|
|
if len(body) > 0 {
|
|
fmt.Println(strings.TrimSpace(string(body)))
|
|
}
|
|
return nil
|
|
}
|