Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: waku sync shard matching check #3259

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/waku_store_sync/sync_utils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ proc newTestWakuRecon*(
idsRx: AsyncQueue[SyncID],
wantsTx: AsyncQueue[(PeerId, Fingerprint)],
needsTx: AsyncQueue[(PeerId, Fingerprint)],
cluster: uint16 = 1,
shards: seq[uint16] = @[0, 1, 2, 3, 4, 5, 6, 7],
): Future[SyncReconciliation] {.async.} =
let peerManager = PeerManager.new(switch)

let res = await SyncReconciliation.new(
cluster = cluster,
shards = shards,
peerManager = peerManager,
wakuArchive = nil,
relayJitter = 0.seconds,
Expand Down
34 changes: 34 additions & 0 deletions tests/waku_store_sync/test_protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,40 @@ suite "Waku Sync: reconciliation":
localWants.contains((clientPeerInfo.peerId, hash3)) == true
localWants.contains((serverPeerInfo.peerId, hash2)) == true

asyncTest "sync 2 nodes different shards":
let
msg1 = fakeWakuMessage(ts = now(), contentTopic = DefaultContentTopic)
msg2 = fakeWakuMessage(ts = now() + 1, contentTopic = DefaultContentTopic)
msg3 = fakeWakuMessage(ts = now() + 2, contentTopic = DefaultContentTopic)
hash1 = computeMessageHash(DefaultPubsubTopic, msg1)
hash2 = computeMessageHash(DefaultPubsubTopic, msg2)
hash3 = computeMessageHash(DefaultPubsubTopic, msg3)

server.messageIngress(hash1, msg1)
server.messageIngress(hash2, msg2)
client.messageIngress(hash1, msg1)
client.messageIngress(hash3, msg3)

check:
remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false
remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false
localWants.contains((clientPeerInfo.peerId, hash3)) == false
localWants.contains((serverPeerInfo.peerId, hash2)) == false

server = await newTestWakuRecon(
serverSwitch, idsChannel, localWants, remoteNeeds, shards = @[0.uint16, 1, 2, 3]
)
client = await newTestWakuRecon(
clientSwitch, idsChannel, localWants, remoteNeeds, shards = @[4.uint16, 5, 6, 7]
)

var syncRes = await client.storeSynchronization(some(serverPeerInfo))
assert syncRes.isOk(), $syncRes.error

check:
remoteNeeds.len == 0
localWants.len == 0

asyncTest "sync 2 nodes same hashes":
let
msg1 = fakeWakuMessage(ts = now(), contentTopic = DefaultContentTopic)
Expand Down
12 changes: 11 additions & 1 deletion waku/node/waku_node.nim
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,19 @@ proc mountStoreSync*(
let wantsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](100)
let needsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](100)

var cluster: uint16
var shards: seq[uint16]
let enrRes = node.enr.toTyped()
if enrRes.isOk():
let shardingRes = enrRes.get().relaySharding()
if shardingRes.isSome():
let relayShard = shardingRes.get()
cluster = relayShard.clusterID
shards = relayShard.shardIds

let recon =
?await SyncReconciliation.new(
node.peerManager, node.wakuArchive, storeSyncRange.seconds,
cluster, shards, node.peerManager, node.wakuArchive, storeSyncRange.seconds,
storeSyncInterval.seconds, storeSyncRelayJitter.seconds, idsChannel, wantsChannel,
needsChannel,
)
Expand Down
49 changes: 48 additions & 1 deletion waku/waku_store_sync/codec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ proc deltaEncode*(value: RangesData): seq[byte] =
i = 0
j = 0

# encode cluster
buf = uint64(value.cluster).toBytes(Leb128)
output &= @buf

# encode shards
buf = uint64(value.shards.len).toBytes(Leb128)
output &= @buf

for shard in value.shards:
buf = uint64(shard).toBytes(Leb128)
output &= @buf

# the first range is implicit but must be explicit when encoded
let (bound, _) = value.ranges[0]

Expand Down Expand Up @@ -209,6 +221,38 @@ proc getReconciled(idx: var int, buffer: seq[byte]): Result[bool, string] =

return ok(recon)

proc getCluster(idx: var int, buffer: seq[byte]): Result[uint16, string] =
if idx + VarIntLen > buffer.len:
return err("Cannot decode cluster")

let slice = buffer[idx ..< idx + VarIntLen]
let (val, len) = uint64.fromBytes(slice, Leb128)
idx += len

return ok(uint16(val))

proc getShards(idx: var int, buffer: seq[byte]): Result[seq[uint16], string] =
if idx + VarIntLen > buffer.len:
return err("Cannot decode shards count")

let slice = buffer[idx ..< idx + VarIntLen]
let (val, len) = uint64.fromBytes(slice, Leb128)
idx += len
let shardsLen = val

var shards: seq[uint16]
for i in 0 ..< shardsLen:
if idx + VarIntLen > buffer.len:
return err("Cannot decode shard value. idx: " & $i)

let slice = buffer[idx ..< idx + VarIntLen]
let (val, len) = uint64.fromBytes(slice, Leb128)
idx += len

shards.add(uint16(val))

return ok(shards)

proc deltaDecode*(
itemSet: var ItemSet, buffer: seq[byte], setLength: int
): Result[int, string] =
Expand Down Expand Up @@ -242,14 +286,17 @@ proc getItemSet(
return ok(itemSet)

proc deltaDecode*(T: type RangesData, buffer: seq[byte]): Result[T, string] =
if buffer.len == 1:
if buffer.len <= 1:
return ok(RangesData())

var
payload = RangesData()
lastTime = Timestamp(0)
idx = 0

payload.cluster = ?getCluster(idx, buffer)
payload.shards = ?getShards(idx, buffer)

lastTime = ?getTimestamp(idx, buffer)

# implicit first hash is always 0
Expand Down
3 changes: 3 additions & 0 deletions waku/waku_store_sync/common.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ type
ItemSet = 2

RangesData* = object
cluster*: uint16
shards*: seq[uint16]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd need the network ID too to fully qualify the shard.


ranges*: seq[(Slice[SyncID], RangeType)]
fingerprints*: seq[Fingerprint] # Range type fingerprint stored here in order
itemSets*: seq[ItemSet] # Range type itemset stored here in order
Expand Down
31 changes: 24 additions & 7 deletions waku/waku_store_sync/reconciliation.nim
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{.push raises: [].}

import
std/[sequtils, options],
std/[sequtils, options, packedsets],
stew/byteutils,
results,
chronicles,
Expand Down Expand Up @@ -37,6 +37,9 @@ logScope:
const DefaultStorageCap = 50_000

type SyncReconciliation* = ref object of LPProtocol
cluster: uint16
shards: PackedSet[uint16]

peerManager: PeerManager

wakuArchive: WakuArchive
Expand Down Expand Up @@ -114,16 +117,24 @@ proc processRequest(
var
hashToRecv: seq[WakuMessageHash]
hashToSend: seq[WakuMessageHash]
sendPayload: RangesData
rawPayload: seq[byte]

# Only process the ranges IF the shards and cluster matches
if self.cluster == recvPayload.cluster and
recvPayload.shards.toPackedSet() == self.shards:
sendPayload = self.storage.processPayload(recvPayload, hashToSend, hashToRecv)

let sendPayload = self.storage.processPayload(recvPayload, hashToSend, hashToRecv)
sendPayload.cluster = self.cluster
sendPayload.shards = self.shards.toSeq()

for hash in hashToSend:
await self.remoteNeedsTx.addLast((conn.peerId, hash))
for hash in hashToSend:
await self.remoteNeedsTx.addLast((conn.peerId, hash))

for hash in hashToRecv:
await self.localWantstx.addLast((conn.peerId, hash))
for hash in hashToRecv:
await self.localWantstx.addLast((conn.peerId, hash))

let rawPayload = sendPayload.deltaEncode()
rawPayload = sendPayload.deltaEncode()

total_bytes_exchanged.observe(
rawPayload.len, labelValues = [Reconciliation, Sending]
Expand Down Expand Up @@ -162,6 +173,8 @@ proc initiate(

fingerprint = self.storage.computeFingerprint(bounds)
initPayload = RangesData(
cluster: self.cluster,
shards: self.shards.toSeq(),
ranges: @[(bounds, RangeType.Fingerprint)],
fingerprints: @[fingerprint],
itemSets: @[],
Expand Down Expand Up @@ -261,6 +274,8 @@ proc initFillStorage(

proc new*(
T: type SyncReconciliation,
cluster: uint16,
shards: seq[uint16],
peerManager: PeerManager,
wakuArchive: WakuArchive,
syncRange: timer.Duration = DefaultSyncRange,
Expand All @@ -279,6 +294,8 @@ proc new*(
SeqStorage.new(res.get())

var sync = SyncReconciliation(
cluster: cluster,
shards: shards.toPackedSet(),
peerManager: peerManager,
storage: storage,
syncRange: syncRange,
Expand Down
Loading