feat: add inspect by index command

This commit is contained in:
Marc 2024-07-11 21:15:37 -04:00
parent e5e96c3de7
commit 0dc457ec9c
Signed by: marc
GPG key ID: 048E042F22B5DC79
3 changed files with 62 additions and 10 deletions

View file

@ -2,6 +2,7 @@ package main
import ( import (
"context" "context"
"fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"log" "log"
"os" "os"
@ -45,8 +46,24 @@ var up = &cobra.Command{
}, },
} }
var inspect = &cobra.Command{
Use: "inspect",
Short: "Prints the nth migration in the history",
Run: func(cmd *cobra.Command, args []string) {
migrationRoot := args[0]
migrationIndex, _ := strconv.Atoi(args[1])
migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot)
migrationHistory, _ := migrationGraph.GetLinearHistory()
migration := migrationHistory[migrationIndex]
sql, _ := migration.Sql()
fmt.Printf("%s:\n%s", migration.Name, sql)
},
}
func main() { func main() {
cli.AddCommand(up) cli.AddCommand(up)
cli.AddCommand(inspect)
if err := cli.Execute(); err != nil { if err := cli.Execute(); err != nil {
log.Fatal(err) log.Fatal(err)

View file

@ -21,20 +21,38 @@ func NewMigration(path string, name string, requires string) Migration {
} }
} }
func (m *Migration) Bytes() ([]byte, error) {
migrationPath := m.Path
migrationBytes, err := os.ReadFile(migrationPath)
if err != nil {
return []byte{}, err
}
return migrationBytes, nil
}
func (m *Migration) Text() (string, error) {
migrationBytes, err := m.Bytes()
return string(migrationBytes), err
}
func (m *Migration) Sql() (string, error) {
migrationBytes, err := m.Bytes()
return string(StripComments(migrationBytes)), err
}
// Applies a migration to the given database connection. // Applies a migration to the given database connection.
// //
// If an error is returned while trying to read the migration file // If an error is returned while trying to read the migration file
// or execute the SQL it contains, the error is returned. // or execute the SQL it contains, the error is returned.
func (m *Migration) ApplyMigration(db DB) error { func (m *Migration) ApplyMigration(db DB) error {
migrationPath := m.Path migrationSql, err := m.Sql()
migrationBytes, err := os.ReadFile(migrationPath)
if err != nil { if err != nil {
return err return err
} }
migrationSql := string(migrationBytes)
log.Printf("SQL: %s", migrationSql) log.Printf("SQL: %s", migrationSql)
err = db.Execute(migrationSql) err = db.Execute(migrationSql)

View file

@ -1,8 +1,3 @@
// Migration tooling
//
// Runs migrations in-order based on files found in the migration
// root directory provided as argument.
package main package main
import ( import (
@ -24,6 +19,28 @@ func cutPrefixAndTrim(b []byte, prefix []byte) []byte {
return bytes.TrimSpace(withoutPrefix) return bytes.TrimSpace(withoutPrefix)
} }
func StripComments(sql []byte) []byte {
scanner := bufio.NewScanner(bytes.NewReader(sql))
sqlWithoutComments := []byte{}
for scanner.Scan() {
currentBytes := scanner.Bytes()
if !bytes.HasPrefix(currentBytes, commentPrefix) {
sqlWithoutComments = append(sqlWithoutComments, currentBytes...)
sqlWithoutComments = append(sqlWithoutComments, []byte("\n")...)
}
}
totalLength := len(sqlWithoutComments)
if totalLength > 0 {
return sqlWithoutComments[:totalLength-1]
}
return sqlWithoutComments
}
// Extracts any migration headers present in the migration source file. // Extracts any migration headers present in the migration source file.
// Headers are left in comments of the form: // Headers are left in comments of the form:
// -- <header-name>: ... // -- <header-name>: ...