Skip to content

Instantly share code, notes, and snippets.

@slashvar
Last active July 30, 2024 18:40
Show Gist options
  • Save slashvar/8874c52d88895a922398289f81cd7a08 to your computer and use it in GitHub Desktop.
Save slashvar/8874c52d88895a922398289f81cd7a08 to your computer and use it in GitHub Desktop.
package main
import (
"context"
"fmt"
"log"
"os"
"time"
_ "github.com/joho/godotenv/autoload"
"<path-to-generated-proto-code>/proto/go/pytorch_serve"
"github.com/kelseyhightower/envconfig"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Config provides envconfig configuration for the command
type Config struct {
Addr string `envconfig:"GRPC_ADDR" default:"localhost:7070"`
Timeout int `envconfig:"TIMEOUT" default:"1000"`
}
func buildData(name, path string) (*pytorch_serve.PredictionsRequest, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return &pytorch_serve.PredictionsRequest{
ModelName: name,
Input: map[string][]byte{"data": data},
}, nil
}
func predictions(ctx context.Context, c *grpc.ClientConn, r *pytorch_serve.PredictionsRequest) (any, error) {
ic := pytorch_serve.NewInferenceAPIsServiceClient(c)
return ic.Predictions(ctx, r)
}
func dial(addr string) (*grpc.ClientConn, error) {
return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
func main() {
if len(os.Args) < 3 {
log.Fatalf("Usage: %s <model_name> <path>\n", os.Args[0])
}
var config Config
if err := envconfig.Process("", &config); err != nil {
log.Fatal(err)
}
c, err := dial(config.Addr)
if err != nil {
log.Fatal(err)
}
rq, err := buildData(os.Args[1], os.Args[2])
if err != nil {
log.Fatal(err)
}
fmt.Printf("Timeout: %v\n", config.Timeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(config.Timeout)*time.Millisecond)
defer cancel()
res, err := predictions(ctx, c, rq)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%v\n", res)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment