diff --git a/cobble.go b/cobble.go index d0c0c3b..298324a 100644 --- a/cobble.go +++ b/cobble.go @@ -23,49 +23,57 @@ func SetupDatabaseConnection(cmd *cobra.Command, args []string) { cmd.SetContext(context.WithValue(cmd.Context(), "db", db)) } -var cli = &cobra.Command{ - Use: "cobble", - Short: "Cobble is a simple SQL migration utility.", -} +func NewCli() *cobra.Command { + cli := &cobra.Command{ + Use: "cobble", + Short: "Cobble is a simple SQL migration utility.", + } -var up = &cobra.Command{ - Use: "up", - Short: "Applies migrations", - PreRun: SetupDatabaseConnection, - Run: func(cmd *cobra.Command, args []string) { - db := cmd.Context().Value("db").(DB) - migrationRoot := args[0] - migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) - migrationHistory, _ := migrationGraph.GetLinearHistory() - for _, migration := range migrationHistory { - log.Printf("%s", migration.Name) - if err := migration.Apply(db); err != nil { - os.Exit(1) + up := &cobra.Command{ + Use: "up", + Short: "Applies migrations", + PreRun: SetupDatabaseConnection, + Run: func(cmd *cobra.Command, args []string) { + db := cmd.Context().Value("db").(DB) + migrationRoot := args[0] + migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) + migrationHistory, _ := migrationGraph.GetLinearHistory() + for _, migration := range migrationHistory { + log.Printf("%s", migration.Name) + if err := migration.Apply(db); err != nil { + os.Exit(1) + } } - } - }, -} + }, + } -var inspect = &cobra.Command{ - Use: "inspect", - Short: "Prints the nth migration in the history", - Run: func(cmd *cobra.Command, args []string) { - migrationRoot := args[0] - migrationIndex, _ := strconv.Atoi(args[1]) + inspect := &cobra.Command{ + Use: "inspect", + Short: "Prints the nth migration in the history", + Run: func(cmd *cobra.Command, args []string) { + migrationRoot, _ := cmd.Flags().GetString("root") + migrationIndex, _ := cmd.Flags().GetInt("index") - migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) - migrationHistory, _ := migrationGraph.GetLinearHistory() - migration := migrationHistory[migrationIndex] - sql, _ := migration.Sql() - fmt.Printf("%s:\n%s", migration.Name, sql) - }, -} - -func main() { + migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) + migrationHistory, _ := migrationGraph.GetLinearHistory() + migration := migrationHistory[migrationIndex] + sql, _ := migration.Sql() + fmt.Printf("%s:\n%s", migration.Name, sql) + }, + } cli.AddCommand(up) cli.AddCommand(inspect) - if err := cli.Execute(); err != nil { + cli.PersistentFlags().StringP("root", "r", "./migrations", "Root directory where migration files live.") + inspect.PersistentFlags().IntP("index", "i", 0, "Zero-based index of the migration to target.") + + inspect.MarkFlagRequired("index") + + return cli +} + +func main() { + if err := NewCli().Execute(); err != nil { log.Fatal(err) os.Exit(1) }