From 567cb0d48fb886e161591d294a1b63b34eba92f9 Mon Sep 17 00:00:00 2001 From: Akira Saitoh Date: Thu, 30 Nov 2023 14:54:55 +0900 Subject: [PATCH] AArch64: Add vectorized implementation of intrinsicIndexOf helper functions This commit implements a vectorized version of JITHelpers.intrinsicIndexOfLatin1 and intrinsicIndexOfUTF16 on AArch64 codegen. Signed-off-by: Akira Saitoh --- .../aarch64/codegen/J9CodeGenerator.cpp | 4 + .../aarch64/codegen/J9TreeEvaluator.cpp | 290 ++++++++++++++++++ 2 files changed, 294 insertions(+) diff --git a/runtime/compiler/aarch64/codegen/J9CodeGenerator.cpp b/runtime/compiler/aarch64/codegen/J9CodeGenerator.cpp index 3c565deaf70..a07110c683f 100644 --- a/runtime/compiler/aarch64/codegen/J9CodeGenerator.cpp +++ b/runtime/compiler/aarch64/codegen/J9CodeGenerator.cpp @@ -86,6 +86,10 @@ J9::ARM64::CodeGenerator::initialize() { cg->setSupportsInlineStringHashCode(); } + if ((!TR::Compiler->om.canGenerateArraylets()) && (!comp->getOption(TR_DisableFastStringIndexOf))) + { + cg->setSupportsInlineStringIndexOf(); + } if (comp->fej9()->hasFixedFrameC_CallingConvention()) cg->setHasFixedFrameC_CallingConvention(); } diff --git a/runtime/compiler/aarch64/codegen/J9TreeEvaluator.cpp b/runtime/compiler/aarch64/codegen/J9TreeEvaluator.cpp index e5ef66846c6..a6d6cb192f1 100644 --- a/runtime/compiler/aarch64/codegen/J9TreeEvaluator.cpp +++ b/runtime/compiler/aarch64/codegen/J9TreeEvaluator.cpp @@ -6214,6 +6214,280 @@ static TR::Register *inlineStringHashCode(TR::Node *node, bool isCompressed, TR: return resultReg; } +/** + * @brief Generates inlined instructions equivalent to com/ibm/jit/JITHelpers.intrinsicIndexOfLatin1 or com/ibm/jit/JITHelpers.intrinsicIndexOfUTF16 + * + * @param node: node + * @param cg: Code Generator + * @param isLatin1: true when the string is Latin1, false when the string is UTF16 + * @returns register + */ +static TR::Register* inlineIntrinsicIndexOf(TR::Node* node, TR::CodeGenerator* cg, bool isLatin1) + { + /* + * add dataAddrReg, addressReg, #headerSize ; get the starting address of byte array data + * add endReg, dataAddrReg, lengthReg(, lsl #1 if utf16) ; get the address of the end of the byte array + * subs lengthReg, lengthReg, offsetReg + * b.eq Ldone + * add tmp0Reg, dataAddrReg, #16 ; save starting address + 16 to tmp0Reg + * add dataAddrReg, dataAddrReg, offsetReg(, lsl #1 if utf16) + * + * if isLatin1 + * dup vtmp0Reg.16b, charReg + * else + * dup vtmp0Reg.8h, charReg + * subs lengthReg, lengthReg (#8 if isLatin1 else #4) + * ; If the length < 8bytes, then it is not guaranteed that we can read the last 16bytes because the minimum size of the array header is 8bytes. + * b.lt Lessthan8 + * subs lengthReg, lengthReg (#8 if isLatin1 else #4) + * b.lt Lresidual + * Loop: + * subs lengthReg, lengthReg, (#16 if isLatin1 else #8) + * ldr vtmp1Reg, [dataAddrReg], #16 + * if isLatin1 + * cmeq vtmp1Reg.16b, vtmp0Reg.16b, vtmp1Reg.16b + * shrn vtmp1Reg.8b, vtmp1Reg.8h, #4 + * else + * cmeq vtmp1Reg.8h, vtmp0Reg.8h, vtmp1Reg.8h + * xtn vtmp1Reg.8b, vtmp1Reg.8h + * umov tmp1Reg, vtmp1Reg.d[0] + * ccmp tmp1Reg, #0, #0, cs ; if lengthReg < 0, clear all condition flags + * b.eq Loop + * cbnz tmp1Reg, Lfound + * cmn lengthReg, (#16 if isLatin1 else #8) + * b.eq Ldone + * Lresidual: + * ldurq vtmp1Reg, [endReg, #-16] ; Read the last 16bytes. It would not a problem if original length >= 8 because we have a array header space. + * if isLatin1 + * cmeq vtmp1Reg.16b, vtmp0Reg.16b, vtmp1Reg.16b + * else + * cmeq vtmp1Reg.8h, vtmp0Reg.8h, vtmp1Reg.8h + * b Lmerge + * LessThan8: + * ldurd vtmp1Reg, [endReg, #-8] + * if isLatin1 + * cmeq vtmp1Reg.8b, vtmp0Reg.8b, vtmp1Reg.8b + * else + * cmeq vtmp1Reg.4h, vtmp0Reg.4h, vtmp1Reg.4h + * Lmerge: + * add dataAddrReg, dataAddrReg, #16 + * neg lengthReg, lengthReg, lsl (#2 if isLatin1 else #3) ; 64 - (4 or 8) * remaining length + * if isLatin1 + * shrn vtmp1Reg.8b, vtmp1Reg.8h, #4 + * else + * xtn vtmp1Reg.8b, vtmp1Reg.8h, #4 + * umov tmp1Reg, vtmp1Reg.d[0] + * lsr tmp1Reg, tmp1Reg, lengthReg ; Move the comparison result of remaining data to LSB + * cmp tmp1Reg, #0 + * b.eq Ldone + * Lfound: + * rbit tmp1Reg, tmp1Reg + * sub dataAddrReg, dataAddrReg, tmp0Reg + * clz tmp1Reg, tmp1Reg + * add resultReg, dataAddrReg, tmp1Reg, lsr #2 + * if !isLatin + * lsr resultReg, resultReg, #1 + * Ldone: + * csinv resultReg, resultReg, xzr, ne + * + */ + + /* + * We omit to evaluate the first child (receiver) as it is not used. + */ + TR::Node *arrayNode = node->getSecondChild(); + TR::Node *charNode = node->getThirdChild(); + TR::Node *offsetNode = node->getChild(3); + TR::Node *lengthNode = node->getChild(4); + TR::Register *arrayReg = cg->evaluate(arrayNode); + TR::Register *charReg = cg->evaluate(charNode); + const bool isOffsetConstZero = offsetNode->isConstZeroValue(); + TR::Register *offsetReg = isOffsetConstZero ? NULL : cg->evaluate(offsetNode); + TR::Register *savedLengthReg = cg->evaluate(lengthNode); + TR_ARM64ScratchRegisterManager *srm = cg->generateScratchRegisterManager(); + TR::Register *dataAddrReg = (arrayNode->getReferenceCount() > 1) ? srm->findOrCreateScratchRegister() : arrayReg; + TR::Register *lengthReg = (lengthNode->getReferenceCount() > 1) ? srm->findOrCreateScratchRegister() : savedLengthReg; + TR::Register *tmp0Reg = srm->findOrCreateScratchRegister(); + /* The live range of tmp1Reg and offsetReg does not overwrap if offsetNode's reference count is 1, so we use the same register in that case. */ + TR::Register *tmp1Reg = ((offsetNode->getReferenceCount() > 1) || isOffsetConstZero) ? srm->findOrCreateScratchRegister() : offsetReg; + /* The live range of endReg and resultReg does not overwrap, so we use the same register. */ + TR::Register *endReg = cg->allocateRegister(); + TR::Register *resultReg = endReg; + TR_Debug *debugObj = cg->getDebug(); + + generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, dataAddrReg, arrayReg, TR::Compiler->om.contiguousArrayHeaderSizeInBytes()); + if (isLatin1) + { + generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, endReg, dataAddrReg, savedLengthReg); + } + else + { + generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, endReg, dataAddrReg, savedLengthReg, TR::SH_LSL, 1); + } + TR::Compilation *comp = cg->comp(); + if (comp->getOptions()->enableDebugCounters()) + { + cg->generateDebugCounter(TR::DebugCounter::debugCounterName(comp, "cg.StringIndexOf/(%s)/%s/%s", + comp->signature(), + (isLatin1 ? "compressed" : "decompressed"), + comp->getHotnessName()), *srm); + } + TR::LabelSymbol *doneLabel = generateLabelSymbol(cg); + if (!isOffsetConstZero) + { + generateTrg1Src2Instruction(cg, TR::InstOpCode::subsx, node, lengthReg, savedLengthReg, offsetReg); + generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ); + } + else + { + generateMovInstruction(cg, node, lengthReg, savedLengthReg); + } + generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, tmp0Reg, dataAddrReg, 16); + if (!isOffsetConstZero) + { + if (isLatin1) + { + generateTrg1Src2Instruction(cg, TR::InstOpCode::addx, node, dataAddrReg, dataAddrReg, offsetReg); + } + else + { + generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, dataAddrReg, dataAddrReg, offsetReg, TR::SH_LSL, 1); + } + } + + TR::Register *vtmp0Reg = srm->findOrCreateScratchRegister(TR_VRF); + + generateTrg1Src1Instruction(cg, isLatin1 ? TR::InstOpCode::vdup16b : TR::InstOpCode::vdup8h, node, vtmp0Reg, charReg); + generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 8 : 4); + + TR::LabelSymbol *lessThan8Label = generateLabelSymbol(cg); + auto branchToLessThan8LabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, lessThan8Label, TR::CC_LT); + + generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 8 : 4); + + TR::LabelSymbol *residualLabel = generateLabelSymbol(cg); + auto branchToResidualLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, residualLabel, TR::CC_LT); + if (comp->getOptions()->enableDebugCounters()) + { + cg->generateDebugCounter(TR::DebugCounter::debugCounterName(comp, "cg.StringIndexOf/(%s)/%s/%s:long", + comp->signature(), + (isLatin1 ? "compressed" : "decompressed"), + comp->getHotnessName()), *srm); + } + /* + * Main loop: 16 bytes are processed in 1 iteration of the loop. + */ + TR::LabelSymbol *loopLabel = generateLabelSymbol(cg); + auto loopLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, loopLabel); + if (debugObj) + { + debugObj->addInstructionComment(branchToLessThan8LabelInstr, "Branch to lessThan8 if remaining length < (isLatin1 ? 8 : 4)"); + debugObj->addInstructionComment(branchToResidualLabelInstr, "Branch to residualLabel if remaining length < (isLatin1 ? 16 : 8)"); + debugObj->addInstructionComment(loopLabelInstr, "loopLabel"); + } + generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::subsimmx, node, lengthReg, lengthReg, isLatin1 ? 16 : 8); + + TR::Register *vtmp1Reg = srm->findOrCreateScratchRegister(TR_VRF); + + generateTrg1MemInstruction(cg, TR::InstOpCode::vldrpostq, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, dataAddrReg, 16)); + generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq16b : TR::InstOpCode::vcmeq8h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg); + if (isLatin1) + { + generateVectorShiftImmediateInstruction(cg, TR::InstOpCode::vshrn_8b, node, vtmp1Reg, vtmp1Reg, 4); + } + else + { + generateTrg1Src1Instruction(cg, TR::InstOpCode::vxtn_8b, node, vtmp1Reg, vtmp1Reg); + } + + generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmp1Reg, vtmp1Reg, 0); + generateConditionalCompareImmInstruction(cg, node, tmp1Reg, 0, 0, TR::CC_CS, true); + auto branchBackToLoopLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, loopLabel, TR::CC_EQ); + TR::LabelSymbol *foundLabel = generateLabelSymbol(cg); + auto branchToFoundLabelInstr = generateCompareBranchInstruction(cg, TR::InstOpCode::cbnzx, node, tmp1Reg, foundLabel); + generateCompareImmInstruction(cg, node, lengthReg, (isLatin1 ? (-16) : (-8)), true); + auto branchToDoneLabelInstr = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ); + auto residualLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, residualLabel); + generateTrg1MemInstruction(cg, TR::InstOpCode::vldurq, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, endReg, -16)); + generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq16b : TR::InstOpCode::vcmeq8h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg); + TR::LabelSymbol *mergeLabel = generateLabelSymbol(cg); + auto branchToMergeLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::b, node, mergeLabel); + auto lessThan8LabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, lessThan8Label); + generateTrg1MemInstruction(cg, TR::InstOpCode::vldurd, node, vtmp1Reg, TR::MemoryReference::createWithDisplacement(cg, endReg, -8)); + generateTrg1Src2Instruction(cg, (isLatin1 ? TR::InstOpCode::vcmeq8b : TR::InstOpCode::vcmeq4h), node, vtmp1Reg, vtmp0Reg, vtmp1Reg); + auto mergeLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, mergeLabel); + if (debugObj) + { + debugObj->addInstructionComment(branchBackToLoopLabelInstr, "Jump back to loopLabel if the character is not found and the remaining data is >= 16 bytes"); + debugObj->addInstructionComment(branchToFoundLabelInstr, "Branch to foundLabel if the character found"); + debugObj->addInstructionComment(branchToDoneLabelInstr, "Branch to doneLabel if no data is left"); + debugObj->addInstructionComment(residualLabelInstr, "residualLabel"); + debugObj->addInstructionComment(lessThan8LabelInstr, "lessThan8Label"); + debugObj->addInstructionComment(branchToMergeLabelInstr, "Branch to mergeLabel"); + debugObj->addInstructionComment(mergeLabelInstr, "mergeLabel"); + } + generateTrg1Src1ImmInstruction(cg, TR::InstOpCode::addimmx, node, dataAddrReg, dataAddrReg, 16); + + TR::Register *zeroReg = cg->allocateRegister(); + generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::subx, node, lengthReg, zeroReg, lengthReg, TR::SH_LSL, (isLatin1 ? 2 : 3)); + if (isLatin1) + { + generateVectorShiftImmediateInstruction(cg, TR::InstOpCode::vshrn_8b, node, vtmp1Reg, vtmp1Reg, 4); + } + else + { + generateTrg1Src1Instruction(cg, TR::InstOpCode::vxtn_8b, node, vtmp1Reg, vtmp1Reg); + } + generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmp1Reg, vtmp1Reg, 0); + generateTrg1Src2Instruction(cg, TR::InstOpCode::lsrvx, node, tmp1Reg, tmp1Reg, lengthReg); + generateCompareImmInstruction(cg, node, tmp1Reg, 0, true); + auto branchToDoneLabelInstr2 = generateConditionalBranchInstruction(cg, TR::InstOpCode::b_cond, node, doneLabel, TR::CC_EQ); + auto foundLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, foundLabel); + generateTrg1Src1Instruction(cg, TR::InstOpCode::rbitx, node, tmp1Reg, tmp1Reg); + generateTrg1Src2Instruction(cg, TR::InstOpCode::subx, node, dataAddrReg, dataAddrReg, tmp0Reg); + generateTrg1Src1Instruction(cg, TR::InstOpCode::clzx, node, tmp1Reg, tmp1Reg); + + generateTrg1Src2ShiftedInstruction(cg, TR::InstOpCode::addx, node, resultReg, dataAddrReg, tmp1Reg, TR::SH_LSR, 2); + if (!isLatin1) + { + generateLogicalShiftRightImmInstruction(cg, node, resultReg, resultReg, 1); + } + + auto doneLabelInstr = generateLabelInstruction(cg, TR::InstOpCode::label, node, doneLabel); + if (debugObj) + { + debugObj->addInstructionComment(branchToDoneLabelInstr2, "Branch to doneLabel if the character is not found"); + debugObj->addInstructionComment(foundLabelInstr, "foundLabel"); + debugObj->addInstructionComment(doneLabelInstr, "doneLabel"); + } + generateCondTrg1Src2Instruction(cg, TR::InstOpCode::csinvx, node, resultReg, resultReg, zeroReg, TR::CC_NE); + + TR::RegisterDependencyConditions *conditions = new (cg->trHeapMemory()) TR::RegisterDependencyConditions(0, (isOffsetConstZero ? 5 : 6) + srm->numAvailableRegisters(), cg->trMemory()); + conditions->addPostCondition(arrayReg, TR::RealRegister::NoReg); + conditions->addPostCondition(charReg, TR::RealRegister::NoReg); + conditions->addPostCondition(savedLengthReg, TR::RealRegister::NoReg); + if (!isOffsetConstZero) + { + conditions->addPostCondition(offsetReg, TR::RealRegister::NoReg); + } + conditions->addPostCondition(zeroReg, TR::RealRegister::xzr); + conditions->addPostCondition(resultReg, TR::RealRegister::NoReg); + srm->addScratchRegistersToDependencyList(conditions); + + generateLabelInstruction(cg, TR::InstOpCode::label, node, generateLabelSymbol(cg), conditions); + + node->setRegister(resultReg); + cg->stopUsingRegister(zeroReg); + srm->stopUsingRegisters(); + cg->recursivelyDecReferenceCount(node->getFirstChild()); + cg->decReferenceCount(arrayNode); + cg->decReferenceCount(charNode); + cg->decReferenceCount(offsetNode); + cg->decReferenceCount(lengthNode); + + return resultReg; + } + bool J9::ARM64::CodeGenerator::inlineDirectCall(TR::Node *node, TR::Register *&resultReg) { @@ -6229,6 +6503,22 @@ J9::ARM64::CodeGenerator::inlineDirectCall(TR::Node *node, TR::Register *&result { switch (methodSymbol->getMandatoryRecognizedMethod()) { + case TR::com_ibm_jit_JITHelpers_intrinsicIndexOfLatin1: + if (cg->getSupportsInlineStringIndexOf()) + { + resultReg = inlineIntrinsicIndexOf(node, cg, true); + return true; + } + break; + + case TR::com_ibm_jit_JITHelpers_intrinsicIndexOfUTF16: + if (cg->getSupportsInlineStringIndexOf()) + { + resultReg = inlineIntrinsicIndexOf(node, cg, false); + return true; + } + break; + case TR::java_lang_String_hashCodeImplDecompressed: if (cg->getSupportsInlineStringHashCode()) {