diff --git a/mysql/mariadb_gtid.go b/mysql/mariadb_gtid.go
index 8bfffd092..8faf0f5ef 100644
--- a/mysql/mariadb_gtid.go
+++ b/mysql/mariadb_gtid.go
@@ -103,13 +103,13 @@ func (gtid *MariadbGTID) forward(newer *MariadbGTID) error {
 
 // MariadbGTIDSet is a set of mariadb gtid
 type MariadbGTIDSet struct {
-	Sets map[uint32]*MariadbGTID
+	Sets map[uint32]map[uint32]*MariadbGTID
 }
 
 // ParseMariadbGTIDSet parses str into mariadb gtid sets
 func ParseMariadbGTIDSet(str string) (GTIDSet, error) {
 	s := new(MariadbGTIDSet)
-	s.Sets = make(map[uint32]*MariadbGTID)
+	s.Sets = make(map[uint32]map[uint32]*MariadbGTID)
 	if str == "" {
 		return s, nil
 	}
@@ -126,14 +126,17 @@ func (s *MariadbGTIDSet) AddSet(gtid *MariadbGTID) error {
 		return nil
 	}
 
-	o, ok := s.Sets[gtid.DomainID]
-	if ok {
+	if serverSets, ok := s.Sets[gtid.DomainID]; !ok {
+		s.Sets[gtid.DomainID] = map[uint32]*MariadbGTID{
+			gtid.ServerID: gtid,
+		}
+	} else if o, ok := serverSets[gtid.ServerID]; !ok {
+		serverSets[gtid.ServerID] = gtid
+	} else {
 		err := o.forward(gtid)
 		if err != nil {
 			return errors.Trace(err)
 		}
-	} else {
-		s.Sets[gtid.DomainID] = gtid
 	}
 
 	return nil
@@ -159,7 +162,9 @@ func (s *MariadbGTIDSet) Update(GTIDStr string) error {
 func (s *MariadbGTIDSet) String() string {
 	sets := make([]string, 0, len(s.Sets))
 	for _, set := range s.Sets {
-		sets = append(sets, set.String())
+		for _, gtid := range set {
+			sets = append(sets, gtid.String())
+		}
 	}
 	sort.Strings(sets)
 
@@ -170,10 +175,12 @@ func (s *MariadbGTIDSet) String() string {
 func (s *MariadbGTIDSet) Encode() []byte {
 	var buf bytes.Buffer
 	sep := ""
-	for _, gtid := range s.Sets {
-		buf.WriteString(sep)
-		buf.WriteString(gtid.String())
-		sep = ","
+	for _, set := range s.Sets {
+		for _, gtid := range set {
+			buf.WriteString(sep)
+			buf.WriteString(gtid.String())
+			sep = ","
+		}
 	}
 
 	return buf.Bytes()
@@ -182,10 +189,13 @@ func (s *MariadbGTIDSet) Encode() []byte {
 // Clone clones a mariadb gtid set
 func (s *MariadbGTIDSet) Clone() GTIDSet {
 	clone := &MariadbGTIDSet{
-		Sets: make(map[uint32]*MariadbGTID),
+		Sets: make(map[uint32]map[uint32]*MariadbGTID),
 	}
-	for domainID, gtid := range s.Sets {
-		clone.Sets[domainID] = gtid.Clone()
+	for domainID, set := range s.Sets {
+		clone.Sets[domainID] = make(map[uint32]*MariadbGTID)
+		for serverID, gtid := range set {
+			clone.Sets[domainID][serverID] = gtid.Clone()
+		}
 	}
 
 	return clone
@@ -202,14 +212,17 @@ func (s *MariadbGTIDSet) Equal(o GTIDSet) bool {
 		return false
 	}
 
-	for domainID, gtid := range other.Sets {
-		o, ok := s.Sets[domainID]
+	for domainID, set := range other.Sets {
+		serverSet, ok := s.Sets[domainID]
 		if !ok {
 			return false
 		}
-
-		if *gtid != *o {
-			return false
+		for serverID, gtid := range set {
+			if o, ok := serverSet[serverID]; !ok {
+				return false
+			} else if *gtid != *o {
+				return false
+			}
 		}
 	}
 
@@ -223,14 +236,17 @@ func (s *MariadbGTIDSet) Contain(o GTIDSet) bool {
 		return false
 	}
 
-	for doaminID, gtid := range other.Sets {
-		o, ok := s.Sets[doaminID]
+	for doaminID, set := range other.Sets {
+		serverSet, ok := s.Sets[doaminID]
 		if !ok {
 			return false
 		}
-
-		if !o.Contain(gtid) {
-			return false
+		for serverID, gtid := range set {
+			if o, ok := serverSet[serverID]; !ok {
+				return false
+			} else if !o.Contain(gtid) {
+				return false
+			}
 		}
 	}
 
diff --git a/mysql/mariadb_gtid_test.go b/mysql/mariadb_gtid_test.go
index 989c95949..0acb8ce5e 100644
--- a/mysql/mariadb_gtid_test.go
+++ b/mysql/mariadb_gtid_test.go
@@ -91,13 +91,13 @@ func TestMariaDBForward(t *testing.T) {
 func TestParseMariaDBGTIDSet(t *testing.T) {
 	cases := []struct {
 		gtidStr     string
-		subGTIDs    map[uint32]string //domain ID => gtid string
-		expectedStr []string          // test String()
+		subGTIDs    map[uint32]map[uint32]string //domain ID => gtid string
+		expectedStr []string                     // test String()
 		hasError    bool
 	}{
-		{"0-1-1", map[uint32]string{0: "0-1-1"}, []string{"0-1-1"}, false},
+		{"0-1-1", map[uint32]map[uint32]string{0: {1: "0-1-1"}}, []string{"0-1-1"}, false},
 		{"", nil, []string{""}, false},
-		{"0-1-1,1-2-3", map[uint32]string{0: "0-1-1", 1: "1-2-3"}, []string{"0-1-1,1-2-3", "1-2-3,0-1-1"}, false},
+		{"0-1-1,1-2-3", map[uint32]map[uint32]string{0: {1: "0-1-1"}, 1: {2: "1-2-3"}}, []string{"0-1-1,1-2-3", "1-2-3,0-1-1"}, false},
 		{"0-1--1", nil, nil, true},
 	}
 
@@ -112,9 +112,12 @@ func TestParseMariaDBGTIDSet(t *testing.T) {
 
 			// check sub gtid
 			require.Len(t, mariadbGTIDSet.Sets, len(cs.subGTIDs))
-			for domainID, gtid := range mariadbGTIDSet.Sets {
+			for domainID, set := range mariadbGTIDSet.Sets {
 				require.Contains(t, mariadbGTIDSet.Sets, domainID)
-				require.Equal(t, cs.subGTIDs[domainID], gtid.String())
+				for serverID, gtid := range set {
+					require.Contains(t, cs.subGTIDs, domainID)
+					require.Equal(t, cs.subGTIDs[domainID][serverID], gtid.String())
+				}
 			}
 
 			// check String() function
@@ -135,13 +138,13 @@ func TestMariaDBGTIDSetUpdate(t *testing.T) {
 	cases := []struct {
 		isNilGTID bool
 		gtidStr   string
-		subGTIDs  map[uint32]string
+		subGTIDs  map[uint32]map[uint32]string
 	}{
-		{true, "", map[uint32]string{1: "1-1-1", 2: "2-2-2"}},
-		{false, "1-2-2", map[uint32]string{1: "1-2-2", 2: "2-2-2"}},
-		{false, "1-2-1", map[uint32]string{1: "1-2-1", 2: "2-2-2"}},
-		{false, "3-2-1", map[uint32]string{1: "1-1-1", 2: "2-2-2", 3: "3-2-1"}},
-		{false, "3-2-1,4-2-1", map[uint32]string{1: "1-1-1", 2: "2-2-2", 3: "3-2-1", 4: "4-2-1"}},
+		{true, "", map[uint32]map[uint32]string{1: {1: "1-1-1"}, 2: {2: "2-2-2"}}},
+		{false, "1-2-2", map[uint32]map[uint32]string{1: {1: "1-1-1", 2: "1-2-2"}, 2: {2: "2-2-2"}}},
+		{false, "1-2-1", map[uint32]map[uint32]string{1: {1: "1-1-1", 2: "1-2-1"}, 2: {2: "2-2-2"}}},
+		{false, "3-2-1", map[uint32]map[uint32]string{1: {1: "1-1-1"}, 2: {2: "2-2-2"}, 3: {2: "3-2-1"}}},
+		{false, "3-2-1,4-2-1", map[uint32]map[uint32]string{1: {1: "1-1-1"}, 2: {2: "2-2-2"}, 3: {2: "3-2-1"}, 4: {2: "4-2-1"}}},
 	}
 
 	for _, cs := range cases {
@@ -158,9 +161,12 @@ func TestMariaDBGTIDSetUpdate(t *testing.T) {
 		}
 		// check sub gtid
 		require.Len(t, mariadbGTIDSet.Sets, len(cs.subGTIDs))
-		for domainID, gtid := range mariadbGTIDSet.Sets {
+		for domainID, set := range mariadbGTIDSet.Sets {
 			require.Contains(t, mariadbGTIDSet.Sets, domainID)
-			require.Equal(t, cs.subGTIDs[domainID], gtid.String())
+			for serverID, gtid := range set {
+				require.Contains(t, cs.subGTIDs, domainID)
+				require.Equal(t, cs.subGTIDs[domainID][serverID], gtid.String())
+			}
 		}
 	}
 }