Skip to content

Commit c23ec23

Browse files
authored
Merge pull request #16171 from gita-omr/mask_reduction_coerced
Handle mask reduction intrinsic in VectorAPIExpansion
2 parents 78db7ef + f806581 commit c23ec23

File tree

4 files changed

+107
-20
lines changed

4 files changed

+107
-20
lines changed

runtime/compiler/codegen/J9RecognizedMethodsEnum.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@
454454
jdk_internal_vm_vector_VectorSupport_blend,
455455
jdk_internal_vm_vector_VectorSupport_compare,
456456
jdk_internal_vm_vector_VectorSupport_fromBitsCoerced,
457+
jdk_internal_vm_vector_VectorSupport_maskReductionCoerced,
457458
jdk_internal_vm_vector_VectorSupport_reductionCoerced,
458459
jdk_internal_vm_vector_VectorSupport_ternaryOp,
459460
jdk_internal_vm_vector_VectorSupport_test,

runtime/compiler/env/j9method.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -3003,6 +3003,7 @@ void TR_ResolvedJ9Method::construct()
30033003
{x(TR::jdk_internal_vm_vector_VectorSupport_blend, "blend", "(Ljava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorBlendOp;)Ljdk/internal/vm/vector/VectorSupport$Vector;")},
30043004
{x(TR::jdk_internal_vm_vector_VectorSupport_compare, "compare", "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorCompareOp;)Ljdk/internal/vm/vector/VectorSupport$VectorMask;")},
30053005
{x(TR::jdk_internal_vm_vector_VectorSupport_fromBitsCoerced, "fromBitsCoerced", "(Ljava/lang/Class;Ljava/lang/Class;IJILjdk/internal/vm/vector/VectorSupport$VectorSpecies;Ljdk/internal/vm/vector/VectorSupport$FromBitsCoercedOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;")},
3006+
{x(TR::jdk_internal_vm_vector_VectorSupport_maskReductionCoerced, "maskReductionCoerced", "(ILjava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorMaskOp;)J")},
30063007
{x(TR::jdk_internal_vm_vector_VectorSupport_reductionCoerced, "reductionCoerced", "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$ReductionOperation;)J")},
30073008
{x(TR::jdk_internal_vm_vector_VectorSupport_ternaryOp, "ternaryOp", "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$TernaryOperation;)Ljdk/internal/vm/vector/VectorSupport$Vector;")},
30083009
{x(TR::jdk_internal_vm_vector_VectorSupport_test, "test", "(ILjava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljava/util/function/BiFunction;)Z")},

runtime/compiler/optimizer/VectorAPIExpansion.cpp

+60-19
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,8 @@ TR_VectorAPIExpansion::validateSymRef(int32_t id, int32_t i, vec_sz_t &classLeng
682682
methodType != classType)
683683
{
684684
if (_trace)
685-
traceMsg(comp(), "%s invalidating6 class #%d due to symref #%d method type %d, seen type %d\n",
686-
OPT_DETAILS_VECTOR, id, i, (int)methodType, (int)classType);
685+
traceMsg(comp(), "%s invalidating6 class #%d due to symref #%d method type %s, seen type %s\n",
686+
OPT_DETAILS_VECTOR, id, i, TR::DataType::getName(methodType), TR::DataType::getName(classType));
687687
return false;
688688
}
689689
}
@@ -1602,6 +1602,14 @@ TR::Node *TR_VectorAPIExpansion::binaryIntrinsicHandler(TR_VectorAPIExpansion *o
16021602
return naryIntrinsicHandler(opt, treeTop, node, elementType, vectorLength, numLanes, mode, 2, Other);
16031603
}
16041604

