Skip to content

Commit

Permalink
Merge pull request #18607 from Akira1Saitoh/aarch64StringIndexof
Browse files Browse the repository at this point in the history
AArch64: Add vectorized implementation of intrinsicIndexOf helper functions
  • Loading branch information
knn-k authored Dec 15, 2023
2 parents db584a0 + 567cb0d commit a81e6b3
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
4 changes: 4 additions & 0 deletions runtime/compiler/aarch64/codegen/J9CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ J9::ARM64::CodeGenerator::initialize()
{
cg->setSupportsInlineStringHashCode();
}
if ((!TR::Compiler->om.canGenerateArraylets()) && (!comp->getOption(TR_DisableFastStringIndexOf)))
{
cg->setSupportsInlineStringIndexOf();
}
if (comp->fej9()->hasFixedFrameC_CallingConvention())
cg->setHasFixedFrameC_CallingConvention();
}
Expand Down
290 changes: 290 additions & 0 deletions runtime/compiler/aarch64/codegen/J9TreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6214,6 +6214,280 @@ static TR::Register *inlineStringHashCode(TR::Node *node, bool isCompressed, TR:
return resultReg;
}

/**
* @brief Generates inlined instructions equivalent to com/ibm/jit/JITHelpers.intrinsicIndexOfLatin1 or com/ibm/jit/JITHelpers.intrinsicIndexOfUTF16
*
* @param node: node
* @param cg: Code Generator
* @param isLatin1: true when the string is Latin1, false when the string is UTF16
* @returns register
*/
static TR::Register* inlineIntrinsicIndexOf(TR::Node* node, TR::CodeGenerator* cg, bool isLatin1)
{
/*
* add dataAddrReg, addressReg, #headerSize ; get the starting address of byte array data
* add endReg, dataAddrReg, lengthReg(, lsl #1 if utf16) ; get the address of the end of the byte array
* subs lengthReg, lengthReg, offsetReg
* b.eq Ldone
* add tmp0Reg, dataAddrReg, #16 ; save starting address + 16 to tmp0Reg
* add dataAddrReg, dataAddrReg, offsetReg(, lsl #1 if utf16)
*
* if isLatin1
* dup vtmp0Reg.16b, charReg
* else
* dup vtmp0Reg.8h, charReg
* subs lengthReg, lengthReg (#8 if isLatin1 else #4)
* ; If the length < 8bytes, then it is not guaranteed that we can read the last 16bytes because the minimum size of the array header is 8bytes.
* b.lt Lessthan8
* subs lengthReg, lengthReg (#8 if isLatin1 else #4)
* b.lt Lresidual
* Loop:
* subs lengthReg, lengthReg, (#16 if isLatin1 else #8)
* ldr vtmp1Reg, [dataAddrReg], #16
* if isLatin1
* cmeq vtmp1Reg.16b, vtmp0Reg.16b, vtmp1Reg.16b
* shrn vtmp1Reg.8b, vtmp1Reg.8h, #4
* else
* cmeq vtmp1Reg.8h, vtmp0Reg.8h, vtmp1Reg.8h
* xtn vtmp1Reg.8b, vtmp1Reg.8h
* umov tmp1Reg, vtmp1Reg.d[0]
* ccmp tmp1Reg, #0, #0, cs ; if lengthReg < 0, clear all condition flags
* b.eq Loop
* cbnz tmp1Reg, Lfound
* cmn lengthReg, (#16 if isLatin1 else #8)
* b.eq Ldone
* Lresidual:
* ldurq vtmp1Reg, [endReg, #-16] ; Read the last 16bytes. It would not a problem if original length >= 8 because we have a array header space.
* if isLatin1
* cmeq vtmp1Reg.16b, vtmp0Reg.16b, vtmp1Reg.16b
* else
* cmeq vtmp1Reg.8h, vtmp0Reg.8h, vtmp1Reg.8h
* b Lmerge
* LessThan8:
* ldurd vtmp1Reg, [endReg, #-8]
* if isLatin1
* cmeq vtmp1Reg.8b, vtmp0Reg.8b, vtmp1Reg.8b
* else
* cmeq vtmp1Reg.4h, vtmp0Reg.4h, vtmp1Reg.4h
* Lmerge:
* add dataAddrReg, dataAddrReg, #16
* neg lengthReg, lengthReg, lsl (#2 if isLatin1 else #3) ; 64 - (4 or 8) * remaining length
* if isLatin1
* shrn vtmp1Reg.8b, vtmp1Reg.8h, #4
* else
* xtn vtmp1Reg.8b, vtmp1Reg.8h, #4
* umov tmp1Reg, vtmp1Reg.d[0]
* lsr tmp1Reg, tmp1Reg, lengthReg ; Move the comparison result of remaining data to LSB
* cmp tmp1Reg, #0
* b.eq Ldone
* Lfound:
* rbit tmp1Reg, tmp1Reg
* sub dataAddrReg, dataAddrReg, tmp0Reg
* clz tmp1Reg, tmp1Reg
* add resultReg, dataAddrReg, tmp1Reg, lsr #2
* if !isLatin
* lsr resultReg, resultReg, #1
* Ldone:
* csinv resultReg, resultReg, xzr, ne
*
*/

/*
* We omit to evaluate the first child (receiver) as it is not used.
*/
TR::Node *arrayNode = node->getSecondChild();
TR::Node *charNode = node->getThirdChild();
TR::Node *offsetNode = node->getChild(3);
TR::Node *lengthNode = node->getChild(4);
TR::Register *arrayReg = cg->evaluate(arrayNode);
TR::Register *charReg = cg->evaluate(charNode);
const bool isOffsetConstZero = offsetNode->isConstZeroValue();
TR::Register *offsetReg = isOffsetConstZero ? NULL : cg->evaluate(offsetNode);
TR::Register *savedLengthReg = cg->evaluate(lengthNode);
TR_ARM64ScratchRegisterManager *srm = cg->generateScratchRegisterManager();
TR::Register *dataAddrReg = (arrayNode->getReferenceCount() > 1) ? srm->findOrCreateScratchRegister() : arrayReg;
TR::Register *lengthReg = (lengthNode->getReferenceCount() > 1) ? srm->findOrCreateScratchRegister() : savedLengthReg;
TR::Register *tmp0Reg = srm->findOrCreateScratchRegister();
/* The live range of tmp1Reg and offsetReg does not overwrap if offsetNode's reference count is 1, so we use the same register in that case. */
TR::Register *tmp1Reg = ((offsetNode->getReferenceCount() > 1) || isOffsetConstZero) ? srm->findOrCreateScratchRegister() : offsetReg;
/* The live range of endReg and resultReg does not overwrap, so we use the same register. */
TR::Register *endReg = cg->allocateRegister();
TR::Register *resultReg = endReg;
TR_Debug *debugObj = cg->getDebug();

generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, dataAddrReg, arrayReg, TR::Compiler->om.contiguousArrayHeaderSizeInBytes());
if (isLatin1)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, endReg, dataAddrReg, savedLengthReg);
}
else
{
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, endReg, dataAddrReg, savedLengthReg, TR::SH_LSL, 1);
}
TR::Compilation *comp = cg->comp();
if (comp->getOptions()->enableDebugCounters())
{
cg->generateDebugCounter(TR::DebugCounter::debugCounterName(comp, "cg.StringIndexOf/(%s)/%s/%s",
comp->signature(),
(isLatin1 ? "compressed" : "decompressed"),
comp->getHotnessName()), *srm);
}
TR::LabelSymbol *doneLabel = generateLabelSymbol(cg);
if (!isOffsetConstZero)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::subsx, node, lengthReg, savedLengthReg, offsetReg);
generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
}
else
{
generateMovInstruction(cg, node, lengthReg, savedLengthReg);
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, tmp0Reg, dataAddrReg, 16);
if (!isOffsetConstZero)
{
if (isLatin1)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, dataAddrReg, dataAddrReg, offsetReg);
}
else
{
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, dataAddrReg, dataAddrReg, offsetReg, TR::SH_LSL, 1);
}
}

