step one
This commit is contained in:
parent
7a8d5d13fa
commit
9e663db9dc
68 changed files with 5647 additions and 2958 deletions
216
cmd/predictor-cli/main.go
Normal file
216
cmd/predictor-cli/main.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
// Command predictor-cli is a small HTTP client for stratoflights-predictor.
|
||||
//
|
||||
// It is intended for operations and development; production callers should
|
||||
// use the REST API directly.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const usage = `predictor-cli — HTTP client for stratoflights-predictor
|
||||
|
||||
USAGE
|
||||
predictor-cli [--server URL] <command> [args...]
|
||||
|
||||
COMMANDS
|
||||
ready Check service health
|
||||
predict <KEY=VAL>... Run a Tawhiri-compat prediction (key=value pairs)
|
||||
datasets list List stored dataset epochs
|
||||
datasets download [--latest|--epoch RFC3339]
|
||||
Trigger a dataset download
|
||||
datasets delete <epoch> Delete a stored dataset
|
||||
jobs list List download jobs
|
||||
jobs get <id> Show one job
|
||||
jobs cancel <id> Cancel a running job
|
||||
|
||||
ENVIRONMENT
|
||||
PREDICTOR_SERVER Default --server (overridden by the flag)
|
||||
`
|
||||
|
||||
func main() {
|
||||
fs := flag.NewFlagSet("predictor-cli", flag.ContinueOnError)
|
||||
fs.Usage = func() { fmt.Fprint(os.Stderr, usage) }
|
||||
server := fs.String("server", envDefault("PREDICTOR_SERVER", "http://localhost:8080"), "predictor server URL")
|
||||
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||
os.Exit(2)
|
||||
}
|
||||
args := fs.Args()
|
||||
if len(args) == 0 {
|
||||
fs.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
c := &client{base: strings.TrimRight(*server, "/"), http: &http.Client{Timeout: 30 * time.Second}}
|
||||
if err := dispatch(c, args); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func envDefault(name, fallback string) string {
|
||||
if v := os.Getenv(name); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func dispatch(c *client, args []string) error {
|
||||
switch args[0] {
|
||||
case "ready":
|
||||
return c.ready()
|
||||
case "predict":
|
||||
return c.predict(args[1:])
|
||||
case "datasets":
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("usage: datasets {list|download|delete}")
|
||||
}
|
||||
switch args[1] {
|
||||
case "list":
|
||||
return c.datasetsList()
|
||||
case "download":
|
||||
return c.datasetsDownload(args[2:])
|
||||
case "delete":
|
||||
if len(args) < 3 {
|
||||
return fmt.Errorf("usage: datasets delete <epoch>")
|
||||
}
|
||||
return c.datasetsDelete(args[2])
|
||||
}
|
||||
case "jobs":
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("usage: jobs {list|get|cancel}")
|
||||
}
|
||||
switch args[1] {
|
||||
case "list":
|
||||
return c.jobsList()
|
||||
case "get":
|
||||
if len(args) < 3 {
|
||||
return fmt.Errorf("usage: jobs get <id>")
|
||||
}
|
||||
return c.jobsGet(args[2])
|
||||
case "cancel":
|
||||
if len(args) < 3 {
|
||||
return fmt.Errorf("usage: jobs cancel <id>")
|
||||
}
|
||||
return c.jobsCancel(args[2])
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unknown command %q", args[0])
|
||||
}
|
||||
|
||||
type client struct {
|
||||
base string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
func (c *client) ready() error {
|
||||
return c.getPrint("/ready")
|
||||
}
|
||||
|
||||
func (c *client) predict(kv []string) error {
|
||||
q := url.Values{}
|
||||
for _, p := range kv {
|
||||
idx := strings.IndexByte(p, '=')
|
||||
if idx <= 0 {
|
||||
return fmt.Errorf("expected key=value, got %q", p)
|
||||
}
|
||||
q.Set(p[:idx], p[idx+1:])
|
||||
}
|
||||
return c.getPrint("/api/v1/prediction?" + q.Encode())
|
||||
}
|
||||
|
||||
func (c *client) datasetsList() error {
|
||||
return c.getPrint("/api/v1/admin/datasets")
|
||||
}
|
||||
|
||||
func (c *client) datasetsDownload(args []string) error {
|
||||
fs := flag.NewFlagSet("datasets download", flag.ContinueOnError)
|
||||
latest := fs.Bool("latest", false, "download the latest available run")
|
||||
epoch := fs.String("epoch", "", "RFC3339 epoch to download")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
body := map[string]any{}
|
||||
if *latest {
|
||||
body["latest"] = true
|
||||
}
|
||||
if *epoch != "" {
|
||||
body["epoch"] = *epoch
|
||||
}
|
||||
return c.postPrint("/api/v1/admin/datasets", body)
|
||||
}
|
||||
|
||||
func (c *client) datasetsDelete(epoch string) error {
|
||||
return c.deletePrint("/api/v1/admin/datasets/" + url.PathEscape(epoch))
|
||||
}
|
||||
|
||||
func (c *client) jobsList() error { return c.getPrint("/api/v1/admin/jobs") }
|
||||
func (c *client) jobsGet(id string) error {
|
||||
return c.getPrint("/api/v1/admin/jobs/" + url.PathEscape(id))
|
||||
}
|
||||
func (c *client) jobsCancel(id string) error {
|
||||
return c.deletePrint("/api/v1/admin/jobs/" + url.PathEscape(id))
|
||||
}
|
||||
|
||||
func (c *client) getPrint(path string) error {
|
||||
resp, err := c.http.Get(c.base + path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printResp(resp)
|
||||
}
|
||||
|
||||
func (c *client) postPrint(path string, body any) error {
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.http.Post(c.base+path, "application/json", bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printResp(resp)
|
||||
}
|
||||
|
||||
func (c *client) deletePrint(path string) error {
|
||||
req, err := http.NewRequest(http.MethodDelete, c.base+path, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return printResp(resp)
|
||||
}
|
||||
|
||||
func printResp(resp *http.Response) error {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
// Pretty-print JSON when possible; raw bytes otherwise.
|
||||
if strings.Contains(resp.Header.Get("Content-Type"), "json") && len(body) > 0 {
|
||||
var any any
|
||||
if err := json.Unmarshal(body, &any); err == nil {
|
||||
pretty, _ := json.MarshalIndent(any, "", " ")
|
||||
fmt.Println(string(pretty))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if len(body) > 0 {
|
||||
fmt.Println(strings.TrimSpace(string(body)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue