Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AArch64: Add vectorized implementation of intrinsicIndexOf helper functions #18607

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions runtime/compiler/aarch64/codegen/J9CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ J9::ARM64::CodeGenerator::initialize()
{
cg->setSupportsInlineStringHashCode();
}
if ((!TR::Compiler->om.canGenerateArraylets()) && (!comp->getOption(TR_DisableFastStringIndexOf)))
{
cg->setSupportsInlineStringIndexOf();
}
if (comp->fej9()->hasFixedFrameC_CallingConvention())
cg->setHasFixedFrameC_CallingConvention();
}
Expand Down
290 changes: 290 additions & 0 deletions runtime/compiler/aarch64/codegen/J9TreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6214,6 +6214,280 @@ static TR::Register *inlineStringHashCode(TR::Node *node, bool isCompressed, TR:
return resultReg;
}

/**
* @brief Generates inlined instructions equivalent to com/ibm/jit/JITHelpers.intrinsicIndexOfLatin1 or com/ibm/jit/JITHelpers.intrinsicIndexOfUTF16
*
* @param node: node
* @param cg: Code Generator
* @param isLatin1: true when the string is Latin1, false when the string is UTF16
* @returns register
*/
static TR::Register* inlineIntrinsicIndexOf(TR::Node* node, TR::CodeGenerator* cg, bool isLatin1)
{
/*
* add dataAddrReg, addressReg, #headerSize ; get the starting address of byte array data
* add endReg, dataAddrReg, lengthReg(, lsl #1 if utf16) ; get the address of the end of the byte array
* subs lengthReg, lengthReg, offsetReg
* b.eq Ldone
* add tmp0Reg, dataAddrReg, #16 ; save starting address + 16 to tmp0Reg
* add dataAddrReg, dataAddrReg, offsetReg(, lsl #1 if utf16)
*
* if isLatin1
* dup vtmp0Reg.16b, charReg
* else
* dup vtmp0Reg.8h, charReg
* subs lengthReg, lengthReg (#8 if isLatin1 else #4)
* ; If the length < 8bytes, then it is not guaranteed that we can read the last 16bytes because the minimum size of the array header is 8bytes.
* b.lt Lessthan8
* subs lengthReg, lengthReg (#8 if isLatin1 else #4)
* b.lt Lresidual
* Loop:
* subs lengthReg, lengthReg, (#16 if isLatin1 else #8)
* ldr vtmp1Reg, [dataAddrReg], #16
* if isLatin1
* cmeq vtmp1Reg.16b, vtmp0Reg.16b, vtmp1Reg.16b
* shrn vtmp1Reg.8b, vtmp1Reg.8h, #4
* else
* cmeq vtmp1Reg.8h, vtmp0Reg.8h, vtmp1Reg.8h
* xtn vtmp1Reg.8b, vtmp1Reg.8h
* umov tmp1Reg, vtmp1Reg.d[0]
* ccmp tmp1Reg, #0, #0, cs ; if lengthReg < 0, clear all condition flags
* b.eq Loop
* cbnz tmp1Reg, Lfound
* cmn lengthReg, (#16 if isLatin1 else #8)
* b.eq Ldone
* Lresidual:
* ldurq vtmp1Reg, [endReg, #-16] ; Read the last 16bytes. It would not a problem if original length >= 8 because we have a array header space.
* if isLatin1
* cmeq vtmp1Reg.16b, vtmp0Reg.16b, vtmp1Reg.16b
* else
* cmeq vtmp1Reg.8h, vtmp0Reg.8h, vtmp1Reg.8h
* b Lmerge
* LessThan8:
* ldurd vtmp1Reg, [endReg, #-8]
* if isLatin1
* cmeq vtmp1Reg.8b, vtmp0Reg.8b, vtmp1Reg.8b
* else
* cmeq vtmp1Reg.4h, vtmp0Reg.4h, vtmp1Reg.4h
* Lmerge:
* add dataAddrReg, dataAddrReg, #16
* neg lengthReg, lengthReg, lsl (#2 if isLatin1 else #3) ; 64 - (4 or 8) * remaining length
* if isLatin1
* shrn vtmp1Reg.8b, vtmp1Reg.8h, #4
* else
* xtn vtmp1Reg.8b, vtmp1Reg.8h, #4
* umov tmp1Reg, vtmp1Reg.d[0]
* lsr tmp1Reg, tmp1Reg, lengthReg ; Move the comparison result of remaining data to LSB
* cmp tmp1Reg, #0
* b.eq Ldone
* Lfound:
* rbit tmp1Reg, tmp1Reg
* sub dataAddrReg, dataAddrReg, tmp0Reg
* clz tmp1Reg, tmp1Reg
* add resultReg, dataAddrReg, tmp1Reg, lsr #2
* if !isLatin
* lsr resultReg, resultReg, #1
* Ldone:
* csinv resultReg, resultReg, xzr, ne
*
*/

/*
* We omit to evaluate the first child (receiver) as it is not used.
*/
TR::Node *arrayNode = node->getSecondChild();
TR::Node *charNode = node->getThirdChild();
TR::Node *offsetNode = node->getChild(3);
TR::Node *lengthNode = node->getChild(4);
TR::Register *arrayReg = cg->evaluate(arrayNode);
TR::Register *charReg = cg->evaluate(charNode);
const bool isOffsetConstZero = offsetNode->isConstZeroValue();
TR::Register *offsetReg = isOffsetConstZero ? NULL : cg->evaluate(offsetNode);
TR::Register *savedLengthReg = cg->evaluate(lengthNode);
TR_ARM64ScratchRegisterManager *srm = cg->generateScratchRegisterManager();
TR::Register *dataAddrReg = (arrayNode->getReferenceCount() > 1) ? srm->findOrCreateScratchRegister() : arrayReg;
TR::Register *lengthReg = (lengthNode->getReferenceCount() > 1) ? srm->findOrCreateScratchRegister() : savedLengthReg;
TR::Register *tmp0Reg = srm->findOrCreateScratchRegister();
/* The live range of tmp1Reg and offsetReg does not overwrap if offsetNode's reference count is 1, so we use the same register in that case. */
TR::Register *tmp1Reg = ((offsetNode->getReferenceCount() > 1) || isOffsetConstZero) ? srm->findOrCreateScratchRegister() : offsetReg;
/* The live range of endReg and resultReg does not overwrap, so we use the same register. */
TR::Register *endReg = cg->allocateRegister();
TR::Register *resultReg = endReg;
TR_Debug *debugObj = cg->getDebug();

generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, dataAddrReg, arrayReg, TR::Compiler->om.contiguousArrayHeaderSizeInBytes());
if (isLatin1)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, endReg, dataAddrReg, savedLengthReg);
}
else
{
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, endReg, dataAddrReg, savedLengthReg, TR::SH_LSL, 1);
}
TR::Compilation *comp = cg->comp();
if (comp->getOptions()->enableDebugCounters())
{
cg->generateDebugCounter(TR::DebugCounter::debugCounterName(comp, "cg.StringIndexOf/(%s)/%s/%s",
comp->signature(),
(isLatin1 ? "compressed" : "decompressed"),
comp->getHotnessName()), *srm);
}
TR::LabelSymbol *doneLabel = generateLabelSymbol(cg);
if (!isOffsetConstZero)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::subsx, node, lengthReg, savedLengthReg, offsetReg);
generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
}
else
{
generateMovInstruction(cg, node, lengthReg, savedLengthReg);
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, tmp0Reg, dataAddrReg, 16);
if (!isOffsetConstZero)
{
if (isLatin1)
{
generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, dataAddrReg, dataAddrReg, offsetReg);
}
else
{
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, dataAddrReg, dataAddrReg, offsetReg, TR::SH_LSL, 1);
}
}

