diff --git a/replication/event.go b/replication/event.go index dacd49b02..18dbf6958 100644 --- a/replication/event.go +++ b/replication/event.go @@ -12,6 +12,7 @@ import ( "unicode" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/serialization" "github.com/google/uuid" "github.com/pingcap/errors" ) @@ -420,6 +421,7 @@ func (e *QueryEvent) Dump(w io.Writer) { type GTIDEvent struct { CommitFlag uint8 SID []byte + Tag string GNO int64 LastCommitted int64 SequenceNumber int64 @@ -512,7 +514,11 @@ func (e *GTIDEvent) Dump(w io.Writer) { fmt.Fprintf(w, "Commit flag: %d\n", e.CommitFlag) u, _ := uuid.FromBytes(e.SID) - fmt.Fprintf(w, "GTID_NEXT: %s:%d\n", u.String(), e.GNO) + if e.Tag != "" { + fmt.Fprintf(w, "GTID_NEXT: %s:%s:%d\n", u.String(), e.Tag, e.GNO) + } else { + fmt.Fprintf(w, "GTID_NEXT: %s:%d\n", u.String(), e.GNO) + } fmt.Fprintf(w, "LAST_COMMITTED: %d\n", e.LastCommitted) fmt.Fprintf(w, "SEQUENCE_NUMBER: %d\n", e.SequenceNumber) fmt.Fprintf(w, "Immediate commmit timestamp: %d (%s)\n", e.ImmediateCommitTimestamp, fmtTime(e.ImmediateCommitTime())) @@ -543,6 +549,202 @@ func (e *GTIDEvent) OriginalCommitTime() time.Time { return microSecTimestampToTime(e.OriginalCommitTimestamp) } +// GtidTaggedLogEvent is for a GTID event with a tag. +// This is similar to GTIDEvent, but it has a tag and uses a different serialization format. +type GtidTaggedLogEvent struct { + GTIDEvent +} + +func (e *GtidTaggedLogEvent) Decode(data []byte) error { + msg := serialization.Message{ + Format: serialization.Format{ + Fields: []serialization.Field{ + { + Name: "gtid_flags", + Type: &serialization.FieldIntFixed{ + Length: 1, + }, + }, + { + Name: "uuid", + Type: &serialization.FieldIntFixed{ + Length: 16, + }, + }, + { + Name: "gno", + Type: &serialization.FieldIntVar{}, + }, + { + Name: "tag", + Type: &serialization.FieldString{}, + }, + { + Name: "last_committed", + Type: &serialization.FieldIntVar{}, + }, + { + Name: "sequence_number", + Type: &serialization.FieldIntVar{}, + }, + { + Name: "immediate_commit_timestamp", + Type: &serialization.FieldUintVar{}, + }, + { + Name: "original_commit_timestamp", + Type: &serialization.FieldUintVar{}, + Optional: true, + }, + { + Name: "transaction_length", + Type: &serialization.FieldUintVar{}, + }, + { + Name: "immediate_server_version", + Type: &serialization.FieldUintVar{}, + }, + { + Name: "original_server_version", + Type: &serialization.FieldUintVar{}, + Optional: true, + }, + { + Name: "commit_group_ticket", + Optional: true, + }, + }, + }, + } + + err := serialization.Unmarshal(data, &msg) + if err != nil { + return err + } + + f, err := msg.GetFieldByName("gtid_flags") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldIntFixed); ok { + e.CommitFlag = v.Value[0] + } else { + return errors.New("failed to get gtid_flags field") + } + + f, err = msg.GetFieldByName("uuid") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldIntFixed); ok { + e.SID = v.Value + } else { + return errors.New("failed to get uuid field") + } + + f, err = msg.GetFieldByName("gno") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldIntVar); ok { + e.GNO = v.Value + } else { + return errors.New("failed to get gno field") + } + + f, err = msg.GetFieldByName("tag") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldString); ok { + e.Tag = v.Value + } else { + return errors.New("failed to get tag field") + } + + f, err = msg.GetFieldByName("last_committed") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldIntVar); ok { + e.LastCommitted = v.Value + } else { + return errors.New("failed to get last_committed field") + } + + f, err = msg.GetFieldByName("sequence_number") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldIntVar); ok { + e.SequenceNumber = v.Value + } else { + return errors.New("failed to get sequence_number field") + } + + f, err = msg.GetFieldByName("immediate_commit_timestamp") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldUintVar); ok { + e.ImmediateCommitTimestamp = v.Value + } else { + return errors.New("failed to get immediate_commit_timestamp field") + } + + f, err = msg.GetFieldByName("original_commit_timestamp") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldUintVar); ok { + if f.Skipped { + e.OriginalCommitTimestamp = e.ImmediateCommitTimestamp + } else { + e.OriginalCommitTimestamp = v.Value + } + } else { + return errors.New("failed to get original_commit_timestamp field") + } + + f, err = msg.GetFieldByName("immediate_server_version") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldUintVar); ok { + e.ImmediateServerVersion = uint32(v.Value) + } else { + return errors.New("failed to get immediate_server_version field") + } + + f, err = msg.GetFieldByName("original_server_version") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldUintVar); ok { + if f.Skipped { + e.OriginalServerVersion = e.ImmediateServerVersion + } else { + e.OriginalServerVersion = uint32(v.Value) + } + } else { + return errors.New("failed to get original_server_version field") + } + + f, err = msg.GetFieldByName("transaction_length") + if err != nil { + return err + } + if v, ok := f.Type.(*serialization.FieldUintVar); ok { + e.TransactionLength = v.Value + } else { + return errors.New("failed to get transaction_length field") + } + + // TODO: add and test commit_group_ticket + + return nil +} + type BeginLoadQueryEvent struct { FileID uint32 BlockData []byte diff --git a/replication/parser.go b/replication/parser.go index 2d5908e78..eedc205ae 100644 --- a/replication/parser.go +++ b/replication/parser.go @@ -291,6 +291,8 @@ func (p *BinlogParser) parseEvent(h *EventHeader, data []byte, rawData []byte) ( e = >IDEvent{} case ANONYMOUS_GTID_EVENT: e = >IDEvent{} + case GTID_TAGGED_LOG_EVENT: + e = &GtidTaggedLogEvent{} case BEGIN_LOAD_QUERY_EVENT: e = &BeginLoadQueryEvent{} case EXECUTE_LOAD_QUERY_EVENT: diff --git a/serialization/serialization.go b/serialization/serialization.go new file mode 100644 index 000000000..59fb38865 --- /dev/null +++ b/serialization/serialization.go @@ -0,0 +1,265 @@ +// Package serialization is for working with the mysql::serialization format +// +// mysql::serialization is a serialization format introduced with tagged GTIDs +// +// https://dev.mysql.com/doc/dev/mysql-server/latest/PageLibsMysqlSerialization.html +package serialization + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math/bits" + "strings" +) + +// Message is a mysql::serialization message +type Message struct { + Version uint8 // >= 0 + Format Format + fieldIndex map[string]uint8 +} + +func (m *Message) String() string { + parts := []string{fmt.Sprintf("Message (version: %d)", m.Version)} + parts = append(parts, m.Format.stringParts()...) + return strings.Join(parts, "\n ") +} + +// GetFieldByName returns a field if the name matches and an error if there is no match +func (m *Message) GetFieldByName(name string) (Field, error) { + if idx, ok := m.fieldIndex[name]; ok { + return m.Format.Fields[idx], nil + } + return Field{}, fmt.Errorf("field not found: %s", name) +} + +// Format is describing a `message_format` +type Format struct { + Size uint8 + LastNonIgnorableField uint8 + Fields []Field +} + +func (f *Format) String() (text string) { + return strings.Join(f.stringParts(), "\n") +} + +func (f *Format) stringParts() (parts []string) { + parts = make([]string, 0, len(f.Fields)*2+1) + parts = append(parts, fmt.Sprintf("Format (Size: %d, LastNonIgnorableField: %d)", + f.Size, f.LastNonIgnorableField)) + + for _, f := range f.Fields { + parts = append(parts, fmt.Sprintf("Field %02d (Name: %s, Skipped: %t, Type: %T)", + f.ID, f.Name, f.Skipped, f.Type)) + if f.Type != nil { + parts = append(parts, fmt.Sprintf(" Value: %s", f.Type.String())) + } + } + return +} + +// Field represents a `message_field` +type Field struct { + ID uint8 + Type FieldType + Optional bool + Name string + Skipped bool +} + +// FieldType represents a `type_field` +type FieldType interface { + fmt.Stringer + decode(data []byte, pos uint64) (uint64, error) +} + +// FieldIntFixed is for values with a fixed length. +// This is also known as the 'fixlen_integer_format'. +// The encoded value can vary be between 1 and 2 times +// of that of the value before encoding. +type FieldIntFixed struct { + Length int // Length of value before encoding, encoded value can be more + Value []byte +} + +func (f FieldIntFixed) String() string { + if f.Value == nil { + return "" + } + return fmt.Sprintf("0x%x", f.Value) +} + +func (f *FieldIntFixed) decode(data []byte, pos uint64) (uint64, error) { + var b bytes.Buffer + b.Grow(f.Length * 2) // output is between 1 and 2 times that of the input + + for { + if len(data) < int(pos)+1 { + return pos, errors.New("data truncated") + } + if data[pos]%2 == 0 { + b.WriteByte(data[pos] >> 1) + } else { + if len(data) < int(pos)+2 { + return pos, errors.New("data truncated") + } + switch data[pos+1] { + case 0x2: + b.WriteByte((data[pos] >> 2) + 0x80) + case 0x3: + b.WriteByte((data[pos] >> 2) + 0xc0) + default: + return pos, fmt.Errorf("unknown decoding for %v", data[pos]) + } + pos++ + } + pos++ + if b.Len() == f.Length { + break + } + } + f.Value = b.Bytes() + return pos, nil +} + +// FieldIntVar is using the signed integer variant of the 'varlen_integer_format' +// and encodes a value as a byte sequence of 1-9 bytes depending on the value. +type FieldIntVar struct { + Value int64 +} + +func (f FieldIntVar) String() string { + return fmt.Sprintf("%d", f.Value) +} + +func (f *FieldIntVar) decode(data []byte, pos uint64) (uint64, error) { + var val interface{} + val, pos, err := decodeVar(data, pos, false) + if err != nil { + return pos, err + } + if intval, ok := val.(int64); ok { + f.Value = intval + } else { + return pos, errors.New("unexpected type, expecting int64") + } + return pos, nil +} + +// FieldUintVar is using the unsigned integer variant of the 'varlen_integer_format' +// and encodes a value as a byte sequence of 1-9 bytes depending on the value. +type FieldUintVar struct { + Value uint64 +} + +func (f FieldUintVar) String() string { + return fmt.Sprintf("%d", f.Value) +} + +func (f *FieldUintVar) decode(data []byte, pos uint64) (uint64, error) { + var val interface{} + val, pos, err := decodeVar(data, pos, true) + if err != nil { + return pos, err + } + if uintval, ok := val.(uint64); ok { + f.Value = uintval + } else { + return pos, errors.New("unexpected type, expecting uint64") + } + return pos, nil +} + +// FieldString is a 'string_format' field +type FieldString struct { + Value string +} + +func (f *FieldString) decode(data []byte, pos uint64) (uint64, error) { + if len(data) < int(pos)+1 { + return pos, errors.New("string truncated, expected at least one byte") + } + strLen := int(data[pos] >> 1) + pos++ + if len(data) < int(pos)+strLen { + return pos, fmt.Errorf("string truncated, expected length: %d", strLen) + } + f.Value = string(data[pos : pos+uint64(strLen)]) + return pos + uint64(strLen), nil +} + +func (f FieldString) String() string { + return f.Value +} + +func Unmarshal(data []byte, v interface{}) error { + switch m := v.(type) { + case *Message: + m.Version = data[0] >> 1 + err := Unmarshal(data[1:], &m.Format) + if err != nil { + return err + } + if m.fieldIndex == nil { + m.fieldIndex = make(map[string]uint8, len(m.Format.Fields)) + } + for _, field := range m.Format.Fields { + m.fieldIndex[field.Name] = field.ID + } + case *Format: + pos := uint64(0) + m.Size = data[pos] >> 1 + pos++ + m.LastNonIgnorableField = data[pos] >> 1 + pos++ + + for i := 0; i < len(m.Fields); i++ { + if int(pos)+1 > len(data) || int(data[pos]>>1) != i { + // The field number we got doesn't match what we expect, + // so a field was skipped. + m.Fields[i].ID = uint8(i) + m.Fields[i].Skipped = true + continue + } + m.Fields[i].ID = data[pos] >> 1 + pos++ + n, err := m.Fields[i].Type.decode(data, pos) + if err != nil { + return err + } + pos = n + } + + default: + return fmt.Errorf("unsupported type: %T", v) + } + return nil +} + +func decodeVar(data []byte, pos uint64, unsigned bool) (interface{}, uint64, error) { + if len(data) < int(pos)+1 { + return 0, pos, errors.New("data truncated") + } + flen := trailingOneBitCount(data[pos]) + 1 + if len(data) < int(pos)+flen { + return 0, pos, fmt.Errorf("truncated data, expected length: %d", flen) + } + var tNumBytes [8]byte + copy(tNumBytes[:], data[pos:int(pos)+flen]) + tNum := binary.LittleEndian.Uint64(tNumBytes[:]) + pos += uint64(flen) + if unsigned { + return tNum >> flen, pos, nil + } + if positive := (tNum>>flen)&1 == 0; positive { + return int64(tNum >> (flen + 1)), pos, nil + } + return int64(-(1 + (tNum >> (flen + 1)))), pos, nil +} + +func trailingOneBitCount(b byte) int { + return bits.TrailingZeros8(^b) +} diff --git a/serialization/serialization_test.go b/serialization/serialization_test.go new file mode 100644 index 000000000..d8e5a7015 --- /dev/null +++ b/serialization/serialization_test.go @@ -0,0 +1,417 @@ +package serialization + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTrailingOneBitCount(t *testing.T) { + testcases := []struct { + input byte + result int + }{ + {0b00000000, 0}, + {0b00000001, 1}, + {0b00000011, 2}, + {0b00000111, 3}, + {0b00001111, 4}, + {0b00011111, 5}, + {0b00111111, 6}, + {0b01111111, 7}, + {0b11111111, 8}, + {0b10000000, 0}, + {0b11111101, 1}, + } + + for _, tc := range testcases { + actual := trailingOneBitCount(tc.input) + require.Equal(t, tc.result, actual) + } +} + +func TestDecodeFixed(t *testing.T) { + testcases := []struct { + input []byte + len int + result []byte + err string + }{ + { + []byte{ + 0xee, 0x81, 0x02, 0xc1, 0x02, 0x01, 0x03, 0x41, 0x03, 0x81, 0x03, 0xc1, 0x03, 0xc5, 0x03, 0x22, + 0x22, 0xee, 0xfd, 0x03, 0xee, 0xfd, 0x03, 0xee, 0xfd, 0x03, + }, + 16, + []byte{0x77, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, 0xf1, 0x11, 0x11, 0x77, 0xff, 0x77, 0xff, 0x77, 0xff}, + "", + }, + { + []byte{0xee, 0x81}, + 16, + []byte{}, + "data truncated", + }, + { + []byte{}, + 16, + []byte{}, + "data truncated", + }, + { + []byte{ + 0xee, 0x81, 0x04, 0xc1, 0x02, 0x01, 0x03, 0x41, 0x03, 0x81, 0x03, 0xc1, 0x03, 0xc5, 0x03, 0x22, + 0x22, 0xee, 0xfd, 0x03, 0xee, 0xfd, 0x03, 0xee, 0xfd, 0x03, + }, + 16, + []byte{}, + "unknown decoding for", + }, + } + + for _, tc := range testcases { + f := FieldIntFixed{ + Length: tc.len, + } + _, err := f.decode(tc.input, 0) + if tc.err == "" { + require.NoError(t, err) + require.Equal(t, tc.result, f.Value) + require.Equal(t, tc.len, len(f.Value)) + } else { + require.ErrorContains(t, err, tc.err) + } + } +} + +func TestDecodeString(t *testing.T) { + testcases := []struct { + input []byte + result string + err string + }{ + { + []byte{0x18, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c}, + "abcdefghijkl", + "", + }, + { + []byte{0x18, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67}, + "", + "string truncated", + }, + { + []byte{}, + "", + "string truncated, expected at least one byte", + }, + { + []byte{0x18}, + "", + "string truncated, expected length", + }, + } + + for _, tc := range testcases { + f := FieldString{} + _, err := f.decode(tc.input, 0) + if tc.err == "" { + require.NoError(t, err) + require.Equal(t, tc.result, f.Value) + } else { + require.ErrorContains(t, err, tc.err) + } + } +} + +func TestDecodeVar(t *testing.T) { + testcases := []struct { + input []byte + unsigned bool + result interface{} + err string + }{ + { + []byte{}, + false, + 0, + "data truncated", + }, + { + []byte{0xd9}, + false, + 0, + "truncated data", + }, + { + []byte{0x4}, + false, + int64(1), + "", + }, + { + []byte{0xd9, 0x03}, + false, + int64(123), + "", + }, + { + []byte{0xc3, 0o2, 0x0b}, + true, + uint64(90200), + "", + }, + { + // From the example on https://dev.mysql.com/doc/dev/mysql-server/latest/PageLibsMysqlhtml + // But converted to LE + // unsigned integer, 65535 + []byte{0b11111011, 0b11111111, 0b00000111}, + true, + uint64(65535), + "", + }, + { + // From the example on https://dev.mysql.com/doc/dev/mysql-server/latest/PageLibsMysqlhtml + // But converted to LE + // signed integer, 65535 + []byte{0b11110011, 0b11111111, 0b00001111}, + false, + int64(65535), + "", + }, + { + // From the example on https://dev.mysql.com/doc/dev/mysql-server/latest/PageLibsMysqlhtml + // But converted to LE + // signed integer, -65535 + []byte{0b11101011, 0b11111111, 0b00001111}, + false, + int64(-65535), + "", + }, + { + // From the example on https://dev.mysql.com/doc/dev/mysql-server/latest/PageLibsMysqlhtml + // But converted to LE + // signed integer, 65536 + []byte{0b11111011, 0b11111111, 0b00001111}, + false, + int64(-65536), + "", + }, + { + []byte{0x5d, 0x03}, + true, + uint64(215), + "", + }, + { + []byte{0x7f, 0x39, 0x7d, 0x89, 0x70, 0xdb, 0x2d, 0x06}, + true, + uint64(1739270369410361), + "", + }, + } + + for _, tc := range testcases { + r, _, err := decodeVar(tc.input, 0, tc.unsigned) + if tc.err == "" { + require.NoError(t, err) + require.Equal(t, tc.result, r, tc.result) + } else { + require.ErrorContains(t, err, tc.err) + } + } +} + +func TestUmarshal_event1(t *testing.T) { + data := []byte{ + 0x2, 0x76, 0x0, 0x0, 0x2, 0x2, 0x25, 0x2, 0xdc, 0xf0, 0x9, 0x2, 0x30, 0xf9, 0x3, 0x22, 0xbd, 0x3, + 0xad, 0x2, 0x21, 0x2, 0x44, 0x44, 0x5a, 0x68, 0x51, 0x3, 0x22, 0x4, 0x4, 0x6, 0xc, 0x66, 0x6f, 0x6f, 0x62, + 0x61, 0x7a, 0x8, 0x0, 0xa, 0x4, 0xc, 0x7f, 0x15, 0x83, 0x22, 0x2d, 0x5c, 0x2e, 0x6, 0x10, 0x49, 0x3, 0x12, + 0xc3, 0x2, 0xb, + } + + msg := Message{ + Format: Format{ + Fields: []Field{ + { + Name: "gtid_flags", + Type: &FieldIntFixed{ + Length: 1, + }, + }, + { + Name: "uuid", + Type: &FieldIntFixed{ + Length: 16, + }, + }, + { + Name: "gno", + Type: &FieldIntVar{}, + }, + { + Name: "tag", + Type: &FieldString{}, + }, + { + Name: "last_committed", + Type: &FieldIntVar{}, + }, + { + Name: "sequence_number", + Type: &FieldIntVar{}, + }, + { + Name: "immediate_commit_timestamp", + Type: &FieldUintVar{}, + }, + { + Name: "original_commit_timestamp", + Type: &FieldUintVar{}, + Optional: true, + }, + { + Name: "transaction_length", + Type: &FieldUintVar{}, + }, + { + Name: "immediate_server_version", + Type: &FieldUintVar{}, + }, + { + Name: "original_server_version", + Type: &FieldUintVar{}, + Optional: true, + }, + { + Name: "commit_group_ticket", + Optional: true, + }, + }, + }, + } + + expected := Message{ + Version: 1, + Format: Format{ + Size: 59, + Fields: []Field{ + { + Name: "gtid_flags", + ID: 0, + Type: &FieldIntFixed{ + Length: 1, + Value: []uint8{0o1}, + }, + }, + { + Name: "uuid", + ID: 1, + Type: &FieldIntFixed{ + Length: 16, + Value: []uint8{ + 0x89, 0x6e, 0x78, 0x82, 0x18, 0xfe, 0x11, 0xef, 0xab, + 0x88, 0x22, 0x22, 0x2d, 0x34, 0xd4, 0x11, + }, + }, + }, + { + Name: "gno", + ID: 2, + Type: &FieldIntVar{ + Value: 1, + }, + }, + { + Name: "tag", + ID: 3, + Type: &FieldString{ + Value: "foobaz", + }, + }, + { + Name: "last_committed", + ID: 4, + Type: &FieldIntVar{ + Value: 0, + }, + }, + { + Name: "sequence_number", + ID: 5, + Type: &FieldIntVar{ + Value: 1, + }, + }, + { + Name: "immediate_commit_timestamp", + ID: 6, + Type: &FieldUintVar{ + Value: 1739823289369365, + }, + }, + { + Name: "original_commit_timestamp", + ID: 7, + Type: &FieldUintVar{}, + Optional: true, + Skipped: true, + }, + { + Name: "transaction_length", + ID: 8, + Type: &FieldUintVar{ + Value: 210, + }, + }, + { + Name: "immediate_server_version", + ID: 9, + Type: &FieldUintVar{ + Value: 90200, + }, + }, + { + Name: "original_server_version", + ID: 10, + Type: &FieldUintVar{}, + Optional: true, + Skipped: true, + }, + { + Name: "commit_group_ticket", + ID: 11, + Optional: true, + Skipped: true, + }, + }, + }, + fieldIndex: map[string]uint8{ + "gtid_flags": 0, + "uuid": 1, + "gno": 2, + "tag": 3, + "last_committed": 4, + "sequence_number": 5, + "immediate_commit_timestamp": 6, + "original_commit_timestamp": 7, + "transaction_length": 8, + "immediate_server_version": 9, + "original_server_version": 10, + "commit_group_ticket": 11, + }, + } + + err := Unmarshal(data, &msg) + require.NoError(t, err) + + for i, f := range msg.Format.Fields { + require.Equal(t, expected.Format.Fields[i], f) + } + + require.Equal(t, expected, msg) + + sv, err := msg.GetFieldByName("immediate_server_version") + require.NoError(t, err) + require.Equal(t, uint8(9), sv.ID) +}