cobble/migrate_test.go

211 lines
5.9 KiB
Go

package main
import (
"database/sql"
"errors"
"os"
"path/filepath"
"reflect"
"testing"
)
type MockSqlDB struct {
queries []string
mockedError error
mockedResult sql.Result
}
type MockResult struct {
sql.Result
}
func (m *MockSqlDB) Execute(sql string) error {
m.queries = append(m.queries, sql)
if m.mockedError != nil {
return m.mockedError
}
if m.mockedResult != nil {
return nil
}
return nil
}
func TestGetMigrationsGathersSqlFiles(t *testing.T) {
root := t.TempDir()
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750)
migrationGraph, err := NewMigrationGraphFromDirectory(root)
migrations, _ := migrationGraph.GetLinearHistory()
if len(migrations) != 3 {
t.Errorf("Expected three migrations collected, got %d instead.", len(migrations))
}
if err != nil {
t.Errorf("Expected no error returned, got %#v", err)
}
expected := []Migration{
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""),
NewMigration(filepath.Join(root, "0002.sql"), "0002.sql", "0001.sql"),
NewMigration(filepath.Join(root, "0003.sql"), "0003.sql", "0002.sql"),
}
collected := []Migration{}
for _, migration := range migrations {
collected = append(collected, migration)
}
if !reflect.DeepEqual(collected, expected) {
t.Errorf("Expected collected migrations to equal %#v, got %3v instead.", expected, collected)
}
}
func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
root := t.TempDir()
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750)
migrationGraph, err := NewMigrationGraphFromDirectory(root)
migrations, _ := migrationGraph.GetLinearHistory()
if len(migrations) != 1 {
t.Errorf("Expected one migration collected, got %d instead.", len(migrations))
}
if err != nil {
t.Errorf("Expected no error returned, got %#v", err)
}
expected := []Migration{
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""),
}
collected := []Migration{}
for _, migration := range migrations {
collected = append(collected, migration)
}
if !reflect.DeepEqual(collected, expected) {
t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected)
}
}
func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) {
migrationGraph, _ := NewMigrationGraphFromDirectory("not-root")
migrations, err := migrationGraph.GetLinearHistory()
if len(migrations) != 0 {
t.Errorf("Did not expect any migrations returned, got %#v", migrations)
}
if err == nil {
t.Errorf("Expected error, got nil.")
}
}
func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) {
root := t.TempDir()
os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750)
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
if err != nil {
t.Errorf("Expected no error returned, got %#v", err)
}
collected := []Migration{}
for _, migration := range migrations {
collected = append(collected, migration)
}
expected := []Migration{
NewMigration(filepath.Join(root, "0000.sql"), "0000.sql", ""),
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", "0000.sql"),
}
if !reflect.DeepEqual(collected, expected) {
t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected)
}
}
func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) {
root := t.TempDir()
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750)
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
if err != nil {
t.Errorf("Expected no error returned, got %#v", err)
}
expected := []Migration{
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""),
}
collected := []Migration{}
for _, migration := range migrations {
collected = append(collected, migration)
}
if !reflect.DeepEqual(collected, expected) {
t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected)
}
}
func TestMigrateRunsSqlOnDBConnection(t *testing.T) {
mockDb := &MockSqlDB{}
root := t.TempDir()
migrationPath := filepath.Join(root, "0001.sql")
migrationSql := "SELECT 1 FROM table;"
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
err = ApplyMigration(mockDb, migrations[0])
if err != nil {
t.Errorf("Expected no error returned, got %#v", err)
}
if !reflect.DeepEqual(mockDb.queries, []string{migrationSql}) {
t.Errorf("Expected queries %#v to be run, got %#v instead.", []string{migrationSql}, mockDb.queries)
}
}
func TestMigrateReturnsErrorOnFailToReadMigrationFile(t *testing.T) {
mockDb := &MockSqlDB{}
root := t.TempDir()
// Does not exist.
migrationPath := filepath.Join(root, "0001.sql")
migration := NewMigration(migrationPath, "0001.sql", "")
err := ApplyMigration(mockDb, migration)
if err == nil {
t.Errorf("Expected error returned, got nil")
}
}
func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) {
mockDb := &MockSqlDB{mockedError: errors.New("Test error!")}
root := t.TempDir()
migrationSql := "SELECT 1 FROM table;"
migrationPath := filepath.Join(root, "0001.sql")
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
err = ApplyMigration(mockDb, migrations[0])
if err == nil {
t.Errorf("Expected error returned, got nil")
}
}