Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Rivulet Shard Implementation #453

Closed
wants to merge 1 commit into from

Conversation

anshumankomawar
Copy link
Collaborator

@anshumankomawar anshumankomawar commented Jan 17, 2025

Shards

Reference-Based Shards

  • Pros:
    • Minimal overhead when creating new shards.
    • Easy to update shards by regenerating pointers instead of creating new datasets.
    • Simplifies storage and avoids duplicating metadata/dataset attributes.
  • Cons:
    • Limited functionality, as shards only act as pointers and do not provide direct access to dataset operations. We will need to pass shards into interfaces like the DatasetReader (refer to demo for existing syntax).

Dataset-Based Shards

  • Pros:
    • Might be more intuitive for users, as shards inherit all dataset attributes and functions.
    • Users can directly operate on shards without referencing the parent dataset.
  • Cons:
    • Higher overhead when creating shards, as they replicate dataset functionality and attributes.
    • Updating shards to reflect changes in the original dataset may require significant effort, including regenerating datasets.

Implementation Note:
The current implementation uses a reference-based approach to minimize overhead and simplify shard updates. However, transitioning to a dataset-based approach could improve usability and is worth considering based on future requirements.

Changes

This pull request introduces a new example for processing and analyzing data with DeltaCat, and includes several updates to the sharding and scanning functionalities in the DeltaCat storage module. The most important changes include the addition of a new Jupyter notebook example, the implementation of an abstract Shard class and sharding strategies, and modifications to the dataset and block scanner classes to support sharding and field selection.

New Example:

  • deltacat/examples/rivulet/shard_demo.ipynb: Added a new Jupyter notebook demonstrating how to generate, load, shard, process, and aggregate data using DeltaCat, PyTorch, and Ray.

Sharding Implementation:

  • deltacat/storage/model/shard.py: Introduced an abstract Shard class and a ShardingStrategyType enum to define different sharding strategies. Added a ShardingStrategy abstract base class for creating shards.
  • deltacat/storage/rivulet/dataset.py: Added a shards method to the Dataset class to create shards using a specified sharding strategy. Modified the scan method to accept an optional shard parameter.

Scanning Enhancements:

  • deltacat/storage/rivulet/reader/block_scanner.py: Updated the BlockScanner class to accept an optional fields parameter in various methods to support field selection during scans.
  • deltacat/storage/rivulet/reader/data_reader.py: Modified the DataReader protocol to include an optional fields parameter in the deserialize_records and join_deserialize_records methods.

Dataset Reader Updates:

  • deltacat/storage/rivulet/reader/data_scan.py: Added shard and fields parameters to the DataScan class and updated the to_arrow, to_pydict, and to_tensor methods to use these parameters.
  • deltacat/storage/rivulet/reader/dataset_reader.py: Updated the DatasetReader class to handle sharding and field selection in the scan method.

Testing

make test

Checklist

  • [ x] Unit tests covering the changes have been added

    • If this is a bugfix, regression tests have been added
  • E2E testing has been performed

