Skip to content

Commit 22d769e

Browse files
Consolidate fetching of MySQL server info
1 parent 7320fda commit 22d769e

14 files changed

+240
-118
lines changed

go/base/context.go

+2-10
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,15 @@ type MigrationContext struct {
163163

164164
Hostname string
165165
AssumeMasterHostname string
166-
ApplierTimeZone string
167166
TableEngine string
168167
RowsEstimate int64
169168
RowsDeltaEstimate int64
170169
UsedRowsEstimateMethod RowsEstimateMethod
171170
HasSuperPrivilege bool
172-
OriginalBinlogFormat string
173-
OriginalBinlogRowImage string
174171
InspectorConnectionConfig *mysql.ConnectionConfig
175-
InspectorMySQLVersion string
172+
InspectorServerInfo *mysql.ServerInfo
176173
ApplierConnectionConfig *mysql.ConnectionConfig
177-
ApplierMySQLVersion string
174+
ApplierServerInfo *mysql.ServerInfo
178175
StartTime time.Time
179176
RowCopyStartTime time.Time
180177
RowCopyEndTime time.Time
@@ -357,11 +354,6 @@ func (this *MigrationContext) GetVoluntaryLockName() string {
357354
return fmt.Sprintf("%s.%s.lock", this.DatabaseName, this.OriginalTableName)
358355
}
359356

360-
// RequiresBinlogFormatChange is `true` when the original binlog format isn't `ROW`
361-
func (this *MigrationContext) RequiresBinlogFormatChange() bool {
362-
return this.OriginalBinlogFormat != "ROW"
363-
}
364-
365357
// GetApplierHostname is a safe access method to the applier hostname
366358
func (this *MigrationContext) GetApplierHostname() string {
367359
if this.ApplierConnectionConfig == nil {

go/base/utils.go

+12-27
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import (
1212
"strings"
1313
"time"
1414

15-
gosql "database/sql"
16-
1715
"github.com/github/gh-ost/go/mysql"
1816
)
1917

@@ -61,35 +59,22 @@ func StringContainsAll(s string, substrings ...string) bool {
6159
return nonEmptyStringsFound
6260
}
6361

64-
func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) (string, error) {
65-
versionQuery := `select @@global.version`
66-
var port, extraPort int
67-
var version string
68-
if err := db.QueryRow(versionQuery).Scan(&version); err != nil {
69-
return "", err
70-
}
71-
extraPortQuery := `select @@global.extra_port`
72-
if err := db.QueryRow(extraPortQuery).Scan(&extraPort); err != nil { // nolint:staticcheck
73-
// swallow this error. not all servers support extra_port
74-
}
62+
// ValidateConnection confirms the database server info matches the provided connection config.
63+
func ValidateConnection(serverInfo *mysql.ServerInfo, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) error {
7564
// AliyunRDS set users port to "NULL", replace it by gh-ost param
7665
// GCP set users port to "NULL", replace it by gh-ost param
77-
// Azure MySQL set users port to a different value by design, replace it by gh-ost para
66+
// Azure MySQL set users port to a different value by design, replace it by gh-ost param
7867
if migrationContext.AliyunRDS || migrationContext.GoogleCloudPlatform || migrationContext.AzureMySQL {
79-
port = connectionConfig.Key.Port
80-
} else {
81-
portQuery := `select @@global.port`
82-
if err := db.QueryRow(portQuery).Scan(&port); err != nil {
83-
return "", err
84-
}
68+
serverInfo.Port.Int64 = connectionConfig.Key.Port
69+
serverInfo.Port.Valid = connectionConfig.Key.Port > 0
8570
}
8671

87-
if connectionConfig.Key.Port == port || (extraPort > 0 && connectionConfig.Key.Port == extraPort) {
88-
migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
89-
return version, nil
90-
} else if extraPort == 0 {
91-
return "", fmt.Errorf("Unexpected database port reported: %+v", port)
92-
} else {
93-
return "", fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", port, extraPort)
72+
if !serverInfo.Port.Valid && !serverInfo.ExtraPort.Valid {
73+
return fmt.Errorf("Unexpected database port reported: %+v", serverInfo.Port.Int64)
74+
} else if connectionConfig.Key.Port != serverInfo.Port.Int64 && connectionConfig.Key.Port != serverInfo.ExtraPort.Int64 {
75+
return fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", serverInfo.Port.Int64, serverInfo.ExtraPort.Int64)
9476
}
77+
78+
migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
79+
return nil
9580
}

go/base/utils_test.go

+80-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
/*
2-
Copyright 2016 GitHub Inc.
2+
Copyright 2022 GitHub Inc.
33
See https://github.com/github/gh-ost/blob/master/LICENSE
44
*/
55

66
package base
77

88
import (
9+
gosql "database/sql"
910
"testing"
1011

12+
"github.com/github/gh-ost/go/mysql"
1113
"github.com/openark/golib/log"
1214
test "github.com/openark/golib/tests"
1315
)
@@ -27,3 +29,80 @@ func TestStringContainsAll(t *testing.T) {
2729
test.S(t).ExpectTrue(StringContainsAll(s, "insert", ""))
2830
test.S(t).ExpectTrue(StringContainsAll(s, "insert", "update", "delete"))
2931
}
32+
33+
func TestValidateConnection(t *testing.T) {
34+
connectionConfig := &mysql.ConnectionConfig{
35+
Key: mysql.InstanceKey{
36+
Hostname: t.Name(),
37+
Port: mysql.DefaultInstancePort,
38+
},
39+
}
40+
41+
// check valid port matching connectionConfig validates
42+
{
43+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
44+
serverInfo := &mysql.ServerInfo{
45+
Port: gosql.NullInt64{Int64: mysql.DefaultInstancePort, Valid: true},
46+
ExtraPort: gosql.NullInt64{Int64: mysql.DefaultInstancePort + 1, Valid: true},
47+
}
48+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
49+
}
50+
// check NULL port validates when AliyunRDS=true
51+
{
52+
migrationContext := &MigrationContext{
53+
Log: NewDefaultLogger(),
54+
AliyunRDS: true,
55+
}
56+
serverInfo := &mysql.ServerInfo{}
57+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
58+
}
59+
// check NULL port validates when AzureMySQL=true
60+
{
61+
migrationContext := &MigrationContext{
62+
Log: NewDefaultLogger(),
63+
AzureMySQL: true,
64+
}
65+
serverInfo := &mysql.ServerInfo{}
66+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
67+
}
68+
// check NULL port validates when GoogleCloudPlatform=true
69+
{
70+
migrationContext := &MigrationContext{
71+
Log: NewDefaultLogger(),
72+
GoogleCloudPlatform: true,
73+
}
74+
serverInfo := &mysql.ServerInfo{}
75+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
76+
}
77+
// check extra_port validates when port=NULL
78+
{
79+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
80+
serverInfo := &mysql.ServerInfo{
81+
ExtraPort: gosql.NullInt64{Int64: mysql.DefaultInstancePort, Valid: true},
82+
}
83+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
84+
}
85+
// check extra_port validates when port does not match but extra_port does
86+
{
87+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
88+
serverInfo := &mysql.ServerInfo{
89+
Port: gosql.NullInt64{Int64: 12345, Valid: true},
90+
ExtraPort: gosql.NullInt64{Int64: mysql.DefaultInstancePort, Valid: true},
91+
}
92+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
93+
}
94+
// check validation fails when valid port does not match connectionConfig
95+
{
96+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
97+
serverInfo := &mysql.ServerInfo{
98+
Port: gosql.NullInt64{Int64: 9999, Valid: true},
99+
}
100+
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
101+
}
102+
// check validation fails when port and extra_port are invalid
103+
{
104+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
105+
serverInfo := &mysql.ServerInfo{}
106+
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
107+
}
108+
}

go/cmd/gh-ost/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func main() {
4949
migrationContext := base.NewMigrationContext()
5050
flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
5151
flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unable to determine the master")
52-
flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
52+
flag.Int64Var(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
5353
flag.Float64Var(&migrationContext.InspectorConnectionConfig.Timeout, "mysql-timeout", 0.0, "Connect, read and write timeout for MySQL")
5454
flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user")
5555
flag.StringVar(&migrationContext.CliPassword, "password", "", "MySQL password")

go/logic/applier.go

+13-24
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,24 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {
7272
}
7373
}
7474

75+
func (this *Applier) ServerInfo() *mysql.ServerInfo {
76+
return this.migrationContext.ApplierServerInfo
77+
}
78+
7579
func (this *Applier) InitDBConnections() (err error) {
7680
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
7781
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
7882
return err
7983
}
84+
if this.migrationContext.ApplierServerInfo, err = mysql.GetServerInfo(this.db); err != nil {
85+
return err
86+
}
8087
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
8188
if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil {
8289
return err
8390
}
8491
this.singletonDB.SetMaxOpenConns(1)
85-
version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name)
86-
if err != nil {
87-
return err
88-
}
89-
if _, err := base.ValidateConnection(this.singletonDB, this.connectionConfig, this.migrationContext, this.name); err != nil {
90-
return err
91-
}
92-
this.migrationContext.ApplierMySQLVersion = version
93-
if err := this.validateAndReadTimeZone(); err != nil {
92+
if err = base.ValidateConnection(this.ServerInfo(), this.connectionConfig, this.migrationContext, this.name); err != nil {
9493
return err
9594
}
9695
if !this.migrationContext.AliyunRDS && !this.migrationContext.GoogleCloudPlatform && !this.migrationContext.AzureMySQL {
@@ -103,18 +102,8 @@ func (this *Applier) InitDBConnections() (err error) {
103102
if err := this.readTableColumns(); err != nil {
104103
return err
105104
}
106-
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion)
107-
return nil
108-
}
109-
110-
// validateAndReadTimeZone potentially reads server time-zone
111-
func (this *Applier) validateAndReadTimeZone() error {
112-
query := `select @@global.time_zone`
113-
if err := this.db.QueryRow(query).Scan(&this.migrationContext.ApplierTimeZone); err != nil {
114-
return err
115-
}
116-
117-
this.migrationContext.Log.Infof("will use time_zone='%s' on applier", this.migrationContext.ApplierTimeZone)
105+
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v (%+v)", this.connectionConfig.ImpliedKey,
106+
this.ServerInfo().Version, this.ServerInfo().VersionComment)
118107
return nil
119108
}
120109

@@ -239,7 +228,7 @@ func (this *Applier) CreateGhostTable() error {
239228
}
240229
defer tx.Rollback()
241230

242-
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
231+
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
243232
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
244233

245234
if _, err := tx.Exec(sessionQuery); err != nil {
@@ -280,7 +269,7 @@ func (this *Applier) AlterGhost() error {
280269
}
281270
defer tx.Rollback()
282271

283-
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
272+
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
284273
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
285274

286275
if _, err := tx.Exec(sessionQuery); err != nil {
@@ -640,7 +629,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
640629
}
641630
defer tx.Rollback()
642631

643-
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
632+
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
644633
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
645634

646635
if _, err := tx.Exec(sessionQuery); err != nil {

0 commit comments

Comments
 (0)