TR::Register *vtmp0Reg = srm->findOrCreateScratchRegister(TR_VRF);

generateTrg1Src1Instruction(cg, isLatin1 ? TR::InstOpCode::vdup16b : TR::InstOpCode::vdup8h, node, vtmp0Reg, charReg);
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 8 : 4);

TR::LabelSymbol *lessThan8Label = generateLabelSymbol(cg);
auto branchToLessThan8LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan8Label, TR::CC_LT);

generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 8 : 4);

TR::LabelSymbol *residualLabel = generateLabelSymbol(cg);
auto branchToResidualLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, residualLabel, TR::CC_LT);
if (comp->getOptions()->enableDebugCounters())
{
cg->generateDebugCounter(TR::DebugCounter::debugCounterName(comp, "cg.StringIndexOf/(%s)/%s/%s:long",
comp->signature(),
(isLatin1 ? "compressed" : "decompressed"),
comp->getHotnessName()), *srm);
}
/*
* Main loop: 16 bytes are processed in 1 iteration of the loop.
*/
TR::LabelSymbol *loopLabel = generateLabelSymbol(cg);
auto loopLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, loopLabel);
if (debugObj)
{
debugObj->addInstructionComment(branchToLessThan8LabelInstr, "Branch to lessThan8 if remaining length < (isLatin1 ? 8 : 4)");
debugObj->addInstructionComment(branchToResidualLabelInstr, "Branch to residualLabel if remaining length < (isLatin1 ? 16 : 8)");
debugObj->addInstructionComment(loopLabelInstr, "loopLabel");
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 16 : 8);