TR::Register *vtmp0Reg = srm->findOrCreateScratchRegister(TR_VRF);

generateTrg1Src1Instruction(cg, isLatin1 ? TR::InstOpCode::vdup16b : TR::InstOpCode::vdup8h, node, vtmp0Reg, charReg);
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 8 : 4);

TR::LabelSymbol *lessThan8Label = generateLabelSymbol(cg);
auto branchToLessThan8LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan8Label, TR::CC_LT);

generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 8 : 4);

TR::LabelSymbol *residualLabel = generateLabelSymbol(cg);
auto branchToResidualLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, residualLabel, TR::CC_LT);
if (comp->getOptions()->enableDebugCounters())
{
cg->generateDebugCounter(TR::DebugCounter::debugCounterName(comp, "cg.StringIndexOf/(%s)/%s/%s:long",
comp->signature(),
(isLatin1 ? "compressed" : "decompressed"),
comp->getHotnessName()), *srm);
}
/*
* Main loop: 16 bytes are processed in 1 iteration of the loop.
*/
TR::LabelSymbol *loopLabel = generateLabelSymbol(cg);
auto loopLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, loopLabel);
if (debugObj)
{
debugObj->addInstructionComment(branchToLessThan8LabelInstr, "Branch to lessThan8 if remaining length < (isLatin1 ? 8 : 4)");
debugObj->addInstructionComment(branchToResidualLabelInstr, "Branch to residualLabel if remaining length < (isLatin1 ? 16 : 8)");
debugObj->addInstructionComment(loopLabelInstr, "loopLabel");
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 16 : 8);

