Author: Simon Pilgrim Date: 2021-01-14T11:05:19Z New Revision: af8d27a7a8266b89916b5e4db2b2fd97eb7d84e5
URL: https://github.com/llvm/llvm-project/commit/af8d27a7a8266b89916b5e4db2b2fd97eb7d84e5 DIFF: https://github.com/llvm/llvm-project/commit/af8d27a7a8266b89916b5e4db2b2fd97eb7d84e5.diff LOG: [DAG] visitVECTOR_SHUFFLE - pull out shuffle merging code into lambda helper. NFCI. Make it easier to reuse in a future patch. Added: Modified: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Removed: ################################################################################ diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 24bc7fe7e0ad..f4c9b814b806 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -20823,30 +20823,19 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { return DAG.getCommutedVectorShuffle(*SVN); } - // Try to fold according to rules: - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) - // Don't try to fold shuffles with illegal type. - // Only fold if this shuffle is the only user of the other shuffle. - if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) && - Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { - ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0); - + // Compute the combined shuffle mask for a shuffle with SV0 as the first + // operand, and SV1 as the second operand. + // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask). + auto MergeInnerShuffle = [NumElts](ShuffleVectorSDNode *SVN, + ShuffleVectorSDNode *OtherSVN, SDValue N1, + SDValue &SV0, SDValue &SV1, + SmallVectorImpl<int> &Mask) -> bool { // Don't try to fold splats; they're likely to simplify somehow, or they // might be free. - if (OtherSV->isSplat()) - return SDValue(); - - // The incoming shuffle must be of the same type as the result of the - // current shuffle. - assert(OtherSV->getOperand(0).getValueType() == VT && - "Shuffle types don't match"); + if (OtherSVN->isSplat()) + return false; - SDValue SV0, SV1; - SmallVector<int, 4> Mask; - // Compute the combined shuffle mask for a shuffle with SV0 as the first - // operand, and SV1 as the second operand. + Mask.clear(); for (unsigned i = 0; i != NumElts; ++i) { int Idx = SVN->getMaskElt(i); if (Idx < 0) { @@ -20859,15 +20848,14 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (Idx < (int)NumElts) { // This shuffle index refers to the inner shuffle N0. Lookup the inner // shuffle mask to identify which vector is actually referenced. - Idx = OtherSV->getMaskElt(Idx); + Idx = OtherSVN->getMaskElt(Idx); if (Idx < 0) { // Propagate Undef. Mask.push_back(Idx); continue; } - - CurrentVec = (Idx < (int) NumElts) ? OtherSV->getOperand(0) - : OtherSV->getOperand(1); + CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0) + : OtherSVN->getOperand(1); } else { // This shuffle index references an element within N1. CurrentVec = N1; @@ -20892,31 +20880,52 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // Bail out if we cannot convert the shuffle pair into a single shuffle. if (SV1.getNode() && SV1 != CurrentVec) - return SDValue(); + return false; // Ok. CurrentVec is the right hand side. // Update the mask accordingly. SV1 = CurrentVec; Mask.push_back(Idx + NumElts); } + return true; + }; - // Check if all indices in Mask are Undef. In case, propagate Undef. - if (llvm::all_of(Mask, [](int M) { return M < 0; })) - return DAG.getUNDEF(VT); + // Try to fold according to rules: + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) + // Don't try to fold shuffles with illegal type. + // Only fold if this shuffle is the only user of the other shuffle. + if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) && + Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { + ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0); + + // The incoming shuffle must be of the same type as the result of the + // current shuffle. + assert(OtherSV->getOperand(0).getValueType() == VT && + "Shuffle types don't match"); - if (!SV0.getNode()) - SV0 = DAG.getUNDEF(VT); - if (!SV1.getNode()) - SV1 = DAG.getUNDEF(VT); - - // Avoid introducing shuffles with illegal mask. - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2) - return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG); + SDValue SV0, SV1; + SmallVector<int, 4> Mask; + if (MergeInnerShuffle(SVN, OtherSV, N1, SV0, SV1, Mask)) { + // Check if all indices in Mask are Undef. In case, propagate Undef. + if (llvm::all_of(Mask, [](int M) { return M < 0; })) + return DAG.getUNDEF(VT); + + if (!SV0.getNode()) + SV0 = DAG.getUNDEF(VT); + if (!SV1.getNode()) + SV1 = DAG.getUNDEF(VT); + + // Avoid introducing shuffles with illegal mask. + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2) + return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG); + } } if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG)) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits