Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
fix(snapshot) Changes the type provided to the producer to fix the sn…
Browse files Browse the repository at this point in the history
…apshot process. (#27)

This seemingly complex change is actually a bug fix.
Currently the snapshot process just does not work. Months ago we (ok I broke it) made changes to the Producer to embed a header that contains the table name in Kafka messages produced by the cdc producer. In order to do that I made the stream/Producer take a ReplicationMessage object instead of a bytes payload.
This was ok, for the normal cdc operation. But it broke the snapshot as the snapshot controller uses the consumer to send a simple bytes payload to Kafka. The test passed because it was mocking the logic too far from Kafka.

To fix that I introduced a StreamMessage to replace ReplicationMessage in the stream/producer interface.

stream message is a slightly abstract version of a message for the stream containing a payload (bytes) and a metadata mapping, which becomes the headers on Kafka. We could make it even more abstract but that would require more calls and more object instantiations (intermediate representation between stream and kafka) for every message processed. I would not add this additional loss in performance now considering how critical the message processing time is here.
Both control messages (for the snapshot) and ReplicationMessage (default cdc message) have a to_stream method that produces the StreamMessage above
the rest is just reshuffling the code to avoid circular dependencies and fix tests.
  • Loading branch information
fpacifici authored Mar 12, 2021
1 parent 4bf6dbe commit 1418595
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 181 deletions.
49 changes: 22 additions & 27 deletions cdc/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import atexit
import logging
import logging.config
import signal
from typing import Any

import click
import jsonschema # type: ignore
import logging, logging.config
import signal
import yaml
import sentry_sdk

from typing import Any
import yaml
from pkg_resources import cleanup_resources, resource_filename
from sentry_sdk.integrations.logging import LoggingIntegration

Expand Down Expand Up @@ -64,7 +65,7 @@ def main(ctx, configuration_file, log_level):
def producer(ctx):
from cdc.producer import Producer
from cdc.sources import source_factory
from cdc.streams import producer_factory
from cdc.streams.producer import producer_factory
from cdc.utils.stats import Stats

configuration = ctx.obj
Expand Down Expand Up @@ -102,13 +103,14 @@ def consumer(ctx):
)
@click.pass_context
def snapshot(ctx, snapshot_config):
from cdc.snapshots.snapshot_coordinator import SnapshotCoordinator
from cdc.snapshots.sources import registry as source_registry
from cdc.snapshots.destinations import registry as destination_registry
from cdc.snapshots.snapshot_control import SnapshotControl
from cdc.streams import producer_factory
from cdc.snapshots.snapshot_coordinator import SnapshotCoordinator
from cdc.snapshots.sources import registry as source_registry
from cdc.streams.producer import producer_factory

configuration = ctx.obj

snapshot_config = yaml.load(snapshot_config, Loader=yaml.SafeLoader)
if configuration["version"] != 1:
raise Exception("Invalid snapshot configuration file version")
Expand All @@ -118,7 +120,7 @@ def snapshot(ctx, snapshot_config):
{
"type": "object",
"properties": {
#TODO: make product more restrictive once we have a better idea on how to use it
# TODO: make product more restrictive once we have a better idea on how to use it
"product": {"type": "string"},
"destination": {"type": "object"},
"tables": {
Expand All @@ -127,22 +129,18 @@ def snapshot(ctx, snapshot_config):
"type": "object",
"properties": {
"table": {"type": "string"},
"columns": {
"type": "array",
"items": {"type": "string"},
}
"columns": {"type": "array", "items": {"type": "string"}},
},
"required": ["table"],
}
}
},
},
},
"required": ["product", "destination", "tables"],
},
)

tables_config = [
TableConfig(t['table'], t.get('columns'))
for t in snapshot_config['tables']
TableConfig(t["table"], t.get("columns")) for t in snapshot_config["tables"]
]