@anshumankomawar anshumankomawar marked this pull request as ready for review January 17, 2025 02:44
Comment on lines +76 to +84
"cell_type": "code",
"source": [
"dataset = dc.Dataset.from_parquet(\n",
" name=\"data\",\n",
" file_uri=parquet_file_path,\n",
" metadata_uri=\".\",\n",
" merge_keys=\"id\"\n",
")\n",
"print(\"Loaded dataset from Parquet file.\")\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to put the metadata outside of the dire the parquet file is in? (which is the default?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there's a slight issue with how we derive the metadata_uri

    metadata_uri = metadata_uri or os.path.join(file_uri, "riv-meta")

The metadata_uri is correctly constructed only if file_uri is a directory and not a singular file. I'll fix this when I implement the FileSpec for rivulet, will also make the from_x methods pluggable.

else:
raise ValueError(f"Unsupported sharding strategy type: {self}")

class ShardingStrategy(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider making this a Protocol instead of an ABC. In general prefer protocols to ABCs, and in this specific case you don't have any shared state/code which benefits from ABC. You will have to refactor shards to accept num_shards as parameter

deltacat/storage/model/shard.py Show resolved Hide resolved
@@ -105,6 +107,7 @@ def deserialize_records(

:param records: Input data (generated by generate_records method)
:param output_type: Type to deserialize into
:param fields: Only include these fields in scan result (currently only for to_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fields is exposed in scan, right? So you can use this even if you aren't using the pytorch connector?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's exposed in scan as well. Not limited to tensor, but only the pytorch connector is making use of it as of now. I can definitely update the other connectors to do something similar as well. The main reason for adding this is cause I wanted to be able to limit the pytorch connector fields to numerical values.

return self.dataset_reader.scan(self.dataset_schema, Dict, self.query)
return self.dataset_reader.scan(self.dataset_schema, Dict, self.query, self.shard, self.fields)

def to_tensor(self) -> Generator[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I actually had implemented pytorch support in a somewhat different way. It looks like this was never ported to deltacat - see CR-165567134

Let's discuss whether we want a pytorch wrapper class like in the CR or to treat tensors as another memory format

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks great! I'll separate this pr so that we can take a look at the pytorch integration independently.

# TODO: is there a better way to filter the scan and apply different shards?
# for now keeping it shard specific ideally will need to restructure/rewrite dataset_reader.
if isinstance(shard, RangeShard):
query = query.intersect_with(shard.query)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little funky. It would be nice to keep separation between the user query and the subset of data specified by the shard

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm would it be better to adapt the __load_sst_rows to take in multiple queries and chain them? I'm not sure what that would look like, might require refactoring query expression itself?

"""
def __init__(self, metastore: DatasetMetastore, start: T, end: T):
super().__init__(metastore)
self.query = QueryExpression[T]().with_range(start, end)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If query is the way that Shard represents the data it is filtering, then this should be part of the Shard ABC/protocol.

I don't think we should represent this as a Query

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that, might be better to not mix queries and shards. Does having a defined start and end key work?

I think adding it to a shard protocol should work. Realistically if we want a uniform distribution across manifests the merge key will need to be involved and most likely it it will be in the form of a range. I'm okay with doing this and adapting if we introduce new ShardingStrategies.

return []

if global_min == global_max:
return [RangeShard(metastore, global_min, global_max)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion topic to have with broader rivulet channel: for correctness, does the data returned by each shard need to be strictly mutually exclusive? In this case, you would want to return a single shard from global min to global max, and N-1 other shards that are empty

Copy link
Collaborator

@flliver flliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start, needs work! It's too complicated for what it's trying to accomplish, make better use of your base classes and remove all unnecessary abstraction.

Comment on lines +106 to +113
"source": [
"@ray.remote\n",
"def process_shard(shard: Generator[torch.Tensor, None, None], fields: List[str]) -> Tuple[float, int]:\n",
" tensor_generator = dataset.scan(shard=shard, fields=fields).to_tensor()\n",
" total_age = 0.0\n",
" count = 0\n",
"\n",
" # calculate total\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typedefs in args are weird here, you're saying a shard is a Generator[torch.Tensor, none, none], but dataset.scan takes a 'Shard'. I think you're letting python get away w/ some weird type conversions, or at the very least it's not clear what you're passing around.

deltacat/storage/model/shard.py Show resolved Hide resolved
deltacat/storage/model/shard.py Show resolved Hide resolved
Comment on lines +509 to 518
def shards(self, num_shards: int, strategy: ShardingStrategyType=ShardingStrategyType.RANGE) -> Iterable[Shard]:
"""Create a set of shards for this dataset.

:param num_shards: The number of shards to create.
:param strategy: Sharding strategy used to create shards..
:return Iterable[Shard]: A set of shards for this dataset.
"""
return strategy.to_class(num_shards).shards(self._metastore)

def scan(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this interface will hold up:

  1. num_shards only applies to a RANGE strategy, there may be other strategies that have a variable # of shards, like a SIZE strategy
  2. ShardingStrategyType. is non-pythonic, because it's self-describing type with a type. ShardingStrategy.RANGE is more pythonic (simple/clear/english)
  3. it returns an iterable of Shard, now everything has to know what a Shard is, while the shard class itself has no methods, accessors, identifiers, etc. (although I think you're saying repr is going to be the identifier, which is weird because repr is just supposed to be the printable representation?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that said, eh, whatever, lets try it out for a bit and iterate on it.

Comment on lines -512 to +523
return DataScan(self.schemas[schema_name], query, dataset_reader)
return DataScan(self.schemas[schema_name], query, dataset_reader, shard=shard, fields=fields)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you just broke the contract for DataScan by putting in a schema and fields?

Comment on lines +120 to +128
@staticmethod
def __join_records_as_tensor(records: List[RecordBatchRowIndex], fields: Optional[List[str]]) -> Tensor:
"""
Deserialize records into a PyTorch Tensor.

:param records: input record data
:returns: A PyTorch Tensor representing the joined records.
"""
tensors = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up w/ these weird types, RecordBatchRowIndex? Why not RecordBatch?

Comment on lines +11 to +34
def test_range_shard_initialization(self):
mock_metastore = MagicMock(spec=DatasetMetastore)
shard = RangeShard(mock_metastore, 0, 10)
self.assertEqual(shard.query.min_key, 0)
self.assertEqual(shard.query.max_key, 10)

def test_range_shard_repr(self):
mock_metastore = MagicMock(spec=DatasetMetastore)
shard = RangeShard(mock_metastore, 5, 15)
self.assertEqual(repr(shard), "Shard(type=range, start=5, end=15)")

class RangeShardingStrategyTests(unittest.TestCase):

def test_shards_generation(self):
mock_metastore = MagicMock(spec=DatasetMetastore)
mock_metastore.generate_manifests.return_value = [
MagicMock(generate_sstables=lambda: [
MagicMock(min_key=0, max_key=50),
MagicMock(min_key=60, max_key=100)
])
]
strategy = RangeShardingStrategy(3)
shards = list(strategy.shards(mock_metastore))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the code smell of having to use a mock for the DatasetMetastore, smells like some inversion of responsibility gone wrong. Its just weird to have the range shard need to know anything about the metastore at all. Gut feel, the scan operation should utilize the defined range against whatever metastore the dataset has. Like why does the shard have a metastore when the dataset already has a metastore? Can't the scan code that receives the range get the metastore from self instead of doing this weird passthrough from the dataset to the metastore to the shard and back to the scan operation? How is the shard ever going to be serializable when it has the metastore inside of it?

Comment on lines +9 to +21

class RangeShard(Shard, Generic[T]):
""" Represents a range-based shard with start and end keys.

param: query: A QueryExpression object defining the range of the shard.
"""
def __init__(self, metastore: DatasetMetastore, start: T, end: T):
super().__init__(metastore)
self.query = QueryExpression[T]().with_range(start, end)

def __repr__(self) -> str:
return f"Shard(type=range, start={self.query.min_key}, end={self.query.max_key})"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I like all this fancy bullshit abstraction. Why isn't a RangeShard just a data class with a 'start, end' and key_name, and implements a 'get_records' that returns a Generator[RecordBatch].

Comment on lines +62 to +79
def __get_max_interval(self, manifests: Set[ManifestAccessor]) -> Tuple[Optional[T], Optional[T]]:
""" Computes the global minimum and maximum keys from the dataset manifests.

param: manifests: A set of ManifestAccessor objects representing the dataset.
returns: A tuple containing the global minimum and maximum keys, or None if no data is found.
"""
global_min, global_max = None, None

for manifest in manifests:
for table in manifest.generate_sstables():
if global_min is None or table.min_key < global_min:
global_min = table.min_key
if global_max is None or table.max_key > global_max:
global_max = table.max_key

print(f"Global min: {global_min}, Global max: {global_max}")
return global_min, global_max

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This belongs in the manifests class.

Comment on lines +44 to +53
# determine range generation type
range_generator = RangeGeneratorType.from_key(key=global_min)

shards = []
local_min = global_min
for idx in range(self.num_shards):
# Compute the max key for the current shard
is_last_shard = idx == self.num_shards - 1
local_max = global_max if is_last_shard else range_generator.interpolate(global_min, global_max, idx + 1, self.num_shards)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's with these extra layers of abstraction to have RangeGeneratorType? in python you use your builtins, just calculate the range, and make sure whatever class you're calculating range from has the appropriate function implemented.

Like this whole function should be three lines:

min = metastore.get_min_value("field")
max = metastore.get_max_value("field")
shards: List[Shard] = RangeShard.split(min, max)
return shards

metastore should have min/max value functions, range shard should have its own functions for common operations, and stop passing state all over god's green earth.

I think you're issue is you're trying to do all this without modifying or adding functionality to the base classes you are working with. Instead, delete everything you have, write the absolute bare minimum human undestandable pseudo code you need to calculate what you are trying to do, which is to split an interval into a set of smaller intervals, then add the functionality to do that to the most obvious class where it belongs, without adding any additional classes. Every time you add a function, ask yourself, 'is this the responsibilty of this class?' and if it's not, move it to where it belongs. Only when you find that there's functionality that clearly does not belong in any class do you add a new class.

The goal should be to add the fewest number of new abstractions, fewest lines, and most easily understood functions you can to get the job done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're issue is you're trying to do all this without modifying or adding functionality to the base classes you are working with.

I think this is why I ended up having to create these abstractions. I tried to have shard do all of the work but maybe the goal should have been to implement some of these features in the existing classes. I'll refactor it to make it simpler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to treat RangeShard as a single unit and so didn't want to include the generation logic there. Thought it would be better to have a range generator that returns a list of ranges. I thought it could be easier to expand in the future if we wanted to support additional types. But yes having the splitting logic in the Shard itself would reduce a lot of the additional code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants