package main import ( "database/sql" "errors" "os" "path/filepath" "reflect" "testing" ) type MockSqlDB struct { queries []string mockedError error mockedResult sql.Result } type MockResult struct { sql.Result } func (m *MockSqlDB) Execute(sql string) error { m.queries = append(m.queries, sql) if m.mockedError != nil { return m.mockedError } if m.mockedResult != nil { return nil } return nil } func TestGetMigrationsGathersSqlFiles(t *testing.T) { root := t.TempDir() os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750) migrationGraph, err := NewMigrationGraphFromDirectory(root) migrations, _ := migrationGraph.GetLinearHistory() if len(migrations) != 3 { t.Errorf("Expected three migrations collected, got %d instead.", len(migrations)) } if err != nil { t.Errorf("Expected no error returned, got %#v", err) } expected := []Migration{ NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""), NewMigration(filepath.Join(root, "0002.sql"), "0002.sql", "0001.sql"), NewMigration(filepath.Join(root, "0003.sql"), "0003.sql", "0002.sql"), } collected := []Migration{} for _, migration := range migrations { collected = append(collected, migration) } if !reflect.DeepEqual(collected, expected) { t.Errorf("Expected collected migrations to equal %#v, got %3v instead.", expected, collected) } } func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) { root := t.TempDir() os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750) migrationGraph, err := NewMigrationGraphFromDirectory(root) migrations, _ := migrationGraph.GetLinearHistory() if len(migrations) != 1 { t.Errorf("Expected one migration collected, got %d instead.", len(migrations)) } if err != nil { t.Errorf("Expected no error returned, got %#v", err) } expected := []Migration{ NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""), } collected := []Migration{} for _, migration := range migrations { collected = append(collected, migration) } if !reflect.DeepEqual(collected, expected) { t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected) } } func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) { migrationGraph, _ := NewMigrationGraphFromDirectory("not-root") migrations, err := migrationGraph.GetLinearHistory() if len(migrations) != 0 { t.Errorf("Did not expect any migrations returned, got %#v", migrations) } if err == nil { t.Errorf("Expected error, got nil.") } } func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) { root := t.TempDir() os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750) migrationGraph, _ := NewMigrationGraphFromDirectory(root) migrations, err := migrationGraph.GetLinearHistory() if err != nil { t.Errorf("Expected no error returned, got %#v", err) } collected := []Migration{} for _, migration := range migrations { collected = append(collected, migration) } expected := []Migration{ NewMigration(filepath.Join(root, "0000.sql"), "0000.sql", ""), NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", "0000.sql"), } if !reflect.DeepEqual(collected, expected) { t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected) } } func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) { root := t.TempDir() os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750) migrationGraph, _ := NewMigrationGraphFromDirectory(root) migrations, err := migrationGraph.GetLinearHistory() if err != nil { t.Errorf("Expected no error returned, got %#v", err) } expected := []Migration{ NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""), } collected := []Migration{} for _, migration := range migrations { collected = append(collected, migration) } if !reflect.DeepEqual(collected, expected) { t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected) } } func TestMigrateRunsSqlOnDBConnection(t *testing.T) { mockDb := &MockSqlDB{} root := t.TempDir() migrationPath := filepath.Join(root, "0001.sql") migrationSql := "SELECT 1 FROM table;" os.WriteFile(migrationPath, []byte(migrationSql), 0750) migrationGraph, _ := NewMigrationGraphFromDirectory(root) migrations, err := migrationGraph.GetLinearHistory() err = migrations[0].Apply(mockDb) if err != nil { t.Errorf("Expected no error returned, got %#v", err) } if !reflect.DeepEqual(mockDb.queries, []string{migrationSql}) { t.Errorf("Expected queries %#v to be run, got %#v instead.", []string{migrationSql}, mockDb.queries) } } func TestMigrateReturnsErrorOnFailToReadMigrationFile(t *testing.T) { mockDb := &MockSqlDB{} root := t.TempDir() // Does not exist. migrationPath := filepath.Join(root, "0001.sql") migration := NewMigration(migrationPath, "0001.sql", "") err := migration.Apply(mockDb) if err == nil { t.Errorf("Expected error returned, got nil") } } func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) { mockDb := &MockSqlDB{mockedError: errors.New("Test error!")} root := t.TempDir() migrationSql := "SELECT 1 FROM table;" migrationPath := filepath.Join(root, "0001.sql") os.WriteFile(migrationPath, []byte(migrationSql), 0750) migrationGraph, _ := NewMigrationGraphFromDirectory(root) migrations, err := migrationGraph.GetLinearHistory() err = migrations[0].Apply(mockDb) if err == nil { t.Errorf("Expected error returned, got nil") } }