1605+
TR::Node *TR_VectorAPIExpansion::maskReductionCoercedIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node,
1606+
TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes,
1607+
handlerMode mode)
1608+
{
1609+
return naryIntrinsicHandler(opt, treeTop, node, elementType, vectorLength, numLanes, mode, 1, MaskReduction);
1610+
}
1611+
1612+
16051613
TR::Node *TR_VectorAPIExpansion::reductionCoercedIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node,
16061614
TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes,
16071615
handlerMode mode)
@@ -1633,11 +1641,16 @@ TR::Node *TR_VectorAPIExpansion::naryIntrinsicHandler(TR_VectorAPIExpansion *opt
16331641
TR::Node *opcodeNode = node->getFirstChild();
16341642
int firstOperand = 5;
16351643

1636-
if (opCodeType == Test)
1644+
if (opCodeType == Test || opCodeType == MaskReduction)
16371645
firstOperand = 4;
16381646

1639-
TR::Node *maskNode = node->getChild(firstOperand + numChildren); // each intrinsic has a mask argument
1640-
bool withMask = !maskNode->isConstZeroValue();
1647+
bool withMask = false;
1648+
1649+
if (opCodeType != MaskReduction)
1650+
{
1651+
TR::Node *maskNode = node->getChild(firstOperand + numChildren); // each intrinsic has a mask argument
1652+
withMask = !maskNode->isConstZeroValue();
1653+
}
16411654

16421655
if (withMask) numChildren++;
16431656

@@ -1758,12 +1771,27 @@ TR::Node *TR_VectorAPIExpansion::fromBitsCoercedIntrinsicHandler(TR_VectorAPIExp
17581771
{
17591772
TR::Compilation *comp = opt->comp();
17601773

1774+
TR::Node *broadcastTypeNode = node->getChild(4);
1775+
1776+
if (!broadcastTypeNode->getOpCode().isLoadConst())
1777+
{
1778+
if (opt->_trace) traceMsg(comp, "Unknown broadcast type in node %p\n", node);
1779+
return NULL;
1780+
}
1781+
1782+
int32_t broadcastType = broadcastTypeNode->get32bitIntegralValue();
1783+
1784+
TR_ASSERT_FATAL(broadcastType == MODE_BROADCAST || broadcastType == MODE_BITS_COERCED_LONG_TO_MASK,
1785+
"Unexpected broadcast type in node %p\n", node);
1786+
1787+
bool mask = (broadcastType == MODE_BITS_COERCED_LONG_TO_MASK);
1788+
17611789
if (mode == checkScalarization)
1762-
return node;
1790+
return mask ? NULL : node;
17631791

17641792
if (mode == checkVectorization)
17651793
{
1766-
TR::ILOpCodes splatsOpCode = TR::ILOpCode::createVectorOpCode(TR::vsplats,
1794+
TR::ILOpCodes splatsOpCode = TR::ILOpCode::createVectorOpCode(mask ? TR::mLongBitsToMask : TR::vsplats,
17671795
TR::DataType::createVectorType(elementType, vectorLength));
17681796

17691797
if (!comp->cg()->getSupportsOpCodeForAutoSIMD(splatsOpCode))
@@ -1829,7 +1857,7 @@ TR::Node *TR_VectorAPIExpansion::fromBitsCoercedIntrinsicHandler(TR_VectorAPIExp
18291857
{
18301858
node->setAndIncChild(0, newNode);
18311859
node->setNumChildren(1);
1832-
TR::ILOpCodes splatsOpCode = TR::ILOpCode::createVectorOpCode(TR::vsplats,
1860+
TR::ILOpCodes splatsOpCode = TR::ILOpCode::createVectorOpCode(mask ? TR::mLongBitsToMask : TR::vsplats,
18331861
TR::DataType::createVectorType(elementType, vectorLength));
18341862

18351863
TR::Node::recreate(node, splatsOpCode);
@@ -1926,7 +1954,7 @@ TR::ILOpCodes TR_VectorAPIExpansion::ILOpcodeFromVectorAPIOpcode(int32_t vectorA
19261954
case VECTOR_OP_MIN: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vreductionMin, vectorType);
19271955
case VECTOR_OP_MAX: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vreductionMax, vectorType);
19281956
case VECTOR_OP_AND: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vreductionAnd, vectorType);
1929-
case VECTOR_OP_OR: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vreductionOr, vectorType);
1957+
case VECTOR_OP_OR: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vreductionOr, vectorType);
19301958
case VECTOR_OP_XOR: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::vreductionXor, vectorType);
19311959
// These don't seem to be generated by the library:
19321960
// vreductionOrUnchecked
@@ -1935,6 +1963,18 @@ TR::ILOpCodes TR_VectorAPIExpansion::ILOpcodeFromVectorAPIOpcode(int32_t vectorA
19351963
return TR::BadILOp;
19361964
}
19371965
}
1966+
else if (opCodeType == MaskReduction)
1967+
{
1968+
switch (vectorAPIOpCode)
1969+
{
1970+
case VECTOR_OP_MASK_TRUECOUNT: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::mTrueCount, vectorType);
1971+
case VECTOR_OP_MASK_FIRSTTRUE: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::mFirstTrue, vectorType);
1972+
case VECTOR_OP_MASK_LASTTRUE: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::mLastTrue, vectorType);
1973+
case VECTOR_OP_MASK_TOLONG: return scalar ? TR::BadILOp : TR::ILOpCode::createVectorOpCode(TR::mToLongBits, vectorType);
1974+
default:
1975+
return TR::BadILOp;
1976+
}
1977+
}
19381978
else if (withMask)
19391979
{
19401980
switch (vectorAPIOpCode)
@@ -2113,16 +2153,17 @@ TR::Node *TR_VectorAPIExpansion::transformNary(TR_VectorAPIExpansion *opt, TR::T
21132153
TR_VectorAPIExpansion::methodTableEntry
21142154
TR_VectorAPIExpansion::methodTable[] =
21152155
{
2116-
{loadIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes}}, // jdk_internal_vm_vector_VectorSupport_load
2117-
{storeIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes, Unknown, Unknown, Vector}}, // jdk_internal_vm_vector_VectorSupport_store
2118-
{binaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_binaryOp
2119-
{blendIntrinsicHandler, Vector, {Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Unknown}}, // jdk_internal_vm_vector_VectorSupport_blend
2120-
{compareIntrinsicHandler, Mask, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_compare
2121-
{fromBitsCoercedIntrinsicHandler, Vector, {Unknown, ElementType, NumLanes, Unknown, Unknown, Unknown}}, // jdk_internal_vm_vector_VectorSupport_fromBitsCoerced
2122-
{reductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_reductionCoerced
2123-
{ternaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_ternaryOp
2124-
{testIntrinsicHandler, Scalar, {Unknown, Unknown, ElementType, NumLanes, Mask, Mask, Unknown}}, // jdk_internal_vm_vector_VectorSupport_test
2125-
{unaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_unaryOp
2156+
{loadIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes}}, // jdk_internal_vm_vector_VectorSupport_load
2157+
{storeIntrinsicHandler, Unknown, {Unknown, ElementType, NumLanes, Unknown, Unknown, Vector}}, // jdk_internal_vm_vector_VectorSupport_store
2158+
{binaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_binaryOp
2159+
{blendIntrinsicHandler, Vector, {Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Unknown}}, // jdk_internal_vm_vector_VectorSupport_blend
2160+
{compareIntrinsicHandler, Mask, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_compare
2161+
{fromBitsCoercedIntrinsicHandler, Vector, {Unknown, ElementType, NumLanes, Unknown, Unknown, Unknown}}, // jdk_internal_vm_vector_VectorSupport_fromBitsCoerced
2162+
{maskReductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, ElementType, NumLanes, Mask}}, // jdk_internal_vm_vector_VectorSupport_maskReductionCoerced
2163+
{reductionCoercedIntrinsicHandler, Scalar, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_reductionCoerced
2164+
{ternaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Vector, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_ternaryOp
2165+
{testIntrinsicHandler, Scalar, {Unknown, Unknown, ElementType, NumLanes, Mask, Mask, Unknown}}, // jdk_internal_vm_vector_VectorSupport_test
2166+
{unaryIntrinsicHandler, Vector, {Unknown, Unknown, Unknown, ElementType, NumLanes, Vector, Mask}}, // jdk_internal_vm_vector_VectorSupport_unaryOp
21262167
};
21272168

21282169

runtime/compiler/optimizer/VectorAPIExpansion.hpp

+45-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,14 @@ class TR_VectorAPIExpansion : public TR::Optimization
120120
static int32_t const VECTOR_OP_URSHIFT = 16;
121121

122122
static int32_t const VECTOR_OP_CAST = 17;
123-
static int32_t const VECTOR_OP_REINTERPRET = 18;
123+
static int32_t const VECTOR_OP_UCAST = 18;
124+
static int32_t const VECTOR_OP_REINTERPRET = 19;
125+
126+
// Mask manipulation operations
127+
static int32_t const VECTOR_OP_MASK_TRUECOUNT = 20;
128+
static int32_t const VECTOR_OP_MASK_FIRSTTRUE = 21;
129+
static int32_t const VECTOR_OP_MASK_LASTTRUE = 22;
130+
static int32_t const VECTOR_OP_MASK_TOLONG = 23;
124131

125132
// Compare
126133
static int32_t const BT_eq = 0;
@@ -137,6 +144,10 @@ class TR_VectorAPIExpansion : public TR::Optimization
137144
static int32_t const BT_ult = BT_lt | BT_unsigned_compare;
138145
static int32_t const BT_ugt = BT_gt | BT_unsigned_compare;
139146

147+
// Various broadcasting modes.
148+
static int32_t const MODE_BROADCAST = 0;
149+
static int32_t const MODE_BITS_COERCED_LONG_TO_MASK = 1;
150+
140151
/** \brief
141152
* Is passed to methods handlers during analysis and transforamtion phases
142153
*
@@ -175,6 +186,7 @@ class TR_VectorAPIExpansion : public TR::Optimization
175186
enum vapiOpCodeType
176187
{
177188
Compare,
189+
MaskReduction,
178190
Reduction,
179191
Test,
180192
Other
@@ -800,6 +812,38 @@ class TR_VectorAPIExpansion : public TR::Optimization
800812
*/
801813
static TR::Node *binaryIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node, TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, handlerMode mode);
802814

815+
/** \brief
816+
* Scalarizes or vectorizes a node that is a call to \c VectorSupport.maskReductionCoerced() intrinsic.
817+
* In both cases, the node is modified in place.
818+
* In the case of scalarization, extra nodes are created(number of lanes minus one)
819+
*
820+
* \param opt
821+
* This optimization object
822+
*
823+
* \param treeTop
824+
* Tree top of the \c node
825+
*
826+
* \param node
827+
* Node to transform
828+
*
829+
* \param elementType
830+
* Element type
831+
*
832+
* \param vectorLength
833+
* Vector length
834+
*
835+
* \param numLanes
836+
* Number of elements
837+
*
838+
* \param mode
839+
* Handler mode
840+
*
841+
* \return
842+
* Transformed node
843+
*
844+
*/
845+
static TR::Node *maskReductionCoercedIntrinsicHandler(TR_VectorAPIExpansion *opt, TR::TreeTop *treeTop, TR::Node *node, TR::DataType elementType, TR::VectorLength vectorLength, int32_t numLanes, handlerMode mode);
846+
803847
/** \brief
804848
* Scalarizes or vectorizes a node that is a call to \c VectorSupport.reductionCoerced() intrinsic.
805849
* In both cases, the node is modified in place.

0 commit comments

Comments
 (0)