From daa19167a76530b1fea95b0510435431ee82d627 Mon Sep 17 00:00:00 2001 From: Marc Cataford Date: Fri, 4 Oct 2024 00:03:12 -0400 Subject: [PATCH] feat: start cmd for remote hosts --- cli/start_service.go | 45 ++++++++++++++++++++---- cli/start_service_test.go | 38 ++++++++++++++++++-- webclient/client_test.go | 73 +++++++++++++++++++++++++++++++++++++++ webclient/main.go | 45 ++++++++++++++++++++++++ 4 files changed, 193 insertions(+), 8 deletions(-) create mode 100644 webclient/client_test.go create mode 100644 webclient/main.go diff --git a/cli/start_service.go b/cli/start_service.go index df47157..d9d9281 100644 --- a/cli/start_service.go +++ b/cli/start_service.go @@ -5,36 +5,69 @@ package cli import ( + "context" "fmt" "github.com/spf13/cobra" service "spud/service" service_definition "spud/service_definition" + webclient "spud/webclient" ) func getStartCommand() *cobra.Command { + type ParsedFlags struct { + definitionPath string + daemonHost string + daemonPort int + } start := &cobra.Command{ Use: "start", Short: "Creates or updates a service based on the provided definition.", - RunE: func(cmd *cobra.Command, args []string) error { - pathProvided, err := cmd.PersistentFlags().GetString("definition") + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + var pathProvided, host string + var port int + var err error - if err != nil { - return fmt.Errorf("%+v", err) + if pathProvided, err = cmd.PersistentFlags().GetString("definition"); err != nil { + return err } - def, err := service_definition.GetServiceDefinitionFromFile(pathProvided) + if host, err = cmd.PersistentFlags().GetString("host"); err != nil { + return err + } + + if port, err = cmd.PersistentFlags().GetInt("port"); err != nil { + return err + } + + if host != "" && port == 0 || host == "" && port != 0 { + return fmt.Errorf("Invalid flags: host and port must be defined together or not at all.") + } + + cmd.SetContext(context.WithValue(cmd.Context(), "flags", ParsedFlags{definitionPath: pathProvided, daemonHost: host, daemonPort: port})) + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + flags := ctx.Value("flags").(ParsedFlags) + def, err := service_definition.GetServiceDefinitionFromFile(flags.definitionPath) if err != nil { return fmt.Errorf("Failed to read service definition from file: %+v", err) } - service.CreateService(def) + if flags.daemonHost != "" && flags.daemonPort != 0 { + webclient.NewWebClient(flags.daemonHost, flags.daemonPort).CreateService(def) + return nil + } + service.CreateService(def) return nil }, } start.PersistentFlags().StringP("definition", "d", "./service.yml", "Path to the service definition to use.") + start.PersistentFlags().StringP("host", "H", "", "If specified, host where the daemon lives.") + start.PersistentFlags().IntP("port", "p", 0, "Port on the daemon host.") return start } diff --git a/cli/start_service_test.go b/cli/start_service_test.go index 5d2f6aa..5d05d66 100644 --- a/cli/start_service_test.go +++ b/cli/start_service_test.go @@ -1,11 +1,12 @@ package cli import ( + "fmt" "github.com/spf13/cobra" "testing" ) -func TestCliStartServiceDefinitionPathMustExist(t *testing.T) { +func Test_StartServiceCli_StartServiceDefinitionPathMustExist(t *testing.T) { cli := GetCli() cli.SetArgs([]string{"start", "-d", "./not-a-file.yml"}) @@ -17,7 +18,7 @@ func TestCliStartServiceDefinitionPathMustExist(t *testing.T) { } } -func TestCliStartDefaultsServiceDefinitionPath(t *testing.T) { +func Test_StartServiceCli_StartDefaultsServiceDefinitionPath(t *testing.T) { cli := GetCli() cli.SetArgs([]string{"start"}) @@ -35,3 +36,36 @@ func TestCliStartDefaultsServiceDefinitionPath(t *testing.T) { cli.Execute() } + +func Test_StartServiceCli_ErrorIfHostAndPortNotProvidedTogether(t *testing.T) { + inputs := [][]string{ + {"start", "-H", "host"}, + {"start", "-p", "9999"}, + } + + for _, input := range inputs { + t.Run(fmt.Sprintf("%+v", input), func(t *testing.T) { + cli := GetCli() + cli.SetArgs(input) + + startCommand, _, _ := cli.Find([]string{"start"}) + + previousPreRun := startCommand.PersistentPreRunE + + startCommand.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + if err := previousPreRun(cmd, args); err == nil { + t.Errorf("Expected error, got nil.") + } + + return nil + } + + startCommand.RunE = func(cmd *cobra.Command, args []string) error { + return nil + } + + cli.Execute() + }) + + } +} diff --git a/webclient/client_test.go b/webclient/client_test.go new file mode 100644 index 0000000..450af35 --- /dev/null +++ b/webclient/client_test.go @@ -0,0 +1,73 @@ +package webclient + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + daemon "spud/daemon" + service_definition "spud/service_definition" + "testing" +) + +type RecordedRequest struct { + method string + url string + contentType string + data io.Reader +} + +type DummyHttpClient struct { + requests []RecordedRequest +} + +func (d *DummyHttpClient) Post(url string, contentType string, data io.Reader) (*http.Response, error) { + d.requests = append(d.requests, RecordedRequest{method: http.MethodPost, url: url, contentType: contentType, data: data}) + return nil, nil +} + +func Test_WebClient_GetBaseUrlGetsUrlFromHostPort(t *testing.T) { + c := NewWebClient("http://host", 9999) + + actual := c.getBaseUrl() + expected := "http://host:9999" + if actual != expected { + t.Errorf("Expected %s, got %s.", expected, actual) + } +} + +func Test_WebClient_CreateServicePostsToDaemon(t *testing.T) { + c := NewWebClient("http://host", 9999) + httpClient := &DummyHttpClient{} + c.httpClient = httpClient + + def := service_definition.ServiceDefinition{Name: "test-service"} + + c.CreateService(def) + + if len(httpClient.requests) != 1 || httpClient.requests[0].method != http.MethodPost { + t.Errorf("Expected one POST requests, got none.") + } +} + +func Test_WebClient_CreateServiceSendsDefinition(t *testing.T) { + c := NewWebClient("http://host", 9999) + httpClient := &DummyHttpClient{} + c.httpClient = httpClient + + payload := daemon.ServiceListPayload{ + Definition: service_definition.ServiceDefinition{Name: "test-service"}, + } + + c.CreateService(payload.Definition) + + req := httpClient.requests[0] + + actualDef := bytes.NewBuffer([]byte{}) + actualDef.ReadFrom(req.data) + expectedDef, _ := json.Marshal(payload) + + if actualDef.String() != string(expectedDef) { + t.Errorf("Unexpected data: %s != %s", actualDef.String(), string(expectedDef)) + } +} diff --git a/webclient/main.go b/webclient/main.go new file mode 100644 index 0000000..e955eab --- /dev/null +++ b/webclient/main.go @@ -0,0 +1,45 @@ +package webclient + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + daemon "spud/daemon" + service_definition "spud/service_definition" +) + +type HttpClient interface { + Post(string, string, io.Reader) (*http.Response, error) +} + +type WebClient struct { + httpClient HttpClient + Host string + Port int +} + +func NewWebClient(host string, port int) *WebClient { + return &WebClient{ + httpClient: &http.Client{}, + Host: host, + Port: port, + } +} + +func (c WebClient) getBaseUrl() string { + return fmt.Sprintf("%s:%d", c.Host, c.Port) +} + +func (c WebClient) CreateService(def service_definition.ServiceDefinition) error { + payload := daemon.ServiceListPayload{ + Definition: def, + } + + serializedPayload, _ := json.Marshal(payload) + + _, e := c.httpClient.Post(c.getBaseUrl()+"/service/", "application/json", bytes.NewBuffer(serializedPayload)) + + return e +}