TR::Register *vtmp1Reg = srm->findOrCreateScratchRegister(TR_VRF);

generateTrg1MemInstruction(cg, TR::InstOpCode::vldrpostq, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, dataAddrReg, 16));
generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq16b : TR::InstOpCode::vcmeq8h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg);
if (isLatin1)
{
generateVectorShiftImmediateInstruction(cg, TR::InstOpCode::vshrn_8b, node, vtmp1Reg, vtmp1Reg, 4);
}
else
{
generateTrg1Src1Instruction(cg, TR::InstOpCode::vxtn_8b, node, vtmp1Reg, vtmp1Reg);
}

generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmp1Reg, vtmp1Reg, 0);
generateConditionalCompareImmInstruction(cg, node, tmp1Reg, 0, 0, TR::CC_CS, true);
auto branchBackToLoopLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, loopLabel, TR::CC_EQ);
TR::LabelSymbol *foundLabel = generateLabelSymbol(cg);
auto branchToFoundLabelInstr = generateCompareBranchInstruction(cg, TR::InstOpCode::cbnzx, node, tmp1Reg, foundLabel);
generateCompareImmInstruction(cg, node, lengthReg, (isLatin1 ? (-16) : (-8)), true);
auto branchToDoneLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
auto residualLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, residualLabel);
generateTrg1MemInstruction(cg, TR::InstOpCode::vldurq, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, endReg, -16));
generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq16b : TR::InstOpCode::vcmeq8h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg);
TR::LabelSymbol *mergeLabel = generateLabelSymbol(cg);
auto branchToMergeLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, mergeLabel);
auto lessThan8LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, lessThan8Label);
generateTrg1MemInstruction(cg, TR::InstOpCode::vldurd, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, endReg, -8));
generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq8b : TR::InstOpCode::vcmeq4h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg);
auto mergeLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, mergeLabel);
if (debugObj)
{
debugObj->addInstructionComment(branchBackToLoopLabelInstr, "Jump back to loopLabel if the character is not found and the remaining data is >= 16 bytes");
debugObj->addInstructionComment(branchToFoundLabelInstr, "Branch to foundLabel if the character found");
debugObj->addInstructionComment(branchToDoneLabelInstr, "Branch to doneLabel if no data is left");
debugObj->addInstructionComment(residualLabelInstr, "residualLabel");
debugObj->addInstructionComment(lessThan8LabelInstr, "lessThan8Label");
debugObj->addInstructionComment(branchToMergeLabelInstr, "Branch to mergeLabel");
debugObj->addInstructionComment(mergeLabelInstr, "mergeLabel");
}
generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, dataAddrReg, dataAddrReg, 16);

