-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Rivulet Shard Implementation #453
Conversation
d4b5139
to
ac822c3
Compare
"cell_type": "code", | ||
"source": [ | ||
"dataset = dc.Dataset.from_parquet(\n", | ||
" name=\"data\",\n", | ||
" file_uri=parquet_file_path,\n", | ||
" metadata_uri=\".\",\n", | ||
" merge_keys=\"id\"\n", | ||
")\n", | ||
"print(\"Loaded dataset from Parquet file.\")\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to put the metadata outside of the dire the parquet file is in? (which is the default?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe there's a slight issue with how we derive the metadata_uri
metadata_uri = metadata_uri or os.path.join(file_uri, "riv-meta")
The metadata_uri is correctly constructed only if file_uri is a directory and not a singular file. I'll fix this when I implement the FileSpec for rivulet, will also make the from_x methods pluggable.
else: | ||
raise ValueError(f"Unsupported sharding strategy type: {self}") | ||
|
||
class ShardingStrategy(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider making this a Protocol instead of an ABC. In general prefer protocols to ABCs, and in this specific case you don't have any shared state/code which benefits from ABC. You will have to refactor shards to accept num_shards as parameter
@@ -105,6 +107,7 @@ def deserialize_records( | |||
|
|||
:param records: Input data (generated by generate_records method) | |||
:param output_type: Type to deserialize into | |||
:param fields: Only include these fields in scan result (currently only for to_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fields is exposed in scan, right? So you can use this even if you aren't using the pytorch connector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's exposed in scan as well. Not limited to tensor, but only the pytorch connector is making use of it as of now. I can definitely update the other connectors to do something similar as well. The main reason for adding this is cause I wanted to be able to limit the pytorch connector fields to numerical values.
return self.dataset_reader.scan(self.dataset_schema, Dict, self.query) | ||
return self.dataset_reader.scan(self.dataset_schema, Dict, self.query, self.shard, self.fields) | ||
|
||
def to_tensor(self) -> Generator[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I actually had implemented pytorch support in a somewhat different way. It looks like this was never ported to deltacat - see CR-165567134
Let's discuss whether we want a pytorch wrapper class like in the CR or to treat tensors as another memory format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks great! I'll separate this pr so that we can take a look at the pytorch integration independently.
# TODO: is there a better way to filter the scan and apply different shards? | ||
# for now keeping it shard specific ideally will need to restructure/rewrite dataset_reader. | ||
if isinstance(shard, RangeShard): | ||
query = query.intersect_with(shard.query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little funky. It would be nice to keep separation between the user query and the subset of data specified by the shard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm would it be better to adapt the __load_sst_rows to take in multiple queries and chain them? I'm not sure what that would look like, might require refactoring query expression itself?
""" | ||
def __init__(self, metastore: DatasetMetastore, start: T, end: T): | ||
super().__init__(metastore) | ||
self.query = QueryExpression[T]().with_range(start, end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If query
is the way that Shard represents the data it is filtering, then this should be part of the Shard
ABC/protocol.
I don't think we should represent this as a Query
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that, might be better to not mix queries and shards. Does having a defined start and end key work?
I think adding it to a shard protocol should work. Realistically if we want a uniform distribution across manifests the merge key will need to be involved and most likely it it will be in the form of a range. I'm okay with doing this and adapting if we introduce new ShardingStrategies.
return [] | ||
|
||
if global_min == global_max: | ||
return [RangeShard(metastore, global_min, global_max)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussion topic to have with broader rivulet channel: for correctness, does the data returned by each shard need to be strictly mutually exclusive? In this case, you would want to return a single shard from global min to global max, and N-1 other shards that are empty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start, needs work! It's too complicated for what it's trying to accomplish, make better use of your base classes and remove all unnecessary abstraction.
"source": [ | ||
"@ray.remote\n", | ||
"def process_shard(shard: Generator[torch.Tensor, None, None], fields: List[str]) -> Tuple[float, int]:\n", | ||
" tensor_generator = dataset.scan(shard=shard, fields=fields).to_tensor()\n", | ||
" total_age = 0.0\n", | ||
" count = 0\n", | ||
"\n", | ||
" # calculate total\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typedefs in args are weird here, you're saying a shard is a Generator[torch.Tensor, none, none], but dataset.scan takes a 'Shard'. I think you're letting python get away w/ some weird type conversions, or at the very least it's not clear what you're passing around.
def shards(self, num_shards: int, strategy: ShardingStrategyType=ShardingStrategyType.RANGE) -> Iterable[Shard]: | ||
"""Create a set of shards for this dataset. | ||
|
||
:param num_shards: The number of shards to create. | ||
:param strategy: Sharding strategy used to create shards.. | ||
:return Iterable[Shard]: A set of shards for this dataset. | ||
""" | ||
return strategy.to_class(num_shards).shards(self._metastore) | ||
|
||
def scan( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this interface will hold up:
- num_shards only applies to a RANGE strategy, there may be other strategies that have a variable # of shards, like a SIZE strategy
- ShardingStrategyType. is non-pythonic, because it's self-describing type with a type. ShardingStrategy.RANGE is more pythonic (simple/clear/english)
- it returns an iterable of Shard, now everything has to know what a Shard is, while the shard class itself has no methods, accessors, identifiers, etc. (although I think you're saying repr is going to be the identifier, which is weird because repr is just supposed to be the printable representation?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that said, eh, whatever, lets try it out for a bit and iterate on it.
return DataScan(self.schemas[schema_name], query, dataset_reader) | ||
return DataScan(self.schemas[schema_name], query, dataset_reader, shard=shard, fields=fields) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you just broke the contract for DataScan by putting in a schema and fields?
@staticmethod | ||
def __join_records_as_tensor(records: List[RecordBatchRowIndex], fields: Optional[List[str]]) -> Tensor: | ||
""" | ||
Deserialize records into a PyTorch Tensor. | ||
|
||
:param records: input record data | ||
:returns: A PyTorch Tensor representing the joined records. | ||
""" | ||
tensors = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's up w/ these weird types, RecordBatchRowIndex? Why not RecordBatch?
def test_range_shard_initialization(self): | ||
mock_metastore = MagicMock(spec=DatasetMetastore) | ||
shard = RangeShard(mock_metastore, 0, 10) | ||
self.assertEqual(shard.query.min_key, 0) | ||
self.assertEqual(shard.query.max_key, 10) | ||
|
||
def test_range_shard_repr(self): | ||
mock_metastore = MagicMock(spec=DatasetMetastore) | ||
shard = RangeShard(mock_metastore, 5, 15) | ||
self.assertEqual(repr(shard), "Shard(type=range, start=5, end=15)") | ||
|
||
class RangeShardingStrategyTests(unittest.TestCase): | ||
|
||
def test_shards_generation(self): | ||
mock_metastore = MagicMock(spec=DatasetMetastore) | ||
mock_metastore.generate_manifests.return_value = [ | ||
MagicMock(generate_sstables=lambda: [ | ||
MagicMock(min_key=0, max_key=50), | ||
MagicMock(min_key=60, max_key=100) | ||
]) | ||
] | ||
strategy = RangeShardingStrategy(3) | ||
shards = list(strategy.shards(mock_metastore)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the code smell of having to use a mock for the DatasetMetastore, smells like some inversion of responsibility gone wrong. Its just weird to have the range shard need to know anything about the metastore at all. Gut feel, the scan operation should utilize the defined range against whatever metastore the dataset has. Like why does the shard have a metastore when the dataset already has a metastore? Can't the scan code that receives the range get the metastore from self instead of doing this weird passthrough from the dataset to the metastore to the shard and back to the scan operation? How is the shard ever going to be serializable when it has the metastore inside of it?
|
||
class RangeShard(Shard, Generic[T]): | ||
""" Represents a range-based shard with start and end keys. | ||
|
||
param: query: A QueryExpression object defining the range of the shard. | ||
""" | ||
def __init__(self, metastore: DatasetMetastore, start: T, end: T): | ||
super().__init__(metastore) | ||
self.query = QueryExpression[T]().with_range(start, end) | ||
|
||
def __repr__(self) -> str: | ||
return f"Shard(type=range, start={self.query.min_key}, end={self.query.max_key})" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I like all this fancy bullshit abstraction. Why isn't a RangeShard just a data class with a 'start, end' and key_name, and implements a 'get_records' that returns a Generator[RecordBatch].
def __get_max_interval(self, manifests: Set[ManifestAccessor]) -> Tuple[Optional[T], Optional[T]]: | ||
""" Computes the global minimum and maximum keys from the dataset manifests. | ||
|
||
param: manifests: A set of ManifestAccessor objects representing the dataset. | ||
returns: A tuple containing the global minimum and maximum keys, or None if no data is found. | ||
""" | ||
global_min, global_max = None, None | ||
|
||
for manifest in manifests: | ||
for table in manifest.generate_sstables(): | ||
if global_min is None or table.min_key < global_min: | ||
global_min = table.min_key | ||
if global_max is None or table.max_key > global_max: | ||
global_max = table.max_key | ||
|
||
print(f"Global min: {global_min}, Global max: {global_max}") | ||
return global_min, global_max | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This belongs in the manifests class.
# determine range generation type | ||
range_generator = RangeGeneratorType.from_key(key=global_min) | ||
|
||
shards = [] | ||
local_min = global_min | ||
for idx in range(self.num_shards): | ||
# Compute the max key for the current shard | ||
is_last_shard = idx == self.num_shards - 1 | ||
local_max = global_max if is_last_shard else range_generator.interpolate(global_min, global_max, idx + 1, self.num_shards) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's with these extra layers of abstraction to have RangeGeneratorType? in python you use your builtins, just calculate the range, and make sure whatever class you're calculating range from has the appropriate function implemented.
Like this whole function should be three lines:
min = metastore.get_min_value("field")
max = metastore.get_max_value("field")
shards: List[Shard] = RangeShard.split(min, max)
return shards
metastore should have min/max value functions, range shard should have its own functions for common operations, and stop passing state all over god's green earth.
I think you're issue is you're trying to do all this without modifying or adding functionality to the base classes you are working with. Instead, delete everything you have, write the absolute bare minimum human undestandable pseudo code you need to calculate what you are trying to do, which is to split an interval into a set of smaller intervals, then add the functionality to do that to the most obvious class where it belongs, without adding any additional classes. Every time you add a function, ask yourself, 'is this the responsibilty of this class?' and if it's not, move it to where it belongs. Only when you find that there's functionality that clearly does not belong in any class do you add a new class.
The goal should be to add the fewest number of new abstractions, fewest lines, and most easily understood functions you can to get the job done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're issue is you're trying to do all this without modifying or adding functionality to the base classes you are working with.
I think this is why I ended up having to create these abstractions. I tried to have shard do all of the work but maybe the goal should have been to implement some of these features in the existing classes. I'll refactor it to make it simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to treat RangeShard as a single unit and so didn't want to include the generation logic there. Thought it would be better to have a range generator that returns a list of ranges. I thought it could be easier to expand in the future if we wanted to support additional types. But yes having the splitting logic in the Shard itself would reduce a lot of the additional code.
Shards
Reference-Based Shards
Dataset-Based Shards
Implementation Note:
The current implementation uses a reference-based approach to minimize overhead and simplify shard updates. However, transitioning to a dataset-based approach could improve usability and is worth considering based on future requirements.
Changes
This pull request introduces a new example for processing and analyzing data with DeltaCat, and includes several updates to the sharding and scanning functionalities in the DeltaCat storage module. The most important changes include the addition of a new Jupyter notebook example, the implementation of an abstract
Shard
class and sharding strategies, and modifications to the dataset and block scanner classes to support sharding and field selection.New Example:
deltacat/examples/rivulet/shard_demo.ipynb
: Added a new Jupyter notebook demonstrating how to generate, load, shard, process, and aggregate data using DeltaCat, PyTorch, and Ray.Sharding Implementation:
deltacat/storage/model/shard.py
: Introduced an abstractShard
class and aShardingStrategyType
enum to define different sharding strategies. Added aShardingStrategy
abstract base class for creating shards.deltacat/storage/rivulet/dataset.py
: Added ashards
method to theDataset
class to create shards using a specified sharding strategy. Modified thescan
method to accept an optionalshard
parameter.Scanning Enhancements:
deltacat/storage/rivulet/reader/block_scanner.py
: Updated theBlockScanner
class to accept an optionalfields
parameter in various methods to support field selection during scans.deltacat/storage/rivulet/reader/data_reader.py
: Modified theDataReader
protocol to include an optionalfields
parameter in thedeserialize_records
andjoin_deserialize_records
methods.Dataset Reader Updates:
deltacat/storage/rivulet/reader/data_scan.py
: Addedshard
andfields
parameters to theDataScan
class and updated theto_arrow
,to_pydict
, andto_tensor
methods to use these parameters.deltacat/storage/rivulet/reader/dataset_reader.py
: Updated theDatasetReader
class to handle sharding and field selection in thescan
method.Testing
make test
Checklist
[ x] Unit tests covering the changes have been added
E2E testing has been performed