TR::Register *vtmp1Reg = srm->findOrCreateScratchRegister(TR_VRF);

generateTrg1MemInstruction(cg, TR::InstOpCode::vldrpostq, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, dataAddrReg, 16));
generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq16b : TR::InstOpCode::vcmeq8h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg);
if (isLatin1)
{
generateVectorShiftImmediateInstruction(cg, TR::InstOpCode::vshrn_8b, node, vtmp1Reg, vtmp1Reg, 4);
}
else
{
generateTrg1Src1Instruction(cg, TR::InstOpCode::vxtn_8b, node, vtmp1Reg, vtmp1Reg);
}

generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmp1Reg, vtmp1Reg, 0);
generateConditionalCompareImmInstruction(cg, node, tmp1Reg, 0, 0, TR::CC_CS, true);
auto branchBackToLoopLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, loopLabel, TR::CC_EQ);
TR::LabelSymbol *foundLabel = generateLabelSymbol(cg);
auto branchToFoundLabelInstr = generateCompareBranchInstruction(cg, TR::InstOpCode::cbnzx, node, tmp1Reg, foundLabel);
generateCompareImmInstruction(cg, node, lengthReg, (isLatin1 ? (-16) : (-8)), true);
auto branchToDoneLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
auto residualLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, residualLabel);
generateTrg1MemInstruction(cg, TR::InstOpCode::vldurq, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, endReg, -16));
generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq16b : TR::InstOpCode::vcmeq8h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg);
TR::LabelSymbol *mergeLabel = generateLabelSymbol(cg);
auto branchToMergeLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, mergeLabel);
auto lessThan8LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, lessThan8Label);
generateTrg1MemInstruction(cg, TR::InstOpCode::vldurd, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, endReg, -8));
generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq8b : TR::InstOpCode::vcmeq4h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg);
auto mergeLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, mergeLabel);
if (debugObj)
{
debugObj->addInstructionComment(branchBackToLoopLabelInstr, "Jump back to loopLabel if the character is not found and the remaining data is >= 16 bytes");
debugObj->addInstructionComment(branchToFoundLabelInstr, "Branch to foundLabel if the character found");
debugObj->addInstructionComment(branchToDoneLabelInstr, "Branch to doneLabel if no data is left");
debugObj->addInstructionComment(residualLabelInstr, "residualLabel");
debugObj->addInstructionComment(lessThan8LabelInstr, "lessThan8Label");
debugObj->addInstructionComment(branchToMergeLabelInstr, "Branch to mergeLabel");
debugObj->addInstructionComment(mergeLabelInstr, "mergeLabel");
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, dataAddrReg, dataAddrReg, 16);