coordinator = SnapshotCoordinator(
Expand All @@ -165,20 +163,17 @@ def snapshot(ctx, snapshot_config):
coordinator.start_process()


@main.command(
help="Aborts a snapshot by sending the message on the control topic"
)
@click.argument(
"snapshot_id",
type=click.STRING,
)
@main.command(help="Aborts a snapshot by sending the message on the control topic")
@click.argument("snapshot_id", type=click.STRING)
@click.pass_context
def snapshot_abort(ctx, snapshot_id):
from uuid import UUID

from cdc.snapshots.snapshot_control import SnapshotControl
from cdc.streams import producer_factory

configuration = ctx.obj

if configuration["version"] != 1:
raise Exception("Invalid snapshot configuration file version")

Expand Down
9 changes: 3 additions & 6 deletions cdc/producer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import functools
import logging
import time

from datetime import datetime

from cdc.sources import Source, CdcMessage
from cdc.streams import Producer as StreamProducer

from cdc.sources import CdcMessage, Source
from cdc.streams.producer import Producer as StreamProducer
from cdc.utils.logging import LoggerAdapter
from cdc.utils.stats import Stats


logger = LoggerAdapter(logging.getLogger(__name__))


Expand Down Expand Up @@ -79,7 +76,7 @@ def run(self) -> None:
logger.trace("Trying to write message to %r...", self.producer)
try:
self.producer.write(
message.payload,
message.payload.to_stream(),
callback=functools.partial(
self.__produce_callback, message, time.time()
),
Expand Down
23 changes: 13 additions & 10 deletions cdc/snapshots/control_protocol.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from abc import ABC, abstractmethod
from cdc.types import Payload
from cdc.streams.types import StreamMessage
from dataclasses import dataclass, asdict
from typing import Any, Mapping, Sequence
import json # type: ignore

from cdc.snapshots.snapshot_types import Xid, SnapshotId, SnapshotDescriptor

from cdc.snapshots.snapshot_types import (
Xid,
SnapshotId,
SnapshotDescriptor
)

class ControlMessage(ABC):
@abstractmethod
def to_dict(self) -> Mapping[str, Any]:
raise NotImplementedError

def to_stream(self) -> StreamMessage:
json_string = json.dumps(self.to_dict())
return StreamMessage(payload=Payload(json_string.encode("utf-8")))


@dataclass(frozen=True)
class SnapshotInit(ControlMessage):
Expand All @@ -25,18 +29,17 @@ def to_dict(self) -> Mapping[str, Any]:
"event": "snapshot-init",
"tables": self.tables,
"snapshot-id": self.snapshot_id,
"product": self.product
"product": self.product,
}


@dataclass(frozen=True)
class SnapshotAbort(ControlMessage):
snapshot_id: SnapshotId

def to_dict(self) -> Mapping[str, Any]:
return {
"event": "snapshot-abort",
"snapshot-id": self.snapshot_id,
}
return {"event": "snapshot-abort", "snapshot-id": self.snapshot_id}


@dataclass(frozen=True)
class SnapshotLoaded(ControlMessage):
Expand Down
9 changes: 3 additions & 6 deletions cdc/snapshots/snapshot_control.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import json # type: ignore
import jsonschema # type: ignore
import logging

from functools import partial
from typing import Optional, Sequence
from uuid import UUID

import jsonschema # type: ignore
from cdc.snapshots.control_protocol import ControlMessage, SnapshotAbort, SnapshotInit
from cdc.snapshots.snapshot_types import SnapshotId
from cdc.sources.types import Payload
from cdc.streams import Producer as StreamProducer
from cdc.streams.producer import Producer as StreamProducer
from cdc.utils.logging import LoggerAdapter
from cdc.utils.registry import Configuration

Expand Down Expand Up @@ -59,10 +58,8 @@ def __msg_sent(self, msg: ControlMessage) -> None:
logger.debug("Message sent %r", msg)

def __write_msg(self, message: ControlMessage) -> None:
json_string = json.dumps(message.to_dict())
self.__producer.write(
payload=Payload(json_string.encode("utf-8")),
callback=partial(self.__msg_sent, message),
message.to_stream(), callback=partial(self.__msg_sent, message)
)

def init_snapshot(
Expand Down
25 changes: 11 additions & 14 deletions cdc/snapshots/snapshot_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from cdc.snapshots.sources import SnapshotSource
from cdc.snapshots.snapshot_control import SnapshotControl
from cdc.snapshots.snapshot_types import SnapshotId, TableConfig
from cdc.streams import Producer as StreamProducer
from cdc.streams.producer import Producer as StreamProducer
from cdc.utils.logging import LoggerAdapter
from cdc.utils.registry import Configuration

logger = LoggerAdapter(logging.getLogger(__name__))


class SnapshotCoordinator(ABC):
"""
Coordinates the process of taking a snapshot from the source database
Expand All @@ -26,38 +27,34 @@ class SnapshotCoordinator(ABC):
- communicate the details of the snapshot to all the listeners TODO
"""

def __init__(self,
def __init__(
self,
source: SnapshotSource,
destination: DestinationContext,
control: SnapshotControl,
product: str,
tables: Sequence[TableConfig]) -> None:
tables: Sequence[TableConfig],
) -> None:
self.__source = source
self.__destination = destination
self.__product = product
self.__tables = tables
self.__control = control


def start_process(self) -> None:
logger.debug("Starting snapshot process for product %s", self.__product)
snapshot_id = uuid.uuid1()
logger.info("Starting snapshot ID %s", snapshot_id)
table_names = [t.table for t in self.__tables]
self.__control.init_snapshot(
snapshot_id=snapshot_id,
tables=table_names,
product=self.__product,
snapshot_id=snapshot_id, tables=table_names, product=self.__product
)
with self.__destination.open(
SnapshotId(str(snapshot_id)),
self.__product) as snapshot_out:
SnapshotId(str(snapshot_id)), self.__product
) as snapshot_out:

logger.info("Snapshot ouput: %s", snapshot_out.get_name())
snapshot_desc = self.__source.dump(
snapshot_out,
self.__tables,
)
snapshot_desc = self.__source.dump(snapshot_out, self.__tables)
logger.info("Snapshot taken: %r", snapshot_desc)

self.__control.wait_messages_sent()
4 changes: 1 addition & 3 deletions cdc/sources/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def poll(self, timeout: float) -> None:

@abstractmethod
def commit_positions(
self,
write_position: Optional[Position],
flush_position: Optional[Position],
self, write_position: Optional[Position], flush_position: Optional[Position]
) -> None:
raise NotImplementedError

Expand Down
10 changes: 8 additions & 2 deletions cdc/sources/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from abc import ABC
from dataclasses import dataclass
from typing import NamedTuple, NewType

from cdc.types import Payload
from cdc.streams.types import StreamMessage

Id = NewType("Id", int)
Position = NewType("Position", int)
Payload = NewType("Payload", bytes)


class CdcMessage(NamedTuple):
Expand Down Expand Up @@ -40,6 +40,9 @@ class ReplicationEvent(ABC):
# any source specific processing.
payload: Payload

def to_stream(self) -> StreamMessage:
return StreamMessage(payload=self.payload)


@dataclass(frozen=True)
class BeginMessage(ReplicationEvent):
Expand All @@ -59,6 +62,9 @@ class ChangeMessage(ReplicationEvent):

table: str

def to_stream(self) -> StreamMessage:
return StreamMessage(payload=self.payload, metadata={"table": self.table})


@dataclass(frozen=True)
class GenericMessage(ReplicationEvent):
Expand Down
Loading

0 comments on commit 1418595

Please sign in to comment.