TR::Register *zeroReg = cg->allocateRegister();
generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::subx, node, lengthReg, zeroReg, lengthReg, TR::SH_LSL, (isLatin1 ? 2 : 3));
if (isLatin1)
{
generateVectorShiftImmediateInstruction(cg, TR::InstOpCode::vshrn_8b, node, vtmp1Reg, vtmp1Reg, 4);
}
else
{
generateTrg1Src1Instruction(cg, TR::InstOpCode::vxtn_8b, node, vtmp1Reg, vtmp1Reg);
}
generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmp1Reg, vtmp1Reg, 0);
generateTrg1Src2Instruction(cg, TR::InstOpCode::lsrvx, node, tmp1Reg, tmp1Reg, lengthReg);
generateCompareImmInstruction(cg, node, tmp1Reg, 0, true);
auto branchToDoneLabelInstr2 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ);
auto foundLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, foundLabel);
generateTrg1Src1Instruction(cg, TR::InstOpCode::rbitx, node, tmp1Reg, tmp1Reg);
generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, dataAddrReg, dataAddrReg, tmp0Reg);
generateTrg1Src1Instruction(cg, TR::InstOpCode::clzx, node, tmp1Reg, tmp1Reg);

generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, resultReg, dataAddrReg, tmp1Reg, TR::SH_LSR, 2);
if (!isLatin1)
{
generateLogicalShiftRightImmInstruction(cg, node, resultReg, resultReg, 1);
}

auto doneLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, doneLabel);
if (debugObj)
{
debugObj->addInstructionComment(branchToDoneLabelInstr2, "Branch to doneLabel if the character is not found");
debugObj->addInstructionComment(foundLabelInstr, "foundLabel");
debugObj->addInstructionComment(doneLabelInstr, "doneLabel");
}
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::csinvx, node, resultReg, resultReg, zeroReg, TR::CC_NE);

TR::RegisterDependencyConditions *conditions = new (cg->trHeapMemory()) TR::RegisterDependencyConditions(0, (isOffsetConstZero ? 5 : 6) + srm->numAvailableRegisters(), cg->trMemory());
conditions->addPostCondition(arrayReg, TR::RealRegister::NoReg);
conditions->addPostCondition(charReg, TR::RealRegister::NoReg);
conditions->addPostCondition(savedLengthReg, TR::RealRegister::NoReg);
if (!isOffsetConstZero)
{
conditions->addPostCondition(offsetReg, TR::RealRegister::NoReg);
}
conditions->addPostCondition(zeroReg, TR::RealRegister::xzr);
conditions->addPostCondition(resultReg, TR::RealRegister::NoReg);
srm->addScratchRegistersToDependencyList(conditions);

generateLabelInstruction(cg, TR::InstOpCode::label, node, generateLabelSymbol(cg), conditions);

node->setRegister(resultReg);
cg->stopUsingRegister(zeroReg);
srm->stopUsingRegisters();
cg->recursivelyDecReferenceCount(node->getFirstChild());
cg->decReferenceCount(arrayNode);
cg->decReferenceCount(charNode);
cg->decReferenceCount(offsetNode);
cg->decReferenceCount(lengthNode);

return resultReg;
}

bool
J9::ARM64::CodeGenerator::inlineDirectCall(TR::Node *node, TR::Register *&resultReg)
{
Expand All @@ -6229,6 +6503,22 @@ J9::ARM64::CodeGenerator::inlineDirectCall(TR::Node *node, TR::Register *&result
{
switch (methodSymbol->getMandatoryRecognizedMethod())
{
case TR::com_ibm_jit_JITHelpers_intrinsicIndexOfLatin1:
if (cg->getSupportsInlineStringIndexOf())
{
resultReg = inlineIntrinsicIndexOf(node, cg, true);
return true;
}
break;

case TR::com_ibm_jit_JITHelpers_intrinsicIndexOfUTF16:
if (cg->getSupportsInlineStringIndexOf())
{
resultReg = inlineIntrinsicIndexOf(node, cg, false);
return true;
}
break;

case TR::java_lang_String_hashCodeImplDecompressed:
if (cg->getSupportsInlineStringHashCode())
{
Expand Down