TR::Register *zeroReg = cg->allocateRegister();
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::subx, node, lengthReg, zeroReg, lengthReg, TR::SH_LSL, (isLatin1 ? 2 : 3));
if (isLatin1)
{
generateVectorShiftImmediateInstruction(cg, TR::InstOpCode::vshrn_8b, node, vtmp1Reg, vtmp1Reg, 4);
}
else
{
generateTrg1Src1Instruction(cg, TR::InstOpCode::vxtn_8b, node, vtmp1Reg, vtmp1Reg);
}
generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmp1Reg, vtmp1Reg, 0);
generateTrg1Src2Instruction(cg, TR::InstOpCode::lsrvx, node, tmp1Reg, tmp1Reg, lengthReg);
generateCompareImmInstruction(cg, node, tmp1Reg, 0, true);
auto branchToDoneLabelInstr2 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
auto foundLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, foundLabel);
generateTrg1Src1Instruction(cg, TR::InstOpCode::rbitx, node, tmp1Reg, tmp1Reg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, dataAddrReg, dataAddrReg, tmp0Reg);
generateTrg1Src1Instruction(cg, TR::InstOpCode::clzx, node, tmp1Reg, tmp1Reg);

generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, resultReg, dataAddrReg, tmp1Reg, TR::SH_LSR, 2);
if (!isLatin1)
{
generateLogicalShiftRightImmInstruction(cg, node, resultReg, resultReg, 1);
}

auto doneLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, doneLabel);
if (debugObj)
{
debugObj->addInstructionComment(branchToDoneLabelInstr2, "Branch to doneLabel if the character is not found");
debugObj->addInstructionComment(foundLabelInstr, "foundLabel");
debugObj->addInstructionComment(doneLabelInstr, "doneLabel");
}
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::csinvx, node, resultReg, resultReg, zeroReg, TR::CC_NE);

TR::RegisterDependencyConditions *conditions = new (cg->trHeapMemory()) TR::RegisterDependencyConditions(0, (isOffsetConstZero ? 5 : 6) + srm->numAvailableRegisters(), cg->trMemory());
conditions->addPostCondition(arrayReg, TR::RealRegister::NoReg);
conditions->addPostCondition(charReg, TR::RealRegister::NoReg);
conditions->addPostCondition(savedLengthReg, TR::RealRegister::NoReg);
if (!isOffsetConstZero)
{
conditions->addPostCondition(offsetReg, TR::RealRegister::NoReg);
}
conditions->addPostCondition(zeroReg, TR::RealRegister::xzr);
conditions->addPostCondition(resultReg, TR::RealRegister::NoReg);
srm->addScratchRegistersToDependencyList(conditions);

generateLabelInstruction(cg, TR::InstOpCode::label, node, generateLabelSymbol(cg), conditions);

node->setRegister(resultReg);
cg->stopUsingRegister(zeroReg);
srm->stopUsingRegisters();
cg->recursivelyDecReferenceCount(node->getFirstChild());
cg->decReferenceCount(arrayNode);
cg->decReferenceCount(charNode);
cg->decReferenceCount(offsetNode);
cg->decReferenceCount(lengthNode);

return resultReg;
}

bool
J9::ARM64::CodeGenerator::inlineDirectCall(TR::Node *node, TR::Register *&resultReg)
{
Expand All @@ -6229,6 +6503,22 @@ J9::ARM64::CodeGenerator::inlineDirectCall(TR::Node *node, TR::Register *&result
{
switch (methodSymbol->getMandatoryRecognizedMethod())
{
case TR::com_ibm_jit_JITHelpers_intrinsicIndexOfLatin1:
if (cg->getSupportsInlineStringIndexOf())
{
resultReg = inlineIntrinsicIndexOf(node, cg, true);
return true;
}
break;

case TR::com_ibm_jit_JITHelpers_intrinsicIndexOfUTF16:
if (cg->getSupportsInlineStringIndexOf())
{
resultReg = inlineIntrinsicIndexOf(node, cg, false);
return true;
}
break;

case TR::java_lang_String_hashCodeImplDecompressed:
if (cg->getSupportsInlineStringHashCode())
{
Expand Down

0 comments on commit a81e6b3

Please sign in to comment.