Skip to content

Commit

Permalink
feat(protocol/message): add inventory message (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
ybensacq authored Oct 7, 2024
1 parent cc73c54 commit 0c66e2c
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 12 deletions.
18 changes: 6 additions & 12 deletions src/network/protocol/messages/getdata.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ const std = @import("std");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const message = @import("./lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;
const genericSerialize = @import("lib.zig").genericSerialize;

const Sha256 = std.crypto.hash.sha2.Sha256;

const protocol = @import("../lib.zig");

pub const GetdataMessage = struct {
inventory: []const protocol.InventoryItem,
const Self = @This();

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.GETDATA ++ [_]u8{0} ** 5;
Expand Down Expand Up @@ -39,14 +42,7 @@ pub const GetdataMessage = struct {
}

pub fn serialize(self: *const GetdataMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
return genericSerialize(self, allocator);
}

/// Serialize a message as bytes and write them to the buffer.
Expand Down Expand Up @@ -94,10 +90,8 @@ pub const GetdataMessage = struct {
}

/// Deserialize bytes into a `GetdataMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !GetdataMessage {
var fbs = std.io.fixedBufferStream(bytes);
const reader = fbs.reader();
return try GetdataMessage.deserializeReader(allocator, reader);
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
return genericDeserializeSlice(Self, allocator, bytes);
}


Expand Down
146 changes: 146 additions & 0 deletions src/network/protocol/messages/inv.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
const std = @import("std");
const CompactSizeUint = @import("bitcoin-primitives").types.CompatSizeUint;
const message = @import("./lib.zig");
const genericChecksum = @import("lib.zig").genericChecksum;
const genericDeserializeSlice = @import("lib.zig").genericDeserializeSlice;

const Sha256 = std.crypto.hash.sha2.Sha256;

const protocol = @import("../lib.zig");

pub const InvMessage = struct {
inventory: []const protocol.InventoryItem,
const Self = @This();

pub inline fn name() *const [12]u8 {
return protocol.CommandNames.INV ++ [_]u8{0} ** 5;
}

/// Returns the message checksum
///
/// Computed as `Sha256(Sha256(self.serialize()))[0..4]`
pub fn checksum(self: *const InvMessage) [4]u8 {
return genericChecksum(self);
}

/// Free the `inventory`
pub fn deinit(self: InvMessage, allocator: std.mem.Allocator) void {
allocator.free(self.inventory);
}

/// Serialize the message as bytes and write them to the Writer.
///
/// `w` should be a valid `Writer`.
pub fn serializeToWriter(self: *const InvMessage, w: anytype) !void {
const count = CompactSizeUint.new(self.inventory.len);
try count.encodeToWriter(w);

for (self.inventory) |item| {
try item.encodeToWriter(w);
}
}

pub fn serialize(self: *const InvMessage, allocator: std.mem.Allocator) ![]u8 {
const serialized_len = self.hintSerializedLen();

const ret = try allocator.alloc(u8, serialized_len);
errdefer allocator.free(ret);

try self.serializeToSlice(ret);

return ret;
}

/// Serialize a message as bytes and write them to the buffer.
///
/// buffer.len must be >= than self.hintSerializedLen()
pub fn serializeToSlice(self: *const InvMessage, buffer: []u8) !void {
var fbs = std.io.fixedBufferStream(buffer);
const writer = fbs.writer();
try self.serializeToWriter(writer);
}

pub fn hintSerializedLen(self: *const InvMessage) usize {
var length: usize = 0;

// Adding the length of CompactSizeUint for the count
const count = CompactSizeUint.new(self.inventory.len);
length += count.hint_encoded_len();

// Adding the length of each inventory item
length += self.inventory.len * (4 + 32); // Type (4 bytes) + Hash (32 bytes)

return length;
}

pub fn deserializeReader(allocator: std.mem.Allocator, r: anytype) !InvMessage {

const compact_count = try CompactSizeUint.decodeReader(r);
const count = compact_count.value();
if (count == 0) {
return InvMessage{
.inventory = &[_]protocol.InventoryItem{},
};
}

const inventory = try allocator.alloc(protocol.InventoryItem, count);
errdefer allocator.free(inventory);

for (inventory) |*item| {
item.* = try protocol.InventoryItem.decodeReader(r);
}

return InvMessage{
.inventory = inventory,
};
}

/// Deserialize bytes into a `InvMessage`
pub fn deserializeSlice(allocator: std.mem.Allocator, bytes: []const u8) !Self {
return genericDeserializeSlice(Self, allocator, bytes);
}


pub fn eql(self: *const InvMessage, other: *const InvMessage) bool {
if (self.inventory.len != other.inventory.len) return false;

for (0..self.inventory.len) |i| {
const item_self = self.inventory[i];
const item_other = other.inventory[i];
if (!item_self.eql(&item_other)) {
return false;
}
}

return true;
}
};


// TESTS
test "ok_full_flow_inv_message" {
const allocator = std.testing.allocator;

// With some inventory items
{
const inventory_items = [_]protocol.InventoryItem{
.{ .type = 1, .hash = [_]u8{0xab} ** 32 },
.{ .type = 2, .hash = [_]u8{0xcd} ** 32 },
.{ .type = 2, .hash = [_]u8{0xef} ** 32 },
};

const gd = InvMessage{
.inventory = inventory_items[0..],
};

const payload = try gd.serialize(allocator);
defer allocator.free(payload);

const deserialized_gd = try InvMessage.deserializeSlice(allocator, payload);

try std.testing.expect(gd.eql(&deserialized_gd));

// Free allocated memory for deserialized inventory
defer allocator.free(deserialized_gd.inventory);
}
}
7 changes: 7 additions & 0 deletions src/network/protocol/messages/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub const FilterLoadMessage = @import("filterload.zig").FilterLoadMessage;
pub const GetBlockTxnMessage = @import("getblocktxn.zig").GetBlockTxnMessage;
pub const HeadersMessage = @import("headers.zig").HeadersMessage;
pub const CmpctBlockMessage = @import("cmpctblock.zig").CmpctBlockMessage;
pub const InvMessage = @import("inv.zig").InvMessage;

pub const MessageTypes = enum {
version,
Expand All @@ -45,6 +46,7 @@ pub const MessageTypes = enum {
getdata,
headers,
cmpctblock,
inv
};


Expand All @@ -70,6 +72,7 @@ pub const Message = union(MessageTypes) {
getdata: GetdataMessage,
headers: HeadersMessage,
cmpctblock: CmpctBlockMessage,
inv: InvMessage,

pub fn name(self: Message) *const [12]u8 {
return switch (self) {
Expand All @@ -94,6 +97,7 @@ pub const Message = union(MessageTypes) {
.getdata => |m| @TypeOf(m).name(),
.headers => |m| @TypeOf(m).name(),
.cmpctblock => |m| @TypeOf(m).name(),
.inv => |m| @TypeOf(m).name(),
};
}

Expand All @@ -113,6 +117,7 @@ pub const Message = union(MessageTypes) {
.filterload => {},
.getblocktxn => |*m| m.deinit(allocator),
.headers => |*m| m.deinit(allocator),
.inv => |*m| m.deinit(allocator),
else => {}
}
}
Expand Down Expand Up @@ -140,6 +145,7 @@ pub const Message = union(MessageTypes) {
.getdata => |*m| m.checksum(),
.headers => |*m| m.checksum(),
.cmpctblock => |*m| m.checksum(),
.inv => |*m| m.checksum(),
};
}

Expand All @@ -166,6 +172,7 @@ pub const Message = union(MessageTypes) {
.getdata => |m| m.hintSerializedLen(),
.headers => |*m| m.hintSerializedLen(),
.cmpctblock => |*m| m.hintSerializedLen(),
.inv => |*m| m.hintSerializedLen(),
};
}
};
Expand Down
40 changes: 40 additions & 0 deletions src/network/wire/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ pub fn receiveMessage(
protocol.messages.Message{ .getdata = try protocol.messages.GetdataMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.CmpctBlockMessage.name()))
protocol.messages.Message{ .cmpctblock = try protocol.messages.CmpctBlockMessage.deserializeReader(allocator, r) }
else if (std.mem.eql(u8, &command, protocol.messages.InvMessage.name()))
protocol.messages.Message{ .inv = try protocol.messages.InvMessage.deserializeReader(allocator, r) }
else {
try r.skipBytes(payload_len, .{}); // Purge the wire
return error.UnknownMessage;
Expand Down Expand Up @@ -787,3 +789,41 @@ test "ok_send_cmpctblock_message" {
try std.testing.expect(hint_len > 0);
try std.testing.expect(hint_len == serialized.len);
}

test "ok_send_inv_message" {
const Config = @import("../../config/config.zig").Config;
const ArrayList = std.ArrayList;
const test_allocator = std.testing.allocator;
const InvMessage = protocol.messages.InvMessage;

var list: std.ArrayListAligned(u8, null) = ArrayList(u8).init(test_allocator);
defer list.deinit();

const inventory = try test_allocator.alloc(protocol.InventoryItem, 5);
defer test_allocator.free(inventory);

for (inventory) |*item| {
item.type = 1;
for (&item.hash) |*byte| {
byte.* = 0xab;
}
}

const message = InvMessage{
.inventory = inventory,
};

var received_message = try write_and_read_message(
test_allocator,
&list,
Config.BitcoinNetworkId.MAINNET,
Config.PROTOCOL_VERSION,
message,
) orelse unreachable;
defer received_message.deinit(test_allocator);

switch (received_message) {
.inv => |rm| try std.testing.expect(message.eql(&rm)),
else => unreachable,
}
}

0 comments on commit 0c66e2c

Please sign in to comment.