Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce TaggedAbstractTensor #2933

Merged
merged 12 commits into from
Sep 18, 2024
Merged

Introduce TaggedAbstractTensor #2933

merged 12 commits into from
Sep 18, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Sep 11, 2024

This generalizes AbstractTensor to the templated struct AbstractTensorWithInfo<Info> and introduces a special case subclass TaggedAbstractTensor<Tag>. This can be used by passing an enum class for Tag, and holds an unordered_set<Tag> for each dimension. Merging and swizzling unions these sets, and split duplicates the set.

Note that a lot of code had to be moved out of the cpp into the header because of templatization. However, there are no changes to the Dispatch* classes. The AbstractTensorWithInfo methods like split, merge, swizzle, etc. are just changed to add calls to Info::merge.

Related to #2913, which specializes this as using AbstractMatmulTensor = TaggedAbstractTensor<MatmulDimRole>.

This generalizes `AbstractTensor` to the templated struct `AbstractTensorWithInfo<Info>`.

Related to #2913.
@jacobhinkle jacobhinkle marked this pull request as ready for review September 11, 2024 18:37
@jacobhinkle
Copy link
Collaborator Author

!build

csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Show resolved Hide resolved
csrc/abstract_tensor.h Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
TaggedAbstractTensor(
std::initializer_list<AbstractId> domain,
std::initializer_list<std::initializer_list<Tag>> tag_sets)
: AbstractTensorWithInfo<TagSetInfo<Tag>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
: AbstractTensorWithInfo<TagSetInfo<Tag>>(
: AbstractTensorWithInfo(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I could get rid of these also but clang complains if I even remove Tag here. 🤷‍♂️

csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
std::initializer_list<std::initializer_list<Tag>> tag_sets)
: AbstractTensorWithInfo<TagSetInfo<Tag>>(
domain,
{tag_sets.begin(), tag_sets.end()}) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
{tag_sets.begin(), tag_sets.end()}) {}
tag_sets) {}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually didn't work for me because although there is an unordered_set ctor for TagSetInfo, there's no vector<unordered_set<Tag>> ctor for vector<TagSetInfo>.

tests/cpp/test_abstract_tensor.cpp Outdated Show resolved Hide resolved
@jacobhinkle
Copy link
Collaborator Author

!build

Fixes test failure in AbstractTensorTest.Reorder caused by operator==
returning false although the vector<AbstractId> given exactly matched
domain.

Previously these were automatically converted to AbstractTensor but I
guess this no longer works given the new constructors.
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
tests/cpp/test_abstract_tensor.cpp Outdated Show resolved Hide resolved
@jacobhinkle jacobhinkle merged commit 94cff06 into main Sep 18, 2024
5 checks passed
@jacobhinkle jacobhinkle deleted the tagged_abstract_tensor branch September 18, 2024 13:55
jacobhinkle added a commit that referenced this pull request Sep 19, 2024
`AbstractTensor` now has many methods and as of #2933 it is subclassed
into `TaggedAbstractTensor` which has additional methods. Also, there is
no longer just a single vector attribute `domain` but also `info`.
Instead of requiring the user to manage those vectors and keep them the
same length, instead this PR makes them `protected` and uses the
existing accessors to adjust them, as suggested in
#2963 (comment).
jacobhinkle added a commit that referenced this pull request Sep 19, 2024
`AbstractTensor` now has many methods and as of #2933 it is subclassed
into `TaggedAbstractTensor` which has additional methods. Also, there is
no longer just a single vector attribute `domain` but also `info`.
Instead of requiring the user to manage those vectors and keep them the
same length, this PR makes them `protected` members of a class and uses
the existing accessors to adjust them, as suggested in
#2963 (comment).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants