Skip to content

Replacement: update RRIP #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/scala/coupledL2/Common.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ReplacerInfo(implicit p: Parameters) extends L2Bundle {
val channel = UInt(3.W)
val opcode = UInt(3.W)
val reqSource = UInt(MemReqSource.reqSourceBits.W)
val refill_prefetch = Bool()
}

trait HasChannelBits { this: Bundle =>
Expand Down
68 changes: 56 additions & 12 deletions src/main/scala/coupledL2/Directory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,47 @@ class Directory(implicit p: Parameters) extends L2Module {
io.replResp.bits.retry := refillRetry

/* ====== Update ====== */
// update replacer only when A hit or refill, at stage 3
val updateHit = reqValid_s3 && hit_s3 && req_s3.replacerInfo.channel(0) &&
// PLRU: update replacer only when A hit or refill, at stage 3
// RRIP: update replacer when A/C hit or refill
val updateHit = if(cacheParams.replacement == "drrip" || cacheParams.replacement == "srrip"){
reqValid_s3 && hit_s3 &&
((req_s3.replacerInfo.channel(0) && (req_s3.replacerInfo.opcode === AcquirePerm || req_s3.replacerInfo.opcode === AcquireBlock || req_s3.replacerInfo.opcode === Hint)) ||
(req_s3.replacerInfo.channel(2) && (req_s3.replacerInfo.opcode === Release || req_s3.replacerInfo.opcode === ReleaseData)))
} else {
reqValid_s3 && hit_s3 && req_s3.replacerInfo.channel(0) &&
(req_s3.replacerInfo.opcode === AcquirePerm || req_s3.replacerInfo.opcode === AcquireBlock)
}
val updateRefill = refillReqValid_s3 && !refillRetry
// update replacer when A/C hit or refill
replacerWen := updateHit || updateRefill

// !!![TODO]!!! check this @CLS
// hit-Promotion, miss-Insertion for RRIP, so refill should hit = false.B
val touch_way_s3 = Mux(refillReqValid_s3, replaceWay, way_s3)
val rrip_hit_s3 = Mux(refillReqValid_s3, false.B, hit_s3)
// origin-bit marks whether the data_block is reused
val origin_bit_opt = if(random_repl) None else
Some(Module(new SRAMTemplate(Bool(), sets, ways, singlePort = true)))
val origin_bits_r = origin_bit_opt.get.io.r(io.read.fire(), io.read.bits.set).resp.data
val origin_bits_hold = Wire(Vec(ways, Bool()))
origin_bits_hold := HoldUnless(origin_bits_r, RegNext(io.read.fire(), false.B))
origin_bit_opt.get.io.w(
replacerWen,
rrip_hit_s3,
req_s3.set,
UIntToOH(touch_way_s3)
)

if(cacheParams.replacement == "srrip"){
val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3)
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
val req_type = WireInit(0.U(4.W))
req_type := Cat(origin_bits_hold(touch_way_s3),
req_s3.replacerInfo.channel(2),
(req_s3.replacerInfo.channel(0) && req_s3.replacerInfo.opcode === Hint) || (req_s3.replacerInfo.channel(2) && metaAll_s3(touch_way_s3).prefetch.getOrElse(false.B)) || req_s3.replacerInfo.refill_prefetch,
req_s3.refill
)

val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3, req_type)
val repl_init = Wire(Vec(ways, UInt(2.W)))
repl_init.foreach(_ := 2.U(2.W))
replacer_sram_opt.get.io.w(
Expand All @@ -284,24 +312,40 @@ class Directory(implicit p: Parameters) extends L2Module {
Mux(resetFinish, set_s3, resetIdx),
1.U
)

} else if(cacheParams.replacement == "drrip"){
//Set Dueling
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
val req_type = WireInit(0.U(4.W))
req_type := Cat(origin_bits_hold(touch_way_s3),
req_s3.replacerInfo.channel(2),
(req_s3.replacerInfo.channel(0) && req_s3.replacerInfo.opcode === Hint) || (req_s3.replacerInfo.channel(2) && metaAll_s3(touch_way_s3).prefetch.getOrElse(false.B)) || req_s3.replacerInfo.refill_prefetch,
req_s3.refill
)

// Set Dueling
val PSEL = RegInit(512.U(10.W)) //32-monitor sets, 10-bits psel
// track monitor sets' hit rate for each policy: srrip-0,128...3968;brrip-64,192...4032
when(refillReqValid_s3 && (set_s3(6,0)===0.U) && !rrip_hit_s3){ //SDMs_srrip miss
// track monitor sets' hit rate for each policy
// basic SDMs complement-selection policy: srrip--set_idx[group-:]==set_idx[group_offset-:]; brrip--set_idx[group-:]==!set_idx[group_offset-:]
val setBits = log2Ceil(sets)
val half_setBits = setBits >> 1
val match_a = set_s3(setBits-1,setBits-half_setBits-1)===set_s3(setBits-half_setBits-1,0) // 512 sets [8:4][4:0]
val match_b = set_s3(setBits-1,setBits-half_setBits-1)===(~set_s3(setBits-half_setBits-1,0))
when(refillReqValid_s3 && match_a && !rrip_hit_s3 && (PSEL=/=1023.U)){ //SDMs_srrip miss
PSEL := PSEL + 1.U
} .elsewhen(refillReqValid_s3 && (set_s3(6,0)===64.U) && !rrip_hit_s3){ //SDMs_brrip miss
} .elsewhen(refillReqValid_s3 && match_b && !rrip_hit_s3 && (PSEL=/=0.U)){ //SDMs_brrip miss
PSEL := PSEL - 1.U
}
// decide use which policy by policy selection counter, for insertion
/* if set -> SDMs: use fix policy
else if PSEL(MSB)==0: use srrip
else if PSEL(MSB)==1: use brrip */
val repl_type = WireInit(false.B)
repl_type := Mux(set_s3(6,0)===0.U, false.B,
Mux(set_s3(6,0)===64.U, true.B,
Mux(PSEL(9)===0.U, false.B, true.B))) // false.B - srrip, true.B - brrip
val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3, repl_type)
repl_type := Mux(match_a, false.B,
Mux(match_b, true.B,
Mux(PSEL(9)===0.U, false.B, true.B))) // false.B - srrip, true.B - brrip

val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3, repl_type, req_type)

val repl_init = Wire(Vec(ways, UInt(2.W)))
repl_init.foreach(_ := 2.U(2.W))
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/coupledL2/L2Param.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class L2Param
pageBytes: Int = 4096,
channelBytes: TLChannelBeatBytes = TLChannelBeatBytes(32),
clientCaches: Seq[L1Param] = Nil,
replacement: String = "plru",
replacement: String = "drrip",
mshrs: Int = 16,
releaseData: Int = 3,
/* 0 for dirty alone
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/coupledL2/RequestArb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class RequestArb(implicit p: Parameters) extends L2Module {
io.dirRead_s1.bits.replacerInfo.opcode := task_s1.bits.opcode
io.dirRead_s1.bits.replacerInfo.channel := task_s1.bits.channel
io.dirRead_s1.bits.replacerInfo.reqSource := task_s1.bits.reqSource
io.dirRead_s1.bits.replacerInfo.refill_prefetch := s1_needs_replRead && (mshr_task_s1.bits.opcode === HintAck && mshr_task_s1.bits.dsWen)
io.dirRead_s1.bits.refill := s1_needs_replRead
io.dirRead_s1.bits.mshrId := task_s1.bits.mshrId

Expand Down
43 changes: 31 additions & 12 deletions src/main/scala/coupledL2/utils/Replacer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ abstract class ReplacementPolicy {
def get_next_state(state: UInt, touch_ways: Seq[Valid[UInt]]): UInt = {
touch_ways.foldLeft(state)((prev, touch_way) => Mux(touch_way.valid, get_next_state(prev, touch_way.bits), prev))
}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool): UInt = {0.U}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool): UInt = {0.U}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool, req_type: UInt): UInt = {0.U}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool, req_type: UInt): UInt = {0.U}

def get_replace_way(state: UInt): UInt
}
Expand Down Expand Up @@ -310,7 +310,6 @@ class SetAssocReplacer(n_sets: Int, n_ways: Int, policy: String) extends SetAsso
def way(set: UInt) = logic.get_replace_way(state_vec(set))
}


// 2-bit static Re-Reference Interval Prediction
class StaticRRIP(n_ways: Int) extends ReplacementPolicy {
def nBits = 2 * n_ways
Expand All @@ -323,17 +322,24 @@ class StaticRRIP(n_ways: Int) extends ReplacementPolicy {
def access(touch_ways: Seq[Valid[UInt]]) = {}
def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare

override def get_next_state(state: UInt, touch_way: UInt, hit: Bool): UInt = {
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, req_type: UInt): UInt = {
val State = Wire(Vec(n_ways, UInt(2.W)))
val nextState = Wire(Vec(n_ways, UInt(2.W)))
State.zipWithIndex.map { case (e, i) =>
e := state(2*i+1,2*i)
}
// hit-Promotion, miss-Insertion & Aging
val increcement = 3.U(2.W) - State(touch_way)
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
// rrpv: non-pref_hit/non-pref_refill(miss)/non-pref_release_reuse = 0;
// pref_hit do nothing; pref_refill = 1; non-pref_release_firstuse/pref_release = 2;
nextState.zipWithIndex.map { case (e, i) =>
e := Mux(i.U === touch_way,
Mux(hit, 0.U(2.W), 2.U(2.W)),
Mux((req_type(2,0) === 0.U && hit) || req_type(2,0) === 1.U || req_type === 12.U, 0.U,
Mux(req_type(2,0) === 3.U, 1.U,
Mux(req_type === 4.U || req_type(2,0) === 6.U, 2.U, State(i)))),
//Mux(hit, 0.U(2.W), 2.U(2.W)),
Mux(hit, State(i), State(i)+increcement)
)
}
Expand Down Expand Up @@ -375,7 +381,7 @@ class BRRIP(n_ways: Int) extends ReplacementPolicy {
def access(touch_ways: Seq[Valid[UInt]]) = {}
def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare

override def get_next_state(state: UInt, touch_way: UInt, hit: Bool): UInt = {
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, req_type: UInt): UInt = {
val State = Wire(Vec(n_ways, UInt(2.W)))
val nextState = Wire(Vec(n_ways, UInt(2.W)))
State.zipWithIndex.map { case (e, i) =>
Expand All @@ -384,13 +390,26 @@ class BRRIP(n_ways: Int) extends ReplacementPolicy {

// hit-Promotion, miss-Insertion & Aging
val increcement = 3.U(2.W) - State(touch_way)
val random = (rand.nextInt(32)).U
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
// rrpv: non-pref_hit/non-pref_refill(miss)/non-pref_release_reuse = 0;
// pref_hit do nothing; pref_refill = 1; non-pref_release_firstuse/pref_release = 3;
nextState.zipWithIndex.map { case (e, i) =>
e := Mux(i.U === touch_way,
Mux(hit, 0.U(2.W), Mux(random === 0.U, 2.U(2.W), 3.U(2.W))), //TODO: touch_way=3 for most insertions, touch_rrpv=2 with certain probability
e := Mux(i.U === touch_way,
Mux((req_type(2,0) === 0.U && hit) || req_type(2,0) === 1.U || req_type === 12.U, 0.U,
Mux(req_type(2,0) === 3.U, 1.U,
Mux(req_type === 4.U || req_type(2,0) === 6.U, 3.U, State(i)))),
//Mux(hit, 0.U(2.W), 3.U(2.W)),
Mux(hit, State(i), State(i)+increcement)
)
}
/* val random = (rand.nextInt(32)).U
nextState.zipWithIndex.map { case (e, i) =>
e := Mux(i.U === touch_way,
Mux(hit, 0.U(2.W), Mux(random === 0.U, 2.U(2.W), 3.U(2.W))), //touch_way=3 for most insertions, touch_rrpv=2 with certain probability
Mux(hit, State(i), State(i)+increcement)
)
} */
Cat(nextState.map(x=>x).reverse)
}

Expand Down Expand Up @@ -432,8 +451,8 @@ class DRRIP(n_ways: Int) extends ReplacementPolicy {
def hit = {}

def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool): UInt = {
Mux(chosen_type, repl_BRRIP.get_next_state(state, touch_way, hit), repl_SRRIP.get_next_state(state, touch_way, hit))
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool, req_type: UInt): UInt = {
Mux(chosen_type, repl_BRRIP.get_next_state(state, touch_way, hit, req_type), repl_SRRIP.get_next_state(state, touch_way, hit, req_type))
}
def get_replace_way(state: UInt): UInt = {
val RRPVVec = Wire(Vec(n_ways, UInt(2.W)))
Expand All @@ -452,4 +471,4 @@ class DRRIP(n_ways: Int) extends ReplacementPolicy {
PriorityEncoder(lrrWayVec)
}

}
}