diff --git a/cobble.go b/cobble.go index 7ce477b..cc61d56 100644 --- a/cobble.go +++ b/cobble.go @@ -1,37 +1,44 @@ package main import ( + "context" "github.com/spf13/cobra" "log" "os" "strconv" ) +func SetupDatabaseConnection(cmd *cobra.Command, args []string) { + dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT")) + + db := NewDatabase(os.Getenv("DB_HOST"), dbPort) + err := db.Connect(os.Getenv("DB_APP_USER"), os.Getenv("DB_APP_PASSWORD"), os.Getenv("DB_NAME")) + + if err != nil { + log.Printf("Failed to connect to the database: %s", err) + os.Exit(1) + } + + cmd.SetContext(context.WithValue(cmd.Context(), "db", db)) +} + var cli = &cobra.Command{ Use: "cobble", Short: "Cobble is a simple SQL migration utility.", } var up = &cobra.Command{ - Use: "up", - Short: "Applies migrations", + Use: "up", + Short: "Applies migrations", + PreRun: SetupDatabaseConnection, Run: func(cmd *cobra.Command, args []string) { - dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT")) - - db := NewDatabase(os.Getenv("DB_HOST"), dbPort) - err := db.Connect(os.Getenv("DB_APP_USER"), os.Getenv("DB_APP_PASSWORD"), os.Getenv("DB_NAME")) - - if err != nil { - log.Printf("Failed to connect to the database: %s", err) - os.Exit(1) - } - + db := cmd.Context().Value("db").(DB) migrationRoot := args[0] migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) migrationHistory, _ := migrationGraph.GetLinearHistory() for _, migration := range migrationHistory { log.Printf("%s", migration.Name) - if err = migration.ApplyMigration(db); err != nil { + if err := migration.ApplyMigration(db); err != nil { os.Exit(1) } }