package main import ( "context" "fmt" "github.com/spf13/cobra" "log" "os" "strconv" ) func SetupDatabaseConnection(cmd *cobra.Command, args []string) { dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT")) db := NewDatabase(os.Getenv("DB_HOST"), dbPort) err := db.Connect(os.Getenv("DB_USER"), os.Getenv("DB_PASSWORD"), os.Getenv("DB_NAME")) if err != nil { log.Printf("Failed to connect to the database: %s", err) os.Exit(1) } cmd.SetContext(context.WithValue(cmd.Context(), "db", db)) } func NewCli() *cobra.Command { cli := &cobra.Command{ Use: "cobble", Short: "Cobble is a simple SQL migration utility.", } up := &cobra.Command{ Use: "up", Short: "Applies migrations", PreRun: SetupDatabaseConnection, Run: func(cmd *cobra.Command, args []string) { db := cmd.Context().Value("db").(DB) migrationRoot := args[0] migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) migrationHistory, _ := migrationGraph.GetLinearHistory() for _, migration := range migrationHistory { log.Printf("%s", migration.Name) if err := migration.Apply(db); err != nil { os.Exit(1) } } }, } inspect := &cobra.Command{ Use: "inspect", Short: "Prints the nth migration in the history", Run: func(cmd *cobra.Command, args []string) { migrationRoot, _ := cmd.Flags().GetString("root") migrationIndex, _ := cmd.Flags().GetInt("index") migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) migrationHistory, _ := migrationGraph.GetLinearHistory() migration := migrationHistory[migrationIndex] sql, _ := migration.Sql() fmt.Printf("%s:\n%s", migration.Name, sql) }, } cli.AddCommand(up) cli.AddCommand(inspect) cli.PersistentFlags().StringP("root", "r", "./migrations", "Root directory where migration files live.") inspect.PersistentFlags().IntP("index", "i", 0, "Zero-based index of the migration to target.") inspect.MarkFlagRequired("index") return cli } func main() { if err := NewCli().Execute(); err != nil { log.Fatal(err) os.Exit(1) } }